1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.transport.socket;
20 
21 import core.thread : Thread;
22 import core.time : dur, Duration;
23 import std.array : empty;
24 import std.conv : text, to;
25 import std.exception : enforce;
26 import std.socket;
27 import thrift.base;
28 import thrift.transport.base;
29 import thrift.internal.socket;
30 
31 version (Windows) {
32   import core.sys.windows.winsock2 : WSAECONNRESET;
33   enum ECONNRESET = WSAECONNRESET;
34 } else version (Posix) {
35   import core.stdc.errno : ECONNRESET;
36 } else static assert(0, "Don't know ECONNRESET on this platform.");
37 
38 /**
39  * Common parts of a socket TTransport implementation, regardless of how the
40  * actual I/O is performed (sync/async).
41  */
42 abstract class TSocketBase : TBaseTransport {
43   /**
44    * Constructor that takes an already created, connected (!) socket.
45    *
46    * Params:
47    *   socket = Already created, connected socket object.
48    */
49   this(Socket socket) {
50     socket_ = socket;
51     setSocketOpts();
52   }
53 
54   /**
55    * Creates a new unconnected socket that will connect to the given host
56    * on the given port.
57    *
58    * Params:
59    *   host = Remote host.
60    *   port = Remote port.
61    */
62   this(string host, ushort port) {
63     host_ = host;
64     port_ = port;
65   }
66 
67   /**
68    * Checks whether the socket is connected.
69    */
70   override bool isOpen() @property {
71     return socket_ !is null;
72   }
73 
74   /**
75    * Writes as much data to the socket as there can be in a single OS call.
76    *
77    * Params:
78    *   buf = Data to write.
79    *
80    * Returns: The actual number of bytes written. Never more than buf.length.
81    */
82   abstract size_t writeSome(in ubyte[] buf) out (written) {
83     // DMD @@BUG@@: Enabling this e.g. fails the contract in the
84     // async_test_server, because buf.length evaluates to 0 here, even though
85     // in the method body it correctly is 27 (equal to the return value).
86     version (none) assert(written <= buf.length, text("Implementation wrote " ~
87       "more data than requested to?! (", written, " vs. ", buf.length, ")"));
88   } body {
89     assert(0, "DMD bug? – Why would contracts work for interfaces, but not " ~
90       "for abstract methods? " ~
91       "(Error: function […] in and out contracts require function body");
92   }
93 
94   /**
95    * Returns the actual address of the peer the socket is connected to.
96    *
97    * In contrast, the host and port properties contain the address used to
98    * establish the connection, and are not updated after the connection.
99    *
100    * The socket must be open when calling this.
101    */
102   Address getPeerAddress() {
103     enforce(isOpen, new TTransportException("Cannot get peer host for " ~
104       "closed socket.", TTransportException.Type.NOT_OPEN));
105 
106     if (!peerAddress_) {
107       peerAddress_ = socket_.remoteAddress();
108       assert(peerAddress_);
109     }
110 
111     return peerAddress_;
112   }
113 
114   /**
115    * The host the socket is connected to or will connect to. Null if an
116    * already connected socket was used to construct the object.
117    */
118   string host() const @property {
119     return host_;
120   }
121 
122   /**
123    * The port the socket is connected to or will connect to. Zero if an
124    * already connected socket was used to construct the object.
125    */
126   ushort port() const @property {
127     return port_;
128   }
129 
130   /// The socket send timeout.
131   Duration sendTimeout() const @property {
132     return sendTimeout_;
133   }
134 
135   /// Ditto
136   void sendTimeout(Duration value) @property {
137     sendTimeout_ = value;
138   }
139 
140   /// The socket receiving timeout. Values smaller than 500 ms are not
141   /// supported on Windows.
142   Duration recvTimeout() const @property {
143     return recvTimeout_;
144   }
145 
146   /// Ditto
147   void recvTimeout(Duration value) @property {
148     recvTimeout_ = value;
149   }
150 
151   /**
152    * Returns the OS handle of the underlying socket.
153    *
154    * Should not usually be used directly, but access to it can be necessary
155    * to interface with C libraries.
156    */
157   typeof(socket_.handle()) socketHandle() @property {
158     return socket_.handle();
159   }
160 
161 protected:
162   /**
163    * Sets the needed socket options.
164    */
165   void setSocketOpts() {
166     try {
167       alias SocketOptionLevel.SOCKET lvlSock;
168       Linger l;
169       l.on = 0;
170       l.time = 0;
171       socket_.setOption(lvlSock, SocketOption.LINGER, l);
172     } catch (SocketException e) {
173       logError("Could not set socket option: %s", e);
174     }
175 
176     // Just try to disable Nagle's algorithm – this will fail if we are passed
177     // in a non-TCP socket via the Socket-accepting constructor.
178     try {
179       socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
180     } catch (SocketException e) {}
181   }
182 
183   /// Remote host.
184   string host_;
185 
186   /// Remote port.
187   ushort port_;
188 
189   /// Timeout for sending.
190   Duration sendTimeout_;
191 
192   /// Timeout for receiving.
193   Duration recvTimeout_;
194 
195   /// Cached peer address.
196   Address peerAddress_;
197 
198   /// Cached peer host name.
199   string peerHost_;
200 
201   /// Cached peer port.
202   ushort peerPort_;
203 
204   /// Wrapped socket object.
205   Socket socket_;
206 }
207 
208 /**
209  * Socket implementation of the TTransport interface.
210  *
211  * Due to the limitations of std.socket, currently only TCP/IP sockets are
212  * supported (i.e. Unix domain sockets are not).
213  */
214 class TSocket : TSocketBase {
215   ///
216   this(Socket socket) {
217     super(socket);
218   }
219 
220   ///
221   this(string host, ushort port) {
222     super(host, port);
223   }
224 
225   /**
226    * Connects the socket.
227    */
228   override void open() {
229     if (isOpen) return;
230 
231     enforce(!host_.empty, new TTransportException(
232       "Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
233     enforce(port_ != 0, new TTransportException(
234       "Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
235 
236     Address[] addrs;
237     try {
238       addrs = getAddress(host_, port_);
239     } catch (SocketException e) {
240       throw new TTransportException("Could not resolve given host string.",
241         TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
242     }
243 
244     Exception[] errors;
245     foreach (addr; addrs) {
246       try {
247         socket_ = new TcpSocket(addr.addressFamily);
248         setSocketOpts();
249         socket_.connect(addr);
250         break;
251       } catch (SocketException e) {
252         errors ~= e;
253       }
254     }
255     if (errors.length == addrs.length) {
256       socket_ = null;
257       // Need to throw a TTransportException to abide the TTransport API.
258       import std.algorithm, std.range;
259       throw new TTransportException(
260         text("Failed to connect to ", host_, ":", port_, "."),
261         TTransportException.Type.NOT_OPEN,
262         __FILE__, __LINE__,
263         new TCompoundOperationException(
264           text(
265             "All addresses tried failed (",
266             joiner(map!q{text(a[0], `: "`, a[1].msg, `"`)}(zip(addrs, errors)), ", "),
267             ")."
268           ),
269           errors
270         )
271       );
272     }
273   }
274 
275   /**
276    * Closes the socket.
277    */
278   override void close() {
279     if (!isOpen) return;
280 
281     socket_.close();
282     socket_ = null;
283   }
284 
285   override bool peek() {
286     if (!isOpen) return false;
287 
288     ubyte buf;
289     auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
290     if (r == -1) {
291       auto lastErrno = getSocketErrno();
292       static if (connresetOnPeerShutdown) {
293         if (lastErrno == ECONNRESET) {
294           close();
295           return false;
296         }
297       }
298       throw new TTransportException("Peeking into socket failed: " ~
299         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
300     }
301     return (r > 0);
302   }
303 
304   override size_t read(ubyte[] buf) {
305     enforce(isOpen, new TTransportException(
306       "Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
307 
308     typeof(getSocketErrno()) lastErrno;
309     ushort tries;
310     while (tries++ <= maxRecvRetries_) {
311       auto r = socket_.receive(cast(void[])buf);
312 
313       // If recv went fine, immediately return.
314       if (r >= 0) return r;
315 
316       // Something went wrong, find out how to handle it.
317       lastErrno = getSocketErrno();
318 
319       if (lastErrno == INTERRUPTED_ERRNO) {
320         // If the syscall was interrupted, just try again.
321         continue;
322       }
323 
324       static if (connresetOnPeerShutdown) {
325         // See top comment.
326         if (lastErrno == ECONNRESET) {
327           return 0;
328         }
329       }
330 
331       // Not an error which is handled in a special way, just leave the loop.
332       break;
333     }
334 
335     if (isSocketCloseErrno(lastErrno)) {
336       close();
337       throw new TTransportException("Receiving failed, closing socket: " ~
338         socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
339     } else if (lastErrno == TIMEOUT_ERRNO) {
340       throw new TTransportException(TTransportException.Type.TIMED_OUT);
341     } else {
342       throw new TTransportException("Receiving from socket failed: " ~
343         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
344     }
345   }
346 
347   override void write(in ubyte[] buf) {
348     size_t sent;
349     while (sent < buf.length) {
350       auto b = writeSome(buf[sent .. $]);
351       if (b == 0) {
352         // This should only happen if the timeout set with SO_SNDTIMEO expired.
353         throw new TTransportException("send() timeout expired.",
354           TTransportException.Type.TIMED_OUT);
355       }
356       sent += b;
357     }
358     assert(sent == buf.length);
359   }
360 
361   override size_t writeSome(in ubyte[] buf) {
362     enforce(isOpen, new TTransportException(
363       "Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
364 
365     auto r = socket_.send(buf);
366 
367     // Everything went well, just return the number of bytes written.
368     if (r > 0) return r;
369 
370     // Handle error conditions.
371     if (r < 0) {
372       auto lastErrno = getSocketErrno();
373 
374       if (lastErrno == WOULD_BLOCK_ERRNO) {
375         // Not an exceptional error per se – even with blocking sockets,
376         // EAGAIN apparently is returned sometimes on out-of-resource
377         // conditions (see the C++ implementation for details). Also, this
378         // allows using TSocket with non-blocking sockets e.g. in
379         // TNonblockingServer.
380         return 0;
381       }
382 
383       auto type = TTransportException.Type.UNKNOWN;
384       if (isSocketCloseErrno(lastErrno)) {
385         type = TTransportException.Type.NOT_OPEN;
386         close();
387       }
388 
389       throw new TTransportException("Sending to socket failed: " ~
390         socketErrnoString(lastErrno), type);
391     }
392 
393     // send() should never return 0.
394     throw new TTransportException("Sending to socket failed (0 bytes written).",
395       TTransportException.Type.UNKNOWN);
396   }
397 
398   override void sendTimeout(Duration value) @property {
399     super.sendTimeout(value);
400     setTimeout(SocketOption.SNDTIMEO, value);
401   }
402 
403   override void recvTimeout(Duration value) @property {
404     super.recvTimeout(value);
405     setTimeout(SocketOption.RCVTIMEO, value);
406   }
407 
408   /**
409    * Maximum number of retries for receiving from socket on read() in case of
410    * EAGAIN/EINTR.
411    */
412   ushort maxRecvRetries() @property const {
413     return maxRecvRetries_;
414   }
415 
416   /// Ditto
417   void maxRecvRetries(ushort value) @property {
418     maxRecvRetries_ = value;
419   }
420 
421   /// Ditto
422   enum DEFAULT_MAX_RECV_RETRIES = 5;
423 
424 protected:
425   override void setSocketOpts() {
426     super.setSocketOpts();
427     setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
428     setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
429   }
430 
431   void setTimeout(SocketOption type, Duration value) {
432     assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
433     version (Win32) {
434       if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
435         logError(
436           "Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
437           (type == SocketOption.SNDTIMEO) ? "send" : "receive",
438           value.total!"msecs"
439         );
440       }
441     }
442 
443     if (socket_) {
444       try {
445         socket_.setOption(SocketOptionLevel.SOCKET, type, value);
446       } catch (SocketException e) {
447         throw new TTransportException(
448           "Could not set timeout.",
449           TTransportException.Type.UNKNOWN,
450           __FILE__,
451           __LINE__,
452           e
453         );
454       }
455     }
456   }
457 
458   /// Maximum number of recv() retries.
459   ushort maxRecvRetries_  = DEFAULT_MAX_RECV_RETRIES;
460 }