1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.transport.socket;
20 
21 import core.stdc.errno: ECONNRESET;
22 import core.thread : Thread;
23 import core.time : dur, Duration;
24 import std.array : empty;
25 import std.conv : text, to;
26 import std.exception : enforce;
27 import std.socket;
28 import thrift.base;
29 import thrift.transport.base;
30 import thrift.internal.socket;
31 
32 /**
33  * Common parts of a socket TTransport implementation, regardless of how the
34  * actual I/O is performed (sync/async).
35  */
36 abstract class TSocketBase : TBaseTransport {
37   /**
38    * Constructor that takes an already created, connected (!) socket.
39    *
40    * Params:
41    *   socket = Already created, connected socket object.
42    */
43   this(Socket socket) {
44     socket_ = socket;
45     setSocketOpts();
46   }
47 
48   /**
49    * Creates a new unconnected socket that will connect to the given host
50    * on the given port.
51    *
52    * Params:
53    *   host = Remote host.
54    *   port = Remote port.
55    */
56   this(string host, ushort port) {
57     host_ = host;
58     port_ = port;
59   }
60 
61   /**
62    * Checks whether the socket is connected.
63    */
64   override bool isOpen() @property {
65     return socket_ !is null;
66   }
67 
68   /**
69    * Writes as much data to the socket as there can be in a single OS call.
70    *
71    * Params:
72    *   buf = Data to write.
73    *
74    * Returns: The actual number of bytes written. Never more than buf.length.
75    */
76   abstract size_t writeSome(in ubyte[] buf) out (written) {
77     // DMD @@BUG@@: Enabling this e.g. fails the contract in the
78     // async_test_server, because buf.length evaluates to 0 here, even though
79     // in the method body it correctly is 27 (equal to the return value).
80     version (none) assert(written <= buf.length, text("Implementation wrote " ~
81       "more data than requested to?! (", written, " vs. ", buf.length, ")"));
82   } body {
83     assert(0, "DMD bug? – Why would contracts work for interfaces, but not " ~
84       "for abstract methods? " ~
85       "(Error: function […] in and out contracts require function body");
86   }
87 
88   /**
89    * Returns the actual address of the peer the socket is connected to.
90    *
91    * In contrast, the host and port properties contain the address used to
92    * establish the connection, and are not updated after the connection.
93    *
94    * The socket must be open when calling this.
95    */
96   Address getPeerAddress() {
97     enforce(isOpen, new TTransportException("Cannot get peer host for " ~
98       "closed socket.", TTransportException.Type.NOT_OPEN));
99 
100     if (!peerAddress_) {
101       peerAddress_ = socket_.remoteAddress();
102       assert(peerAddress_);
103     }
104 
105     return peerAddress_;
106   }
107 
108   /**
109    * The host the socket is connected to or will connect to. Null if an
110    * already connected socket was used to construct the object.
111    */
112   string host() const @property {
113     return host_;
114   }
115 
116   /**
117    * The port the socket is connected to or will connect to. Zero if an
118    * already connected socket was used to construct the object.
119    */
120   ushort port() const @property {
121     return port_;
122   }
123 
124   /// The socket send timeout.
125   Duration sendTimeout() const @property {
126     return sendTimeout_;
127   }
128 
129   /// Ditto
130   void sendTimeout(Duration value) @property {
131     sendTimeout_ = value;
132   }
133 
134   /// The socket receiving timeout. Values smaller than 500 ms are not
135   /// supported on Windows.
136   Duration recvTimeout() const @property {
137     return recvTimeout_;
138   }
139 
140   /// Ditto
141   void recvTimeout(Duration value) @property {
142     recvTimeout_ = value;
143   }
144 
145   /**
146    * Returns the OS handle of the underlying socket.
147    *
148    * Should not usually be used directly, but access to it can be necessary
149    * to interface with C libraries.
150    */
151   typeof(socket_.handle()) socketHandle() @property {
152     return socket_.handle();
153   }
154 
155 protected:
156   /**
157    * Sets the needed socket options.
158    */
159   void setSocketOpts() {
160     try {
161       alias SocketOptionLevel.SOCKET lvlSock;
162       Linger l;
163       l.on = 0;
164       l.time = 0;
165       socket_.setOption(lvlSock, SocketOption.LINGER, l);
166     } catch (SocketException e) {
167       logError("Could not set socket option: %s", e);
168     }
169 
170     // Just try to disable Nagle's algorithm – this will fail if we are passed
171     // in a non-TCP socket via the Socket-accepting constructor.
172     try {
173       socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
174     } catch (SocketException e) {}
175   }
176 
177   /// Remote host.
178   string host_;
179 
180   /// Remote port.
181   ushort port_;
182 
183   /// Timeout for sending.
184   Duration sendTimeout_;
185 
186   /// Timeout for receiving.
187   Duration recvTimeout_;
188 
189   /// Cached peer address.
190   Address peerAddress_;
191 
192   /// Cached peer host name.
193   string peerHost_;
194 
195   /// Cached peer port.
196   ushort peerPort_;
197 
198   /// Wrapped socket object.
199   Socket socket_;
200 }
201 
202 /**
203  * Socket implementation of the TTransport interface.
204  *
205  * Due to the limitations of std.socket, currently only TCP/IP sockets are
206  * supported (i.e. Unix domain sockets are not).
207  */
208 class TSocket : TSocketBase {
209   ///
210   this(Socket socket) {
211     super(socket);
212   }
213 
214   ///
215   this(string host, ushort port) {
216     super(host, port);
217   }
218 
219   /**
220    * Connects the socket.
221    */
222   override void open() {
223     if (isOpen) return;
224 
225     enforce(!host_.empty, new TTransportException(
226       "Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
227     enforce(port_ != 0, new TTransportException(
228       "Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
229 
230     Address[] addrs;
231     try {
232       addrs = getAddress(host_, port_);
233     } catch (SocketException e) {
234       throw new TTransportException("Could not resolve given host string.",
235         TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
236     }
237 
238     Exception[] errors;
239     foreach (addr; addrs) {
240       try {
241         socket_ = new TcpSocket(addr.addressFamily);
242         setSocketOpts();
243         socket_.connect(addr);
244         break;
245       } catch (SocketException e) {
246         errors ~= e;
247       }
248     }
249     if (errors.length == addrs.length) {
250       socket_ = null;
251       // Need to throw a TTransportException to abide the TTransport API.
252       import std.algorithm, std.range;
253       throw new TTransportException(
254         text("Failed to connect to ", host_, ":", port_, "."),
255         TTransportException.Type.NOT_OPEN,
256         __FILE__, __LINE__,
257         new TCompoundOperationException(
258           text(
259             "All addresses tried failed (",
260             joiner(map!q{text(a[0], `: "`, a[1].msg, `"`)}(zip(addrs, errors)), ", "),
261             ")."
262           ),
263           errors
264         )
265       );
266     }
267   }
268 
269   /**
270    * Closes the socket.
271    */
272   override void close() {
273     if (!isOpen) return;
274 
275     socket_.close();
276     socket_ = null;
277   }
278 
279   override bool peek() {
280     if (!isOpen) return false;
281 
282     ubyte buf;
283     auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
284     if (r == -1) {
285       auto lastErrno = getSocketErrno();
286       static if (connresetOnPeerShutdown) {
287         if (lastErrno == ECONNRESET) {
288           close();
289           return false;
290         }
291       }
292       throw new TTransportException("Peeking into socket failed: " ~
293         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
294     }
295     return (r > 0);
296   }
297 
298   override size_t read(ubyte[] buf) {
299     enforce(isOpen, new TTransportException(
300       "Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
301 
302     typeof(getSocketErrno()) lastErrno;
303     ushort tries;
304     while (tries++ <= maxRecvRetries_) {
305       auto r = socket_.receive(cast(void[])buf);
306 
307       // If recv went fine, immediately return.
308       if (r >= 0) return r;
309 
310       // Something went wrong, find out how to handle it.
311       lastErrno = getSocketErrno();
312 
313       if (lastErrno == INTERRUPTED_ERRNO) {
314         // If the syscall was interrupted, just try again.
315         continue;
316       }
317 
318       static if (connresetOnPeerShutdown) {
319         // See top comment.
320         if (lastErrno == ECONNRESET) {
321           return 0;
322         }
323       }
324 
325       // Not an error which is handled in a special way, just leave the loop.
326       break;
327     }
328 
329     if (isSocketCloseErrno(lastErrno)) {
330       close();
331       throw new TTransportException("Receiving failed, closing socket: " ~
332         socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
333     } else if (lastErrno == TIMEOUT_ERRNO) {
334       throw new TTransportException(TTransportException.Type.TIMED_OUT);
335     } else {
336       throw new TTransportException("Receiving from socket failed: " ~
337         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
338     }
339   }
340 
341   override void write(in ubyte[] buf) {
342     size_t sent;
343     while (sent < buf.length) {
344       auto b = writeSome(buf[sent .. $]);
345       if (b == 0) {
346         // This should only happen if the timeout set with SO_SNDTIMEO expired.
347         throw new TTransportException("send() timeout expired.",
348           TTransportException.Type.TIMED_OUT);
349       }
350       sent += b;
351     }
352     assert(sent == buf.length);
353   }
354 
355   override size_t writeSome(in ubyte[] buf) {
356     enforce(isOpen, new TTransportException(
357       "Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
358 
359     auto r = socket_.send(buf);
360 
361     // Everything went well, just return the number of bytes written.
362     if (r > 0) return r;
363 
364     // Handle error conditions.
365     if (r < 0) {
366       auto lastErrno = getSocketErrno();
367 
368       if (lastErrno == WOULD_BLOCK_ERRNO) {
369         // Not an exceptional error per se – even with blocking sockets,
370         // EAGAIN apparently is returned sometimes on out-of-resource
371         // conditions (see the C++ implementation for details). Also, this
372         // allows using TSocket with non-blocking sockets e.g. in
373         // TNonblockingServer.
374         return 0;
375       }
376 
377       auto type = TTransportException.Type.UNKNOWN;
378       if (isSocketCloseErrno(lastErrno)) {
379         type = TTransportException.Type.NOT_OPEN;
380         close();
381       }
382 
383       throw new TTransportException("Sending to socket failed: " ~
384         socketErrnoString(lastErrno), type);
385     }
386 
387     // send() should never return 0.
388     throw new TTransportException("Sending to socket failed (0 bytes written).",
389       TTransportException.Type.UNKNOWN);
390   }
391 
392   override void sendTimeout(Duration value) @property {
393     super.sendTimeout(value);
394     setTimeout(SocketOption.SNDTIMEO, value);
395   }
396 
397   override void recvTimeout(Duration value) @property {
398     super.recvTimeout(value);
399     setTimeout(SocketOption.RCVTIMEO, value);
400   }
401 
402   /**
403    * Maximum number of retries for receiving from socket on read() in case of
404    * EAGAIN/EINTR.
405    */
406   ushort maxRecvRetries() @property const {
407     return maxRecvRetries_;
408   }
409 
410   /// Ditto
411   void maxRecvRetries(ushort value) @property {
412     maxRecvRetries_ = value;
413   }
414 
415   /// Ditto
416   enum DEFAULT_MAX_RECV_RETRIES = 5;
417 
418 protected:
419   override void setSocketOpts() {
420     super.setSocketOpts();
421     setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
422     setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
423   }
424 
425   void setTimeout(SocketOption type, Duration value) {
426     assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
427     version (Win32) {
428       if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
429         logError(
430           "Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
431           (type == SocketOption.SNDTIMEO) ? "send" : "receive",
432           value.total!"msecs"
433         );
434       }
435     }
436 
437     if (socket_) {
438       try {
439         socket_.setOption(SocketOptionLevel.SOCKET, type, value);
440       } catch (SocketException e) {
441         throw new TTransportException(
442           "Could not set timeout.",
443           TTransportException.Type.UNKNOWN,
444           __FILE__,
445           __LINE__,
446           e
447         );
448       }
449     }
450   }
451 
452   /// Maximum number of recv() retries.
453   ushort maxRecvRetries_  = DEFAULT_MAX_RECV_RETRIES;
454 }