1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.protocol.compact;
20 
21 import std.array : uninitializedArray;
22 import std.typetuple : allSatisfy, TypeTuple;
23 import thrift.protocol.base;
24 import thrift.transport.base;
25 import thrift.internal.endian;
26 
27 /**
28  * D implementation of the Compact protocol.
29  *
30  * See THRIFT-110 for a protocol description. This implementation is based on
31  * the C++ one.
32  */
33 final class TCompactProtocol(Transport = TTransport) if (
34   isTTransport!Transport
35 ) : TProtocol {
36   /**
37    * Constructs a new instance.
38    *
39    * Params:
40    *   trans = The transport to use.
41    *   containerSizeLimit = If positive, the container size is limited to the
42    *     given number of items.
43    *   stringSizeLimit = If positive, the string length is limited to the
44    *     given number of bytes.
45    */
46   this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) {
47     trans_ = trans;
48     this.containerSizeLimit = containerSizeLimit;
49     this.stringSizeLimit = stringSizeLimit;
50   }
51 
52   Transport transport() @property {
53     return trans_;
54   }
55 
56   void reset() {
57     lastFieldId_ = 0;
58     fieldIdStack_ = null;
59     booleanField_ = TField.init;
60     hasBoolValue_ = false;
61   }
62 
63   /**
64    * If positive, limits the number of items of deserialized containers to the
65    * given amount.
66    *
67    * This is useful to avoid allocating excessive amounts of memory when broken
68    * data is received. If the limit is exceeded, a SIZE_LIMIT-type
69    * TProtocolException is thrown.
70    *
71    * Defaults to zero (no limit).
72    */
73   int containerSizeLimit;
74 
75   /**
76    * If positive, limits the length of deserialized strings/binary data to the
77    * given number of bytes.
78    *
79    * This is useful to avoid allocating excessive amounts of memory when broken
80    * data is received. If the limit is exceeded, a SIZE_LIMIT-type
81    * TProtocolException is thrown.
82    *
83    * Defaults to zero (no limit).
84    */
85   int stringSizeLimit;
86 
87   /*
88    * Writing methods.
89    */
90 
91   void writeBool(bool b) {
92     if (booleanField_.name !is null) {
93       // we haven't written the field header yet
94       writeFieldBeginInternal(booleanField_,
95         b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
96       booleanField_.name = null;
97     } else {
98       // we're not part of a field, so just write the value
99       writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
100     }
101   }
102 
103   void writeByte(byte b) {
104     trans_.write((cast(ubyte*)&b)[0..1]);
105   }
106 
107   void writeI16(short i16) {
108     writeVarint32(i32ToZigzag(i16));
109   }
110 
111   void writeI32(int i32) {
112     writeVarint32(i32ToZigzag(i32));
113   }
114 
115   void writeI64(long i64) {
116     writeVarint64(i64ToZigzag(i64));
117   }
118 
119   void writeDouble(double dub) {
120     ulong bits = hostToLe(*cast(ulong*)(&dub));
121     trans_.write((cast(ubyte*)&bits)[0 .. 8]);
122   }
123 
124   void writeString(string str) {
125     writeBinary(cast(ubyte[])str);
126   }
127 
128   void writeBinary(ubyte[] buf) {
129     assert(buf.length <= int.max);
130     writeVarint32(cast(int)buf.length);
131     trans_.write(buf);
132   }
133 
134   void writeMessageBegin(TMessage msg) {
135     writeByte(cast(byte)PROTOCOL_ID);
136     writeByte(cast(byte)((VERSION_N & VERSION_MASK) |
137                          ((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)));
138     writeVarint32(msg.seqid);
139     writeString(msg.name);
140   }
141   void writeMessageEnd() {}
142 
143   void writeStructBegin(TStruct tstruct) {
144     fieldIdStack_ ~= lastFieldId_;
145     lastFieldId_ = 0;
146   }
147 
148   void writeStructEnd() {
149     lastFieldId_ = fieldIdStack_[$ - 1];
150     fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
151     fieldIdStack_.assumeSafeAppend();
152   }
153 
154   void writeFieldBegin(TField field) {
155     if (field.type == TType.BOOL) {
156       booleanField_.name = field.name;
157       booleanField_.type = field.type;
158       booleanField_.id = field.id;
159     } else {
160       return writeFieldBeginInternal(field);
161     }
162   }
163   void writeFieldEnd() {}
164 
165   void writeFieldStop() {
166     writeByte(TType.STOP);
167   }
168 
169   void writeListBegin(TList list) {
170     writeCollectionBegin(list.elemType, list.size);
171   }
172   void writeListEnd() {}
173 
174   void writeMapBegin(TMap map) {
175     if (map.size == 0) {
176       writeByte(0);
177     } else {
178       assert(map.size <= int.max);
179       writeVarint32(cast(int)map.size);
180       writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType)));
181     }
182   }
183   void writeMapEnd() {}
184 
185   void writeSetBegin(TSet set) {
186     writeCollectionBegin(set.elemType, set.size);
187   }
188   void writeSetEnd() {}
189 
190 
191   /*
192    * Reading methods.
193    */
194 
195   bool readBool() {
196     if (hasBoolValue_ == true) {
197       hasBoolValue_ = false;
198       return boolValue_;
199     }
200 
201     return readByte() == CType.BOOLEAN_TRUE;
202   }
203 
204   byte readByte() {
205     ubyte[1] b = void;
206     trans_.readAll(b);
207     return cast(byte)b[0];
208   }
209 
210   short readI16() {
211     return cast(short)zigzagToI32(readVarint32());
212   }
213 
214   int readI32() {
215     return zigzagToI32(readVarint32());
216   }
217 
218   long readI64() {
219     return zigzagToI64(readVarint64());
220   }
221 
222   double readDouble() {
223     IntBuf!long b = void;
224     trans_.readAll(b.bytes);
225     b.value = leToHost(b.value);
226     return *cast(double*)(&b.value);
227   }
228 
229   string readString() {
230     return cast(string)readBinary();
231   }
232 
233   ubyte[] readBinary() {
234     auto size = readVarint32();
235     checkSize(size, stringSizeLimit);
236 
237     if (size == 0) {
238       return null;
239     }
240 
241     auto buf = uninitializedArray!(ubyte[])(size);
242     trans_.readAll(buf);
243     return buf;
244   }
245 
246   TMessage readMessageBegin() {
247     TMessage msg = void;
248 
249     auto protocolId = readByte();
250     if (protocolId != cast(byte)PROTOCOL_ID) {
251       throw new TProtocolException("Bad protocol identifier",
252         TProtocolException.Type.BAD_VERSION);
253     }
254 
255     auto versionAndType = readByte();
256     auto ver = versionAndType & VERSION_MASK;
257     if (ver != VERSION_N) {
258       throw new TProtocolException("Bad protocol version",
259         TProtocolException.Type.BAD_VERSION);
260     }
261 
262     msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS);
263     msg.seqid = readVarint32();
264     msg.name = readString();
265 
266     return msg;
267   }
268   void readMessageEnd() {}
269 
270   TStruct readStructBegin() {
271     fieldIdStack_ ~= lastFieldId_;
272     lastFieldId_ = 0;
273     return TStruct();
274   }
275 
276   void readStructEnd() {
277     lastFieldId_ = fieldIdStack_[$ - 1];
278     fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
279   }
280 
281   TField readFieldBegin() {
282     TField f = void;
283     f.name = null;
284 
285     auto bite = readByte();
286     auto type = cast(CType)(bite & 0x0f);
287 
288     if (type == CType.STOP) {
289       // Struct stop byte, nothing more to do.
290       f.id = 0;
291       f.type = TType.STOP;
292       return f;
293     }
294 
295     // Mask off the 4 MSB of the type header, which could contain a field id
296     // delta.
297     auto modifier = cast(short)((bite & 0xf0) >> 4);
298     if (modifier > 0) {
299       f.id = cast(short)(lastFieldId_ + modifier);
300     } else {
301       // Delta encoding not used, just read the id as usual.
302       f.id = readI16();
303     }
304     f.type = getTType(type);
305 
306     if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) {
307       // For boolean fields, the value is encoded in the type – keep it around
308       // for the readBool() call.
309       hasBoolValue_ = true;
310       boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false);
311     }
312 
313     lastFieldId_ = f.id;
314     return f;
315   }
316   void readFieldEnd() {}
317 
318   TList readListBegin() {
319     auto sizeAndType = readByte();
320 
321     auto lsize = (sizeAndType >> 4) & 0xf;
322     if (lsize == 0xf) {
323       lsize = readVarint32();
324     }
325     checkSize(lsize, containerSizeLimit);
326 
327     TList l = void;
328     l.elemType = getTType(cast(CType)(sizeAndType & 0x0f));
329     l.size = cast(size_t)lsize;
330 
331     return l;
332   }
333   void readListEnd() {}
334 
335   TMap readMapBegin() {
336     TMap m = void;
337 
338     auto size = readVarint32();
339     ubyte kvType;
340     if (size != 0) {
341       kvType = readByte();
342     }
343     checkSize(size, containerSizeLimit);
344 
345     m.size = size;
346     m.keyType = getTType(cast(CType)(kvType >> 4));
347     m.valueType = getTType(cast(CType)(kvType & 0xf));
348 
349     return m;
350   }
351   void readMapEnd() {}
352 
353   TSet readSetBegin() {
354     auto sizeAndType = readByte();
355 
356     auto lsize = (sizeAndType >> 4) & 0xf;
357     if (lsize == 0xf) {
358       lsize = readVarint32();
359     }
360     checkSize(lsize, containerSizeLimit);
361 
362     TSet s = void;
363     s.elemType = getTType(cast(CType)(sizeAndType & 0xf));
364     s.size = cast(size_t)lsize;
365 
366     return s;
367   }
368   void readSetEnd() {}
369 
370 private:
371   void writeFieldBeginInternal(TField field, byte typeOverride = -1) {
372     // If there's a type override, use that.
373     auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride);
374 
375     // check if we can use delta encoding for the field id
376     if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) {
377       // write them together
378       writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite));
379     } else {
380       // write them separate
381       writeByte(cast(byte)typeToWrite);
382       writeI16(field.id);
383     }
384 
385     lastFieldId_ = field.id;
386   }
387 
388 
389   void writeCollectionBegin(TType elemType, size_t size) {
390     if (size <= 14) {
391       writeByte(cast(byte)(size << 4 | toCType(elemType)));
392     } else {
393       assert(size <= int.max);
394       writeByte(cast(byte)(0xf0 | toCType(elemType)));
395       writeVarint32(cast(int)size);
396     }
397   }
398 
399   void writeVarint32(uint n) {
400     ubyte[5] buf = void;
401     ubyte wsize;
402 
403     while (true) {
404       if ((n & ~0x7F) == 0) {
405         buf[wsize++] = cast(ubyte)n;
406         break;
407       } else {
408         buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
409         n >>= 7;
410       }
411     }
412 
413     trans_.write(buf[0 .. wsize]);
414   }
415 
416   /*
417    * Write an i64 as a varint. Results in 1-10 bytes on the wire.
418    */
419   void writeVarint64(ulong n) {
420     ubyte[10] buf = void;
421     ubyte wsize;
422 
423     while (true) {
424       if ((n & ~0x7FL) == 0) {
425         buf[wsize++] = cast(ubyte)n;
426         break;
427       } else {
428         buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
429         n >>= 7;
430       }
431     }
432 
433     trans_.write(buf[0 .. wsize]);
434   }
435 
436   /*
437    * Convert l into a zigzag long. This allows negative numbers to be
438    * represented compactly as a varint.
439    */
440   ulong i64ToZigzag(long l) {
441     return (l << 1) ^ (l >> 63);
442   }
443 
444   /*
445    * Convert n into a zigzag int. This allows negative numbers to be
446    * represented compactly as a varint.
447    */
448   uint i32ToZigzag(int n) {
449     return (n << 1) ^ (n >> 31);
450   }
451 
452   CType toCType(TType type) {
453     final switch (type) {
454       case TType.STOP:
455         return CType.STOP;
456       case TType.BOOL:
457         return CType.BOOLEAN_TRUE;
458       case TType.BYTE:
459         return CType.BYTE;
460       case TType.DOUBLE:
461         return CType.DOUBLE;
462       case TType.I16:
463         return CType.I16;
464       case TType.I32:
465         return CType.I32;
466       case TType.I64:
467         return CType.I64;
468       case TType.STRING:
469         return CType.BINARY;
470       case TType.STRUCT:
471         return CType.STRUCT;
472       case TType.MAP:
473         return CType.MAP;
474       case TType.SET:
475         return CType.SET;
476       case TType.LIST:
477         return CType.LIST;
478       case TType.VOID:
479         assert(false, "Invalid type passed.");
480     }
481   }
482 
483   int readVarint32() {
484     return cast(int)readVarint64();
485   }
486 
487   long readVarint64() {
488     ulong val;
489     ubyte shift;
490     ubyte[10] buf = void;  // 64 bits / (7 bits/byte) = 10 bytes.
491     auto bufSize = buf.sizeof;
492     auto borrowed = trans_.borrow(buf.ptr, bufSize);
493 
494     ubyte rsize;
495 
496     if (borrowed) {
497       // Fast path.
498       while (true) {
499         auto bite = borrowed[rsize];
500         rsize++;
501         val |= cast(ulong)(bite & 0x7f) << shift;
502         shift += 7;
503         if (!(bite & 0x80)) {
504           trans_.consume(rsize);
505           return val;
506         }
507         // Have to check for invalid data so we don't crash.
508         if (rsize == buf.sizeof) {
509           throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
510             "Variable-length int over 10 bytes.");
511         }
512       }
513     } else {
514       // Slow path.
515       while (true) {
516         ubyte[1] bite;
517         trans_.readAll(bite);
518         ++rsize;
519 
520         val |= cast(ulong)(bite[0] & 0x7f) << shift;
521         shift += 7;
522         if (!(bite[0] & 0x80)) {
523           return val;
524         }
525 
526         // Might as well check for invalid data on the slow path too.
527         if (rsize >= buf.sizeof) {
528           throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
529             "Variable-length int over 10 bytes.");
530         }
531       }
532     }
533   }
534 
535   /*
536    * Convert from zigzag int to int.
537    */
538   int zigzagToI32(uint n) {
539     return (n >> 1) ^ -(n & 1);
540   }
541 
542   /*
543    * Convert from zigzag long to long.
544    */
545   long zigzagToI64(ulong n) {
546     return (n >> 1) ^ -(n & 1);
547   }
548 
549   TType getTType(CType type) {
550     final switch (type) {
551       case CType.STOP:
552         return TType.STOP;
553       case CType.BOOLEAN_FALSE:
554         return TType.BOOL;
555       case CType.BOOLEAN_TRUE:
556         return TType.BOOL;
557       case CType.BYTE:
558         return TType.BYTE;
559       case CType.I16:
560         return TType.I16;
561       case CType.I32:
562         return TType.I32;
563       case CType.I64:
564         return TType.I64;
565       case CType.DOUBLE:
566         return TType.DOUBLE;
567       case CType.BINARY:
568         return TType.STRING;
569       case CType.LIST:
570         return TType.LIST;
571       case CType.SET:
572         return TType.SET;
573       case CType.MAP:
574         return TType.MAP;
575       case CType.STRUCT:
576         return TType.STRUCT;
577     }
578   }
579 
580   void checkSize(int size, int limit) {
581     if (size < 0) {
582       throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
583     } else if (limit > 0 && size > limit) {
584       throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
585     }
586   }
587 
588   enum PROTOCOL_ID = 0x82;
589   enum VERSION_N = 1;
590   enum VERSION_MASK = 0b0001_1111;
591   enum TYPE_MASK = 0b1110_0000;
592   enum TYPE_BITS = 0b0000_0111;
593   enum TYPE_SHIFT_AMOUNT = 5;
594 
595   // Probably need to implement a better stack at some point.
596   short[] fieldIdStack_;
597   short lastFieldId_;
598 
599   TField booleanField_;
600 
601   bool hasBoolValue_;
602   bool boolValue_;
603 
604   Transport trans_;
605 }
606 
607 /**
608  * TCompactProtocol construction helper to avoid having to explicitly specify
609  * the transport type, i.e. to allow the constructor being called using IFTI
610  * (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
611  * enhancement requet 6082)).
612  */
613 TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans,
614   int containerSizeLimit = 0, int stringSizeLimit = 0
615 ) if (isTTransport!Transport)
616 {
617   return new TCompactProtocol!Transport(trans,
618     containerSizeLimit, stringSizeLimit);
619 }
620 
621 private {
622   enum CType : ubyte {
623     STOP = 0x0,
624     BOOLEAN_TRUE = 0x1,
625     BOOLEAN_FALSE = 0x2,
626     BYTE = 0x3,
627     I16 = 0x4,
628     I32 = 0x5,
629     I64 = 0x6,
630     DOUBLE = 0x7,
631     BINARY = 0x8,
632     LIST = 0x9,
633     SET = 0xa,
634     MAP = 0xb,
635     STRUCT = 0xc
636   }
637   static assert(CType.max <= 0xf,
638     "Compact protocol wire type representation must fit into 4 bits.");
639 }
640 
641 unittest {
642   import std.exception;
643   import thrift.transport.memory;
644 
645   // Check the message header format.
646   auto buf = new TMemoryBuffer;
647   auto compact = tCompactProtocol(buf);
648   compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
649 
650   auto header = new ubyte[7];
651   buf.readAll(header);
652   enforce(header == [
653     130, // Protocol id.
654     33, // Version/type byte.
655     0, // Sequence id.
656     3, 102, 111, 111 // Method name.
657   ]);
658 }
659 
660 unittest {
661   import thrift.internal.test.protocol;
662   testContainerSizeLimit!(TCompactProtocol!())();
663   testStringSizeLimit!(TCompactProtocol!())();
664 }
665 
666 /**
667  * TProtocolFactory creating a TCompactProtocol instance for passed in
668  * transports.
669  *
670  * The optional Transports template tuple parameter can be used to specify
671  * one or more TTransport implementations to specifically instantiate
672  * TCompactProtocol for. If the actual transport types encountered at
673  * runtime match one of the transports in the list, a specialized protocol
674  * instance is created. Otherwise, a generic TTransport version is used.
675  */
676 class TCompactProtocolFactory(Transports...) if (
677   allSatisfy!(isTTransport, Transports)
678 ) : TProtocolFactory {
679   ///
680   this(int containerSizeLimit = 0, int stringSizeLimit = 0) {
681     containerSizeLimit_ = 0;
682     stringSizeLimit_ = 0;
683   }
684 
685   TProtocol getProtocol(TTransport trans) const {
686     foreach (Transport; TypeTuple!(Transports, TTransport)) {
687       auto concreteTrans = cast(Transport)trans;
688       if (concreteTrans) {
689         return new TCompactProtocol!Transport(concreteTrans);
690       }
691     }
692     throw new TProtocolException(
693       "Passed null transport to TCompactProtocolFactory.");
694   }
695 
696   int containerSizeLimit_;
697   int stringSizeLimit_;
698 }