1//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file is a part of the ORC runtime support library.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14#define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16#include "orc_rt/c_api.h"
17#include "common.h"
18#include "error.h"
19#include "executor_address.h"
20#include "simple_packed_serialization.h"
21#include <type_traits>
22
23namespace __orc_rt {
24
25/// C++ wrapper function result: Same as CWrapperFunctionResult but
26/// auto-releases memory.
27class WrapperFunctionResult {
28public:
29 /// Create a default WrapperFunctionResult.
30 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(R: &R); }
31
32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33 /// instance takes ownership of the result object and will automatically
34 /// call dispose on the result upon destruction.
35 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
36
37 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
39
40 WrapperFunctionResult(WrapperFunctionResult &&Other) {
41 orc_rt_CWrapperFunctionResultInit(R: &R);
42 std::swap(a&: R, b&: Other.R);
43 }
44
45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46 orc_rt_CWrapperFunctionResult Tmp;
47 orc_rt_CWrapperFunctionResultInit(R: &Tmp);
48 std::swap(a&: Tmp, b&: Other.R);
49 std::swap(a&: R, b&: Tmp);
50 return *this;
51 }
52
53 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(R: &R); }
54
55 /// Relinquish ownership of and return the
56 /// orc_rt_CWrapperFunctionResult.
57 orc_rt_CWrapperFunctionResult release() {
58 orc_rt_CWrapperFunctionResult Tmp;
59 orc_rt_CWrapperFunctionResultInit(R: &Tmp);
60 std::swap(a&: R, b&: Tmp);
61 return Tmp;
62 }
63
64 /// Get a pointer to the data contained in this instance.
65 char *data() { return orc_rt_CWrapperFunctionResultData(R: &R); }
66
67 /// Returns the size of the data contained in this instance.
68 size_t size() const { return orc_rt_CWrapperFunctionResultSize(R: &R); }
69
70 /// Returns true if this value is equivalent to a default-constructed
71 /// WrapperFunctionResult.
72 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(R: &R); }
73
74 /// Create a WrapperFunctionResult with the given size and return a pointer
75 /// to the underlying memory.
76 static WrapperFunctionResult allocate(size_t Size) {
77 WrapperFunctionResult R;
78 R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
79 return R;
80 }
81
82 /// Copy from the given char range.
83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84 return orc_rt_CreateCWrapperFunctionResultFromRange(Data: Source, Size);
85 }
86
87 /// Copy from the given null-terminated string (includes the null-terminator).
88 static WrapperFunctionResult copyFrom(const char *Source) {
89 return orc_rt_CreateCWrapperFunctionResultFromString(Source);
90 }
91
92 /// Copy from the given std::string (includes the null terminator).
93 static WrapperFunctionResult copyFrom(const std::string &Source) {
94 return copyFrom(Source: Source.c_str());
95 }
96
97 /// Create an out-of-band error by copying the given string.
98 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(ErrMsg: Msg);
100 }
101
102 /// Create an out-of-band error by copying the given string.
103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104 return createOutOfBandError(Msg: Msg.c_str());
105 }
106
107 template <typename SPSArgListT, typename... ArgTs>
108 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
109 auto Result = allocate(Size: SPSArgListT::size(Args...));
110 SPSOutputBuffer OB(Result.data(), Result.size());
111 if (!SPSArgListT::serialize(OB, Args...))
112 return createOutOfBandError(
113 Msg: "Error serializing arguments to blob in call");
114 return Result;
115 }
116
117 /// If this value is an out-of-band error then this returns the error message,
118 /// otherwise returns nullptr.
119 const char *getOutOfBandError() const {
120 return orc_rt_CWrapperFunctionResultGetOutOfBandError(R: &R);
121 }
122
123private:
124 orc_rt_CWrapperFunctionResult R;
125};
126
127namespace detail {
128
129template <typename RetT> class WrapperFunctionHandlerCaller {
130public:
131 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
133 std::index_sequence<I...>) {
134 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
135 }
136};
137
138template <> class WrapperFunctionHandlerCaller<void> {
139public:
140 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
142 std::index_sequence<I...>) {
143 std::forward<HandlerT>(H)(std::get<I>(Args)...);
144 return SPSEmpty();
145 }
146};
147
148template <typename WrapperFunctionImplT,
149 template <typename> class ResultSerializer, typename... SPSTagTs>
150class WrapperFunctionHandlerHelper
151 : public WrapperFunctionHandlerHelper<
152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
153 ResultSerializer, SPSTagTs...> {};
154
155template <typename RetT, typename... ArgTs,
156 template <typename> class ResultSerializer, typename... SPSTagTs>
157class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
158 SPSTagTs...> {
159public:
160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
162
163 template <typename HandlerT>
164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
165 size_t ArgSize) {
166 ArgTuple Args;
167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
168 return WrapperFunctionResult::createOutOfBandError(
169 Msg: "Could not deserialize arguments for wrapper function call");
170
171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
172 std::forward<HandlerT>(H), Args, ArgIndices{});
173
174 return ResultSerializer<decltype(HandlerResult)>::serialize(
175 std::move(HandlerResult));
176 }
177
178private:
179 template <std::size_t... I>
180 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
181 std::index_sequence<I...>) {
182 SPSInputBuffer IB(ArgData, ArgSize);
183 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
184 }
185};
186
187// Map function pointers to function types.
188template <typename RetT, typename... ArgTs,
189 template <typename> class ResultSerializer, typename... SPSTagTs>
190class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
191 SPSTagTs...>
192 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
193 SPSTagTs...> {};
194
195// Map non-const member function types to function types.
196template <typename ClassT, typename RetT, typename... ArgTs,
197 template <typename> class ResultSerializer, typename... SPSTagTs>
198class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
199 SPSTagTs...>
200 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
201 SPSTagTs...> {};
202
203// Map const member function types to function types.
204template <typename ClassT, typename RetT, typename... ArgTs,
205 template <typename> class ResultSerializer, typename... SPSTagTs>
206class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
207 ResultSerializer, SPSTagTs...>
208 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
209 SPSTagTs...> {};
210
211template <typename SPSRetTagT, typename RetT> class ResultSerializer {
212public:
213 static WrapperFunctionResult serialize(RetT Result) {
214 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
215 }
216};
217
218template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
219public:
220 static WrapperFunctionResult serialize(Error Err) {
221 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
222 toSPSSerializable(Err: std::move(Err)));
223 }
224};
225
226template <typename SPSRetTagT, typename T>
227class ResultSerializer<SPSRetTagT, Expected<T>> {
228public:
229 static WrapperFunctionResult serialize(Expected<T> E) {
230 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
231 toSPSSerializable(std::move(E)));
232 }
233};
234
235template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
236public:
237 static void makeSafe(RetT &Result) {}
238
239 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
240 SPSInputBuffer IB(ArgData, ArgSize);
241 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
242 return make_error<StringError>(
243 Args: "Error deserializing return value from blob in call");
244 return Error::success();
245 }
246};
247
248template <> class ResultDeserializer<SPSError, Error> {
249public:
250 static void makeSafe(Error &Err) { cantFail(Err: std::move(Err)); }
251
252 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
253 SPSInputBuffer IB(ArgData, ArgSize);
254 SPSSerializableError BSE;
255 if (!SPSArgList<SPSError>::deserialize(IB, Arg&: BSE))
256 return make_error<StringError>(
257 Args: "Error deserializing return value from blob in call");
258 Err = fromSPSSerializable(BSE: std::move(BSE));
259 return Error::success();
260 }
261};
262
263template <typename SPSTagT, typename T>
264class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
265public:
266 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
267
268 static Error deserialize(Expected<T> &E, const char *ArgData,
269 size_t ArgSize) {
270 SPSInputBuffer IB(ArgData, ArgSize);
271 SPSSerializableExpected<T> BSE;
272 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
273 return make_error<StringError>(
274 Args: "Error deserializing return value from blob in call");
275 E = fromSPSSerializable(std::move(BSE));
276 return Error::success();
277 }
278};
279
280} // end namespace detail
281
282template <typename SPSSignature> class WrapperFunction;
283
284template <typename SPSRetTagT, typename... SPSTagTs>
285class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
286private:
287 template <typename RetT>
288 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
289
290public:
291 template <typename RetT, typename... ArgTs>
292 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
293
294 // RetT might be an Error or Expected value. Set the checked flag now:
295 // we don't want the user to have to check the unused result if this
296 // operation fails.
297 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
298
299 // Since the functions cannot be zero/unresolved on Windows, the following
300 // reference taking would always be non-zero, thus generating a compiler
301 // warning otherwise.
302#if !defined(_WIN32)
303 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
304 return make_error<StringError>(Args: "__orc_rt_jit_dispatch_ctx not set");
305 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
306 return make_error<StringError>(Args: "__orc_rt_jit_dispatch not set");
307#endif
308 auto ArgBuffer =
309 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
310 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
311 return make_error<StringError>(Args&: ErrMsg);
312
313 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
314 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
315 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
316 return make_error<StringError>(Args&: ErrMsg);
317
318 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
319 Result, ResultBuffer.data(), ResultBuffer.size());
320 }
321
322 template <typename HandlerT>
323 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
324 HandlerT &&Handler) {
325 using WFHH =
326 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
327 ResultSerializer, SPSTagTs...>;
328 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
329 }
330
331private:
332 template <typename T> static const T &makeSerializable(const T &Value) {
333 return Value;
334 }
335
336 static detail::SPSSerializableError makeSerializable(Error Err) {
337 return detail::toSPSSerializable(Err: std::move(Err));
338 }
339
340 template <typename T>
341 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
342 return detail::toSPSSerializable(std::move(E));
343 }
344};
345
346template <typename... SPSTagTs>
347class WrapperFunction<void(SPSTagTs...)>
348 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
349public:
350 template <typename... ArgTs>
351 static Error call(const void *FnTag, const ArgTs &...Args) {
352 SPSEmpty BE;
353 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
354 }
355
356 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
357};
358
359/// A function object that takes an ExecutorAddr as its first argument,
360/// casts that address to a ClassT*, then calls the given method on that
361/// pointer passing in the remaining function arguments. This utility
362/// removes some of the boilerplate from writing wrappers for method calls.
363///
364/// @code{.cpp}
365/// class MyClass {
366/// public:
367/// void myMethod(uint32_t, bool) { ... }
368/// };
369///
370/// // SPS Method signature -- note MyClass object address as first argument.
371/// using SPSMyMethodWrapperSignature =
372/// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
373///
374/// WrapperFunctionResult
375/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
376/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
377/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
378/// }
379/// @endcode
380///
381template <typename RetT, typename ClassT, typename... ArgTs>
382class MethodWrapperHandler {
383public:
384 using MethodT = RetT (ClassT::*)(ArgTs...);
385 MethodWrapperHandler(MethodT M) : M(M) {}
386 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
387 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
388 }
389
390private:
391 MethodT M;
392};
393
394/// Create a MethodWrapperHandler object from the given method pointer.
395template <typename RetT, typename ClassT, typename... ArgTs>
396MethodWrapperHandler<RetT, ClassT, ArgTs...>
397makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
398 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
399}
400
401/// Represents a call to a wrapper function.
402class WrapperFunctionCall {
403public:
404 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
405 // smallvector.
406 using ArgDataBufferType = std::vector<char>;
407
408 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
409 /// the arguments.
410 template <typename SPSSerializer, typename... ArgTs>
411 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
412 const ArgTs &...Args) {
413 ArgDataBufferType ArgData;
414 ArgData.resize(SPSSerializer::size(Args...));
415 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
416 ArgData.size());
417 if (SPSSerializer::serialize(OB, Args...))
418 return WrapperFunctionCall(FnAddr, std::move(ArgData));
419 return make_error<StringError>(Args: "Cannot serialize arguments for "
420 "AllocActionCall");
421 }
422
423 WrapperFunctionCall() = default;
424
425 /// Create a WrapperFunctionCall from a target function and arg buffer.
426 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
427 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
428
429 /// Returns the address to be called.
430 const ExecutorAddr &getCallee() const { return FnAddr; }
431
432 /// Returns the argument data.
433 const ArgDataBufferType &getArgData() const { return ArgData; }
434
435 /// WrapperFunctionCalls convert to true if the callee is non-null.
436 explicit operator bool() const { return !!FnAddr; }
437
438 /// Run call returning raw WrapperFunctionResult.
439 WrapperFunctionResult run() const {
440 using FnTy =
441 orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
442 return WrapperFunctionResult(
443 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
444 }
445
446 /// Run call and deserialize result using SPS.
447 template <typename SPSRetT, typename RetT>
448 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
449 runWithSPSRet(RetT &RetVal) const {
450 auto WFR = run();
451 if (const char *ErrMsg = WFR.getOutOfBandError())
452 return make_error<StringError>(Args&: ErrMsg);
453 SPSInputBuffer IB(WFR.data(), WFR.size());
454 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
455 return make_error<StringError>(Args: "Could not deserialize result from "
456 "serialized wrapper function call");
457 return Error::success();
458 }
459
460 /// Overload for SPS functions returning void.
461 template <typename SPSRetT>
462 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
463 runWithSPSRet() const {
464 SPSEmpty E;
465 return runWithSPSRet<SPSEmpty>(RetVal&: E);
466 }
467
468 /// Run call and deserialize an SPSError result. SPSError returns and
469 /// deserialization failures are merged into the returned error.
470 Error runWithSPSRetErrorMerged() const {
471 detail::SPSSerializableError RetErr;
472 if (auto Err = runWithSPSRet<SPSError>(RetVal&: RetErr))
473 return Err;
474 return detail::fromSPSSerializable(BSE: std::move(RetErr));
475 }
476
477private:
478 ExecutorAddr FnAddr;
479 std::vector<char> ArgData;
480};
481
482using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
483
484template <>
485class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
486public:
487 static size_t size(const WrapperFunctionCall &WFC) {
488 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
489 Arg: WFC.getCallee(), Args: WFC.getArgData());
490 }
491
492 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
493 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
494 OB, Arg: WFC.getCallee(), Args: WFC.getArgData());
495 }
496
497 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
498 ExecutorAddr FnAddr;
499 WrapperFunctionCall::ArgDataBufferType ArgData;
500 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, Arg&: FnAddr, Args&: ArgData))
501 return false;
502 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
503 return true;
504 }
505};
506
507} // end namespace __orc_rt
508
509#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
510