1 | //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file is a part of the ORC runtime support library. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H |
14 | #define ORC_RT_WRAPPER_FUNCTION_UTILS_H |
15 | |
16 | #include "error.h" |
17 | #include "executor_address.h" |
18 | #include "orc_rt/c_api.h" |
19 | #include "simple_packed_serialization.h" |
20 | #include <type_traits> |
21 | |
22 | namespace orc_rt { |
23 | |
24 | /// C++ wrapper function result: Same as orc_rt_WrapperFunctionResult but |
25 | /// auto-releases memory. |
26 | class WrapperFunctionResult { |
27 | public: |
28 | /// Create a default WrapperFunctionResult. |
29 | WrapperFunctionResult() { orc_rt_WrapperFunctionResultInit(R: &R); } |
30 | |
31 | /// Create a WrapperFunctionResult from a WrapperFunctionResult. This |
32 | /// instance takes ownership of the result object and will automatically |
33 | /// call dispose on the result upon destruction. |
34 | WrapperFunctionResult(orc_rt_WrapperFunctionResult R) : R(R) {} |
35 | |
36 | WrapperFunctionResult(const WrapperFunctionResult &) = delete; |
37 | WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; |
38 | |
39 | WrapperFunctionResult(WrapperFunctionResult &&Other) { |
40 | orc_rt_WrapperFunctionResultInit(R: &R); |
41 | std::swap(a&: R, b&: Other.R); |
42 | } |
43 | |
44 | WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { |
45 | orc_rt_WrapperFunctionResult Tmp; |
46 | orc_rt_WrapperFunctionResultInit(R: &Tmp); |
47 | std::swap(a&: Tmp, b&: Other.R); |
48 | std::swap(a&: R, b&: Tmp); |
49 | return *this; |
50 | } |
51 | |
52 | ~WrapperFunctionResult() { orc_rt_DisposeWrapperFunctionResult(R: &R); } |
53 | |
54 | /// Relinquish ownership of and return the |
55 | /// orc_rt_WrapperFunctionResult. |
56 | orc_rt_WrapperFunctionResult release() { |
57 | orc_rt_WrapperFunctionResult Tmp; |
58 | orc_rt_WrapperFunctionResultInit(R: &Tmp); |
59 | std::swap(a&: R, b&: Tmp); |
60 | return Tmp; |
61 | } |
62 | |
63 | /// Get a pointer to the data contained in this instance. |
64 | char *data() { return orc_rt_WrapperFunctionResultData(R: &R); } |
65 | |
66 | /// Returns the size of the data contained in this instance. |
67 | size_t size() const { return orc_rt_WrapperFunctionResultSize(R: &R); } |
68 | |
69 | /// Returns true if this value is equivalent to a default-constructed |
70 | /// WrapperFunctionResult. |
71 | bool empty() const { return orc_rt_WrapperFunctionResultEmpty(R: &R); } |
72 | |
73 | /// Create a WrapperFunctionResult with the given size and return a pointer |
74 | /// to the underlying memory. |
75 | static WrapperFunctionResult allocate(size_t Size) { |
76 | WrapperFunctionResult R; |
77 | R.R = orc_rt_WrapperFunctionResultAllocate(Size); |
78 | return R; |
79 | } |
80 | |
81 | /// Copy from the given char range. |
82 | static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { |
83 | return orc_rt_CreateWrapperFunctionResultFromRange(Data: Source, Size); |
84 | } |
85 | |
86 | /// Copy from the given null-terminated string (includes the null-terminator). |
87 | static WrapperFunctionResult copyFrom(const char *Source) { |
88 | return orc_rt_CreateWrapperFunctionResultFromString(Source); |
89 | } |
90 | |
91 | /// Copy from the given std::string (includes the null terminator). |
92 | static WrapperFunctionResult copyFrom(const std::string &Source) { |
93 | return copyFrom(Source: Source.c_str()); |
94 | } |
95 | |
96 | /// Create an out-of-band error by copying the given string. |
97 | static WrapperFunctionResult createOutOfBandError(const char *Msg) { |
98 | return orc_rt_CreateWrapperFunctionResultFromOutOfBandError(ErrMsg: Msg); |
99 | } |
100 | |
101 | /// Create an out-of-band error by copying the given string. |
102 | static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { |
103 | return createOutOfBandError(Msg: Msg.c_str()); |
104 | } |
105 | |
106 | template <typename SPSArgListT, typename... ArgTs> |
107 | static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { |
108 | auto Result = allocate(Size: SPSArgListT::size(Args...)); |
109 | SPSOutputBuffer OB(Result.data(), Result.size()); |
110 | if (!SPSArgListT::serialize(OB, Args...)) |
111 | return createOutOfBandError( |
112 | Msg: "Error serializing arguments to blob in call" ); |
113 | return Result; |
114 | } |
115 | |
116 | /// If this value is an out-of-band error then this returns the error message, |
117 | /// otherwise returns nullptr. |
118 | const char *getOutOfBandError() const { |
119 | return orc_rt_WrapperFunctionResultGetOutOfBandError(R: &R); |
120 | } |
121 | |
122 | private: |
123 | orc_rt_WrapperFunctionResult R; |
124 | }; |
125 | |
126 | namespace detail { |
127 | |
128 | template <typename RetT> class WrapperFunctionHandlerCaller { |
129 | public: |
130 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
131 | static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, |
132 | std::index_sequence<I...>) { |
133 | return std::forward<HandlerT>(H)(std::get<I>(Args)...); |
134 | } |
135 | }; |
136 | |
137 | template <> class WrapperFunctionHandlerCaller<void> { |
138 | public: |
139 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> |
140 | static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, |
141 | std::index_sequence<I...>) { |
142 | std::forward<HandlerT>(H)(std::get<I>(Args)...); |
143 | return SPSEmpty(); |
144 | } |
145 | }; |
146 | |
147 | template <typename WrapperFunctionImplT, |
148 | template <typename> class ResultSerializer, typename... SPSTagTs> |
149 | class WrapperFunctionHandlerHelper |
150 | : public WrapperFunctionHandlerHelper< |
151 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), |
152 | ResultSerializer, SPSTagTs...> {}; |
153 | |
154 | template <typename RetT, typename... ArgTs, |
155 | template <typename> class ResultSerializer, typename... SPSTagTs> |
156 | class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
157 | SPSTagTs...> { |
158 | public: |
159 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; |
160 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; |
161 | |
162 | template <typename HandlerT> |
163 | static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, |
164 | size_t ArgSize) { |
165 | ArgTuple Args; |
166 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) |
167 | return WrapperFunctionResult::createOutOfBandError( |
168 | Msg: "Could not deserialize arguments for wrapper function call" ); |
169 | |
170 | auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( |
171 | std::forward<HandlerT>(H), Args, ArgIndices{}); |
172 | |
173 | return ResultSerializer<decltype(HandlerResult)>::serialize( |
174 | std::move(HandlerResult)); |
175 | } |
176 | |
177 | private: |
178 | template <std::size_t... I> |
179 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, |
180 | std::index_sequence<I...>) { |
181 | SPSInputBuffer IB(ArgData, ArgSize); |
182 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); |
183 | } |
184 | }; |
185 | |
186 | // Map function pointers to function types. |
187 | template <typename RetT, typename... ArgTs, |
188 | template <typename> class ResultSerializer, typename... SPSTagTs> |
189 | class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, |
190 | SPSTagTs...> |
191 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
192 | SPSTagTs...> {}; |
193 | |
194 | // Map non-const member function types to function types. |
195 | template <typename ClassT, typename RetT, typename... ArgTs, |
196 | template <typename> class ResultSerializer, typename... SPSTagTs> |
197 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, |
198 | SPSTagTs...> |
199 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
200 | SPSTagTs...> {}; |
201 | |
202 | // Map const member function types to function types. |
203 | template <typename ClassT, typename RetT, typename... ArgTs, |
204 | template <typename> class ResultSerializer, typename... SPSTagTs> |
205 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, |
206 | ResultSerializer, SPSTagTs...> |
207 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, |
208 | SPSTagTs...> {}; |
209 | |
210 | template <typename SPSRetTagT, typename RetT> class ResultSerializer { |
211 | public: |
212 | static WrapperFunctionResult serialize(RetT Result) { |
213 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); |
214 | } |
215 | }; |
216 | |
217 | template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { |
218 | public: |
219 | static WrapperFunctionResult serialize(Error Err) { |
220 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( |
221 | toSPSSerializable(Err: std::move(t&: Err))); |
222 | } |
223 | }; |
224 | |
225 | template <typename SPSRetTagT, typename T> |
226 | class ResultSerializer<SPSRetTagT, Expected<T>> { |
227 | public: |
228 | static WrapperFunctionResult serialize(Expected<T> E) { |
229 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( |
230 | toSPSSerializable(std::move(E))); |
231 | } |
232 | }; |
233 | |
234 | template <typename SPSRetTagT, typename RetT> class ResultDeserializer { |
235 | public: |
236 | static void makeSafe(RetT &Result) {} |
237 | |
238 | static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { |
239 | SPSInputBuffer IB(ArgData, ArgSize); |
240 | if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) |
241 | return make_error<StringError>( |
242 | Args: "Error deserializing return value from blob in call" ); |
243 | return Error::success(); |
244 | } |
245 | }; |
246 | |
247 | template <> class ResultDeserializer<SPSError, Error> { |
248 | public: |
249 | static void makeSafe(Error &Err) { cantFail(Err: std::move(t&: Err)); } |
250 | |
251 | static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { |
252 | SPSInputBuffer IB(ArgData, ArgSize); |
253 | SPSSerializableError BSE; |
254 | if (!SPSArgList<SPSError>::deserialize(IB, Arg&: BSE)) |
255 | return make_error<StringError>( |
256 | Args: "Error deserializing return value from blob in call" ); |
257 | Err = fromSPSSerializable(BSE: std::move(t&: BSE)); |
258 | return Error::success(); |
259 | } |
260 | }; |
261 | |
262 | template <typename SPSTagT, typename T> |
263 | class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { |
264 | public: |
265 | static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } |
266 | |
267 | static Error deserialize(Expected<T> &E, const char *ArgData, |
268 | size_t ArgSize) { |
269 | SPSInputBuffer IB(ArgData, ArgSize); |
270 | SPSSerializableExpected<T> BSE; |
271 | if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) |
272 | return make_error<StringError>( |
273 | Args: "Error deserializing return value from blob in call" ); |
274 | E = fromSPSSerializable(std::move(BSE)); |
275 | return Error::success(); |
276 | } |
277 | }; |
278 | |
279 | } // end namespace detail |
280 | |
281 | template <typename SPSSignature> class WrapperFunction; |
282 | |
283 | template <typename SPSRetTagT, typename... SPSTagTs> |
284 | class WrapperFunction<SPSRetTagT(SPSTagTs...)> { |
285 | private: |
286 | template <typename RetT> |
287 | using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; |
288 | |
289 | public: |
290 | template <typename DispatchFn, typename RetT, typename... ArgTs> |
291 | static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) { |
292 | |
293 | // RetT might be an Error or Expected value. Set the checked flag now: |
294 | // we don't want the user to have to check the unused result if this |
295 | // operation fails. |
296 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); |
297 | |
298 | auto ArgBuffer = |
299 | WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); |
300 | if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) |
301 | return make_error<StringError>(Args&: ErrMsg); |
302 | |
303 | WrapperFunctionResult ResultBuffer = |
304 | Dispatch(ArgBuffer.data(), ArgBuffer.size()); |
305 | |
306 | if (auto ErrMsg = ResultBuffer.getOutOfBandError()) |
307 | return make_error<StringError>(Args&: ErrMsg); |
308 | |
309 | return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( |
310 | Result, ResultBuffer.data(), ResultBuffer.size()); |
311 | } |
312 | |
313 | template <typename HandlerT> |
314 | static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, |
315 | HandlerT &&Handler) { |
316 | using WFHH = |
317 | detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, |
318 | ResultSerializer, SPSTagTs...>; |
319 | return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); |
320 | } |
321 | |
322 | private: |
323 | template <typename T> static const T &makeSerializable(const T &Value) { |
324 | return Value; |
325 | } |
326 | |
327 | static detail::SPSSerializableError makeSerializable(Error Err) { |
328 | return detail::toSPSSerializable(Err: std::move(t&: Err)); |
329 | } |
330 | |
331 | template <typename T> |
332 | static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { |
333 | return detail::toSPSSerializable(std::move(E)); |
334 | } |
335 | }; |
336 | |
337 | template <typename... SPSTagTs> |
338 | class WrapperFunction<void(SPSTagTs...)> |
339 | : private WrapperFunction<SPSEmpty(SPSTagTs...)> { |
340 | public: |
341 | template <typename DispatchFn, typename... ArgTs> |
342 | static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) { |
343 | SPSEmpty BE; |
344 | return WrapperFunction<SPSEmpty(SPSTagTs...)>::call( |
345 | std::forward<DispatchFn>(Dispatch), BE, Args...); |
346 | } |
347 | |
348 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; |
349 | }; |
350 | |
351 | /// A function object that takes an ExecutorAddr as its first argument, |
352 | /// casts that address to a ClassT*, then calls the given method on that |
353 | /// pointer passing in the remaining function arguments. This utility |
354 | /// removes some of the boilerplate from writing wrappers for method calls. |
355 | /// |
356 | /// @code{.cpp} |
357 | /// class MyClass { |
358 | /// public: |
359 | /// void myMethod(uint32_t, bool) { ... } |
360 | /// }; |
361 | /// |
362 | /// // SPS Method signature -- note MyClass object address as first argument. |
363 | /// using SPSMyMethodWrapperSignature = |
364 | /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; |
365 | /// |
366 | /// WrapperFunctionResult |
367 | /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { |
368 | /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( |
369 | /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); |
370 | /// } |
371 | /// @endcode |
372 | /// |
373 | template <typename RetT, typename ClassT, typename... ArgTs> |
374 | class MethodWrapperHandler { |
375 | public: |
376 | using MethodT = RetT (ClassT::*)(ArgTs...); |
377 | MethodWrapperHandler(MethodT M) : M(M) {} |
378 | RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { |
379 | return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); |
380 | } |
381 | |
382 | private: |
383 | MethodT M; |
384 | }; |
385 | |
386 | /// Create a MethodWrapperHandler object from the given method pointer. |
387 | template <typename RetT, typename ClassT, typename... ArgTs> |
388 | MethodWrapperHandler<RetT, ClassT, ArgTs...> |
389 | makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { |
390 | return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); |
391 | } |
392 | |
393 | /// Represents a call to a wrapper function. |
394 | class WrapperFunctionCall { |
395 | public: |
396 | // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a |
397 | // smallvector. |
398 | using ArgDataBufferType = std::vector<char>; |
399 | |
400 | /// Create a WrapperFunctionCall using the given SPS serializer to serialize |
401 | /// the arguments. |
402 | template <typename SPSSerializer, typename... ArgTs> |
403 | static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, |
404 | const ArgTs &...Args) { |
405 | ArgDataBufferType ArgData; |
406 | ArgData.resize(SPSSerializer::size(Args...)); |
407 | SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), |
408 | ArgData.size()); |
409 | if (SPSSerializer::serialize(OB, Args...)) |
410 | return WrapperFunctionCall(FnAddr, std::move(t&: ArgData)); |
411 | return make_error<StringError>(Args: "Cannot serialize arguments for " |
412 | "AllocActionCall" ); |
413 | } |
414 | |
415 | WrapperFunctionCall() = default; |
416 | |
417 | /// Create a WrapperFunctionCall from a target function and arg buffer. |
418 | WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) |
419 | : FnAddr(FnAddr), ArgData(std::move(t&: ArgData)) {} |
420 | |
421 | /// Returns the address to be called. |
422 | const ExecutorAddr &getCallee() const { return FnAddr; } |
423 | |
424 | /// Returns the argument data. |
425 | const ArgDataBufferType &getArgData() const { return ArgData; } |
426 | |
427 | /// WrapperFunctionCalls convert to true if the callee is non-null. |
428 | explicit operator bool() const { return !!FnAddr; } |
429 | |
430 | /// Run call returning raw WrapperFunctionResult. |
431 | WrapperFunctionResult run() const { |
432 | using FnTy = |
433 | orc_rt_WrapperFunctionResult(const char *ArgData, size_t ArgSize); |
434 | return WrapperFunctionResult( |
435 | FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); |
436 | } |
437 | |
438 | /// Run call and deserialize result using SPS. |
439 | template <typename SPSRetT, typename RetT> |
440 | std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> |
441 | runWithSPSRet(RetT &RetVal) const { |
442 | auto WFR = run(); |
443 | if (const char *ErrMsg = WFR.getOutOfBandError()) |
444 | return make_error<StringError>(Args&: ErrMsg); |
445 | SPSInputBuffer IB(WFR.data(), WFR.size()); |
446 | if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) |
447 | return make_error<StringError>(Args: "Could not deserialize result from " |
448 | "serialized wrapper function call" ); |
449 | return Error::success(); |
450 | } |
451 | |
452 | /// Overload for SPS functions returning void. |
453 | template <typename SPSRetT> |
454 | std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> |
455 | runWithSPSRet() const { |
456 | SPSEmpty E; |
457 | return runWithSPSRet<SPSEmpty>(RetVal&: E); |
458 | } |
459 | |
460 | /// Run call and deserialize an SPSError result. SPSError returns and |
461 | /// deserialization failures are merged into the returned error. |
462 | Error runWithSPSRetErrorMerged() const { |
463 | detail::SPSSerializableError RetErr; |
464 | if (auto Err = runWithSPSRet<SPSError>(RetVal&: RetErr)) |
465 | return Err; |
466 | return detail::fromSPSSerializable(BSE: std::move(t&: RetErr)); |
467 | } |
468 | |
469 | private: |
470 | ExecutorAddr FnAddr; |
471 | std::vector<char> ArgData; |
472 | }; |
473 | |
474 | using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; |
475 | |
476 | template <> |
477 | class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { |
478 | public: |
479 | static size_t size(const WrapperFunctionCall &WFC) { |
480 | return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size( |
481 | Arg: WFC.getCallee(), Args: WFC.getArgData()); |
482 | } |
483 | |
484 | static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { |
485 | return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize( |
486 | OB, Arg: WFC.getCallee(), Args: WFC.getArgData()); |
487 | } |
488 | |
489 | static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { |
490 | ExecutorAddr FnAddr; |
491 | WrapperFunctionCall::ArgDataBufferType ArgData; |
492 | if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, Arg&: FnAddr, Args&: ArgData)) |
493 | return false; |
494 | WFC = WrapperFunctionCall(FnAddr, std::move(t&: ArgData)); |
495 | return true; |
496 | } |
497 | }; |
498 | |
499 | } // namespace orc_rt |
500 | |
501 | #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H |
502 | |