1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.server.transport.socket;
20 
21 import core.thread : dur, Duration, Thread;
22 import core.stdc..string : strerror;
23 import std.array : empty;
24 import std.conv : text, to;
25 import std.exception : enforce;
26 import std.socket;
27 import thrift.base;
28 import thrift.internal.socket;
29 import thrift.server.transport.base;
30 import thrift.transport.base;
31 import thrift.transport.socket;
32 import thrift.util.awaitable;
33 import thrift.util.cancellation;
34 
35 private alias TServerTransportException STE;
36 
37 /**
38  * Server socket implementation of TServerTransport.
39  *
40  * Maps to std.socket listen()/accept(); only provides TCP/IP sockets (i.e. no
41  * Unix sockets) for now, because they are not supported in std.socket.
42  */
43 class TServerSocket : TServerTransport {
44   /**
45    * Constructs a new instance.
46    *
47    * Params:
48    *   port = The TCP port to listen at (host is always 0.0.0.0).
49    *   sendTimeout = The socket sending timeout.
50    *   recvTimout = The socket receiving timeout.
51    */
52   this(ushort port, Duration sendTimeout = dur!"hnsecs"(0),
53     Duration recvTimeout = dur!"hnsecs"(0))
54   {
55     port_ = port;
56     sendTimeout_ = sendTimeout;
57     recvTimeout_ = recvTimeout;
58 
59     cancellationNotifier_ = new TSocketNotifier;
60 
61     socketSet_ = new SocketSet;
62   }
63 
64   /// The port the server socket listens at.
65   ushort port() const @property {
66     return port_;
67   }
68 
69   /// The socket sending timeout, zero to block infinitely.
70   void sendTimeout(Duration sendTimeout) @property {
71     sendTimeout_ = sendTimeout;
72   }
73 
74   /// The socket receiving timeout, zero to block infinitely.
75   void recvTimeout(Duration recvTimeout) @property {
76     recvTimeout_ = recvTimeout;
77   }
78 
79   /// The maximum number of listening retries if it fails.
80   void retryLimit(ushort retryLimit) @property {
81     retryLimit_ = retryLimit;
82   }
83 
84   /// The delay between a listening attempt failing and retrying it.
85   void retryDelay(Duration retryDelay) @property {
86     retryDelay_ = retryDelay;
87   }
88 
89   /// The size of the TCP send buffer, in bytes.
90   void tcpSendBuffer(int tcpSendBuffer) @property {
91     tcpSendBuffer_ = tcpSendBuffer;
92   }
93 
94   /// The size of the TCP receiving buffer, in bytes.
95   void tcpRecvBuffer(int tcpRecvBuffer) @property {
96     tcpRecvBuffer_ = tcpRecvBuffer;
97   }
98 
99   /// Whether to listen on IPv6 only, if IPv6 support is detected
100   /// (default: false).
101   void ipv6Only(bool value) @property {
102     ipv6Only_ = value;
103   }
104 
105   override void listen() {
106     enforce(!isListening, new STE(STE.Type.ALREADY_LISTENING));
107 
108     serverSocket_ = makeSocketAndListen(port_, ACCEPT_BACKLOG, retryLimit_,
109       retryDelay_, tcpSendBuffer_, tcpRecvBuffer_, ipv6Only_);
110   }
111 
112   override void close() {
113     enforce(isListening, new STE(STE.Type.NOT_LISTENING));
114 
115     serverSocket_.shutdown(SocketShutdown.BOTH);
116     serverSocket_.close();
117     serverSocket_ = null;
118   }
119 
120   override bool isListening() @property {
121     return serverSocket_ !is null;
122   }
123 
124   /// Number of connections listen() backlogs.
125   enum ACCEPT_BACKLOG = 1024;
126 
127   override TTransport accept(TCancellation cancellation = null) {
128     enforce(isListening, new STE(STE.Type.NOT_LISTENING));
129 
130     if (cancellation) cancellationNotifier_.attach(cancellation.triggering);
131     scope (exit) if (cancellation) cancellationNotifier_.detach();
132 
133 
134     // Too many EINTRs is a fault condition and would need to be handled
135     // manually by our caller, but we can tolerate a certain number.
136     enum MAX_EINTRS = 10;
137     uint numEintrs;
138 
139     while (true) {
140       socketSet_.reset();
141       socketSet_.add(serverSocket_);
142       socketSet_.add(cancellationNotifier_.socket);
143 
144       auto ret = Socket.select(socketSet_, null, null);
145       enforce(ret != 0, new STE("Socket.select() returned 0.",
146         STE.Type.RESOURCE_FAILED));
147 
148       if (ret < 0) {
149         // Select itself failed, check if it was just due to an interrupted
150         // syscall.
151         if (getSocketErrno() == INTERRUPTED_ERRNO) {
152           if (numEintrs++ < MAX_EINTRS) {
153             continue;
154           } else {
155             throw new STE("Socket.select() was interrupted by a signal (EINTR) " ~
156               "more than " ~ to!string(MAX_EINTRS) ~ " times.",
157               STE.Type.RESOURCE_FAILED
158             );
159           }
160         }
161         throw new STE("Unknown error on Socket.select(): " ~
162           socketErrnoString(getSocketErrno()), STE.Type.RESOURCE_FAILED);
163       } else {
164         // Check for a ping on the interrupt socket.
165         if (socketSet_.isSet(cancellationNotifier_.socket)) {
166           cancellation.throwIfTriggered();
167         }
168 
169         // Check for the actual server socket having a connection waiting.
170         if (socketSet_.isSet(serverSocket_)) {
171           break;
172         }
173       }
174     }
175 
176     try {
177       auto client = createTSocket(serverSocket_.accept());
178       client.sendTimeout = sendTimeout_;
179       client.recvTimeout = recvTimeout_;
180       return client;
181     } catch (SocketException e) {
182       throw new STE("Unknown error on accepting: " ~ to!string(e),
183         STE.Type.RESOURCE_FAILED);
184     }
185   }
186 
187 protected:
188   /**
189    * Allows derived classes to create a different TSocket type.
190    */
191   TSocket createTSocket(Socket socket) {
192     return new TSocket(socket);
193   }
194 
195 private:
196   ushort port_;
197   Duration sendTimeout_;
198   Duration recvTimeout_;
199   ushort retryLimit_;
200   Duration retryDelay_;
201   uint tcpSendBuffer_;
202   uint tcpRecvBuffer_;
203   bool ipv6Only_;
204 
205   Socket serverSocket_;
206   TSocketNotifier cancellationNotifier_;
207 
208   // Keep socket set between accept() calls to avoid reallocating.
209   SocketSet socketSet_;
210 }
211 
212 Socket makeSocketAndListen(ushort port, int backlog, ushort retryLimit,
213   Duration retryDelay, uint tcpSendBuffer = 0, uint tcpRecvBuffer = 0,
214   bool ipv6Only = false
215 ) {
216   Address localAddr;
217   try {
218     // null represents the wildcard address.
219     auto addrInfos = getAddressInfo(null, to!string(port),
220       AddressInfoFlags.PASSIVE, SocketType.STREAM, ProtocolType.TCP);
221     foreach (i, ai; addrInfos) {
222       // Prefer to bind to IPv6 addresses, because then IPv4 is listened to as
223       // well, but not the other way round.
224       if (ai.family == AddressFamily.INET6 || i == (addrInfos.length - 1)) {
225         localAddr = ai.address;
226         break;
227       }
228     }
229   } catch (Exception e) {
230     throw new STE("Could not determine local address to listen on.",
231       STE.Type.RESOURCE_FAILED, __FILE__, __LINE__, e);
232   }
233 
234   Socket socket;
235   try {
236     socket = new Socket(localAddr.addressFamily, SocketType.STREAM,
237       ProtocolType.TCP);
238   } catch (SocketException e) {
239     throw new STE("Could not create accepting socket: " ~ to!string(e),
240       STE.Type.RESOURCE_FAILED);
241   }
242 
243   try {
244     socket.setOption(SocketOptionLevel.IPV6, SocketOption.IPV6_V6ONLY, ipv6Only);
245   } catch (SocketException e) {
246     // This is somewhat expected on older systems (e.g. pre-Vista Windows),
247     // which do not support the IPV6_V6ONLY flag yet. Racy flag just to avoid
248     // log spew in unit tests.
249     shared static warned = false;
250     if (!warned) {
251       logError("Could not set IPV6_V6ONLY socket option: %s", e);
252       warned = true;
253     }
254   }
255 
256   alias SocketOptionLevel.SOCKET lvlSock;
257 
258   // Prevent 2 maximum segement lifetime delay on accept.
259   try {
260     socket.setOption(lvlSock, SocketOption.REUSEADDR, true);
261   } catch (SocketException e) {
262     throw new STE("Could not set REUSEADDR socket option: " ~ to!string(e),
263       STE.Type.RESOURCE_FAILED);
264   }
265 
266   // Set TCP buffer sizes.
267   if (tcpSendBuffer > 0) {
268     try {
269       socket.setOption(lvlSock, SocketOption.SNDBUF, tcpSendBuffer);
270     } catch (SocketException e) {
271       throw new STE("Could not set socket send buffer size: " ~ to!string(e),
272         STE.Type.RESOURCE_FAILED);
273     }
274   }
275 
276   if (tcpRecvBuffer > 0) {
277     try {
278       socket.setOption(lvlSock, SocketOption.RCVBUF, tcpRecvBuffer);
279     } catch (SocketException e) {
280       throw new STE("Could not set receive send buffer size: " ~ to!string(e),
281         STE.Type.RESOURCE_FAILED);
282     }
283   }
284 
285   // Turn linger off to avoid blocking on socket close.
286   try {
287     Linger l;
288     l.on = 0;
289     l.time = 0;
290     socket.setOption(lvlSock, SocketOption.LINGER, l);
291   } catch (SocketException e) {
292     throw new STE("Could not disable socket linger: " ~ to!string(e),
293       STE.Type.RESOURCE_FAILED);
294   }
295 
296   // Set TCP_NODELAY.
297   try {
298     socket.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
299   } catch (SocketException e) {
300     throw new STE("Could not disable Nagle's algorithm: " ~ to!string(e),
301       STE.Type.RESOURCE_FAILED);
302   }
303 
304   ushort retries;
305   while (true) {
306     try {
307       socket.bind(localAddr);
308       break;
309     } catch (SocketException) {}
310 
311     // If bind() worked, we breaked outside the loop above.
312     retries++;
313     if (retries < retryLimit) {
314       Thread.sleep(retryDelay);
315     } else {
316       throw new STE(text("Could not bind to address: ", localAddr),
317         STE.Type.RESOURCE_FAILED);
318     }
319   }
320 
321   socket.listen(backlog);
322   return socket;
323 }
324 
325 unittest {
326   // Test interrupt().
327   {
328     auto sock = new TServerSocket(0);
329     sock.listen();
330     scope (exit) sock.close();
331 
332     auto cancellation = new TCancellationOrigin;
333 
334     auto intThread = new Thread({
335       // Sleep for a bit until the socket is accepting.
336       Thread.sleep(dur!"msecs"(50));
337       cancellation.trigger();
338     });
339     intThread.start();
340 
341     import std.exception;
342     assertThrown!TCancelledException(sock.accept(cancellation));
343   }
344 
345   // Test receive() timeout on accepted client sockets.
346   {
347     immutable port = 11122;
348     auto timeout = dur!"msecs"(500);
349     auto serverSock = new TServerSocket(port, timeout, timeout);
350     serverSock.listen();
351     scope (exit) serverSock.close();
352 
353     auto clientSock = new TSocket("127.0.0.1", port);
354     clientSock.open();
355     scope (exit) clientSock.close();
356 
357     shared bool hasTimedOut;
358     auto recvThread = new Thread({
359       auto sock = serverSock.accept();
360       ubyte[1] data;
361       try {
362         sock.read(data);
363       } catch (TTransportException e) {
364         if (e.type == TTransportException.Type.TIMED_OUT) {
365           hasTimedOut = true;
366         } else {
367           import std.stdio;
368           stderr.writeln(e);
369         }
370       }
371     });
372     recvThread.isDaemon = true;
373     recvThread.start();
374 
375     // Wait for the timeout, with a little bit of spare time.
376     Thread.sleep(timeout + dur!"msecs"(50));
377     enforce(hasTimedOut,
378       "Client socket receive() blocked for longer than recvTimeout.");
379   }
380 }