1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 module thrift.transport.websocket;
20 
21 import std.algorithm;
22 import std.algorithm.searching;
23 import std.base64;
24 import std.bitmanip;
25 import std.conv;
26 import std.digest.sha;
27 import std.stdio;
28 import std..string;
29 import std.uni;
30 import thrift.base : VERSION;
31 import thrift.transport.base;
32 import thrift.transport.http;
33 
34 /**
35  * WebSocket server transport.
36  */
37 final class TServerWebSocketTransport(bool binary) : THttpTransport {
38   /**
39    * Constructs a new instance.
40    *
41    * Param:
42    *   transport = The underlying transport used for the actual I/O.
43    */
44   this(TTransport transport) {
45     super(transport);
46     transport_ = transport;
47   }
48 
49   override size_t read(ubyte[] buf) {
50     // If we do not have a good handshake, the client will attempt one.
51     if (!handshakeComplete) {
52       resetHandshake();
53       super.read(buf);
54       // If we did not get everything we expected, the handshake failed
55       // and we need to send a 400 response back.
56       if (!handshakeComplete) {
57         sendBadRequest();
58         return 0;
59       }
60       // Otherwise, send back the 101 response.
61       super.flush();
62     }
63 
64     // If the buffer is empty, read a new frame off the wire.
65     if (readBuffer_.empty) {
66       if (!readFrame()) {
67         return 0;
68       }
69     }
70 
71     auto size = min(readBuffer_.length, buf.length);
72     buf[0..size] = readBuffer_[0..size];
73     readBuffer_ = readBuffer_[size..$];
74     return size;
75   }
76 
77   override void write(in ubyte[] buf) {
78     writeBuffer_ ~= buf;
79   }
80 
81   override void flush() {
82     if (writeBuffer_.empty) {
83       return;
84     }
85 
86     // Properly reset the write buffer even some of the protocol operations go
87     // wrong.
88     scope (exit) {
89       writeBuffer_.length = 0;
90       writeBuffer_.assumeSafeAppend();
91     }
92 
93     writeFrameHeader();
94     transport_.write(writeBuffer_);
95     transport_.flush();
96   }
97 
98 protected:
99   override string getHeader(size_t dataLength) {
100     return "HTTP/1.1 101 Switching Protocols\r\n" ~
101       "Server: Thrift/" ~ VERSION ~ "\r\n" ~
102       "Upgrade: websocket\r\n" ~
103       "Connection: Upgrade\r\n" ~
104       "Sec-WebSocket-Accept: " ~ acceptKey_ ~ "\r\n" ~
105       "\r\n";
106   }
107 
108   override void parseHeader(const(ubyte)[] header) {
109     auto split = findSplit(header, [':']);
110     if (split[1].empty) {
111       // No colon found.
112       return;
113     }
114 
115     static bool compToLower(ubyte a, ubyte b) {
116       return toLower(a) == toLower(b);
117     }
118 
119     if (startsWith!compToLower(split[0], cast(ubyte[])"upgrade")) {
120       auto upgrade = stripLeft(cast(const(char)[])split[2]);
121       upgrade_ = sicmp(upgrade, "websocket") == 0;
122     } else if (startsWith!compToLower(split[0], cast(ubyte[])"connection")) {
123       auto connection = stripLeft(cast(const(char)[])split[2]);
124       connection_ = canFind(connection.toLower, "upgrade");
125     } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-key")) {
126       auto secWebSocketKey = stripLeft(cast(const(char)[])split[2]);
127       auto hash = sha1Of(secWebSocketKey ~ WEBSOCKET_GUID);
128       acceptKey_ = Base64.encode(hash);
129       secWebSocketKey_ = true;
130     } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-version")) {
131       auto secWebSocketVersion = stripLeft(cast(const(char)[])split[2]);
132       secWebSocketVersion_ = sicmp(secWebSocketVersion, "13") == 0;
133     }
134   }
135 
136   override bool parseStatusLine(const(ubyte)[] status) {
137     // Method SP Request-URI SP HTTP-Version CRLF.
138     auto split = findSplit(status, [' ']);
139     if (split[1].empty) {
140       throw new TTransportException("Bad status: " ~ to!string(status),
141         TTransportException.Type.CORRUPTED_DATA);
142     }
143 
144     auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $];
145     if (!canFind(uriVersion, ' ')) {
146       throw new TTransportException("Bad status: " ~ to!string(status),
147         TTransportException.Type.CORRUPTED_DATA);
148     }
149 
150     if (split[0] == "GET") {
151       // GET method ok, looking for content.
152       return true;
153     }
154 
155     throw new TTransportException("Bad status (unsupported method): " ~
156       to!string(status), TTransportException.Type.CORRUPTED_DATA);
157   }
158 
159 private:
160   @property bool handshakeComplete() { 
161     return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
162   }
163 
164   void failConnection(CloseCode reason) {
165     writeFrameHeader(Opcode.Close);
166     transport_.write(nativeToBigEndian!ushort(reason));
167     transport_.flush();
168     transport_.close();
169   }
170 
171   void pong() {
172     writeFrameHeader(Opcode.Pong);
173     transport_.write(readBuffer_);
174     transport_.flush();
175   }
176 
177   bool readFrame() {
178     ubyte[8] headerBuffer;
179 
180     auto read = transport_.read(headerBuffer[0..2]);
181     if (read < 2) {
182       return false;
183     }
184     // Since Thrift has its own message end marker and we read frame by frame,
185     // it doesn't really matter if the frame is marked as FIN.
186     // Capture it only for debugging only.
187     debug auto fin = (headerBuffer[0] & 0x80) != 0;
188 
189     // RSV1, RSV2, RSV3
190     if ((headerBuffer[0] & 0x70) != 0) {
191       failConnection(CloseCode.ProtocolError);
192       throw new TTransportException("Reserved bits must be zeroes", TTransportException.Type.CORRUPTED_DATA);
193     }
194 
195     Opcode opcode;
196     try {
197       opcode = to!Opcode(headerBuffer[0] & 0x0F);
198     } catch (ConvException) {
199       failConnection(CloseCode.ProtocolError);
200       throw new TTransportException("Unknown opcode", TTransportException.Type.CORRUPTED_DATA);
201     }
202 
203     // Mask
204     if ((headerBuffer[1] & 0x80) == 0) {
205       failConnection(CloseCode.ProtocolError);
206       throw new TTransportException("Messages from the client must be masked", TTransportException.Type.CORRUPTED_DATA);
207     }
208 
209     // Read the length
210     ulong payloadLength = headerBuffer[1] & 0x7F;
211     if (payloadLength == 126) {
212       read = transport_.read(headerBuffer[0..2]);
213       if (read < 2) {
214         return false;
215       }
216       payloadLength = bigEndianToNative!ushort(headerBuffer[0..2]);
217     } else if (payloadLength == 127) {
218       read = transport_.read(headerBuffer);
219       if (read < headerBuffer.length) {
220         return false;
221       }
222       payloadLength = bigEndianToNative!ulong(headerBuffer);
223       if ((payloadLength & 0x8000000000000000) != 0) {
224         failConnection(CloseCode.ProtocolError);
225         throw new TTransportException("The most significant bit of the payload length must be zero", 
226           TTransportException.Type.CORRUPTED_DATA);
227       }
228     }
229 
230     // size_t is smaller than a ulong on a 32-bit system
231     static if (size_t.max < ulong.max) {
232       if(payloadLength > size_t.max) {
233         failConnection(CloseCode.MessageTooBig);
234         return false;
235       }
236     }
237 
238     auto length = cast(size_t)payloadLength;
239 
240     if (length > 0) {
241       // Read the masking key
242       read = transport_.read(headerBuffer[0..4]);
243       if (read < 4) {
244         return false;
245       }
246 
247       readBuffer_ = new ubyte[](length);
248       read = transport_.read(readBuffer_);
249       if (read < length) {
250         return false;
251       }
252 
253       // Unmask the data
254       for (size_t i = 0; i < length; i++) {
255         readBuffer_[i] ^= headerBuffer[i % 4];
256       }
257 
258       debug writef("FIN=%d, Opcode=%X, length=%d, payload=%s\n",
259           fin,
260           opcode,
261           length,
262           binary ? readBuffer_.toHexString() : cast(string)readBuffer_);
263     }
264 
265     switch (opcode) {
266       case Opcode.Close:
267         debug {
268           if (length >= 2) {
269             CloseCode closeCode;
270             try {
271               closeCode = to!CloseCode(bigEndianToNative!ushort(readBuffer_[0..2]));
272             } catch (ConvException) {
273               closeCode = CloseCode.NoStatusCode;
274             }
275 
276             string closeReason;
277             if (length == 2) {
278               closeReason = to!string(cast(CloseCode)closeCode);
279             } else {
280               closeReason = cast(string)readBuffer_[2..$];
281             }
282 
283             writef("Connection closed: %d %s\n", closeCode, closeReason);
284           }
285         }
286         transport_.close();
287         return false;
288       case Opcode.Ping:
289         pong();
290         return readFrame();
291       default:
292         return true;
293     }
294   }
295 
296   void resetHandshake() {
297     connection_ = false;
298     secWebSocketKey_ = false;
299     secWebSocketVersion_ = false;
300     upgrade_ = false;
301   }
302 
303   void sendBadRequest() {
304     auto header = "HTTP/1.1 400 Bad Request\r\n" ~
305       "Server: Thrift/" ~ VERSION ~ "\r\n" ~
306       "\r\n";
307     transport_.write(cast(const(ubyte[]))header);
308     transport_.flush();
309     transport_.close();
310   }
311 
312   void writeFrameHeader(Opcode opcode = Opcode.Continuation) {
313     size_t headerSize = 1;
314     if (writeBuffer_.length < 126) {
315       ++headerSize;
316     } else if (writeBuffer_.length < 65536) {
317       headerSize += 3;
318     } else {
319       headerSize += 9;
320     }
321     // The server does not mask the response
322 
323     ubyte[] header = new ubyte[headerSize];
324     if (opcode == Opcode.Continuation) {
325       header[0] = binary ? Opcode.Binary : Opcode.Text;
326     }
327     else {
328       header[0] = opcode;
329     }
330     header[0] |= 0x80;
331     if (writeBuffer_.length < 126) {
332       header[1] = cast(ubyte)writeBuffer_.length;
333     } else if (writeBuffer_.length < 65536) {
334       header[1] = 126;
335       header[2..4] = nativeToBigEndian(cast(ushort)writeBuffer_.length);
336     } else {
337       header[1] = 127;
338       header[2..10] = nativeToBigEndian(cast(ulong)writeBuffer_.length);
339     }
340 
341     transport_.write(header);
342   }
343 
344   enum WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
345 
346   TTransport transport_;
347 
348   string acceptKey_;
349   bool connection_;
350   bool secWebSocketKey_;
351   bool secWebSocketVersion_;
352   bool upgrade_;
353   ubyte[] readBuffer_;
354   ubyte[] writeBuffer_;
355 }
356 
357 class TServerWebSocketTransportFactory(bool binary) : TTransportFactory {
358   override TTransport getTransport(TTransport trans) {
359     return new TServerWebSocketTransport!binary(trans);
360   }
361 }
362 
363 alias TServerBinaryWebSocketTransportFactory = TServerWebSocketTransportFactory!true;
364 alias TServerTextWebSocketTransportFactory = TServerWebSocketTransportFactory!false;
365 
366 private enum CloseCode : ushort {
367   NormalClosure = 1000,
368   GoingAway = 1001,
369   ProtocolError = 1002,
370   UnsupportedDataType = 1003,
371   NoStatusCode = 1005,
372   AbnormalClosure = 1006,
373   InvalidData = 1007,
374   PolicyViolation = 1008,
375   MessageTooBig = 1009,
376   ExtensionExpected = 1010,
377   UnexpectedError = 1011,
378   NotSecure = 1015
379 }
380 
381 private enum Opcode : ubyte {
382   Continuation = 0x0,
383   Text = 0x1,
384   Binary = 0x2,
385   Close = 0x8,
386   Ping = 0x9,
387   Pong = 0xA
388 }