1//===------ SimpleRemoteEPCUtils.cpp - Utils for Simple Remote EPC --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Message definitions and other utilities for SimpleRemoteEPC and
10// SimpleRemoteEPCServer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.h"
15#include "llvm/Config/llvm-config.h" // for LLVM_ENABLE_THREADS
16#include "llvm/Support/Endian.h"
17
18#if !defined(_MSC_VER) && !defined(__MINGW32__)
19#include <unistd.h>
20#else
21#include <io.h>
22#endif
23#ifndef _WIN32
24#include <sys/socket.h>
25#endif
26
27namespace {
28
29struct FDMsgHeader {
30 static constexpr unsigned MsgSizeOffset = 0;
31 static constexpr unsigned OpCOffset = MsgSizeOffset + sizeof(uint64_t);
32 static constexpr unsigned SeqNoOffset = OpCOffset + sizeof(uint64_t);
33 static constexpr unsigned TagAddrOffset = SeqNoOffset + sizeof(uint64_t);
34 static constexpr unsigned Size = TagAddrOffset + sizeof(uint64_t);
35};
36
37} // namespace
38
39namespace llvm {
40namespace orc {
41namespace SimpleRemoteEPCDefaultBootstrapSymbolNames {
42
43const char *ExecutorSessionObjectName =
44 "__llvm_orc_SimpleRemoteEPC_dispatch_ctx";
45const char *DispatchFnName = "__llvm_orc_SimpleRemoteEPC_dispatch_fn";
46
47} // end namespace SimpleRemoteEPCDefaultBootstrapSymbolNames
48
49SimpleRemoteEPCTransportClient::~SimpleRemoteEPCTransportClient() = default;
50SimpleRemoteEPCTransport::~SimpleRemoteEPCTransport() = default;
51
52Expected<std::unique_ptr<FDSimpleRemoteEPCTransport>>
53FDSimpleRemoteEPCTransport::Create(SimpleRemoteEPCTransportClient &C, int InFD,
54 int OutFD) {
55#if LLVM_ENABLE_THREADS
56 if (InFD == -1)
57 return make_error<StringError>(Args: "Invalid input file descriptor " +
58 Twine(InFD),
59 Args: inconvertibleErrorCode());
60 if (OutFD == -1)
61 return make_error<StringError>(Args: "Invalid output file descriptor " +
62 Twine(OutFD),
63 Args: inconvertibleErrorCode());
64 std::unique_ptr<FDSimpleRemoteEPCTransport> FDT(
65 new FDSimpleRemoteEPCTransport(C, InFD, OutFD));
66 return std::move(FDT);
67#else
68 return make_error<StringError>("FD-based SimpleRemoteEPC transport requires "
69 "thread support, but llvm was built with "
70 "LLVM_ENABLE_THREADS=Off",
71 inconvertibleErrorCode());
72#endif
73}
74
75FDSimpleRemoteEPCTransport::~FDSimpleRemoteEPCTransport() {
76#if LLVM_ENABLE_THREADS
77 ListenerThread.join();
78#endif
79}
80
81Error FDSimpleRemoteEPCTransport::start() {
82#if LLVM_ENABLE_THREADS
83 ListenerThread = std::thread([this]() { listenLoop(); });
84 return Error::success();
85#endif
86 llvm_unreachable("Should not be called with LLVM_ENABLE_THREADS=Off");
87}
88
89Error FDSimpleRemoteEPCTransport::sendMessage(SimpleRemoteEPCOpcode OpC,
90 uint64_t SeqNo,
91 ExecutorAddr TagAddr,
92 ArrayRef<char> ArgBytes) {
93 char HeaderBuffer[FDMsgHeader::Size];
94
95 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::MsgSizeOffset)) =
96 FDMsgHeader::Size + ArgBytes.size();
97 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::OpCOffset)) =
98 static_cast<uint64_t>(OpC);
99 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::SeqNoOffset)) = SeqNo;
100 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::TagAddrOffset)) =
101 TagAddr.getValue();
102
103 std::lock_guard<std::mutex> Lock(M);
104 if (Disconnected)
105 return make_error<StringError>(Args: "FD-transport disconnected",
106 Args: inconvertibleErrorCode());
107 if (int ErrNo = writeBytes(Src: HeaderBuffer, Size: FDMsgHeader::Size))
108 return errorCodeToError(EC: std::error_code(ErrNo, std::generic_category()));
109 if (int ErrNo = writeBytes(Src: ArgBytes.data(), Size: ArgBytes.size()))
110 return errorCodeToError(EC: std::error_code(ErrNo, std::generic_category()));
111 return Error::success();
112}
113
114void FDSimpleRemoteEPCTransport::disconnect() {
115 if (Disconnected)
116 return; // Return if already disconnected.
117
118 Disconnected = true;
119 bool CloseOutFD = InFD != OutFD;
120
121#ifndef _WIN32
122 // We need to shutdown the socket to wake up (and terminate) any ongoing
123 // blocking read on this FD. If the FD is not a socket, shutdown will just
124 // complain through errno (instead of crashing).
125 // FIXME: what about Windows?
126 ::shutdown(fd: InFD, how: CloseOutFD ? SHUT_RD : SHUT_RDWR);
127#endif
128 // Close InFD.
129 while (close(fd: InFD) == -1) {
130 if (errno == EBADF)
131 break;
132 }
133
134 // Close OutFD.
135 if (CloseOutFD) {
136#ifndef _WIN32
137 // FIXME: what about Windows?
138 ::shutdown(fd: OutFD, SHUT_WR);
139#endif
140 while (close(fd: OutFD) == -1) {
141 if (errno == EBADF)
142 break;
143 }
144 }
145}
146
147static Error makeUnexpectedEOFError() {
148 return make_error<StringError>(Args: "Unexpected end-of-file",
149 Args: inconvertibleErrorCode());
150}
151
152Error FDSimpleRemoteEPCTransport::readBytes(char *Dst, size_t Size,
153 bool *IsEOF) {
154 assert((Size == 0 || Dst) && "Attempt to read into null.");
155 ssize_t Completed = 0;
156 while (Completed < static_cast<ssize_t>(Size)) {
157 ssize_t Read = ::read(fd: InFD, buf: Dst + Completed, nbytes: Size - Completed);
158 if (Read <= 0) {
159 auto ErrNo = errno;
160 if (Read == 0) {
161 if (Completed == 0 && IsEOF) {
162 *IsEOF = true;
163 return Error::success();
164 } else
165 return makeUnexpectedEOFError();
166 } else if (ErrNo == EAGAIN || ErrNo == EINTR)
167 continue;
168 else {
169 std::lock_guard<std::mutex> Lock(M);
170 if (Disconnected && IsEOF) { // disconnect called, pretend this is EOF.
171 *IsEOF = true;
172 return Error::success();
173 }
174 return errorCodeToError(
175 EC: std::error_code(ErrNo, std::generic_category()));
176 }
177 }
178 Completed += Read;
179 }
180 return Error::success();
181}
182
183int FDSimpleRemoteEPCTransport::writeBytes(const char *Src, size_t Size) {
184 assert((Size == 0 || Src) && "Attempt to append from null.");
185 ssize_t Completed = 0;
186 while (Completed < static_cast<ssize_t>(Size)) {
187 ssize_t Written = ::write(fd: OutFD, buf: Src + Completed, n: Size - Completed);
188 if (Written < 0) {
189 auto ErrNo = errno;
190 if (ErrNo == EAGAIN || ErrNo == EINTR)
191 continue;
192 else
193 return ErrNo;
194 }
195 Completed += Written;
196 }
197 return 0;
198}
199
200void FDSimpleRemoteEPCTransport::listenLoop() {
201 Error Err = Error::success();
202 do {
203
204 char HeaderBuffer[FDMsgHeader::Size];
205 // Read the header buffer.
206 {
207 bool IsEOF = false;
208 if (auto Err2 = readBytes(Dst: HeaderBuffer, Size: FDMsgHeader::Size, IsEOF: &IsEOF)) {
209 Err = joinErrors(E1: std::move(Err), E2: std::move(Err2));
210 break;
211 }
212 if (IsEOF)
213 break;
214 }
215
216 // Decode header buffer.
217 uint64_t MsgSize;
218 SimpleRemoteEPCOpcode OpC;
219 uint64_t SeqNo;
220 ExecutorAddr TagAddr;
221
222 MsgSize =
223 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::MsgSizeOffset));
224 OpC = static_cast<SimpleRemoteEPCOpcode>(static_cast<uint64_t>(
225 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::OpCOffset))));
226 SeqNo =
227 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::SeqNoOffset));
228 TagAddr.setValue(
229 *((support::ulittle64_t *)(HeaderBuffer + FDMsgHeader::TagAddrOffset)));
230
231 if (MsgSize < FDMsgHeader::Size) {
232 Err = joinErrors(E1: std::move(Err),
233 E2: make_error<StringError>(Args: "Message size too small",
234 Args: inconvertibleErrorCode()));
235 break;
236 }
237
238 // Read the argument bytes.
239 auto ArgBytes =
240 shared::WrapperFunctionBuffer::allocate(Size: MsgSize - FDMsgHeader::Size);
241 if (auto Err2 = readBytes(Dst: ArgBytes.data(), Size: ArgBytes.size())) {
242 Err = joinErrors(E1: std::move(Err), E2: std::move(Err2));
243 break;
244 }
245
246 if (auto Action =
247 C.handleMessage(OpC, SeqNo, TagAddr, ArgBytes: std::move(ArgBytes))) {
248 if (*Action == SimpleRemoteEPCTransportClient::EndSession)
249 break;
250 } else {
251 Err = joinErrors(E1: std::move(Err), E2: Action.takeError());
252 break;
253 }
254 } while (true);
255
256 // Attempt to close FDs, set Disconnected to true so that subsequent
257 // sendMessage calls fail.
258 disconnect();
259
260 // Call up to the client to handle the disconnection.
261 C.handleDisconnect(Err: std::move(Err));
262}
263
264} // end namespace orc
265} // end namespace llvm
266