diff options
author | David Benjamin <davidben@google.com> | 2016-02-11 20:02:01 +0300 |
---|---|---|
committer | Adam Langley <agl@google.com> | 2016-02-18 01:05:29 +0300 |
commit | de9423821705341a2be0ae98e8aab8c3a06887a3 (patch) | |
tree | c816280f3273c55ba1d637a5ac907b4f245a9432 /ssl | |
parent | d3a49953d884d25eee0af77030ec205ef0e0a6a1 (diff) |
Fix SSL_get_{read,write}_sequence.
I switched up the endianness. Add some tests to make sure those work right.
Also tweak the DTLS semantics. SSL_get_read_sequence should return the highest
sequence number received so far. Include the epoch number in both so we don't
need a second API for it.
Change-Id: I9901a1665b41224c46fadb7ce0b0881dcb466bcc
Reviewed-on: https://boringssl-review.googlesource.com/7141
Reviewed-by: Adam Langley <agl@google.com>
Diffstat (limited to 'ssl')
-rw-r--r-- | ssl/ssl_lib.c | 20 | ||||
-rw-r--r-- | ssl/ssl_test.cc | 158 |
2 files changed, 172 insertions, 6 deletions
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index 542dc17e..0a3c8f71 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c @@ -2545,19 +2545,29 @@ int SSL_get_ivs(const SSL *ssl, const uint8_t **out_read_iv, } static uint64_t be_to_u64(const uint8_t in[8]) { - return (((uint64_t)in[7]) << 56) | (((uint64_t)in[6]) << 48) | - (((uint64_t)in[5]) << 40) | (((uint64_t)in[4]) << 32) | - (((uint64_t)in[3]) << 24) | (((uint64_t)in[2]) << 16) | - (((uint64_t)in[1]) << 8) | ((uint64_t)in[0]); + return (((uint64_t)in[0]) << 56) | (((uint64_t)in[1]) << 48) | + (((uint64_t)in[2]) << 40) | (((uint64_t)in[3]) << 32) | + (((uint64_t)in[4]) << 24) | (((uint64_t)in[5]) << 16) | + (((uint64_t)in[6]) << 8) | ((uint64_t)in[7]); } uint64_t SSL_get_read_sequence(const SSL *ssl) { /* TODO(davidben): Internally represent sequence numbers as uint64_t. */ + if (SSL_IS_DTLS(ssl)) { + /* max_seq_num already includes the epoch. */ + assert(ssl->d1->r_epoch == (ssl->d1->bitmap.max_seq_num >> 48)); + return ssl->d1->bitmap.max_seq_num; + } return be_to_u64(ssl->s3->read_sequence); } uint64_t SSL_get_write_sequence(const SSL *ssl) { - return be_to_u64(ssl->s3->write_sequence); + uint64_t ret = be_to_u64(ssl->s3->write_sequence); + if (SSL_IS_DTLS(ssl)) { + assert((ret >> 48) == 0); + ret |= ((uint64_t)ssl->d1->w_epoch) << 48; + } + return ret; } uint8_t SSL_get_server_key_exchange_hash(const SSL *ssl) { diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index f91542dc..ed1d3c7d 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc @@ -25,7 +25,9 @@ #include <openssl/bio.h> #include <openssl/crypto.h> #include <openssl/err.h> +#include <openssl/pem.h> #include <openssl/ssl.h> +#include <openssl/x509.h> #include "test/scoped_types.h" #include "../crypto/test/test_util.h" @@ -978,6 +980,158 @@ static bool TestInternalSessionCache() { return true; } +static uint16_t EpochFromSequence(uint64_t seq) { + return static_cast<uint16_t>(seq >> 48); +} + +static ScopedX509 GetTestCertificate() { + static const char kCertPEM[] = + "-----BEGIN CERTIFICATE-----\n" + "MIICWDCCAcGgAwIBAgIJAPuwTC6rEJsMMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV\n" + "BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX\n" + "aWRnaXRzIFB0eSBMdGQwHhcNMTQwNDIzMjA1MDQwWhcNMTcwNDIyMjA1MDQwWjBF\n" + "MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50\n" + "ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB\n" + "gQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92kWdGMdAQhLci\n" + "HnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiFKKAnHmUcrgfV\n" + "W28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQABo1AwTjAdBgNV\n" + "HQ4EFgQUi3XVrMsIvg4fZbf6Vr5sp3Xaha8wHwYDVR0jBBgwFoAUi3XVrMsIvg4f\n" + "Zbf6Vr5sp3Xaha8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOBgQA76Hht\n" + "ldY9avcTGSwbwoiuIqv0jTL1fHFnzy3RHMLDh+Lpvolc5DSrSJHCP5WuK0eeJXhr\n" + "T5oQpHL9z/cCDLAKCKRa4uV0fhEdOWBqyR9p8y5jJtye72t6CuFUV5iqcpF4BH4f\n" + "j2VNHwsSrJwkD4QUGlUtH7vwnQmyCFxZMmWAJg==\n" + "-----END CERTIFICATE-----\n"; + ScopedBIO bio( + BIO_new_mem_buf(const_cast<char *>(kCertPEM), strlen(kCertPEM))); + return ScopedX509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); +} + +static ScopedEVP_PKEY GetTestKey() { + static const char kKeyPEM[] = + "-----BEGIN RSA PRIVATE KEY-----\n" + "MIICXgIBAAKBgQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92\n" + "kWdGMdAQhLciHnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiF\n" + "KKAnHmUcrgfVW28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQAB\n" + "AoGBAIBy09Fd4DOq/Ijp8HeKuCMKTHqTW1xGHshLQ6jwVV2vWZIn9aIgmDsvkjCe\n" + "i6ssZvnbjVcwzSoByhjN8ZCf/i15HECWDFFh6gt0P5z0MnChwzZmvatV/FXCT0j+\n" + "WmGNB/gkehKjGXLLcjTb6dRYVJSCZhVuOLLcbWIV10gggJQBAkEA8S8sGe4ezyyZ\n" + "m4e9r95g6s43kPqtj5rewTsUxt+2n4eVodD+ZUlCULWVNAFLkYRTBCASlSrm9Xhj\n" + "QpmWAHJUkQJBAOVzQdFUaewLtdOJoPCtpYoY1zd22eae8TQEmpGOR11L6kbxLQsk\n" + "aMly/DOnOaa82tqAGTdqDEZgSNmCeKKknmECQAvpnY8GUOVAubGR6c+W90iBuQLj\n" + "LtFp/9ihd2w/PoDwrHZaoUYVcT4VSfJQog/k7kjE4MYXYWL8eEKg3WTWQNECQQDk\n" + "104Wi91Umd1PzF0ijd2jXOERJU1wEKe6XLkYYNHWQAe5l4J4MWj9OdxFXAxIuuR/\n" + "tfDwbqkta4xcux67//khAkEAvvRXLHTaa6VFzTaiiO8SaFsHV3lQyXOtMrBpB5jd\n" + "moZWgjHvB2W9Ckn7sDqsPB+U2tyX0joDdQEyuiMECDY8oQ==\n" + "-----END RSA PRIVATE KEY-----\n"; + ScopedBIO bio(BIO_new_mem_buf(const_cast<char *>(kKeyPEM), strlen(kKeyPEM))); + return ScopedEVP_PKEY( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); +} + +static bool TestSequenceNumber(bool dtls) { + ScopedSSL_CTX client_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); + ScopedSSL_CTX server_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); + if (!client_ctx || !server_ctx) { + return false; + } + + ScopedX509 cert = GetTestCertificate(); + ScopedEVP_PKEY key = GetTestKey(); + if (!cert || !key || + !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { + return false; + } + + // Create a client and server connected to each other. + ScopedSSL client(SSL_new(client_ctx.get())), server(SSL_new(server_ctx.get())); + if (!client || !server) { + return false; + } + SSL_set_connect_state(client.get()); + SSL_set_accept_state(server.get()); + + BIO *bio1, *bio2; + if (!BIO_new_bio_pair(&bio1, 0, &bio2, 0)) { + return false; + } + // SSL_set_bio takes ownership. + SSL_set_bio(client.get(), bio1, bio1); + SSL_set_bio(server.get(), bio2, bio2); + + // Drive both their handshakes to completion. + for (;;) { + int client_ret = SSL_do_handshake(client.get()); + int client_err = SSL_get_error(client.get(), client_ret); + if (client_err != SSL_ERROR_NONE && + client_err != SSL_ERROR_WANT_READ && + client_err != SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "Client error: %d\n", client_err); + return false; + } + + int server_ret = SSL_do_handshake(server.get()); + int server_err = SSL_get_error(server.get(), server_ret); + if (server_err != SSL_ERROR_NONE && + server_err != SSL_ERROR_WANT_READ && + server_err != SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "Server error: %d\n", server_err); + return false; + } + + if (client_ret == 1 && server_ret == 1) { + break; + } + } + + uint64_t client_read_seq = SSL_get_read_sequence(client.get()); + uint64_t client_write_seq = SSL_get_write_sequence(client.get()); + uint64_t server_read_seq = SSL_get_read_sequence(server.get()); + uint64_t server_write_seq = SSL_get_write_sequence(server.get()); + + if (dtls) { + // Both client and server must be at epoch 1. + if (EpochFromSequence(client_read_seq) != 1 || + EpochFromSequence(client_write_seq) != 1 || + EpochFromSequence(server_read_seq) != 1 || + EpochFromSequence(server_write_seq) != 1) { + fprintf(stderr, "Bad epochs.\n"); + return false; + } + + // The next record to be written should exceed the largest received. + if (client_write_seq <= server_read_seq || + server_write_seq <= client_read_seq) { + fprintf(stderr, "Inconsistent sequence numbers.\n"); + return false; + } + } else { + // The next record to be written should equal the next to be received. + if (client_write_seq != server_read_seq || + server_write_seq != client_write_seq) { + fprintf(stderr, "Inconsistent sequence numbers.\n"); + return false; + } + } + + // Send a record from client to server. + uint8_t byte = 0; + if (SSL_write(client.get(), &byte, 1) != 1 || + SSL_read(server.get(), &byte, 1) != 1) { + fprintf(stderr, "Could not send byte.\n"); + return false; + } + + // The client write and server read sequence numbers should have incremented. + if (client_write_seq + 1 != SSL_get_write_sequence(client.get()) || + server_read_seq + 1 != SSL_get_read_sequence(server.get())) { + fprintf(stderr, "Sequence numbers did not increment.\n");\ + return false; + } + + return true; +} + int main() { CRYPTO_library_init(); @@ -999,7 +1153,9 @@ int main() { !TestCipherGetRFCName() || !TestPaddingExtension() || !TestClientCAList() || - !TestInternalSessionCache()) { + !TestInternalSessionCache() || + !TestSequenceNumber(false /* TLS */) || + !TestSequenceNumber(true /* DTLS */)) { ERR_print_errors_fp(stderr); return 1; } |