1 | //===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains raw_ostream implementations for streams to communicate |
10 | // via UNIX sockets |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/raw_socket_stream.h" |
15 | #include "llvm/Config/config.h" |
16 | #include "llvm/Support/Error.h" |
17 | #include "llvm/Support/FileSystem.h" |
18 | |
19 | #include <atomic> |
20 | #include <fcntl.h> |
21 | #include <functional> |
22 | #include <thread> |
23 | |
24 | #ifndef _WIN32 |
25 | #include <poll.h> |
26 | #include <sys/socket.h> |
27 | #include <sys/un.h> |
28 | #else |
29 | #include "llvm/Support/Windows/WindowsSupport.h" |
30 | // winsock2.h must be included before afunix.h. Briefly turn off clang-format to |
31 | // avoid error. |
32 | // clang-format off |
33 | #include <winsock2.h> |
34 | #include <afunix.h> |
35 | // clang-format on |
36 | #include <io.h> |
37 | #endif // _WIN32 |
38 | |
39 | #if defined(HAVE_UNISTD_H) |
40 | #include <unistd.h> |
41 | #endif |
42 | |
43 | using namespace llvm; |
44 | |
45 | #ifdef _WIN32 |
46 | WSABalancer::WSABalancer() { |
47 | WSADATA WsaData; |
48 | ::memset(&WsaData, 0, sizeof(WsaData)); |
49 | if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) { |
50 | llvm::report_fatal_error("WSAStartup failed" ); |
51 | } |
52 | } |
53 | |
54 | WSABalancer::~WSABalancer() { WSACleanup(); } |
55 | #endif // _WIN32 |
56 | |
57 | static std::error_code getLastSocketErrorCode() { |
58 | #ifdef _WIN32 |
59 | return std::error_code(::WSAGetLastError(), std::system_category()); |
60 | #else |
61 | return errnoAsErrorCode(); |
62 | #endif |
63 | } |
64 | |
65 | static sockaddr_un setSocketAddr(StringRef SocketPath) { |
66 | struct sockaddr_un Addr; |
67 | memset(s: &Addr, c: 0, n: sizeof(Addr)); |
68 | Addr.sun_family = AF_UNIX; |
69 | strncpy(dest: Addr.sun_path, src: SocketPath.str().c_str(), n: sizeof(Addr.sun_path) - 1); |
70 | return Addr; |
71 | } |
72 | |
73 | static Expected<int> getSocketFD(StringRef SocketPath) { |
74 | #ifdef _WIN32 |
75 | SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0); |
76 | if (Socket == INVALID_SOCKET) { |
77 | #else |
78 | int Socket = socket(AF_UNIX, SOCK_STREAM, protocol: 0); |
79 | if (Socket == -1) { |
80 | #endif // _WIN32 |
81 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
82 | Args: "Create socket failed" ); |
83 | } |
84 | |
85 | struct sockaddr_un Addr = setSocketAddr(SocketPath); |
86 | if (::connect(fd: Socket, addr: (struct sockaddr *)&Addr, len: sizeof(Addr)) == -1) |
87 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
88 | Args: "Connect socket failed" ); |
89 | |
90 | #ifdef _WIN32 |
91 | return _open_osfhandle(Socket, 0); |
92 | #else |
93 | return Socket; |
94 | #endif // _WIN32 |
95 | } |
96 | |
97 | ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath, |
98 | int PipeFD[2]) |
99 | : FD(SocketFD), SocketPath(SocketPath), PipeFD{PipeFD[0], PipeFD[1]} {} |
100 | |
101 | ListeningSocket::ListeningSocket(ListeningSocket &&LS) |
102 | : FD(LS.FD.load()), SocketPath(LS.SocketPath), |
103 | PipeFD{LS.PipeFD[0], LS.PipeFD[1]} { |
104 | |
105 | LS.FD = -1; |
106 | LS.SocketPath.clear(); |
107 | LS.PipeFD[0] = -1; |
108 | LS.PipeFD[1] = -1; |
109 | } |
110 | |
111 | Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath, |
112 | int MaxBacklog) { |
113 | |
114 | // Handle instances where the target socket address already exists and |
115 | // differentiate between a preexisting file with and without a bound socket |
116 | // |
117 | // ::bind will return std::errc:address_in_use if a file at the socket address |
118 | // already exists (e.g., the file was not properly unlinked due to a crash) |
119 | // even if another socket has not yet binded to that address |
120 | if (llvm::sys::fs::exists(Path: SocketPath)) { |
121 | Expected<int> MaybeFD = getSocketFD(SocketPath); |
122 | if (!MaybeFD) { |
123 | |
124 | // Regardless of the error, notify the caller that a file already exists |
125 | // at the desired socket address and that there is no bound socket at that |
126 | // address. The file must be removed before ::bind can use the address |
127 | consumeError(Err: MaybeFD.takeError()); |
128 | return llvm::make_error<StringError>( |
129 | Args: std::make_error_code(e: std::errc::file_exists), |
130 | Args: "Socket address unavailable" ); |
131 | } |
132 | ::close(fd: std::move(*MaybeFD)); |
133 | |
134 | // Notify caller that the provided socket address already has a bound socket |
135 | return llvm::make_error<StringError>( |
136 | Args: std::make_error_code(e: std::errc::address_in_use), |
137 | Args: "Socket address unavailable" ); |
138 | } |
139 | |
140 | #ifdef _WIN32 |
141 | WSABalancer _; |
142 | SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0); |
143 | if (Socket == INVALID_SOCKET) |
144 | #else |
145 | int Socket = socket(AF_UNIX, SOCK_STREAM, protocol: 0); |
146 | if (Socket == -1) |
147 | #endif |
148 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
149 | Args: "socket create failed" ); |
150 | |
151 | struct sockaddr_un Addr = setSocketAddr(SocketPath); |
152 | if (::bind(fd: Socket, addr: (struct sockaddr *)&Addr, len: sizeof(Addr)) == -1) { |
153 | // Grab error code from call to ::bind before calling ::close |
154 | std::error_code EC = getLastSocketErrorCode(); |
155 | ::close(fd: Socket); |
156 | return llvm::make_error<StringError>(Args&: EC, Args: "Bind error" ); |
157 | } |
158 | |
159 | // Mark socket as passive so incoming connections can be accepted |
160 | if (::listen(fd: Socket, n: MaxBacklog) == -1) |
161 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
162 | Args: "Listen error" ); |
163 | |
164 | int PipeFD[2]; |
165 | #ifdef _WIN32 |
166 | // Reserve 1 byte for the pipe and use default textmode |
167 | if (::_pipe(PipeFD, 1, 0) == -1) |
168 | #else |
169 | if (::pipe(pipedes: PipeFD) == -1) |
170 | #endif // _WIN32 |
171 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
172 | Args: "pipe failed" ); |
173 | |
174 | #ifdef _WIN32 |
175 | return ListeningSocket{_open_osfhandle(Socket, 0), SocketPath, PipeFD}; |
176 | #else |
177 | return ListeningSocket{Socket, SocketPath, PipeFD}; |
178 | #endif // _WIN32 |
179 | } |
180 | |
181 | // If a file descriptor being monitored by ::poll is closed by another thread, |
182 | // the result is unspecified. In the case ::poll does not unblock and return, |
183 | // when ActiveFD is closed, you can provide another file descriptor via CancelFD |
184 | // that when written to will cause poll to return. Typically CancelFD is the |
185 | // read end of a unidirectional pipe. |
186 | // |
187 | // Timeout should be -1 to block indefinitly |
188 | // |
189 | // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int |
190 | static std::error_code |
191 | manageTimeout(const std::chrono::milliseconds &Timeout, |
192 | const std::function<int()> &getActiveFD, |
193 | const std::optional<int> &CancelFD = std::nullopt) { |
194 | struct pollfd FD[2]; |
195 | FD[0].events = POLLIN; |
196 | #ifdef _WIN32 |
197 | SOCKET WinServerSock = _get_osfhandle(getActiveFD()); |
198 | FD[0].fd = WinServerSock; |
199 | #else |
200 | FD[0].fd = getActiveFD(); |
201 | #endif |
202 | uint8_t FDCount = 1; |
203 | if (CancelFD.has_value()) { |
204 | FD[1].events = POLLIN; |
205 | FD[1].fd = CancelFD.value(); |
206 | FDCount++; |
207 | } |
208 | |
209 | // Keep track of how much time has passed in case ::poll or WSAPoll are |
210 | // interupted by a signal and need to be recalled |
211 | auto Start = std::chrono::steady_clock::now(); |
212 | auto RemainingTimeout = Timeout; |
213 | int PollStatus = 0; |
214 | do { |
215 | // If Timeout is -1 then poll should block and RemainingTimeout does not |
216 | // need to be recalculated |
217 | if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) { |
218 | auto TotalElapsedTime = |
219 | std::chrono::duration_cast<std::chrono::milliseconds>( |
220 | d: std::chrono::steady_clock::now() - Start); |
221 | |
222 | if (TotalElapsedTime >= Timeout) |
223 | return std::make_error_code(e: std::errc::operation_would_block); |
224 | |
225 | RemainingTimeout = Timeout - TotalElapsedTime; |
226 | } |
227 | #ifdef _WIN32 |
228 | PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count()); |
229 | } while (PollStatus == SOCKET_ERROR && |
230 | getLastSocketErrorCode() == std::errc::interrupted); |
231 | #else |
232 | PollStatus = ::poll(fds: FD, nfds: FDCount, timeout: RemainingTimeout.count()); |
233 | } while (PollStatus == -1 && |
234 | getLastSocketErrorCode() == std::errc::interrupted); |
235 | #endif |
236 | |
237 | // If ActiveFD equals -1 or CancelFD has data to be read then the operation |
238 | // has been canceled by another thread |
239 | if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN)) |
240 | return std::make_error_code(e: std::errc::operation_canceled); |
241 | #if _WIN32 |
242 | if (PollStatus == SOCKET_ERROR) |
243 | #else |
244 | if (PollStatus == -1) |
245 | #endif |
246 | return getLastSocketErrorCode(); |
247 | if (PollStatus == 0) |
248 | return std::make_error_code(e: std::errc::timed_out); |
249 | if (FD[0].revents & POLLNVAL) |
250 | return std::make_error_code(e: std::errc::bad_file_descriptor); |
251 | return std::error_code(); |
252 | } |
253 | |
254 | Expected<std::unique_ptr<raw_socket_stream>> |
255 | ListeningSocket::accept(const std::chrono::milliseconds &Timeout) { |
256 | auto getActiveFD = [this]() -> int { return FD; }; |
257 | std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, CancelFD: PipeFD[0]); |
258 | if (TimeoutErr) |
259 | return llvm::make_error<StringError>(Args&: TimeoutErr, Args: "Timeout error" ); |
260 | |
261 | int AcceptFD; |
262 | #ifdef _WIN32 |
263 | SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL); |
264 | AcceptFD = _open_osfhandle(WinAcceptSock, 0); |
265 | #else |
266 | AcceptFD = ::accept(fd: FD, NULL, NULL); |
267 | #endif |
268 | |
269 | if (AcceptFD == -1) |
270 | return llvm::make_error<StringError>(Args: getLastSocketErrorCode(), |
271 | Args: "Socket accept failed" ); |
272 | return std::make_unique<raw_socket_stream>(args&: AcceptFD); |
273 | } |
274 | |
275 | void ListeningSocket::shutdown() { |
276 | int ObservedFD = FD.load(); |
277 | |
278 | if (ObservedFD == -1) |
279 | return; |
280 | |
281 | // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then |
282 | // another thread is responsible for shutdown so return |
283 | if (!FD.compare_exchange_strong(i1&: ObservedFD, i2: -1)) |
284 | return; |
285 | |
286 | ::close(fd: ObservedFD); |
287 | ::unlink(name: SocketPath.c_str()); |
288 | |
289 | // Ensure ::poll returns if shutdown is called by a separate thread |
290 | char Byte = 'A'; |
291 | ssize_t written = ::write(fd: PipeFD[1], buf: &Byte, n: 1); |
292 | |
293 | // Ignore any write() error |
294 | (void)written; |
295 | } |
296 | |
297 | ListeningSocket::~ListeningSocket() { |
298 | shutdown(); |
299 | |
300 | // Close the pipe's FDs in the destructor instead of within |
301 | // ListeningSocket::shutdown to avoid unnecessary synchronization issues that |
302 | // would occur as PipeFD's values would have to be changed to -1 |
303 | // |
304 | // The move constructor sets PipeFD to -1 |
305 | if (PipeFD[0] != -1) |
306 | ::close(fd: PipeFD[0]); |
307 | if (PipeFD[1] != -1) |
308 | ::close(fd: PipeFD[1]); |
309 | } |
310 | |
311 | //===----------------------------------------------------------------------===// |
312 | // raw_socket_stream |
313 | //===----------------------------------------------------------------------===// |
314 | |
315 | raw_socket_stream::raw_socket_stream(int SocketFD) |
316 | : raw_fd_stream(SocketFD, true) {} |
317 | |
318 | raw_socket_stream::~raw_socket_stream() {} |
319 | |
320 | Expected<std::unique_ptr<raw_socket_stream>> |
321 | raw_socket_stream::createConnectedUnix(StringRef SocketPath) { |
322 | #ifdef _WIN32 |
323 | WSABalancer _; |
324 | #endif // _WIN32 |
325 | Expected<int> FD = getSocketFD(SocketPath); |
326 | if (!FD) |
327 | return FD.takeError(); |
328 | return std::make_unique<raw_socket_stream>(args&: *FD); |
329 | } |
330 | |
331 | ssize_t raw_socket_stream::read(char *Ptr, size_t Size, |
332 | const std::chrono::milliseconds &Timeout) { |
333 | auto getActiveFD = [this]() -> int { return this->get_fd(); }; |
334 | std::error_code Err = manageTimeout(Timeout, getActiveFD); |
335 | // Mimic raw_fd_stream::read error handling behavior |
336 | if (Err) { |
337 | raw_fd_stream::error_detected(EC: Err); |
338 | return -1; |
339 | } |
340 | return raw_fd_stream::read(Ptr, Size); |
341 | } |
342 | |