1//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file is a part of the ORC runtime support library.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14#define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16#include "error.h"
17#include "executor_address.h"
18#include "orc_rt/c_api.h"
19#include "simple_packed_serialization.h"
20#include <type_traits>
21
22namespace orc_rt {
23
24/// C++ wrapper function result: Same as orc_rt_WrapperFunctionResult but
25/// auto-releases memory.
26class WrapperFunctionResult {
27public:
28 /// Create a default WrapperFunctionResult.
29 WrapperFunctionResult() { orc_rt_WrapperFunctionResultInit(R: &R); }
30
31 /// Create a WrapperFunctionResult from a WrapperFunctionResult. This
32 /// instance takes ownership of the result object and will automatically
33 /// call dispose on the result upon destruction.
34 WrapperFunctionResult(orc_rt_WrapperFunctionResult R) : R(R) {}
35
36 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
38
39 WrapperFunctionResult(WrapperFunctionResult &&Other) {
40 orc_rt_WrapperFunctionResultInit(R: &R);
41 std::swap(a&: R, b&: Other.R);
42 }
43
44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
45 orc_rt_WrapperFunctionResult Tmp;
46 orc_rt_WrapperFunctionResultInit(R: &Tmp);
47 std::swap(a&: Tmp, b&: Other.R);
48 std::swap(a&: R, b&: Tmp);
49 return *this;
50 }
51
52 ~WrapperFunctionResult() { orc_rt_DisposeWrapperFunctionResult(R: &R); }
53
54 /// Relinquish ownership of and return the
55 /// orc_rt_WrapperFunctionResult.
56 orc_rt_WrapperFunctionResult release() {
57 orc_rt_WrapperFunctionResult Tmp;
58 orc_rt_WrapperFunctionResultInit(R: &Tmp);
59 std::swap(a&: R, b&: Tmp);
60 return Tmp;
61 }
62
63 /// Get a pointer to the data contained in this instance.
64 char *data() { return orc_rt_WrapperFunctionResultData(R: &R); }
65
66 /// Returns the size of the data contained in this instance.
67 size_t size() const { return orc_rt_WrapperFunctionResultSize(R: &R); }
68
69 /// Returns true if this value is equivalent to a default-constructed
70 /// WrapperFunctionResult.
71 bool empty() const { return orc_rt_WrapperFunctionResultEmpty(R: &R); }
72
73 /// Create a WrapperFunctionResult with the given size and return a pointer
74 /// to the underlying memory.
75 static WrapperFunctionResult allocate(size_t Size) {
76 WrapperFunctionResult R;
77 R.R = orc_rt_WrapperFunctionResultAllocate(Size);
78 return R;
79 }
80
81 /// Copy from the given char range.
82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
83 return orc_rt_CreateWrapperFunctionResultFromRange(Data: Source, Size);
84 }
85
86 /// Copy from the given null-terminated string (includes the null-terminator).
87 static WrapperFunctionResult copyFrom(const char *Source) {
88 return orc_rt_CreateWrapperFunctionResultFromString(Source);
89 }
90
91 /// Copy from the given std::string (includes the null terminator).
92 static WrapperFunctionResult copyFrom(const std::string &Source) {
93 return copyFrom(Source: Source.c_str());
94 }
95
96 /// Create an out-of-band error by copying the given string.
97 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
98 return orc_rt_CreateWrapperFunctionResultFromOutOfBandError(ErrMsg: Msg);
99 }
100
101 /// Create an out-of-band error by copying the given string.
102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
103 return createOutOfBandError(Msg: Msg.c_str());
104 }
105
106 template <typename SPSArgListT, typename... ArgTs>
107 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
108 auto Result = allocate(Size: SPSArgListT::size(Args...));
109 SPSOutputBuffer OB(Result.data(), Result.size());
110 if (!SPSArgListT::serialize(OB, Args...))
111 return createOutOfBandError(
112 Msg: "Error serializing arguments to blob in call");
113 return Result;
114 }
115
116 /// If this value is an out-of-band error then this returns the error message,
117 /// otherwise returns nullptr.
118 const char *getOutOfBandError() const {
119 return orc_rt_WrapperFunctionResultGetOutOfBandError(R: &R);
120 }
121
122private:
123 orc_rt_WrapperFunctionResult R;
124};
125
126namespace detail {
127
128template <typename RetT> class WrapperFunctionHandlerCaller {
129public:
130 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
131 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
132 std::index_sequence<I...>) {
133 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
134 }
135};
136
137template <> class WrapperFunctionHandlerCaller<void> {
138public:
139 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
140 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
141 std::index_sequence<I...>) {
142 std::forward<HandlerT>(H)(std::get<I>(Args)...);
143 return SPSEmpty();
144 }
145};
146
147template <typename WrapperFunctionImplT,
148 template <typename> class ResultSerializer, typename... SPSTagTs>
149class WrapperFunctionHandlerHelper
150 : public WrapperFunctionHandlerHelper<
151 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
152 ResultSerializer, SPSTagTs...> {};
153
154template <typename RetT, typename... ArgTs,
155 template <typename> class ResultSerializer, typename... SPSTagTs>
156class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
157 SPSTagTs...> {
158public:
159 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
160 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
161
162 template <typename HandlerT>
163 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
164 size_t ArgSize) {
165 ArgTuple Args;
166 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
167 return WrapperFunctionResult::createOutOfBandError(
168 Msg: "Could not deserialize arguments for wrapper function call");
169
170 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
171 std::forward<HandlerT>(H), Args, ArgIndices{});
172
173 return ResultSerializer<decltype(HandlerResult)>::serialize(
174 std::move(HandlerResult));
175 }
176
177private:
178 template <std::size_t... I>
179 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
180 std::index_sequence<I...>) {
181 SPSInputBuffer IB(ArgData, ArgSize);
182 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
183 }
184};
185
186// Map function pointers to function types.
187template <typename RetT, typename... ArgTs,
188 template <typename> class ResultSerializer, typename... SPSTagTs>
189class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
190 SPSTagTs...>
191 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
192 SPSTagTs...> {};
193
194// Map non-const member function types to function types.
195template <typename ClassT, typename RetT, typename... ArgTs,
196 template <typename> class ResultSerializer, typename... SPSTagTs>
197class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
198 SPSTagTs...>
199 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
200 SPSTagTs...> {};
201
202// Map const member function types to function types.
203template <typename ClassT, typename RetT, typename... ArgTs,
204 template <typename> class ResultSerializer, typename... SPSTagTs>
205class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
206 ResultSerializer, SPSTagTs...>
207 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
208 SPSTagTs...> {};
209
210template <typename SPSRetTagT, typename RetT> class ResultSerializer {
211public:
212 static WrapperFunctionResult serialize(RetT Result) {
213 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
214 }
215};
216
217template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
218public:
219 static WrapperFunctionResult serialize(Error Err) {
220 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
221 toSPSSerializable(Err: std::move(t&: Err)));
222 }
223};
224
225template <typename SPSRetTagT, typename T>
226class ResultSerializer<SPSRetTagT, Expected<T>> {
227public:
228 static WrapperFunctionResult serialize(Expected<T> E) {
229 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
230 toSPSSerializable(std::move(E)));
231 }
232};
233
234template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
235public:
236 static void makeSafe(RetT &Result) {}
237
238 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
239 SPSInputBuffer IB(ArgData, ArgSize);
240 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
241 return make_error<StringError>(
242 Args: "Error deserializing return value from blob in call");
243 return Error::success();
244 }
245};
246
247template <> class ResultDeserializer<SPSError, Error> {
248public:
249 static void makeSafe(Error &Err) { cantFail(Err: std::move(t&: Err)); }
250
251 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
252 SPSInputBuffer IB(ArgData, ArgSize);
253 SPSSerializableError BSE;
254 if (!SPSArgList<SPSError>::deserialize(IB, Arg&: BSE))
255 return make_error<StringError>(
256 Args: "Error deserializing return value from blob in call");
257 Err = fromSPSSerializable(BSE: std::move(t&: BSE));
258 return Error::success();
259 }
260};
261
262template <typename SPSTagT, typename T>
263class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
264public:
265 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
266
267 static Error deserialize(Expected<T> &E, const char *ArgData,
268 size_t ArgSize) {
269 SPSInputBuffer IB(ArgData, ArgSize);
270 SPSSerializableExpected<T> BSE;
271 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
272 return make_error<StringError>(
273 Args: "Error deserializing return value from blob in call");
274 E = fromSPSSerializable(std::move(BSE));
275 return Error::success();
276 }
277};
278
279} // end namespace detail
280
281template <typename SPSSignature> class WrapperFunction;
282
283template <typename SPSRetTagT, typename... SPSTagTs>
284class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
285private:
286 template <typename RetT>
287 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
288
289public:
290 template <typename DispatchFn, typename RetT, typename... ArgTs>
291 static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) {
292
293 // RetT might be an Error or Expected value. Set the checked flag now:
294 // we don't want the user to have to check the unused result if this
295 // operation fails.
296 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
297
298 auto ArgBuffer =
299 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
300 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
301 return make_error<StringError>(Args&: ErrMsg);
302
303 WrapperFunctionResult ResultBuffer =
304 Dispatch(ArgBuffer.data(), ArgBuffer.size());
305
306 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
307 return make_error<StringError>(Args&: ErrMsg);
308
309 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
310 Result, ResultBuffer.data(), ResultBuffer.size());
311 }
312
313 template <typename HandlerT>
314 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
315 HandlerT &&Handler) {
316 using WFHH =
317 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
318 ResultSerializer, SPSTagTs...>;
319 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
320 }
321
322private:
323 template <typename T> static const T &makeSerializable(const T &Value) {
324 return Value;
325 }
326
327 static detail::SPSSerializableError makeSerializable(Error Err) {
328 return detail::toSPSSerializable(Err: std::move(t&: Err));
329 }
330
331 template <typename T>
332 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
333 return detail::toSPSSerializable(std::move(E));
334 }
335};
336
337template <typename... SPSTagTs>
338class WrapperFunction<void(SPSTagTs...)>
339 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
340public:
341 template <typename DispatchFn, typename... ArgTs>
342 static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) {
343 SPSEmpty BE;
344 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(
345 std::forward<DispatchFn>(Dispatch), BE, Args...);
346 }
347
348 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
349};
350
351/// A function object that takes an ExecutorAddr as its first argument,
352/// casts that address to a ClassT*, then calls the given method on that
353/// pointer passing in the remaining function arguments. This utility
354/// removes some of the boilerplate from writing wrappers for method calls.
355///
356/// @code{.cpp}
357/// class MyClass {
358/// public:
359/// void myMethod(uint32_t, bool) { ... }
360/// };
361///
362/// // SPS Method signature -- note MyClass object address as first argument.
363/// using SPSMyMethodWrapperSignature =
364/// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
365///
366/// WrapperFunctionResult
367/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
368/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
369/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
370/// }
371/// @endcode
372///
373template <typename RetT, typename ClassT, typename... ArgTs>
374class MethodWrapperHandler {
375public:
376 using MethodT = RetT (ClassT::*)(ArgTs...);
377 MethodWrapperHandler(MethodT M) : M(M) {}
378 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
379 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
380 }
381
382private:
383 MethodT M;
384};
385
386/// Create a MethodWrapperHandler object from the given method pointer.
387template <typename RetT, typename ClassT, typename... ArgTs>
388MethodWrapperHandler<RetT, ClassT, ArgTs...>
389makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
390 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
391}
392
393/// Represents a call to a wrapper function.
394class WrapperFunctionCall {
395public:
396 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
397 // smallvector.
398 using ArgDataBufferType = std::vector<char>;
399
400 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
401 /// the arguments.
402 template <typename SPSSerializer, typename... ArgTs>
403 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
404 const ArgTs &...Args) {
405 ArgDataBufferType ArgData;
406 ArgData.resize(SPSSerializer::size(Args...));
407 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
408 ArgData.size());
409 if (SPSSerializer::serialize(OB, Args...))
410 return WrapperFunctionCall(FnAddr, std::move(t&: ArgData));
411 return make_error<StringError>(Args: "Cannot serialize arguments for "
412 "AllocActionCall");
413 }
414
415 WrapperFunctionCall() = default;
416
417 /// Create a WrapperFunctionCall from a target function and arg buffer.
418 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
419 : FnAddr(FnAddr), ArgData(std::move(t&: ArgData)) {}
420
421 /// Returns the address to be called.
422 const ExecutorAddr &getCallee() const { return FnAddr; }
423
424 /// Returns the argument data.
425 const ArgDataBufferType &getArgData() const { return ArgData; }
426
427 /// WrapperFunctionCalls convert to true if the callee is non-null.
428 explicit operator bool() const { return !!FnAddr; }
429
430 /// Run call returning raw WrapperFunctionResult.
431 WrapperFunctionResult run() const {
432 using FnTy =
433 orc_rt_WrapperFunctionResult(const char *ArgData, size_t ArgSize);
434 return WrapperFunctionResult(
435 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
436 }
437
438 /// Run call and deserialize result using SPS.
439 template <typename SPSRetT, typename RetT>
440 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
441 runWithSPSRet(RetT &RetVal) const {
442 auto WFR = run();
443 if (const char *ErrMsg = WFR.getOutOfBandError())
444 return make_error<StringError>(Args&: ErrMsg);
445 SPSInputBuffer IB(WFR.data(), WFR.size());
446 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
447 return make_error<StringError>(Args: "Could not deserialize result from "
448 "serialized wrapper function call");
449 return Error::success();
450 }
451
452 /// Overload for SPS functions returning void.
453 template <typename SPSRetT>
454 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
455 runWithSPSRet() const {
456 SPSEmpty E;
457 return runWithSPSRet<SPSEmpty>(RetVal&: E);
458 }
459
460 /// Run call and deserialize an SPSError result. SPSError returns and
461 /// deserialization failures are merged into the returned error.
462 Error runWithSPSRetErrorMerged() const {
463 detail::SPSSerializableError RetErr;
464 if (auto Err = runWithSPSRet<SPSError>(RetVal&: RetErr))
465 return Err;
466 return detail::fromSPSSerializable(BSE: std::move(t&: RetErr));
467 }
468
469private:
470 ExecutorAddr FnAddr;
471 std::vector<char> ArgData;
472};
473
474using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
475
476template <>
477class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
478public:
479 static size_t size(const WrapperFunctionCall &WFC) {
480 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
481 Arg: WFC.getCallee(), Args: WFC.getArgData());
482 }
483
484 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
485 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
486 OB, Arg: WFC.getCallee(), Args: WFC.getArgData());
487 }
488
489 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
490 ExecutorAddr FnAddr;
491 WrapperFunctionCall::ArgDataBufferType ArgData;
492 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, Arg&: FnAddr, Args&: ArgData))
493 return false;
494 WFC = WrapperFunctionCall(FnAddr, std::move(t&: ArgData));
495 return true;
496 }
497};
498
499} // namespace orc_rt
500
501#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
502