| 1 | //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// | 
|---|
| 2 | // | 
|---|
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|---|
| 4 | // See https://llvm.org/LICENSE.txt for license information. | 
|---|
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|---|
| 6 | // | 
|---|
| 7 | //===----------------------------------------------------------------------===// | 
|---|
| 8 | // | 
|---|
| 9 | // This file is a part of the ORC runtime support library. | 
|---|
| 10 | // | 
|---|
| 11 | //===----------------------------------------------------------------------===// | 
|---|
| 12 |  | 
|---|
| 13 | #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H | 
|---|
| 14 | #define ORC_RT_WRAPPER_FUNCTION_UTILS_H | 
|---|
| 15 |  | 
|---|
| 16 | #include "error.h" | 
|---|
| 17 | #include "executor_address.h" | 
|---|
| 18 | #include "orc_rt/c_api.h" | 
|---|
| 19 | #include "simple_packed_serialization.h" | 
|---|
| 20 | #include <type_traits> | 
|---|
| 21 |  | 
|---|
| 22 | namespace orc_rt { | 
|---|
| 23 |  | 
|---|
| 24 | /// C++ wrapper function result: Same as orc_rt_WrapperFunctionResult but | 
|---|
| 25 | /// auto-releases memory. | 
|---|
| 26 | class WrapperFunctionResult { | 
|---|
| 27 | public: | 
|---|
| 28 | /// Create a default WrapperFunctionResult. | 
|---|
| 29 | WrapperFunctionResult() { orc_rt_WrapperFunctionResultInit(R: &R); } | 
|---|
| 30 |  | 
|---|
| 31 | /// Create a WrapperFunctionResult from a WrapperFunctionResult. This | 
|---|
| 32 | /// instance takes ownership of the result object and will automatically | 
|---|
| 33 | /// call dispose on the result upon destruction. | 
|---|
| 34 | WrapperFunctionResult(orc_rt_WrapperFunctionResult R) : R(R) {} | 
|---|
| 35 |  | 
|---|
| 36 | WrapperFunctionResult(const WrapperFunctionResult &) = delete; | 
|---|
| 37 | WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; | 
|---|
| 38 |  | 
|---|
| 39 | WrapperFunctionResult(WrapperFunctionResult &&Other) { | 
|---|
| 40 | orc_rt_WrapperFunctionResultInit(R: &R); | 
|---|
| 41 | std::swap(a&: R, b&: Other.R); | 
|---|
| 42 | } | 
|---|
| 43 |  | 
|---|
| 44 | WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { | 
|---|
| 45 | orc_rt_WrapperFunctionResult Tmp; | 
|---|
| 46 | orc_rt_WrapperFunctionResultInit(R: &Tmp); | 
|---|
| 47 | std::swap(a&: Tmp, b&: Other.R); | 
|---|
| 48 | std::swap(a&: R, b&: Tmp); | 
|---|
| 49 | return *this; | 
|---|
| 50 | } | 
|---|
| 51 |  | 
|---|
| 52 | ~WrapperFunctionResult() { orc_rt_DisposeWrapperFunctionResult(R: &R); } | 
|---|
| 53 |  | 
|---|
| 54 | /// Relinquish ownership of and return the | 
|---|
| 55 | /// orc_rt_WrapperFunctionResult. | 
|---|
| 56 | orc_rt_WrapperFunctionResult release() { | 
|---|
| 57 | orc_rt_WrapperFunctionResult Tmp; | 
|---|
| 58 | orc_rt_WrapperFunctionResultInit(R: &Tmp); | 
|---|
| 59 | std::swap(a&: R, b&: Tmp); | 
|---|
| 60 | return Tmp; | 
|---|
| 61 | } | 
|---|
| 62 |  | 
|---|
| 63 | /// Get a pointer to the data contained in this instance. | 
|---|
| 64 | char *data() { return orc_rt_WrapperFunctionResultData(R: &R); } | 
|---|
| 65 |  | 
|---|
| 66 | /// Returns the size of the data contained in this instance. | 
|---|
| 67 | size_t size() const { return orc_rt_WrapperFunctionResultSize(R: &R); } | 
|---|
| 68 |  | 
|---|
| 69 | /// Returns true if this value is equivalent to a default-constructed | 
|---|
| 70 | /// WrapperFunctionResult. | 
|---|
| 71 | bool empty() const { return orc_rt_WrapperFunctionResultEmpty(R: &R); } | 
|---|
| 72 |  | 
|---|
| 73 | /// Create a WrapperFunctionResult with the given size and return a pointer | 
|---|
| 74 | /// to the underlying memory. | 
|---|
| 75 | static WrapperFunctionResult allocate(size_t Size) { | 
|---|
| 76 | WrapperFunctionResult R; | 
|---|
| 77 | R.R = orc_rt_WrapperFunctionResultAllocate(Size); | 
|---|
| 78 | return R; | 
|---|
| 79 | } | 
|---|
| 80 |  | 
|---|
| 81 | /// Copy from the given char range. | 
|---|
| 82 | static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { | 
|---|
| 83 | return orc_rt_CreateWrapperFunctionResultFromRange(Data: Source, Size); | 
|---|
| 84 | } | 
|---|
| 85 |  | 
|---|
| 86 | /// Copy from the given null-terminated string (includes the null-terminator). | 
|---|
| 87 | static WrapperFunctionResult copyFrom(const char *Source) { | 
|---|
| 88 | return orc_rt_CreateWrapperFunctionResultFromString(Source); | 
|---|
| 89 | } | 
|---|
| 90 |  | 
|---|
| 91 | /// Copy from the given std::string (includes the null terminator). | 
|---|
| 92 | static WrapperFunctionResult copyFrom(const std::string &Source) { | 
|---|
| 93 | return copyFrom(Source: Source.c_str()); | 
|---|
| 94 | } | 
|---|
| 95 |  | 
|---|
| 96 | /// Create an out-of-band error by copying the given string. | 
|---|
| 97 | static WrapperFunctionResult createOutOfBandError(const char *Msg) { | 
|---|
| 98 | return orc_rt_CreateWrapperFunctionResultFromOutOfBandError(ErrMsg: Msg); | 
|---|
| 99 | } | 
|---|
| 100 |  | 
|---|
| 101 | /// Create an out-of-band error by copying the given string. | 
|---|
| 102 | static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { | 
|---|
| 103 | return createOutOfBandError(Msg: Msg.c_str()); | 
|---|
| 104 | } | 
|---|
| 105 |  | 
|---|
| 106 | template <typename SPSArgListT, typename... ArgTs> | 
|---|
| 107 | static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { | 
|---|
| 108 | auto Result = allocate(Size: SPSArgListT::size(Args...)); | 
|---|
| 109 | SPSOutputBuffer OB(Result.data(), Result.size()); | 
|---|
| 110 | if (!SPSArgListT::serialize(OB, Args...)) | 
|---|
| 111 | return createOutOfBandError( | 
|---|
| 112 | Msg: "Error serializing arguments to blob in call"); | 
|---|
| 113 | return Result; | 
|---|
| 114 | } | 
|---|
| 115 |  | 
|---|
| 116 | /// If this value is an out-of-band error then this returns the error message, | 
|---|
| 117 | /// otherwise returns nullptr. | 
|---|
| 118 | const char *getOutOfBandError() const { | 
|---|
| 119 | return orc_rt_WrapperFunctionResultGetOutOfBandError(R: &R); | 
|---|
| 120 | } | 
|---|
| 121 |  | 
|---|
| 122 | private: | 
|---|
| 123 | orc_rt_WrapperFunctionResult R; | 
|---|
| 124 | }; | 
|---|
| 125 |  | 
|---|
| 126 | namespace detail { | 
|---|
| 127 |  | 
|---|
| 128 | template <typename RetT> class WrapperFunctionHandlerCaller { | 
|---|
| 129 | public: | 
|---|
| 130 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> | 
|---|
| 131 | static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, | 
|---|
| 132 | std::index_sequence<I...>) { | 
|---|
| 133 | return std::forward<HandlerT>(H)(std::get<I>(Args)...); | 
|---|
| 134 | } | 
|---|
| 135 | }; | 
|---|
| 136 |  | 
|---|
| 137 | template <> class WrapperFunctionHandlerCaller<void> { | 
|---|
| 138 | public: | 
|---|
| 139 | template <typename HandlerT, typename ArgTupleT, std::size_t... I> | 
|---|
| 140 | static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, | 
|---|
| 141 | std::index_sequence<I...>) { | 
|---|
| 142 | std::forward<HandlerT>(H)(std::get<I>(Args)...); | 
|---|
| 143 | return SPSEmpty(); | 
|---|
| 144 | } | 
|---|
| 145 | }; | 
|---|
| 146 |  | 
|---|
| 147 | template <typename WrapperFunctionImplT, | 
|---|
| 148 | template <typename> class ResultSerializer, typename... SPSTagTs> | 
|---|
| 149 | class WrapperFunctionHandlerHelper | 
|---|
| 150 | : public WrapperFunctionHandlerHelper< | 
|---|
| 151 | decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), | 
|---|
| 152 | ResultSerializer, SPSTagTs...> {}; | 
|---|
| 153 |  | 
|---|
| 154 | template <typename RetT, typename... ArgTs, | 
|---|
| 155 | template <typename> class ResultSerializer, typename... SPSTagTs> | 
|---|
| 156 | class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, | 
|---|
| 157 | SPSTagTs...> { | 
|---|
| 158 | public: | 
|---|
| 159 | using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; | 
|---|
| 160 | using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; | 
|---|
| 161 |  | 
|---|
| 162 | template <typename HandlerT> | 
|---|
| 163 | static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, | 
|---|
| 164 | size_t ArgSize) { | 
|---|
| 165 | ArgTuple Args; | 
|---|
| 166 | if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) | 
|---|
| 167 | return WrapperFunctionResult::createOutOfBandError( | 
|---|
| 168 | Msg: "Could not deserialize arguments for wrapper function call"); | 
|---|
| 169 |  | 
|---|
| 170 | auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( | 
|---|
| 171 | std::forward<HandlerT>(H), Args, ArgIndices{}); | 
|---|
| 172 |  | 
|---|
| 173 | return ResultSerializer<decltype(HandlerResult)>::serialize( | 
|---|
| 174 | std::move(HandlerResult)); | 
|---|
| 175 | } | 
|---|
| 176 |  | 
|---|
| 177 | private: | 
|---|
| 178 | template <std::size_t... I> | 
|---|
| 179 | static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, | 
|---|
| 180 | std::index_sequence<I...>) { | 
|---|
| 181 | SPSInputBuffer IB(ArgData, ArgSize); | 
|---|
| 182 | return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); | 
|---|
| 183 | } | 
|---|
| 184 | }; | 
|---|
| 185 |  | 
|---|
| 186 | // Map function pointers to function types. | 
|---|
| 187 | template <typename RetT, typename... ArgTs, | 
|---|
| 188 | template <typename> class ResultSerializer, typename... SPSTagTs> | 
|---|
| 189 | class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, | 
|---|
| 190 | SPSTagTs...> | 
|---|
| 191 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, | 
|---|
| 192 | SPSTagTs...> {}; | 
|---|
| 193 |  | 
|---|
| 194 | // Map non-const member function types to function types. | 
|---|
| 195 | template <typename ClassT, typename RetT, typename... ArgTs, | 
|---|
| 196 | template <typename> class ResultSerializer, typename... SPSTagTs> | 
|---|
| 197 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, | 
|---|
| 198 | SPSTagTs...> | 
|---|
| 199 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, | 
|---|
| 200 | SPSTagTs...> {}; | 
|---|
| 201 |  | 
|---|
| 202 | // Map const member function types to function types. | 
|---|
| 203 | template <typename ClassT, typename RetT, typename... ArgTs, | 
|---|
| 204 | template <typename> class ResultSerializer, typename... SPSTagTs> | 
|---|
| 205 | class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, | 
|---|
| 206 | ResultSerializer, SPSTagTs...> | 
|---|
| 207 | : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, | 
|---|
| 208 | SPSTagTs...> {}; | 
|---|
| 209 |  | 
|---|
| 210 | template <typename SPSRetTagT, typename RetT> class ResultSerializer { | 
|---|
| 211 | public: | 
|---|
| 212 | static WrapperFunctionResult serialize(RetT Result) { | 
|---|
| 213 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); | 
|---|
| 214 | } | 
|---|
| 215 | }; | 
|---|
| 216 |  | 
|---|
| 217 | template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { | 
|---|
| 218 | public: | 
|---|
| 219 | static WrapperFunctionResult serialize(Error Err) { | 
|---|
| 220 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( | 
|---|
| 221 | toSPSSerializable(Err: std::move(t&: Err))); | 
|---|
| 222 | } | 
|---|
| 223 | }; | 
|---|
| 224 |  | 
|---|
| 225 | template <typename SPSRetTagT, typename T> | 
|---|
| 226 | class ResultSerializer<SPSRetTagT, Expected<T>> { | 
|---|
| 227 | public: | 
|---|
| 228 | static WrapperFunctionResult serialize(Expected<T> E) { | 
|---|
| 229 | return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( | 
|---|
| 230 | toSPSSerializable(std::move(E))); | 
|---|
| 231 | } | 
|---|
| 232 | }; | 
|---|
| 233 |  | 
|---|
| 234 | template <typename SPSRetTagT, typename RetT> class ResultDeserializer { | 
|---|
| 235 | public: | 
|---|
| 236 | static void makeSafe(RetT &Result) {} | 
|---|
| 237 |  | 
|---|
| 238 | static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { | 
|---|
| 239 | SPSInputBuffer IB(ArgData, ArgSize); | 
|---|
| 240 | if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) | 
|---|
| 241 | return make_error<StringError>( | 
|---|
| 242 | Args: "Error deserializing return value from blob in call"); | 
|---|
| 243 | return Error::success(); | 
|---|
| 244 | } | 
|---|
| 245 | }; | 
|---|
| 246 |  | 
|---|
| 247 | template <> class ResultDeserializer<SPSError, Error> { | 
|---|
| 248 | public: | 
|---|
| 249 | static void makeSafe(Error &Err) { cantFail(Err: std::move(t&: Err)); } | 
|---|
| 250 |  | 
|---|
| 251 | static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { | 
|---|
| 252 | SPSInputBuffer IB(ArgData, ArgSize); | 
|---|
| 253 | SPSSerializableError BSE; | 
|---|
| 254 | if (!SPSArgList<SPSError>::deserialize(IB, Arg&: BSE)) | 
|---|
| 255 | return make_error<StringError>( | 
|---|
| 256 | Args: "Error deserializing return value from blob in call"); | 
|---|
| 257 | Err = fromSPSSerializable(BSE: std::move(t&: BSE)); | 
|---|
| 258 | return Error::success(); | 
|---|
| 259 | } | 
|---|
| 260 | }; | 
|---|
| 261 |  | 
|---|
| 262 | template <typename SPSTagT, typename T> | 
|---|
| 263 | class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { | 
|---|
| 264 | public: | 
|---|
| 265 | static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } | 
|---|
| 266 |  | 
|---|
| 267 | static Error deserialize(Expected<T> &E, const char *ArgData, | 
|---|
| 268 | size_t ArgSize) { | 
|---|
| 269 | SPSInputBuffer IB(ArgData, ArgSize); | 
|---|
| 270 | SPSSerializableExpected<T> BSE; | 
|---|
| 271 | if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) | 
|---|
| 272 | return make_error<StringError>( | 
|---|
| 273 | Args: "Error deserializing return value from blob in call"); | 
|---|
| 274 | E = fromSPSSerializable(std::move(BSE)); | 
|---|
| 275 | return Error::success(); | 
|---|
| 276 | } | 
|---|
| 277 | }; | 
|---|
| 278 |  | 
|---|
| 279 | } // end namespace detail | 
|---|
| 280 |  | 
|---|
| 281 | template <typename SPSSignature> class WrapperFunction; | 
|---|
| 282 |  | 
|---|
| 283 | template <typename SPSRetTagT, typename... SPSTagTs> | 
|---|
| 284 | class WrapperFunction<SPSRetTagT(SPSTagTs...)> { | 
|---|
| 285 | private: | 
|---|
| 286 | template <typename RetT> | 
|---|
| 287 | using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; | 
|---|
| 288 |  | 
|---|
| 289 | public: | 
|---|
| 290 | template <typename DispatchFn, typename RetT, typename... ArgTs> | 
|---|
| 291 | static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) { | 
|---|
| 292 |  | 
|---|
| 293 | // RetT might be an Error or Expected value. Set the checked flag now: | 
|---|
| 294 | // we don't want the user to have to check the unused result if this | 
|---|
| 295 | // operation fails. | 
|---|
| 296 | detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); | 
|---|
| 297 |  | 
|---|
| 298 | auto ArgBuffer = | 
|---|
| 299 | WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); | 
|---|
| 300 | if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) | 
|---|
| 301 | return make_error<StringError>(Args&: ErrMsg); | 
|---|
| 302 |  | 
|---|
| 303 | WrapperFunctionResult ResultBuffer = | 
|---|
| 304 | Dispatch(ArgBuffer.data(), ArgBuffer.size()); | 
|---|
| 305 |  | 
|---|
| 306 | if (auto ErrMsg = ResultBuffer.getOutOfBandError()) | 
|---|
| 307 | return make_error<StringError>(Args&: ErrMsg); | 
|---|
| 308 |  | 
|---|
| 309 | return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( | 
|---|
| 310 | Result, ResultBuffer.data(), ResultBuffer.size()); | 
|---|
| 311 | } | 
|---|
| 312 |  | 
|---|
| 313 | template <typename HandlerT> | 
|---|
| 314 | static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, | 
|---|
| 315 | HandlerT &&Handler) { | 
|---|
| 316 | using WFHH = | 
|---|
| 317 | detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, | 
|---|
| 318 | ResultSerializer, SPSTagTs...>; | 
|---|
| 319 | return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); | 
|---|
| 320 | } | 
|---|
| 321 |  | 
|---|
| 322 | private: | 
|---|
| 323 | template <typename T> static const T &makeSerializable(const T &Value) { | 
|---|
| 324 | return Value; | 
|---|
| 325 | } | 
|---|
| 326 |  | 
|---|
| 327 | static detail::SPSSerializableError makeSerializable(Error Err) { | 
|---|
| 328 | return detail::toSPSSerializable(Err: std::move(t&: Err)); | 
|---|
| 329 | } | 
|---|
| 330 |  | 
|---|
| 331 | template <typename T> | 
|---|
| 332 | static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { | 
|---|
| 333 | return detail::toSPSSerializable(std::move(E)); | 
|---|
| 334 | } | 
|---|
| 335 | }; | 
|---|
| 336 |  | 
|---|
| 337 | template <typename... SPSTagTs> | 
|---|
| 338 | class WrapperFunction<void(SPSTagTs...)> | 
|---|
| 339 | : private WrapperFunction<SPSEmpty(SPSTagTs...)> { | 
|---|
| 340 | public: | 
|---|
| 341 | template <typename DispatchFn, typename... ArgTs> | 
|---|
| 342 | static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) { | 
|---|
| 343 | SPSEmpty BE; | 
|---|
| 344 | return WrapperFunction<SPSEmpty(SPSTagTs...)>::call( | 
|---|
| 345 | std::forward<DispatchFn>(Dispatch), BE, Args...); | 
|---|
| 346 | } | 
|---|
| 347 |  | 
|---|
| 348 | using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; | 
|---|
| 349 | }; | 
|---|
| 350 |  | 
|---|
| 351 | /// A function object that takes an ExecutorAddr as its first argument, | 
|---|
| 352 | /// casts that address to a ClassT*, then calls the given method on that | 
|---|
| 353 | /// pointer passing in the remaining function arguments. This utility | 
|---|
| 354 | /// removes some of the boilerplate from writing wrappers for method calls. | 
|---|
| 355 | /// | 
|---|
| 356 | ///   @code{.cpp} | 
|---|
| 357 | ///   class MyClass { | 
|---|
| 358 | ///   public: | 
|---|
| 359 | ///     void myMethod(uint32_t, bool) { ... } | 
|---|
| 360 | ///   }; | 
|---|
| 361 | /// | 
|---|
| 362 | ///   // SPS Method signature -- note MyClass object address as first argument. | 
|---|
| 363 | ///   using SPSMyMethodWrapperSignature = | 
|---|
| 364 | ///     SPSTuple<SPSExecutorAddr, uint32_t, bool>; | 
|---|
| 365 | /// | 
|---|
| 366 | ///   WrapperFunctionResult | 
|---|
| 367 | ///   myMethodCallWrapper(const char *ArgData, size_t ArgSize) { | 
|---|
| 368 | ///     return WrapperFunction<SPSMyMethodWrapperSignature>::handle( | 
|---|
| 369 | ///        ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); | 
|---|
| 370 | ///   } | 
|---|
| 371 | ///   @endcode | 
|---|
| 372 | /// | 
|---|
| 373 | template <typename RetT, typename ClassT, typename... ArgTs> | 
|---|
| 374 | class MethodWrapperHandler { | 
|---|
| 375 | public: | 
|---|
| 376 | using MethodT = RetT (ClassT::*)(ArgTs...); | 
|---|
| 377 | MethodWrapperHandler(MethodT M) : M(M) {} | 
|---|
| 378 | RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { | 
|---|
| 379 | return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); | 
|---|
| 380 | } | 
|---|
| 381 |  | 
|---|
| 382 | private: | 
|---|
| 383 | MethodT M; | 
|---|
| 384 | }; | 
|---|
| 385 |  | 
|---|
| 386 | /// Create a MethodWrapperHandler object from the given method pointer. | 
|---|
| 387 | template <typename RetT, typename ClassT, typename... ArgTs> | 
|---|
| 388 | MethodWrapperHandler<RetT, ClassT, ArgTs...> | 
|---|
| 389 | makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { | 
|---|
| 390 | return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); | 
|---|
| 391 | } | 
|---|
| 392 |  | 
|---|
| 393 | /// Represents a call to a wrapper function. | 
|---|
| 394 | class WrapperFunctionCall { | 
|---|
| 395 | public: | 
|---|
| 396 | // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a | 
|---|
| 397 | // smallvector. | 
|---|
| 398 | using ArgDataBufferType = std::vector<char>; | 
|---|
| 399 |  | 
|---|
| 400 | /// Create a WrapperFunctionCall using the given SPS serializer to serialize | 
|---|
| 401 | /// the arguments. | 
|---|
| 402 | template <typename SPSSerializer, typename... ArgTs> | 
|---|
| 403 | static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, | 
|---|
| 404 | const ArgTs &...Args) { | 
|---|
| 405 | ArgDataBufferType ArgData; | 
|---|
| 406 | ArgData.resize(SPSSerializer::size(Args...)); | 
|---|
| 407 | SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), | 
|---|
| 408 | ArgData.size()); | 
|---|
| 409 | if (SPSSerializer::serialize(OB, Args...)) | 
|---|
| 410 | return WrapperFunctionCall(FnAddr, std::move(t&: ArgData)); | 
|---|
| 411 | return make_error<StringError>(Args: "Cannot serialize arguments for " | 
|---|
| 412 | "AllocActionCall"); | 
|---|
| 413 | } | 
|---|
| 414 |  | 
|---|
| 415 | WrapperFunctionCall() = default; | 
|---|
| 416 |  | 
|---|
| 417 | /// Create a WrapperFunctionCall from a target function and arg buffer. | 
|---|
| 418 | WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) | 
|---|
| 419 | : FnAddr(FnAddr), ArgData(std::move(t&: ArgData)) {} | 
|---|
| 420 |  | 
|---|
| 421 | /// Returns the address to be called. | 
|---|
| 422 | const ExecutorAddr &getCallee() const { return FnAddr; } | 
|---|
| 423 |  | 
|---|
| 424 | /// Returns the argument data. | 
|---|
| 425 | const ArgDataBufferType &getArgData() const { return ArgData; } | 
|---|
| 426 |  | 
|---|
| 427 | /// WrapperFunctionCalls convert to true if the callee is non-null. | 
|---|
| 428 | explicit operator bool() const { return !!FnAddr; } | 
|---|
| 429 |  | 
|---|
| 430 | /// Run call returning raw WrapperFunctionResult. | 
|---|
| 431 | WrapperFunctionResult run() const { | 
|---|
| 432 | using FnTy = | 
|---|
| 433 | orc_rt_WrapperFunctionResult(const char *ArgData, size_t ArgSize); | 
|---|
| 434 | return WrapperFunctionResult( | 
|---|
| 435 | FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); | 
|---|
| 436 | } | 
|---|
| 437 |  | 
|---|
| 438 | /// Run call and deserialize result using SPS. | 
|---|
| 439 | template <typename SPSRetT, typename RetT> | 
|---|
| 440 | std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> | 
|---|
| 441 | runWithSPSRet(RetT &RetVal) const { | 
|---|
| 442 | auto WFR = run(); | 
|---|
| 443 | if (const char *ErrMsg = WFR.getOutOfBandError()) | 
|---|
| 444 | return make_error<StringError>(Args&: ErrMsg); | 
|---|
| 445 | SPSInputBuffer IB(WFR.data(), WFR.size()); | 
|---|
| 446 | if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) | 
|---|
| 447 | return make_error<StringError>(Args: "Could not deserialize result from " | 
|---|
| 448 | "serialized wrapper function call"); | 
|---|
| 449 | return Error::success(); | 
|---|
| 450 | } | 
|---|
| 451 |  | 
|---|
| 452 | /// Overload for SPS functions returning void. | 
|---|
| 453 | template <typename SPSRetT> | 
|---|
| 454 | std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> | 
|---|
| 455 | runWithSPSRet() const { | 
|---|
| 456 | SPSEmpty E; | 
|---|
| 457 | return runWithSPSRet<SPSEmpty>(RetVal&: E); | 
|---|
| 458 | } | 
|---|
| 459 |  | 
|---|
| 460 | /// Run call and deserialize an SPSError result. SPSError returns and | 
|---|
| 461 | /// deserialization failures are merged into the returned error. | 
|---|
| 462 | Error runWithSPSRetErrorMerged() const { | 
|---|
| 463 | detail::SPSSerializableError RetErr; | 
|---|
| 464 | if (auto Err = runWithSPSRet<SPSError>(RetVal&: RetErr)) | 
|---|
| 465 | return Err; | 
|---|
| 466 | return detail::fromSPSSerializable(BSE: std::move(t&: RetErr)); | 
|---|
| 467 | } | 
|---|
| 468 |  | 
|---|
| 469 | private: | 
|---|
| 470 | ExecutorAddr FnAddr; | 
|---|
| 471 | std::vector<char> ArgData; | 
|---|
| 472 | }; | 
|---|
| 473 |  | 
|---|
| 474 | using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; | 
|---|
| 475 |  | 
|---|
| 476 | template <> | 
|---|
| 477 | class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { | 
|---|
| 478 | public: | 
|---|
| 479 | static size_t size(const WrapperFunctionCall &WFC) { | 
|---|
| 480 | return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size( | 
|---|
| 481 | Arg: WFC.getCallee(), Args: WFC.getArgData()); | 
|---|
| 482 | } | 
|---|
| 483 |  | 
|---|
| 484 | static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { | 
|---|
| 485 | return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize( | 
|---|
| 486 | OB, Arg: WFC.getCallee(), Args: WFC.getArgData()); | 
|---|
| 487 | } | 
|---|
| 488 |  | 
|---|
| 489 | static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { | 
|---|
| 490 | ExecutorAddr FnAddr; | 
|---|
| 491 | WrapperFunctionCall::ArgDataBufferType ArgData; | 
|---|
| 492 | if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, Arg&: FnAddr, Args&: ArgData)) | 
|---|
| 493 | return false; | 
|---|
| 494 | WFC = WrapperFunctionCall(FnAddr, std::move(t&: ArgData)); | 
|---|
| 495 | return true; | 
|---|
| 496 | } | 
|---|
| 497 | }; | 
|---|
| 498 |  | 
|---|
| 499 | } // namespace orc_rt | 
|---|
| 500 |  | 
|---|
| 501 | #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H | 
|---|
| 502 |  | 
|---|