1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 module thrift.transport.websocket; 20 21 import std.algorithm; 22 import std.algorithm.searching; 23 import std.base64; 24 import std.bitmanip; 25 import std.conv; 26 import std.digest.sha; 27 import std.stdio; 28 import std..string; 29 import std.uni; 30 import thrift.base : VERSION; 31 import thrift.transport.base; 32 import thrift.transport.http; 33 34 /** 35 * WebSocket server transport. 36 */ 37 final class TServerWebSocketTransport(bool binary) : THttpTransport { 38 /** 39 * Constructs a new instance. 40 * 41 * Param: 42 * transport = The underlying transport used for the actual I/O. 43 */ 44 this(TTransport transport) { 45 super(transport); 46 transport_ = transport; 47 } 48 49 override size_t read(ubyte[] buf) { 50 // If we do not have a good handshake, the client will attempt one. 51 if (!handshakeComplete) { 52 resetHandshake(); 53 super.read(buf); 54 // If we did not get everything we expected, the handshake failed 55 // and we need to send a 400 response back. 56 if (!handshakeComplete) { 57 sendBadRequest(); 58 return 0; 59 } 60 // Otherwise, send back the 101 response. 61 super.flush(); 62 } 63 64 // If the buffer is empty, read a new frame off the wire. 65 if (readBuffer_.empty) { 66 if (!readFrame()) { 67 return 0; 68 } 69 } 70 71 auto size = min(readBuffer_.length, buf.length); 72 buf[0..size] = readBuffer_[0..size]; 73 readBuffer_ = readBuffer_[size..$]; 74 return size; 75 } 76 77 override void write(in ubyte[] buf) { 78 writeBuffer_ ~= buf; 79 } 80 81 override void flush() { 82 if (writeBuffer_.empty) { 83 return; 84 } 85 86 // Properly reset the write buffer even some of the protocol operations go 87 // wrong. 88 scope (exit) { 89 writeBuffer_.length = 0; 90 writeBuffer_.assumeSafeAppend(); 91 } 92 93 writeFrameHeader(); 94 transport_.write(writeBuffer_); 95 transport_.flush(); 96 } 97 98 protected: 99 override string getHeader(size_t dataLength) { 100 return "HTTP/1.1 101 Switching Protocols\r\n" ~ 101 "Server: Thrift/" ~ VERSION ~ "\r\n" ~ 102 "Upgrade: websocket\r\n" ~ 103 "Connection: Upgrade\r\n" ~ 104 "Sec-WebSocket-Accept: " ~ acceptKey_ ~ "\r\n" ~ 105 "\r\n"; 106 } 107 108 override void parseHeader(const(ubyte)[] header) { 109 auto split = findSplit(header, [':']); 110 if (split[1].empty) { 111 // No colon found. 112 return; 113 } 114 115 static bool compToLower(ubyte a, ubyte b) { 116 return toLower(a) == toLower(b); 117 } 118 119 if (startsWith!compToLower(split[0], cast(ubyte[])"upgrade")) { 120 auto upgrade = stripLeft(cast(const(char)[])split[2]); 121 upgrade_ = sicmp(upgrade, "websocket") == 0; 122 } else if (startsWith!compToLower(split[0], cast(ubyte[])"connection")) { 123 auto connection = stripLeft(cast(const(char)[])split[2]); 124 connection_ = canFind(connection.toLower, "upgrade"); 125 } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-key")) { 126 auto secWebSocketKey = stripLeft(cast(const(char)[])split[2]); 127 auto hash = sha1Of(secWebSocketKey ~ WEBSOCKET_GUID); 128 acceptKey_ = Base64.encode(hash); 129 secWebSocketKey_ = true; 130 } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-version")) { 131 auto secWebSocketVersion = stripLeft(cast(const(char)[])split[2]); 132 secWebSocketVersion_ = sicmp(secWebSocketVersion, "13") == 0; 133 } 134 } 135 136 override bool parseStatusLine(const(ubyte)[] status) { 137 // Method SP Request-URI SP HTTP-Version CRLF. 138 auto split = findSplit(status, [' ']); 139 if (split[1].empty) { 140 throw new TTransportException("Bad status: " ~ to!string(status), 141 TTransportException.Type.CORRUPTED_DATA); 142 } 143 144 auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $]; 145 if (!canFind(uriVersion, ' ')) { 146 throw new TTransportException("Bad status: " ~ to!string(status), 147 TTransportException.Type.CORRUPTED_DATA); 148 } 149 150 if (split[0] == "GET") { 151 // GET method ok, looking for content. 152 return true; 153 } 154 155 throw new TTransportException("Bad status (unsupported method): " ~ 156 to!string(status), TTransportException.Type.CORRUPTED_DATA); 157 } 158 159 private: 160 @property bool handshakeComplete() { 161 return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_; 162 } 163 164 void failConnection(CloseCode reason) { 165 writeFrameHeader(Opcode.Close); 166 transport_.write(nativeToBigEndian!ushort(reason)); 167 transport_.flush(); 168 transport_.close(); 169 } 170 171 void pong() { 172 writeFrameHeader(Opcode.Pong); 173 transport_.write(readBuffer_); 174 transport_.flush(); 175 } 176 177 bool readFrame() { 178 ubyte[8] headerBuffer; 179 180 auto read = transport_.read(headerBuffer[0..2]); 181 if (read < 2) { 182 return false; 183 } 184 // Since Thrift has its own message end marker and we read frame by frame, 185 // it doesn't really matter if the frame is marked as FIN. 186 // Capture it only for debugging only. 187 debug auto fin = (headerBuffer[0] & 0x80) != 0; 188 189 // RSV1, RSV2, RSV3 190 if ((headerBuffer[0] & 0x70) != 0) { 191 failConnection(CloseCode.ProtocolError); 192 throw new TTransportException("Reserved bits must be zeroes", TTransportException.Type.CORRUPTED_DATA); 193 } 194 195 Opcode opcode; 196 try { 197 opcode = to!Opcode(headerBuffer[0] & 0x0F); 198 } catch (ConvException) { 199 failConnection(CloseCode.ProtocolError); 200 throw new TTransportException("Unknown opcode", TTransportException.Type.CORRUPTED_DATA); 201 } 202 203 // Mask 204 if ((headerBuffer[1] & 0x80) == 0) { 205 failConnection(CloseCode.ProtocolError); 206 throw new TTransportException("Messages from the client must be masked", TTransportException.Type.CORRUPTED_DATA); 207 } 208 209 // Read the length 210 ulong payloadLength = headerBuffer[1] & 0x7F; 211 if (payloadLength == 126) { 212 read = transport_.read(headerBuffer[0..2]); 213 if (read < 2) { 214 return false; 215 } 216 payloadLength = bigEndianToNative!ushort(headerBuffer[0..2]); 217 } else if (payloadLength == 127) { 218 read = transport_.read(headerBuffer); 219 if (read < headerBuffer.length) { 220 return false; 221 } 222 payloadLength = bigEndianToNative!ulong(headerBuffer); 223 if ((payloadLength & 0x8000000000000000) != 0) { 224 failConnection(CloseCode.ProtocolError); 225 throw new TTransportException("The most significant bit of the payload length must be zero", 226 TTransportException.Type.CORRUPTED_DATA); 227 } 228 } 229 230 // size_t is smaller than a ulong on a 32-bit system 231 static if (size_t.max < ulong.max) { 232 if(payloadLength > size_t.max) { 233 failConnection(CloseCode.MessageTooBig); 234 return false; 235 } 236 } 237 238 auto length = cast(size_t)payloadLength; 239 240 if (length > 0) { 241 // Read the masking key 242 read = transport_.read(headerBuffer[0..4]); 243 if (read < 4) { 244 return false; 245 } 246 247 readBuffer_ = new ubyte[](length); 248 read = transport_.read(readBuffer_); 249 if (read < length) { 250 return false; 251 } 252 253 // Unmask the data 254 for (size_t i = 0; i < length; i++) { 255 readBuffer_[i] ^= headerBuffer[i % 4]; 256 } 257 258 debug writef("FIN=%d, Opcode=%X, length=%d, payload=%s\n", 259 fin, 260 opcode, 261 length, 262 binary ? readBuffer_.toHexString() : cast(string)readBuffer_); 263 } 264 265 switch (opcode) { 266 case Opcode.Close: 267 debug { 268 if (length >= 2) { 269 CloseCode closeCode; 270 try { 271 closeCode = to!CloseCode(bigEndianToNative!ushort(readBuffer_[0..2])); 272 } catch (ConvException) { 273 closeCode = CloseCode.NoStatusCode; 274 } 275 276 string closeReason; 277 if (length == 2) { 278 closeReason = to!string(cast(CloseCode)closeCode); 279 } else { 280 closeReason = cast(string)readBuffer_[2..$]; 281 } 282 283 writef("Connection closed: %d %s\n", closeCode, closeReason); 284 } 285 } 286 transport_.close(); 287 return false; 288 case Opcode.Ping: 289 pong(); 290 return readFrame(); 291 default: 292 return true; 293 } 294 } 295 296 void resetHandshake() { 297 connection_ = false; 298 secWebSocketKey_ = false; 299 secWebSocketVersion_ = false; 300 upgrade_ = false; 301 } 302 303 void sendBadRequest() { 304 auto header = "HTTP/1.1 400 Bad Request\r\n" ~ 305 "Server: Thrift/" ~ VERSION ~ "\r\n" ~ 306 "\r\n"; 307 transport_.write(cast(const(ubyte[]))header); 308 transport_.flush(); 309 transport_.close(); 310 } 311 312 void writeFrameHeader(Opcode opcode = Opcode.Continuation) { 313 size_t headerSize = 1; 314 if (writeBuffer_.length < 126) { 315 ++headerSize; 316 } else if (writeBuffer_.length < 65536) { 317 headerSize += 3; 318 } else { 319 headerSize += 9; 320 } 321 // The server does not mask the response 322 323 ubyte[] header = new ubyte[headerSize]; 324 if (opcode == Opcode.Continuation) { 325 header[0] = binary ? Opcode.Binary : Opcode.Text; 326 } 327 else { 328 header[0] = opcode; 329 } 330 header[0] |= 0x80; 331 if (writeBuffer_.length < 126) { 332 header[1] = cast(ubyte)writeBuffer_.length; 333 } else if (writeBuffer_.length < 65536) { 334 header[1] = 126; 335 header[2..4] = nativeToBigEndian(cast(ushort)writeBuffer_.length); 336 } else { 337 header[1] = 127; 338 header[2..10] = nativeToBigEndian(cast(ulong)writeBuffer_.length); 339 } 340 341 transport_.write(header); 342 } 343 344 enum WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 345 346 TTransport transport_; 347 348 string acceptKey_; 349 bool connection_; 350 bool secWebSocketKey_; 351 bool secWebSocketVersion_; 352 bool upgrade_; 353 ubyte[] readBuffer_; 354 ubyte[] writeBuffer_; 355 } 356 357 class TServerWebSocketTransportFactory(bool binary) : TTransportFactory { 358 override TTransport getTransport(TTransport trans) { 359 return new TServerWebSocketTransport!binary(trans); 360 } 361 } 362 363 alias TServerBinaryWebSocketTransportFactory = TServerWebSocketTransportFactory!true; 364 alias TServerTextWebSocketTransportFactory = TServerWebSocketTransportFactory!false; 365 366 private enum CloseCode : ushort { 367 NormalClosure = 1000, 368 GoingAway = 1001, 369 ProtocolError = 1002, 370 UnsupportedDataType = 1003, 371 NoStatusCode = 1005, 372 AbnormalClosure = 1006, 373 InvalidData = 1007, 374 PolicyViolation = 1008, 375 MessageTooBig = 1009, 376 ExtensionExpected = 1010, 377 UnexpectedError = 1011, 378 NotSecure = 1015 379 } 380 381 private enum Opcode : ubyte { 382 Continuation = 0x0, 383 Text = 0x1, 384 Binary = 0x2, 385 Close = 0x8, 386 Ping = 0x9, 387 Pong = 0xA 388 }