1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 module thrift.protocol.compact; 20 21 import std.array : uninitializedArray; 22 import std.typetuple : allSatisfy, TypeTuple; 23 import thrift.protocol.base; 24 import thrift.transport.base; 25 import thrift.internal.endian; 26 27 /** 28 * D implementation of the Compact protocol. 29 * 30 * See THRIFT-110 for a protocol description. This implementation is based on 31 * the C++ one. 32 */ 33 final class TCompactProtocol(Transport = TTransport) if ( 34 isTTransport!Transport 35 ) : TProtocol { 36 /** 37 * Constructs a new instance. 38 * 39 * Params: 40 * trans = The transport to use. 41 * containerSizeLimit = If positive, the container size is limited to the 42 * given number of items. 43 * stringSizeLimit = If positive, the string length is limited to the 44 * given number of bytes. 45 */ 46 this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) { 47 trans_ = trans; 48 this.containerSizeLimit = containerSizeLimit; 49 this.stringSizeLimit = stringSizeLimit; 50 } 51 52 Transport transport() @property { 53 return trans_; 54 } 55 56 void reset() { 57 lastFieldId_ = 0; 58 fieldIdStack_ = null; 59 booleanField_ = TField.init; 60 hasBoolValue_ = false; 61 } 62 63 /** 64 * If positive, limits the number of items of deserialized containers to the 65 * given amount. 66 * 67 * This is useful to avoid allocating excessive amounts of memory when broken 68 * data is received. If the limit is exceeded, a SIZE_LIMIT-type 69 * TProtocolException is thrown. 70 * 71 * Defaults to zero (no limit). 72 */ 73 int containerSizeLimit; 74 75 /** 76 * If positive, limits the length of deserialized strings/binary data to the 77 * given number of bytes. 78 * 79 * This is useful to avoid allocating excessive amounts of memory when broken 80 * data is received. If the limit is exceeded, a SIZE_LIMIT-type 81 * TProtocolException is thrown. 82 * 83 * Defaults to zero (no limit). 84 */ 85 int stringSizeLimit; 86 87 /* 88 * Writing methods. 89 */ 90 91 void writeBool(bool b) { 92 if (booleanField_.name !is null) { 93 // we haven't written the field header yet 94 writeFieldBeginInternal(booleanField_, 95 b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE); 96 booleanField_.name = null; 97 } else { 98 // we're not part of a field, so just write the value 99 writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE); 100 } 101 } 102 103 void writeByte(byte b) { 104 trans_.write((cast(ubyte*)&b)[0..1]); 105 } 106 107 void writeI16(short i16) { 108 writeVarint32(i32ToZigzag(i16)); 109 } 110 111 void writeI32(int i32) { 112 writeVarint32(i32ToZigzag(i32)); 113 } 114 115 void writeI64(long i64) { 116 writeVarint64(i64ToZigzag(i64)); 117 } 118 119 void writeDouble(double dub) { 120 ulong bits = hostToLe(*cast(ulong*)(&dub)); 121 trans_.write((cast(ubyte*)&bits)[0 .. 8]); 122 } 123 124 void writeString(string str) { 125 writeBinary(cast(ubyte[])str); 126 } 127 128 void writeBinary(ubyte[] buf) { 129 assert(buf.length <= int.max); 130 writeVarint32(cast(int)buf.length); 131 trans_.write(buf); 132 } 133 134 void writeMessageBegin(TMessage msg) { 135 writeByte(cast(byte)PROTOCOL_ID); 136 writeByte(cast(byte)((VERSION_N & VERSION_MASK) | 137 ((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK))); 138 writeVarint32(msg.seqid); 139 writeString(msg.name); 140 } 141 void writeMessageEnd() {} 142 143 void writeStructBegin(TStruct tstruct) { 144 fieldIdStack_ ~= lastFieldId_; 145 lastFieldId_ = 0; 146 } 147 148 void writeStructEnd() { 149 lastFieldId_ = fieldIdStack_[$ - 1]; 150 fieldIdStack_ = fieldIdStack_[0 .. $ - 1]; 151 fieldIdStack_.assumeSafeAppend(); 152 } 153 154 void writeFieldBegin(TField field) { 155 if (field.type == TType.BOOL) { 156 booleanField_.name = field.name; 157 booleanField_.type = field.type; 158 booleanField_.id = field.id; 159 } else { 160 return writeFieldBeginInternal(field); 161 } 162 } 163 void writeFieldEnd() {} 164 165 void writeFieldStop() { 166 writeByte(TType.STOP); 167 } 168 169 void writeListBegin(TList list) { 170 writeCollectionBegin(list.elemType, list.size); 171 } 172 void writeListEnd() {} 173 174 void writeMapBegin(TMap map) { 175 if (map.size == 0) { 176 writeByte(0); 177 } else { 178 assert(map.size <= int.max); 179 writeVarint32(cast(int)map.size); 180 writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType))); 181 } 182 } 183 void writeMapEnd() {} 184 185 void writeSetBegin(TSet set) { 186 writeCollectionBegin(set.elemType, set.size); 187 } 188 void writeSetEnd() {} 189 190 191 /* 192 * Reading methods. 193 */ 194 195 bool readBool() { 196 if (hasBoolValue_ == true) { 197 hasBoolValue_ = false; 198 return boolValue_; 199 } 200 201 return readByte() == CType.BOOLEAN_TRUE; 202 } 203 204 byte readByte() { 205 ubyte[1] b = void; 206 trans_.readAll(b); 207 return cast(byte)b[0]; 208 } 209 210 short readI16() { 211 return cast(short)zigzagToI32(readVarint32()); 212 } 213 214 int readI32() { 215 return zigzagToI32(readVarint32()); 216 } 217 218 long readI64() { 219 return zigzagToI64(readVarint64()); 220 } 221 222 double readDouble() { 223 IntBuf!long b = void; 224 trans_.readAll(b.bytes); 225 b.value = leToHost(b.value); 226 return *cast(double*)(&b.value); 227 } 228 229 string readString() { 230 return cast(string)readBinary(); 231 } 232 233 ubyte[] readBinary() { 234 auto size = readVarint32(); 235 checkSize(size, stringSizeLimit); 236 237 if (size == 0) { 238 return null; 239 } 240 241 auto buf = uninitializedArray!(ubyte[])(size); 242 trans_.readAll(buf); 243 return buf; 244 } 245 246 TMessage readMessageBegin() { 247 TMessage msg = void; 248 249 auto protocolId = readByte(); 250 if (protocolId != cast(byte)PROTOCOL_ID) { 251 throw new TProtocolException("Bad protocol identifier", 252 TProtocolException.Type.BAD_VERSION); 253 } 254 255 auto versionAndType = readByte(); 256 auto ver = versionAndType & VERSION_MASK; 257 if (ver != VERSION_N) { 258 throw new TProtocolException("Bad protocol version", 259 TProtocolException.Type.BAD_VERSION); 260 } 261 262 msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS); 263 msg.seqid = readVarint32(); 264 msg.name = readString(); 265 266 return msg; 267 } 268 void readMessageEnd() {} 269 270 TStruct readStructBegin() { 271 fieldIdStack_ ~= lastFieldId_; 272 lastFieldId_ = 0; 273 return TStruct(); 274 } 275 276 void readStructEnd() { 277 lastFieldId_ = fieldIdStack_[$ - 1]; 278 fieldIdStack_ = fieldIdStack_[0 .. $ - 1]; 279 } 280 281 TField readFieldBegin() { 282 TField f = void; 283 f.name = null; 284 285 auto bite = readByte(); 286 auto type = cast(CType)(bite & 0x0f); 287 288 if (type == CType.STOP) { 289 // Struct stop byte, nothing more to do. 290 f.id = 0; 291 f.type = TType.STOP; 292 return f; 293 } 294 295 // Mask off the 4 MSB of the type header, which could contain a field id 296 // delta. 297 auto modifier = cast(short)((bite & 0xf0) >> 4); 298 if (modifier > 0) { 299 f.id = cast(short)(lastFieldId_ + modifier); 300 } else { 301 // Delta encoding not used, just read the id as usual. 302 f.id = readI16(); 303 } 304 f.type = getTType(type); 305 306 if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) { 307 // For boolean fields, the value is encoded in the type – keep it around 308 // for the readBool() call. 309 hasBoolValue_ = true; 310 boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false); 311 } 312 313 lastFieldId_ = f.id; 314 return f; 315 } 316 void readFieldEnd() {} 317 318 TList readListBegin() { 319 auto sizeAndType = readByte(); 320 321 auto lsize = (sizeAndType >> 4) & 0xf; 322 if (lsize == 0xf) { 323 lsize = readVarint32(); 324 } 325 checkSize(lsize, containerSizeLimit); 326 327 TList l = void; 328 l.elemType = getTType(cast(CType)(sizeAndType & 0x0f)); 329 l.size = cast(size_t)lsize; 330 331 return l; 332 } 333 void readListEnd() {} 334 335 TMap readMapBegin() { 336 TMap m = void; 337 338 auto size = readVarint32(); 339 ubyte kvType; 340 if (size != 0) { 341 kvType = readByte(); 342 } 343 checkSize(size, containerSizeLimit); 344 345 m.size = size; 346 m.keyType = getTType(cast(CType)(kvType >> 4)); 347 m.valueType = getTType(cast(CType)(kvType & 0xf)); 348 349 return m; 350 } 351 void readMapEnd() {} 352 353 TSet readSetBegin() { 354 auto sizeAndType = readByte(); 355 356 auto lsize = (sizeAndType >> 4) & 0xf; 357 if (lsize == 0xf) { 358 lsize = readVarint32(); 359 } 360 checkSize(lsize, containerSizeLimit); 361 362 TSet s = void; 363 s.elemType = getTType(cast(CType)(sizeAndType & 0xf)); 364 s.size = cast(size_t)lsize; 365 366 return s; 367 } 368 void readSetEnd() {} 369 370 private: 371 void writeFieldBeginInternal(TField field, byte typeOverride = -1) { 372 // If there's a type override, use that. 373 auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride); 374 375 // check if we can use delta encoding for the field id 376 if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) { 377 // write them together 378 writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite)); 379 } else { 380 // write them separate 381 writeByte(cast(byte)typeToWrite); 382 writeI16(field.id); 383 } 384 385 lastFieldId_ = field.id; 386 } 387 388 389 void writeCollectionBegin(TType elemType, size_t size) { 390 if (size <= 14) { 391 writeByte(cast(byte)(size << 4 | toCType(elemType))); 392 } else { 393 assert(size <= int.max); 394 writeByte(cast(byte)(0xf0 | toCType(elemType))); 395 writeVarint32(cast(int)size); 396 } 397 } 398 399 void writeVarint32(uint n) { 400 ubyte[5] buf = void; 401 ubyte wsize; 402 403 while (true) { 404 if ((n & ~0x7F) == 0) { 405 buf[wsize++] = cast(ubyte)n; 406 break; 407 } else { 408 buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80); 409 n >>= 7; 410 } 411 } 412 413 trans_.write(buf[0 .. wsize]); 414 } 415 416 /* 417 * Write an i64 as a varint. Results in 1-10 bytes on the wire. 418 */ 419 void writeVarint64(ulong n) { 420 ubyte[10] buf = void; 421 ubyte wsize; 422 423 while (true) { 424 if ((n & ~0x7FL) == 0) { 425 buf[wsize++] = cast(ubyte)n; 426 break; 427 } else { 428 buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80); 429 n >>= 7; 430 } 431 } 432 433 trans_.write(buf[0 .. wsize]); 434 } 435 436 /* 437 * Convert l into a zigzag long. This allows negative numbers to be 438 * represented compactly as a varint. 439 */ 440 ulong i64ToZigzag(long l) { 441 return (l << 1) ^ (l >> 63); 442 } 443 444 /* 445 * Convert n into a zigzag int. This allows negative numbers to be 446 * represented compactly as a varint. 447 */ 448 uint i32ToZigzag(int n) { 449 return (n << 1) ^ (n >> 31); 450 } 451 452 CType toCType(TType type) { 453 final switch (type) { 454 case TType.STOP: 455 return CType.STOP; 456 case TType.BOOL: 457 return CType.BOOLEAN_TRUE; 458 case TType.BYTE: 459 return CType.BYTE; 460 case TType.DOUBLE: 461 return CType.DOUBLE; 462 case TType.I16: 463 return CType.I16; 464 case TType.I32: 465 return CType.I32; 466 case TType.I64: 467 return CType.I64; 468 case TType.STRING: 469 return CType.BINARY; 470 case TType.STRUCT: 471 return CType.STRUCT; 472 case TType.MAP: 473 return CType.MAP; 474 case TType.SET: 475 return CType.SET; 476 case TType.LIST: 477 return CType.LIST; 478 case TType.VOID: 479 assert(false, "Invalid type passed."); 480 } 481 } 482 483 int readVarint32() { 484 return cast(int)readVarint64(); 485 } 486 487 long readVarint64() { 488 ulong val; 489 ubyte shift; 490 ubyte[10] buf = void; // 64 bits / (7 bits/byte) = 10 bytes. 491 auto bufSize = buf.sizeof; 492 auto borrowed = trans_.borrow(buf.ptr, bufSize); 493 494 ubyte rsize; 495 496 if (borrowed) { 497 // Fast path. 498 while (true) { 499 auto bite = borrowed[rsize]; 500 rsize++; 501 val |= cast(ulong)(bite & 0x7f) << shift; 502 shift += 7; 503 if (!(bite & 0x80)) { 504 trans_.consume(rsize); 505 return val; 506 } 507 // Have to check for invalid data so we don't crash. 508 if (rsize == buf.sizeof) { 509 throw new TProtocolException(TProtocolException.Type.INVALID_DATA, 510 "Variable-length int over 10 bytes."); 511 } 512 } 513 } else { 514 // Slow path. 515 while (true) { 516 ubyte[1] bite; 517 trans_.readAll(bite); 518 ++rsize; 519 520 val |= cast(ulong)(bite[0] & 0x7f) << shift; 521 shift += 7; 522 if (!(bite[0] & 0x80)) { 523 return val; 524 } 525 526 // Might as well check for invalid data on the slow path too. 527 if (rsize >= buf.sizeof) { 528 throw new TProtocolException(TProtocolException.Type.INVALID_DATA, 529 "Variable-length int over 10 bytes."); 530 } 531 } 532 } 533 } 534 535 /* 536 * Convert from zigzag int to int. 537 */ 538 int zigzagToI32(uint n) { 539 return (n >> 1) ^ -(n & 1); 540 } 541 542 /* 543 * Convert from zigzag long to long. 544 */ 545 long zigzagToI64(ulong n) { 546 return (n >> 1) ^ -(n & 1); 547 } 548 549 TType getTType(CType type) { 550 final switch (type) { 551 case CType.STOP: 552 return TType.STOP; 553 case CType.BOOLEAN_FALSE: 554 return TType.BOOL; 555 case CType.BOOLEAN_TRUE: 556 return TType.BOOL; 557 case CType.BYTE: 558 return TType.BYTE; 559 case CType.I16: 560 return TType.I16; 561 case CType.I32: 562 return TType.I32; 563 case CType.I64: 564 return TType.I64; 565 case CType.DOUBLE: 566 return TType.DOUBLE; 567 case CType.BINARY: 568 return TType.STRING; 569 case CType.LIST: 570 return TType.LIST; 571 case CType.SET: 572 return TType.SET; 573 case CType.MAP: 574 return TType.MAP; 575 case CType.STRUCT: 576 return TType.STRUCT; 577 } 578 } 579 580 void checkSize(int size, int limit) { 581 if (size < 0) { 582 throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE); 583 } else if (limit > 0 && size > limit) { 584 throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT); 585 } 586 } 587 588 enum PROTOCOL_ID = 0x82; 589 enum VERSION_N = 1; 590 enum VERSION_MASK = 0b0001_1111; 591 enum TYPE_MASK = 0b1110_0000; 592 enum TYPE_BITS = 0b0000_0111; 593 enum TYPE_SHIFT_AMOUNT = 5; 594 595 // Probably need to implement a better stack at some point. 596 short[] fieldIdStack_; 597 short lastFieldId_; 598 599 TField booleanField_; 600 601 bool hasBoolValue_; 602 bool boolValue_; 603 604 Transport trans_; 605 } 606 607 /** 608 * TCompactProtocol construction helper to avoid having to explicitly specify 609 * the transport type, i.e. to allow the constructor being called using IFTI 610 * (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla 611 * enhancement requet 6082)). 612 */ 613 TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans, 614 int containerSizeLimit = 0, int stringSizeLimit = 0 615 ) if (isTTransport!Transport) 616 { 617 return new TCompactProtocol!Transport(trans, 618 containerSizeLimit, stringSizeLimit); 619 } 620 621 private { 622 enum CType : ubyte { 623 STOP = 0x0, 624 BOOLEAN_TRUE = 0x1, 625 BOOLEAN_FALSE = 0x2, 626 BYTE = 0x3, 627 I16 = 0x4, 628 I32 = 0x5, 629 I64 = 0x6, 630 DOUBLE = 0x7, 631 BINARY = 0x8, 632 LIST = 0x9, 633 SET = 0xa, 634 MAP = 0xb, 635 STRUCT = 0xc 636 } 637 static assert(CType.max <= 0xf, 638 "Compact protocol wire type representation must fit into 4 bits."); 639 } 640 641 unittest { 642 import std.exception; 643 import thrift.transport.memory; 644 645 // Check the message header format. 646 auto buf = new TMemoryBuffer; 647 auto compact = tCompactProtocol(buf); 648 compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0)); 649 650 auto header = new ubyte[7]; 651 buf.readAll(header); 652 enforce(header == [ 653 130, // Protocol id. 654 33, // Version/type byte. 655 0, // Sequence id. 656 3, 102, 111, 111 // Method name. 657 ]); 658 } 659 660 unittest { 661 import thrift.internal.test.protocol; 662 testContainerSizeLimit!(TCompactProtocol!())(); 663 testStringSizeLimit!(TCompactProtocol!())(); 664 } 665 666 /** 667 * TProtocolFactory creating a TCompactProtocol instance for passed in 668 * transports. 669 * 670 * The optional Transports template tuple parameter can be used to specify 671 * one or more TTransport implementations to specifically instantiate 672 * TCompactProtocol for. If the actual transport types encountered at 673 * runtime match one of the transports in the list, a specialized protocol 674 * instance is created. Otherwise, a generic TTransport version is used. 675 */ 676 class TCompactProtocolFactory(Transports...) if ( 677 allSatisfy!(isTTransport, Transports) 678 ) : TProtocolFactory { 679 /// 680 this(int containerSizeLimit = 0, int stringSizeLimit = 0) { 681 containerSizeLimit_ = 0; 682 stringSizeLimit_ = 0; 683 } 684 685 TProtocol getProtocol(TTransport trans) const { 686 foreach (Transport; TypeTuple!(Transports, TTransport)) { 687 auto concreteTrans = cast(Transport)trans; 688 if (concreteTrans) { 689 return new TCompactProtocol!Transport(concreteTrans); 690 } 691 } 692 throw new TProtocolException( 693 "Passed null transport to TCompactProtocolFactory."); 694 } 695 696 int containerSizeLimit_; 697 int stringSizeLimit_; 698 }