1#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2#include "llvm/ExecutionEngine/Orc/Mangling.h"
3
4using namespace llvm;
5using namespace orc;
6
7bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8 std::unique_lock<std::mutex> Lock(Mutex);
9 if (Reoptimizing)
10 return false;
11
12 Reoptimizing = true;
13 return true;
14}
15
16void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17 std::unique_lock<std::mutex> Lock(Mutex);
18 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19 Reoptimizing = false;
20 CurVersion++;
21}
22
23void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24 std::unique_lock<std::mutex> Lock(Mutex);
25 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26 Reoptimizing = false;
27}
28
29static void orc_rt_lite_reoptimize_helper(
30 shared::CWrapperFunctionBuffer (*JITDispatch)(void *Ctx, void *Tag,
31 const char *Data,
32 size_t Size),
33 void *JITDispatchCtx, void *Tag, uint64_t MUID, uint32_t CurVersion) {
34 // Serialize the arguments into a WrapperFunctionBuffer and call dispatch.
35 using SPSArgs = shared::SPSArgList<uint64_t, uint32_t>;
36 auto ArgBytes =
37 shared::WrapperFunctionBuffer::allocate(Size: SPSArgs::size(Arg: MUID, Args: CurVersion));
38 shared::SPSOutputBuffer OB(ArgBytes.data(), ArgBytes.size());
39 if (!SPSArgs::serialize(OB, Arg: MUID, Args: CurVersion)) {
40 errs()
41 << "Reoptimization error: could not serialize reoptimization arguments";
42 abort();
43 }
44 shared::WrapperFunctionBuffer Buf{
45 JITDispatch(JITDispatchCtx, Tag, ArgBytes.data(), ArgBytes.size())};
46
47 if (const char *ErrMsg = Buf.getOutOfBandError()) {
48 errs() << "Reoptimization error: " << ErrMsg << "\naborting.\n";
49 abort();
50 }
51}
52
53Error ReOptimizeLayer::addOrcRTLiteSupport(JITDylib &PlatformJD,
54 const DataLayout &DL) {
55 auto Ctx = std::make_unique<LLVMContext>();
56 auto Mod = std::make_unique<Module>(args: "orc-rt-lite-reoptimize.ll", args&: *Ctx);
57 Mod->setDataLayout(DL);
58
59 IRBuilder<> Builder(*Ctx);
60
61 // Create basic types portably
62 Type *VoidTy = Type::getVoidTy(C&: *Ctx);
63 Type *Int8Ty = Type::getInt8Ty(C&: *Ctx);
64 Type *Int32Ty = Type::getInt32Ty(C&: *Ctx);
65 Type *Int64Ty = Type::getInt64Ty(C&: *Ctx);
66 Type *VoidPtrTy = PointerType::getUnqual(C&: *Ctx);
67
68 // Helper function type: void (void*, void*, void*, uint64_t, uint32_t)
69 FunctionType *HelperFnTy = FunctionType::get(
70 Result: VoidTy, Params: {VoidPtrTy, VoidPtrTy, VoidPtrTy, Int64Ty, Int32Ty}, isVarArg: false);
71
72 // Define ReoptimizeTag with initializer = 0
73 GlobalVariable *ReoptimizeTag = new GlobalVariable(
74 *Mod, Int8Ty, false, GlobalValue::ExternalLinkage,
75 ConstantInt::get(Ty: Int8Ty, V: 0), "__orc_rt_reoptimize_tag");
76
77 // Define orc_rt_lite_reoptimize function: void (uint64_t, uint32_t)
78 FunctionType *ReOptimizeFnTy =
79 FunctionType::get(Result: VoidTy, Params: {Int64Ty, Int32Ty}, isVarArg: false);
80
81 Function *ReOptimizeFn =
82 Function::Create(Ty: ReOptimizeFnTy, Linkage: Function::ExternalLinkage,
83 N: "__orc_rt_reoptimize", M: Mod.get());
84
85 // Set parameter names
86 auto ArgIt = ReOptimizeFn->arg_begin();
87 Value *MUID = &*ArgIt++;
88 MUID->setName("MUID");
89 Value *CurVersion = &*ArgIt;
90 CurVersion->setName("CurVersion");
91
92 // Build function body
93 BasicBlock *Entry = BasicBlock::Create(Context&: *Ctx, Name: "entry", Parent: ReOptimizeFn);
94 Builder.SetInsertPoint(Entry);
95
96 // Create absolute address constants
97 auto &JDI = PlatformJD.getExecutionSession()
98 .getExecutorProcessControl()
99 .getJITDispatchInfo();
100
101 Type *IntPtrTy = DL.getIntPtrType(C&: *Ctx);
102 Constant *JITDispatchPtr = ConstantExpr::getIntToPtr(
103 C: ConstantInt::get(Ty: IntPtrTy, V: JDI.JITDispatchFunction.getValue()),
104 Ty: VoidPtrTy);
105 Constant *JITDispatchCtxPtr = ConstantExpr::getIntToPtr(
106 C: ConstantInt::get(Ty: IntPtrTy, V: JDI.JITDispatchContext.getValue()), Ty: VoidPtrTy);
107 Constant *HelperFnAddr = ConstantExpr::getIntToPtr(
108 C: ConstantInt::get(Ty: IntPtrTy, V: reinterpret_cast<uintptr_t>(
109 &orc_rt_lite_reoptimize_helper)),
110 Ty: PointerType::getUnqual(C&: *Ctx));
111
112 // Cast ReoptimizeTag to void*
113 Value *ReoptimizeTagPtr = Builder.CreatePointerCast(V: ReoptimizeTag, DestTy: VoidPtrTy);
114
115 // Call the helper function
116 Builder.CreateCall(
117 FTy: HelperFnTy, Callee: HelperFnAddr,
118 Args: {JITDispatchPtr, JITDispatchCtxPtr, ReoptimizeTagPtr, MUID, CurVersion});
119
120 // Return void
121 Builder.CreateRetVoid();
122
123 return BaseLayer.add(JD&: PlatformJD,
124 TSM: ThreadSafeModule(std::move(Mod), std::move(Ctx)));
125}
126
127Error ReOptimizeLayer::registerRuntimeFunctions(JITDylib &PlatformJD) {
128 ExecutionSession::JITDispatchHandlerAssociationMap WFs;
129 using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
130 WFs[Mangle("__orc_rt_reoptimize_tag")] =
131 ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(Instance: this,
132 Method: &ReOptimizeLayer::rt_reoptimize);
133 return ES.registerJITDispatchHandlers(JD&: PlatformJD, WFs: std::move(WFs));
134}
135
136void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
137 ThreadSafeModule TSM) {
138 auto &JD = R->getTargetJITDylib();
139
140 bool HasNonCallable = false;
141 for (auto &KV : R->getSymbols()) {
142 auto &Flags = KV.second;
143 if (!Flags.isCallable())
144 HasNonCallable = true;
145 }
146
147 if (HasNonCallable) {
148 BaseLayer.emit(R: std::move(R), TSM: std::move(TSM));
149 return;
150 }
151
152 auto &MUState = createMaterializationUnitState(TSM);
153
154 if (auto Err = R->withResourceKeyDo(F: [&](ResourceKey Key) {
155 registerMaterializationUnitResource(Key, State&: MUState);
156 })) {
157 ES.reportError(Err: std::move(Err));
158 R->failMaterialization();
159 return;
160 }
161
162 if (auto Err =
163 ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
164 ES.reportError(Err: std::move(Err));
165 R->failMaterialization();
166 return;
167 }
168
169 auto InitialDests =
170 emitMUImplSymbols(MUState, Version: MUState.getCurVersion(), JD, TSM: std::move(TSM));
171 if (!InitialDests) {
172 ES.reportError(Err: InitialDests.takeError());
173 R->failMaterialization();
174 return;
175 }
176
177 RSManager.emitRedirectableSymbols(MR: std::move(R), InitialDests: std::move(*InitialDests));
178}
179
180Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
181 ReOptMaterializationUnitID MUID,
182 unsigned CurVersion,
183 ThreadSafeModule &TSM) {
184 return TSM.withModuleDo(F: [&](Module &M) -> Error {
185 Type *I64Ty = Type::getInt64Ty(C&: M.getContext());
186 GlobalVariable *Counter = new GlobalVariable(
187 M, I64Ty, false, GlobalValue::InternalLinkage,
188 Constant::getNullValue(Ty: I64Ty), "__orc_reopt_counter");
189 for (auto &F : M) {
190 if (F.isDeclaration())
191 continue;
192 auto &BB = F.getEntryBlock();
193 auto *IP = &*BB.getFirstInsertionPt();
194 IRBuilder<> IRB(IP);
195 Value *Threshold = ConstantInt::get(Ty: I64Ty, V: CallCountThreshold, IsSigned: true);
196 Value *Cnt = IRB.CreateLoad(Ty: I64Ty, Ptr: Counter);
197 // Use EQ to prevent further reoptimize calls.
198 Value *Cmp = IRB.CreateICmpEQ(LHS: Cnt, RHS: Threshold);
199 Value *Added = IRB.CreateAdd(LHS: Cnt, RHS: ConstantInt::get(Ty: I64Ty, V: 1));
200 (void)IRB.CreateStore(Val: Added, Ptr: Counter);
201 Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cond: Cmp, SplitBefore: IP, Unreachable: false);
202 createReoptimizeCall(M, IP&: *SplitTerminator, MUID, CurVersion);
203 }
204 return Error::success();
205 });
206}
207
208Expected<SymbolMap>
209ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
210 uint32_t Version, JITDylib &JD,
211 ThreadSafeModule TSM) {
212 DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;
213 cantFail(Err: TSM.withModuleDo(F: [&](Module &M) -> Error {
214 MangleAndInterner Mangle(ES, M.getDataLayout());
215 for (auto &F : M)
216 if (!F.isDeclaration()) {
217 std::string NewName =
218 (F.getName() + ".__def__." + Twine(Version)).str();
219 RenamedMap[Mangle(F.getName())] = Mangle(NewName);
220 F.setName(NewName);
221 }
222 return Error::success();
223 }));
224
225 auto RT = JD.createResourceTracker();
226 if (auto Err =
227 JD.define(MU: std::make_unique<BasicIRLayerMaterializationUnit>(
228 args&: BaseLayer, args: *getManglingOptions(), args: std::move(TSM)),
229 RT))
230 return Err;
231 MUState.setResourceTracker(RT);
232
233 SymbolLookupSet LookupSymbols;
234 for (auto [K, V] : RenamedMap)
235 LookupSymbols.add(Name: V);
236
237 auto ImplSymbols =
238 ES.lookup(SearchOrder: {{&JD, JITDylibLookupFlags::MatchAllSymbols}}, Symbols: LookupSymbols,
239 K: LookupKind::Static, RequiredState: SymbolState::Resolved);
240 if (auto Err = ImplSymbols.takeError())
241 return Err;
242
243 SymbolMap Result;
244 for (auto [K, V] : RenamedMap)
245 Result[K] = (*ImplSymbols)[V];
246
247 return Result;
248}
249
250void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
251 ReOptMaterializationUnitID MUID,
252 uint32_t CurVersion) {
253 auto &MUState = getMaterializationUnitState(MUID);
254 if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
255 SendResult(Error::success());
256 return;
257 }
258
259 ThreadSafeModule TSM = cloneToNewContext(TSMW: MUState.getThreadSafeModule());
260 auto OldRT = MUState.getResourceTracker();
261 auto &JD = OldRT->getJITDylib();
262
263 if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
264 ES.reportError(Err: std::move(Err));
265 MUState.reoptimizeFailed();
266 SendResult(Error::success());
267 return;
268 }
269
270 auto SymbolDests =
271 emitMUImplSymbols(MUState, Version: CurVersion + 1, JD, TSM: std::move(TSM));
272 if (!SymbolDests) {
273 ES.reportError(Err: SymbolDests.takeError());
274 MUState.reoptimizeFailed();
275 SendResult(Error::success());
276 return;
277 }
278
279 if (auto Err = RSManager.redirect(JD, NewDests: std::move(*SymbolDests))) {
280 ES.reportError(Err: std::move(Err));
281 MUState.reoptimizeFailed();
282 SendResult(Error::success());
283 return;
284 }
285
286 MUState.reoptimizeSucceeded();
287 SendResult(Error::success());
288}
289
290void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,
291 ReOptMaterializationUnitID MUID,
292 uint32_t CurVersion) {
293 Type *MUIDTy = IntegerType::get(C&: M.getContext(), NumBits: 64);
294 Type *VersionTy = IntegerType::get(C&: M.getContext(), NumBits: 32);
295 Function *ReoptimizeFunc = M.getFunction(Name: "__orc_rt_reoptimize");
296 if (!ReoptimizeFunc) {
297 std::vector<Type *> ArgTys = {MUIDTy, VersionTy};
298 FunctionType *FuncTy =
299 FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), Params: ArgTys, isVarArg: false);
300 ReoptimizeFunc = Function::Create(Ty: FuncTy, Linkage: GlobalValue::ExternalLinkage,
301 N: "__orc_rt_reoptimize", M: &M);
302 }
303 Constant *MUIDArg = ConstantInt::get(Ty: MUIDTy, V: MUID, IsSigned: false);
304 Constant *CurVersionArg = ConstantInt::get(Ty: VersionTy, V: CurVersion, IsSigned: false);
305 IRBuilder<> IRB(&IP);
306 (void)IRB.CreateCall(Callee: ReoptimizeFunc, Args: {MUIDArg, CurVersionArg});
307}
308
309ReOptimizeLayer::ReOptMaterializationUnitState &
310ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
311 std::unique_lock<std::mutex> Lock(Mutex);
312 ReOptMaterializationUnitID MUID = NextID;
313 MUStates.emplace(args&: MUID,
314 args: ReOptMaterializationUnitState(MUID, cloneToNewContext(TSMW: TSM)));
315 ++NextID;
316 return MUStates.at(k: MUID);
317}
318
319ReOptimizeLayer::ReOptMaterializationUnitState &
320ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
321 std::unique_lock<std::mutex> Lock(Mutex);
322 return MUStates.at(k: MUID);
323}
324
325void ReOptimizeLayer::registerMaterializationUnitResource(
326 ResourceKey Key, ReOptMaterializationUnitState &State) {
327 std::unique_lock<std::mutex> Lock(Mutex);
328 MUResources[Key].insert(V: State.getID());
329}
330
331Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {
332 std::unique_lock<std::mutex> Lock(Mutex);
333 for (auto MUID : MUResources[K])
334 MUStates.erase(x: MUID);
335
336 MUResources.erase(Val: K);
337 return Error::success();
338}
339
340void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,
341 ResourceKey SrcK) {
342 std::unique_lock<std::mutex> Lock(Mutex);
343 MUResources[DstK].insert_range(R&: MUResources[SrcK]);
344 MUResources.erase(Val: SrcK);
345}
346