1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /**
21  * OpenSSL socket implementation, in large parts ported from C++.
22  */
23 module thrift.transport.ssl;
24 
25 import core.exception : onOutOfMemoryError;
26 import core.stdc.errno : errno, EINTR;
27 import core.sync.mutex : Mutex;
28 import core.memory : GC;
29 import core.stdc.config;
30 import core.stdc.stdlib : free, malloc;
31 import std.ascii : toUpper;
32 import std.array : empty, front, popFront;
33 import std.conv : emplace, to;
34 import std.exception : enforce;
35 import std.socket : Address, InternetAddress, Internet6Address, Socket;
36 import std..string : toStringz;
37 import deimos.openssl.err;
38 import deimos.openssl.rand;
39 import deimos.openssl.ssl;
40 import deimos.openssl.x509v3;
41 import thrift.base;
42 import thrift.internal.ssl;
43 import thrift.transport.base;
44 import thrift.transport.socket;
45 
46 /**
47  * SSL encrypted socket implementation using OpenSSL.
48  *
49  * Note:
50  * On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
51  * might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
52  * a closed socket if the peer disconnects abruptly:
53  * ---
54  * import core.stdc.signal;
55  * import core.sys.posix.signal;
56  * signal(SIGPIPE, SIG_IGN);
57  * ---
58  */
59 final class TSSLSocket : TSocket {
60   /**
61    * Creates an instance that wraps an already created, connected (!) socket.
62    *
63    * Params:
64    *   context = The SSL socket context to use. A reference to it is stored so
65    *     that it doesn't get cleaned up while the socket is used.
66    *   socket = Already created, connected socket object.
67    */
68   this(TSSLContext context, Socket socket) {
69     super(socket);
70     context_ = context;
71     serverSide_ = context.serverSide;
72     accessManager_ = context.accessManager;
73   }
74 
75   /**
76    * Creates a new unconnected socket that will connect to the given host
77    * on the given port.
78    *
79    * Params:
80    *   context = The SSL socket context to use. A reference to it is stored so
81     *     that it doesn't get cleaned up while the socket is used.
82    *   host = Remote host.
83    *   port = Remote port.
84    */
85   this(TSSLContext context, string host, ushort port) {
86     super(host, port);
87     context_ = context;
88     serverSide_ = context.serverSide;
89     accessManager_ = context.accessManager;
90   }
91 
92   override bool isOpen() @property {
93     if (ssl_ is null || !super.isOpen()) return false;
94 
95     auto shutdown = SSL_get_shutdown(ssl_);
96     bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
97     bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
98     return !(shutdownReceived && shutdownSent);
99   }
100 
101   override bool peek() {
102     if (!isOpen) return false;
103     checkHandshake();
104 
105     byte bt;
106     auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
107     enforce(rc >= 0, getSSLException("SSL_peek"));
108 
109     if (rc == 0) {
110       ERR_clear_error();
111     }
112     return (rc > 0);
113   }
114 
115   override void open() {
116     enforce(!serverSide_, "Cannot open a server-side SSL socket.");
117     if (isOpen) return;
118     super.open();
119   }
120 
121   override void close() {
122     if (!isOpen) return;
123 
124     if (ssl_ !is null) {
125       // Two-step SSL shutdown.
126       auto rc = SSL_shutdown(ssl_);
127       if (rc == 0) {
128         rc = SSL_shutdown(ssl_);
129       }
130       if (rc < 0) {
131         // Do not throw an exception here as leaving the transport "open" will
132         // probably produce only more errors, and the chance we can do
133         // something about the error e.g. by retrying is very low.
134         logError("Error shutting down SSL: %s", getSSLException());
135       }
136 
137       SSL_free(ssl_);
138       ssl_ = null;
139       ERR_remove_state(0);
140     }
141     super.close();
142   }
143 
144   override size_t read(ubyte[] buf) {
145     checkHandshake();
146 
147     int bytes;
148     foreach (_; 0 .. maxRecvRetries) {
149       bytes = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
150       if (bytes >= 0) break;
151 
152       auto errnoCopy = errno;
153       if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
154         if (ERR_get_error() == 0 && errnoCopy == EINTR) {
155           // FIXME: Windows.
156           continue;
157         }
158       }
159       throw getSSLException("SSL_read");
160     }
161     return bytes;
162   }
163 
164   override void write(in ubyte[] buf) {
165     checkHandshake();
166 
167     // Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
168     size_t written = 0;
169     while (written < buf.length) {
170       auto bytes = SSL_write(ssl_, buf.ptr + written,
171         cast(int)(buf.length - written));
172       if (bytes <= 0) {
173         throw getSSLException("SSL_write");
174       }
175       written += bytes;
176     }
177   }
178 
179   override void flush() {
180     checkHandshake();
181 
182     auto bio = SSL_get_wbio(ssl_);
183     enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
184 
185     auto rc = BIO_flush(bio);
186     enforce(rc == 1, getSSLException("BIO_flush"));
187   }
188 
189   /**
190    * Whether to use client or server side SSL handshake protocol.
191    */
192   bool serverSide() @property const {
193     return serverSide_;
194   }
195 
196   /// Ditto
197   void serverSide(bool value) @property {
198     serverSide_ = value;
199   }
200 
201   /**
202    * The access manager to use.
203    */
204   void accessManager(TAccessManager value) @property {
205     accessManager_ = value;
206   }
207 
208 private:
209   void checkHandshake() {
210     enforce(super.isOpen(), new TTransportException(
211       TTransportException.Type.NOT_OPEN));
212 
213     if (ssl_ !is null) return;
214     ssl_ = context_.createSSL();
215 
216     SSL_set_fd(ssl_, cast(int)socketHandle);
217     int rc;
218     if (serverSide_) {
219       rc = SSL_accept(ssl_);
220     } else {
221       rc = SSL_connect(ssl_);
222     }
223     enforce(rc > 0, getSSLException());
224     authorize(ssl_, accessManager_, getPeerAddress(),
225       (serverSide_ ? getPeerAddress().toHostNameString() : host));
226   }
227 
228   bool serverSide_;
229   SSL* ssl_;
230   TSSLContext context_;
231   TAccessManager accessManager_;
232 }
233 
234 /**
235  * Represents an OpenSSL context with certification settings, etc. and handles
236  * initialization/teardown.
237  *
238  * OpenSSL is initialized when the first instance of this class is created
239  * and shut down when the last one is destroyed (thread-safe).
240  */
241 class TSSLContext {
242   this() {
243     initMutex_.lock();
244     scope(exit) initMutex_.unlock();
245 
246     if (count_ == 0) {
247       initializeOpenSSL();
248       randomize();
249     }
250     count_++;
251 
252     static if (OPENSSL_VERSION_NUMBER >= 0x1010000f) { // OPENSSL_VERSION_AT_LEAST(1, 1)) {
253         ctx_ = SSL_CTX_new(TLS_method());
254     } else {
255         ctx_ = SSL_CTX_new(SSLv23_method());
256         SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv2);
257     }
258     SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);   // THRIFT-3164
259     enforce(ctx_, getSSLException("SSL_CTX_new"));
260     SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
261   }
262 
263   ~this() {
264     initMutex_.lock();
265     scope(exit) initMutex_.unlock();
266 
267     if (ctx_ !is null) {
268       SSL_CTX_free(ctx_);
269       ctx_ = null;
270     }
271 
272     count_--;
273     if (count_ == 0) {
274       cleanupOpenSSL();
275     }
276   }
277 
278   /**
279    * Ciphers to be used in SSL handshake process.
280    *
281    * The string must be in the colon-delimited OpenSSL notation described in
282    * ciphers(1), for example: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH".
283    */
284   void ciphers(string enable) @property {
285     auto rc = SSL_CTX_set_cipher_list(ctx_, toStringz(enable));
286 
287     enforce(ERR_peek_error() == 0, getSSLException("SSL_CTX_set_cipher_list"));
288     enforce(rc > 0, new TSSLException("None of specified ciphers are supported"));
289   }
290 
291   /**
292    * Whether peer is required to present a valid certificate.
293    */
294   void authenticate(bool required) @property {
295     int mode;
296     if (required) {
297       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
298         SSL_VERIFY_CLIENT_ONCE;
299     } else {
300       mode = SSL_VERIFY_NONE;
301     }
302     SSL_CTX_set_verify(ctx_, mode, null);
303   }
304 
305   /**
306    * Load server certificate.
307    *
308    * Params:
309    *   path = Path to the certificate file.
310    *   format = Certificate file format. Defaults to PEM, which is currently
311    *     the only one supported.
312    */
313   void loadCertificate(string path, string format = "PEM") {
314     enforce(path !is null && format !is null, new TTransportException(
315       "loadCertificateChain: either <path> or <format> is null",
316       TTransportException.Type.BAD_ARGS));
317 
318     if (format == "PEM") {
319       enforce(SSL_CTX_use_certificate_chain_file(ctx_, toStringz(path)),
320         getSSLException(
321           `Could not load SSL server certificate from file "` ~ path ~ `"`
322         )
323       );
324     } else {
325       throw new TSSLException("Unsupported certificate format: " ~ format);
326     }
327   }
328 
329   /*
330    * Load private key.
331    *
332    * Params:
333    *   path = Path to the certificate file.
334    *   format = Private key file format. Defaults to PEM, which is currently
335    *     the only one supported.
336    */
337   void loadPrivateKey(string path, string format = "PEM") {
338     enforce(path !is null && format !is null, new TTransportException(
339       "loadPrivateKey: either <path> or <format> is NULL",
340       TTransportException.Type.BAD_ARGS));
341 
342     if (format == "PEM") {
343       enforce(SSL_CTX_use_PrivateKey_file(ctx_, toStringz(path), SSL_FILETYPE_PEM),
344         getSSLException(
345           `Could not load SSL private key from file "` ~ path ~ `"`
346         )
347       );
348     } else {
349       throw new TSSLException("Unsupported certificate format: " ~ format);
350     }
351   }
352 
353   /**
354    * Load trusted certificates from specified file (in PEM format).
355    *
356    * Params.
357    *   path = Path to the file containing the trusted certificates.
358    */
359   void loadTrustedCertificates(string path) {
360     enforce(path !is null, new TTransportException(
361       "loadTrustedCertificates: <path> is NULL",
362       TTransportException.Type.BAD_ARGS));
363 
364     enforce(SSL_CTX_load_verify_locations(ctx_, toStringz(path), null),
365       getSSLException(
366         `Could not load SSL trusted certificate list from file "` ~ path ~ `"`
367       )
368     );
369   }
370 
371   /**
372    * Called during OpenSSL initialization to seed the OpenSSL entropy pool.
373    *
374    * Defaults to simply calling RAND_poll(), but it can be overwritten if a
375    * different, perhaps more secure implementation is desired.
376    */
377   void randomize() {
378     RAND_poll();
379   }
380 
381   /**
382    * Whether to use client or server side SSL handshake protocol.
383    */
384   bool serverSide() @property const {
385     return serverSide_;
386   }
387 
388   /// Ditto
389   void serverSide(bool value) @property {
390     serverSide_ = value;
391   }
392 
393   /**
394    * The access manager to use.
395    */
396   TAccessManager accessManager() @property {
397     if (!serverSide_ && !accessManager_) {
398       accessManager_ = new TDefaultClientAccessManager;
399     }
400     return accessManager_;
401   }
402 
403   /// Ditto
404   void accessManager(TAccessManager value) @property {
405     accessManager_ = value;
406   }
407 
408   SSL* createSSL() out (result) {
409     assert(result);
410   } body {
411     auto result = SSL_new(ctx_);
412     enforce(result, getSSLException("SSL_new"));
413     return result;
414   }
415 
416 protected:
417   /**
418    * Override this method for custom password callback. It may be called
419    * multiple times at any time during a session as necessary.
420    *
421    * Params:
422    *   size = Maximum length of password, including null byte.
423    */
424   string getPassword(int size) nothrow out(result) {
425     assert(result.length < size);
426   } body {
427     return "";
428   }
429 
430   /**
431    * Notifies OpenSSL to use getPassword() instead of the default password
432    * callback with getPassword().
433    */
434   void overrideDefaultPasswordCallback() {
435     SSL_CTX_set_default_passwd_cb(ctx_, &passwordCallback);
436     SSL_CTX_set_default_passwd_cb_userdata(ctx_, cast(void*)this);
437   }
438 
439   SSL_CTX* ctx_;
440 
441 private:
442   bool serverSide_;
443   TAccessManager accessManager_;
444 
445   shared static this() {
446     initMutex_ = new Mutex();
447   }
448 
449   static void initializeOpenSSL() {
450     if (initialized_) {
451       return;
452     }
453     initialized_ = true;
454 
455    static if (OPENSSL_VERSION_NUMBER < 0x1010000f) { // OPENSSL_VERSION_BEFORE(1, 1)) {
456     SSL_library_init();
457     SSL_load_error_strings();
458 
459     mutexes_ = new Mutex[CRYPTO_num_locks()];
460     foreach (ref m; mutexes_) {
461       m = new Mutex;
462     }
463 
464     import thrift.internal.traits;
465     // As per the OpenSSL threads manpage, this isn't needed on Windows.
466     version (Posix) {
467       CRYPTO_set_id_callback(assumeNothrow(&threadIdCallback));
468     }
469     CRYPTO_set_locking_callback(assumeNothrow(&lockingCallback));
470     CRYPTO_set_dynlock_create_callback(assumeNothrow(&dynlockCreateCallback));
471     CRYPTO_set_dynlock_lock_callback(assumeNothrow(&dynlockLockCallback));
472     CRYPTO_set_dynlock_destroy_callback(assumeNothrow(&dynlockDestroyCallback));
473    }
474   }
475 
476   static void cleanupOpenSSL() {
477     if (!initialized_) return;
478 
479     initialized_ = false;
480    static if (OPENSSL_VERSION_NUMBER < 0x1010000f) { // OPENSSL_VERSION_BEFORE(1, 1)) {
481     CRYPTO_set_locking_callback(null);
482     CRYPTO_set_dynlock_create_callback(null);
483     CRYPTO_set_dynlock_lock_callback(null);
484     CRYPTO_set_dynlock_destroy_callback(null);
485     CRYPTO_cleanup_all_ex_data();
486     ERR_free_strings();
487     ERR_remove_state(0);
488    }
489   }
490 
491   static extern(C) {
492     version (Posix) {
493       import core.sys.posix.pthread : pthread_self;
494       c_ulong threadIdCallback() {
495         return cast(c_ulong)pthread_self();
496       }
497     }
498 
499     void lockingCallback(int mode, int n, const(char)* file, int line) {
500       if (mode & CRYPTO_LOCK) {
501         mutexes_[n].lock();
502       } else {
503         mutexes_[n].unlock();
504       }
505     }
506 
507     CRYPTO_dynlock_value* dynlockCreateCallback(const(char)* file, int line) {
508       enum size =  __traits(classInstanceSize, Mutex);
509       auto mem = malloc(size)[0 .. size];
510       if (!mem) onOutOfMemoryError();
511       GC.addRange(mem.ptr, size);
512       auto mutex = emplace!Mutex(mem);
513       return cast(CRYPTO_dynlock_value*)mutex;
514     }
515 
516     void dynlockLockCallback(int mode, CRYPTO_dynlock_value* l,
517       const(char)* file, int line)
518     {
519       if (l is null) return;
520       if (mode & CRYPTO_LOCK) {
521         (cast(Mutex)l).lock();
522       } else {
523         (cast(Mutex)l).unlock();
524       }
525     }
526 
527     void dynlockDestroyCallback(CRYPTO_dynlock_value* l,
528       const(char)* file, int line)
529     {
530       GC.removeRange(l);
531       destroy(cast(Mutex)l);
532       free(l);
533     }
534 
535     int passwordCallback(char* password, int size, int, void* data) nothrow {
536       auto context = cast(TSSLContext) data;
537       auto userPassword = context.getPassword(size);
538       auto len = userPassword.length;
539       if (len > size) {
540         len = size;
541       }
542       password[0 .. len] = userPassword[0 .. len]; // TODO: \0 handling correct?
543       return cast(int)len;
544     }
545   }
546 
547   static __gshared bool initialized_;
548   static __gshared Mutex initMutex_;
549   static __gshared Mutex[] mutexes_;
550   static __gshared uint count_;
551 }
552 
553 /**
554  * Decides whether a remote host is legitimate or not.
555  *
556  * It is usually set at a TSSLContext, which then passes it to all the created
557  * TSSLSockets.
558  */
559 class TAccessManager {
560   ///
561   enum Decision {
562     DENY = -1, /// Deny access.
563     SKIP =  0, /// Cannot decide, move on to next check (deny if last).
564     ALLOW = 1  /// Allow access.
565   }
566 
567   /**
568    * Determines whether a peer should be granted access or not based on its
569    * IP address.
570    *
571    * Called once after SSL handshake is completes successfully and before peer
572    * certificate is examined.
573    *
574    * If a valid decision (ALLOW or DENY) is returned, the peer certificate
575    * will not be verified.
576    */
577   Decision verify(Address address) {
578     return Decision.DENY;
579   }
580 
581   /**
582    * Determines whether a peer should be granted access or not based on a
583    * name from its certificate.
584    *
585    * Called every time a DNS subjectAltName/common name is extracted from the
586    * peer's certificate.
587    *
588    * Params:
589    *   host = The actual host name string from the socket connection.
590    *   certHost = A host name string from the certificate.
591    */
592   Decision verify(string host, const(char)[] certHost) {
593     return Decision.DENY;
594   }
595 
596   /**
597    * Determines whether a peer should be granted access or not based on an IP
598    * address from its certificate.
599    *
600    * Called every time an IP subjectAltName is extracted from the peer's
601    * certificate.
602    *
603    * Params:
604    *   address = The actual address from the socket connection.
605    *   certHost = A host name string from the certificate.
606    */
607   Decision verify(Address address, ubyte[] certAddress) {
608     return Decision.DENY;
609   }
610 }
611 
612 /**
613  * Default access manager implementation, which just checks the host name
614  * resp. IP address of the connection against the certificate.
615  */
616 class TDefaultClientAccessManager : TAccessManager {
617   override Decision verify(Address address) {
618     return Decision.SKIP;
619   }
620 
621   override Decision verify(string host, const(char)[] certHost) {
622     if (host.empty || certHost.empty) {
623       return Decision.SKIP;
624     }
625     return (matchName(host, certHost) ? Decision.ALLOW : Decision.SKIP);
626   }
627 
628   override Decision verify(Address address, ubyte[] certAddress) {
629     bool match;
630     if (certAddress.length == 4) {
631       if (auto ia = cast(InternetAddress)address) {
632         match = ((cast(ubyte*)ia.addr())[0 .. 4] == certAddress[]);
633       }
634     } else if (certAddress.length == 16) {
635       if (auto ia = cast(Internet6Address)address) {
636         match = (ia.addr() == certAddress[]);
637       }
638     }
639     return (match ? Decision.ALLOW : Decision.SKIP);
640   }
641 }
642 
643 private {
644   /**
645    * Matches a name with a pattern. The pattern may include wildcard. A single
646    * wildcard "*" can match up to one component in the domain name.
647    *
648    * Params:
649    *   host = Host name to match, typically the SSL remote peer.
650    *   pattern = Host name pattern, typically from the SSL certificate.
651    *
652    * Returns: true if host matches pattern, false otherwise.
653    */
654   bool matchName(const(char)[] host, const(char)[] pattern) {
655     while (!host.empty && !pattern.empty) {
656       if (toUpper(pattern.front) == toUpper(host.front)) {
657         host.popFront;
658         pattern.popFront;
659       } else if (pattern.front == '*') {
660         while (!host.empty && host.front != '.') {
661           host.popFront;
662         }
663         pattern.popFront;
664       } else {
665         break;
666       }
667     }
668     return (host.empty && pattern.empty);
669   }
670 
671   unittest {
672     enforce(matchName("thrift.apache.org", "*.apache.org"));
673     enforce(!matchName("thrift.apache.org", "apache.org"));
674     enforce(matchName("thrift.apache.org", "thrift.*.*"));
675     enforce(matchName("", ""));
676     enforce(!matchName("", "*"));
677   }
678 }
679 
680 /**
681  * SSL-level exception.
682  */
683 class TSSLException : TTransportException {
684   ///
685   this(string msg, string file = __FILE__, size_t line = __LINE__,
686     Throwable next = null)
687   {
688     super(msg, TTransportException.Type.INTERNAL_ERROR, file, line, next);
689   }
690 }