diff options
Diffstat (limited to 'include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp')
-rw-r--r-- | include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp | 671 |
1 files changed, 397 insertions, 274 deletions
diff --git a/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp b/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp index e9a86f98..ffb72949 100644 --- a/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp +++ b/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp @@ -1,5 +1,5 @@ /* A TLS socket source based on OpenSSL -(C) 2021-2021 Niall Douglas <http://www.nedproductions.biz/> (20 commits) +(C) 2021-2022 Niall Douglas <http://www.nedproductions.biz/> (20 commits) File Created: Dec 2021 @@ -29,11 +29,17 @@ Distributed under the Boost Software License, Version 1.0. #include <openssl/crypto.h> #include <openssl/err.h> #include <openssl/ssl.h> +#include <openssl/x509.h> #if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING #include <iostream> #endif +#ifdef _WIN32 +#include <cryptuiapi.h> +#pragma comment(lib, "cryptui.lib") +#endif + LLFIO_V2_NAMESPACE_BEGIN namespace detail @@ -101,7 +107,7 @@ namespace detail PSK-AES128-CCM */ static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA:@SECLEVEL=0"; - //static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:!aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA"; + // static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:!aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA"; /* This is the list my OpenSSL v1.1 offers in HELLO for this spec (48): TLS_AES_256_GCM_SHA384 @@ -246,6 +252,64 @@ namespace detail constexpr openssl_error_domain openssl_error_domain_inst; inline constexpr const openssl_error_domain &openssl_error_domain::get() { return openssl_error_domain_inst; } using openssl_code = OUTCOME_V2_NAMESPACE::experimental::status_code<openssl_error_domain>; + + struct x509_error_domain final : public OUTCOME_V2_NAMESPACE::experimental::status_code_domain + { + using value_type = int; + + constexpr x509_error_domain() + : OUTCOME_V2_NAMESPACE::experimental::status_code_domain("{bea47e79-6787-6009-7ac7-ba8616575312}") + { + } + + virtual string_ref name() const noexcept override { return string_ref("x509"); } + virtual payload_info_t payload_info() const noexcept override + { + return {sizeof(value_type), sizeof(status_code_domain *) + sizeof(value_type), + (alignof(value_type) > alignof(status_code_domain *)) ? alignof(value_type) : alignof(status_code_domain *)}; + } + static inline constexpr const x509_error_domain &get(); + + virtual bool _do_failure(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const noexcept override + { + auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code); + return c.value() != 0; + } + virtual bool _do_equivalent(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code1, + const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code2) const noexcept override + { + assert(code1.domain() == *this); // NOLINT + const auto &c1 = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code1); // NOLINT + if(code2.domain() == *this) + { + const auto &c2 = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code2); // NOLINT + return c1.value() == c2.value(); + } + return false; + } + virtual OUTCOME_V2_NAMESPACE::experimental::generic_code + _generic_code(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &) const noexcept override + { + return errc::unknown; + } + virtual string_ref _do_message(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const noexcept override + { + auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code); + if(c.value() == 0) + { + return string_ref("not an error"); + } + return string_ref(X509_verify_cert_error_string(c.value())); + } + SYSTEM_ERROR2_NORETURN virtual void _do_throw_exception(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const override + { + auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code); + throw OUTCOME_V2_NAMESPACE::experimental::status_error<x509_error_domain>(c); + } + }; + constexpr x509_error_domain x509_error_domain_inst; + inline constexpr const x509_error_domain &x509_error_domain::get() { return x509_error_domain_inst; } + using x509_code = OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain>; } // namespace detail template <class T> inline result<void> openssl_error(T *inst, unsigned long errcode = ERR_get_error()) { @@ -284,6 +348,16 @@ inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_ge assert(ret.failure()); return ret; } +inline result<void> x509_error(int errcode) +{ + detail::x509_code ret(errcode); +#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING + std::lock_guard<std::mutex> g(detail::openssl_printing_lock); + std::cerr << "X509 error: " << ret.message().c_str() << std::endl; +#endif + assert(ret.failure()); + return ret; +} #else namespace detail { @@ -321,6 +395,17 @@ namespace detail return {buffer}; } }; + struct x509_error_category final : public std::error_category + { + virtual const char *name() const noexcept override { return "x509"; } + virtual std::error_condition default_error_condition(int code) const noexcept override { return {code, *this}; } + virtual std::string message(int code) const override + { + if(code == 0) + return "not an error"; + return X509_verify_cert_error_string(code); + } + }; } // namespace detail template <class T> inline result<void> openssl_error(T *inst, unsigned long errcode = ERR_get_error()) { @@ -333,7 +418,12 @@ template <class T> inline result<void> openssl_error(T *inst, unsigned long errc return (ERR_GET_REASON(errcode) == 2) ? std::move(inst->_write_error) : std::move(inst->_read_error); } static detail::openssl_error_category cat; - return error_info(std::error_code((int) errcode, cat)); + error_info ret(std::error_code((int) errcode, cat)); +#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING + std::lock_guard<std::mutex> g(detail::openssl_printing_lock); + std::cout << "ERROR: " << ret.message() << std::endl; +#endif + return ret; } inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_get_error()) { @@ -344,6 +434,11 @@ inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_ge static detail::openssl_error_category cat; return error_info(std::error_code((int) errcode, cat)); } +inline result<void> x509_error(int errcode) +{ + static detail::x509_error_category cat; + return error_info(std::error_code((int) errcode, cat)); +} #endif namespace detail @@ -396,8 +491,14 @@ namespace detail static struct openssl_default_ctxs_t { SSL_CTX *unverified{nullptr}, *verified{nullptr}; + X509_STORE *certstore{nullptr}; ~openssl_default_ctxs_t() { + if(certstore != nullptr) + { + X509_STORE_free(certstore); + certstore = nullptr; + } if(verified != nullptr) { SSL_CTX_free(verified); @@ -417,7 +518,48 @@ namespace detail QUICKCPPLIB_NAMESPACE::configurable_spinlock::lock_guard<QUICKCPPLIB_NAMESPACE::configurable_spinlock::spinlock<unsigned>> g(lock); if(verified == nullptr) { - auto make_ctx = [](bool verify_peer) -> result<SSL_CTX *> +#ifdef _WIN32 + // Create an OpenSSL certificate store made up of the certs from the Windows certificate store + HCERTSTORE winstore = CertOpenSystemStoreW(NULL, L"ROOT"); + if(!winstore) + { + return win32_error(); + } + auto unwinstore = make_scope_exit([&]() noexcept { CertCloseStore(winstore, 0); }); + certstore = X509_STORE_new(); + if(!certstore) + { + return openssl_error(nullptr); + } + PCCERT_CONTEXT context = nullptr; + auto uncontext = make_scope_exit( + [&]() noexcept + { + if(context != nullptr) + { + CertFreeCertificateContext(context); + } + }); + while((context = CertEnumCertificatesInStore(winstore, context)) != nullptr) + { + const unsigned char *in = (const unsigned char *) context->pbCertEncoded; + X509 *x509 = d2i_X509(nullptr, &in, context->cbCertEncoded); + if(!x509) + { + return openssl_error(nullptr); + } + auto unx509 = make_scope_exit([&]() noexcept { X509_free(x509); }); + //{ + // X509_NAME_print_ex_fp(stdout, X509_get_issuer_name(x509), 3, 0); + // printf("\n"); + //} + if(X509_STORE_add_cert(certstore, x509) <= 0) + { + return openssl_error(nullptr); + } + } +#endif + auto make_ctx = [certstore = certstore](bool verify_peer) -> result<SSL_CTX *> { SSL_CTX *_ctx = SSL_CTX_new(TLS_method()); if(_ctx == nullptr) @@ -434,13 +576,17 @@ namespace detail } else { + if(certstore != nullptr) + { + SSL_CTX_set1_cert_store(_ctx, certstore); + } if(!SSL_CTX_set_cipher_list(_ctx, openssl_verified_cipher_list)) { return openssl_error(nullptr).as_failure(); } SSL_CTX_set_verify(_ctx, SSL_VERIFY_PEER, nullptr); SSL_CTX_set_verify_depth(_ctx, 4); - if(!SSL_CTX_set_default_verify_paths(_ctx)) + if(SSL_CTX_set_default_verify_paths(_ctx) <= 0) { return openssl_error(nullptr).as_failure(); } @@ -471,16 +617,33 @@ class openssl_socket_handle final : public tls_socket_handle optional<filesystem::path> _authentication_certificates_path; std::string _connect_hostname_port; + /* We use a registered buffer from the underlying transport for reads only, but not writes. + The reason why not is that from my best reading of the implementation source code, + OpenSSL doesn't seem to allow fixed size write buffers (as according to the documentation + for BIO_set_mem_buf), plus OpenSSL seems to treat failure to allocate memory as an abort + situation rather than a retry situation, so there seems to me that is no way of backpressuring + a fixed size buffer in OpenSSL. + + Furthermore, the documentation for BUF_MEM suggests that the buffer must be a single + contiguous region of memory, so we can't use a list of multiple registered + buffers either. If your TLS implementation were better designed, it should be possible to + use registered buffers for both reads and writes, and that then reduces CPU cache loading + on a very busy server. + + Be aware that due to this design flaw in OpenSSL, it is possible to buffer writes to + infinity i.e. it never returns a partial write, it always accepts fully every write. We + therefore have to emulate backpressure below. + */ std::mutex _lock; std::unique_lock<std::mutex> _lock_holder{_lock, std::defer_lock}; - uint16_t _read_buffer_source_idx{0}, _write_buffer_source_idx{0}; - uint16_t _read_buffer_sink_idx{0}, _write_buffer_sink_idx{0}; + uint16_t _read_buffer_source_idx{0}, _read_buffer_sink_idx{0}; + bool _write_socket_full{false}; + uint8_t _still_connecting{0}; deadline _read_deadline, _write_deadline; result<void> _read_error{success()}, _write_error{success()}; std::chrono::steady_clock::time_point _read_deadline_began_steady, _write_deadline_began_steady; - byte_socket_handle::registered_buffer_type _read_buffers[BUFFERS_COUNT]{}, _write_buffers[BUFFERS_COUNT]{}; + byte_socket_handle::registered_buffer_type _read_buffers[BUFFERS_COUNT]{}; byte_socket_handle::buffer_type _read_buffers_valid[BUFFERS_COUNT]{}; - byte_socket_handle::const_buffer_type _write_buffers_valid[BUFFERS_COUNT]{}; // Front of the queue std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::buffer_type *> _toread_source() noexcept @@ -499,23 +662,6 @@ class openssl_socket_handle final : public tls_socket_handle } return {&_read_buffers[idx], &_read_buffers_valid[idx]}; } - // Front of the queue - std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::const_buffer_type *> _towrite_source() noexcept - { - return {&_write_buffers[_write_buffer_source_idx % BUFFERS_COUNT], &_write_buffers_valid[_write_buffer_source_idx % BUFFERS_COUNT]}; - } - bool _towrite_source_empty() const noexcept { return _write_buffers_valid[_write_buffer_source_idx % BUFFERS_COUNT].empty(); } - // Back of the queue. Can return "full" - std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::const_buffer_type *> _towrite_sink() noexcept - { - const auto idx = _write_buffer_sink_idx % BUFFERS_COUNT; - if(idx == (_write_buffer_source_idx % BUFFERS_COUNT) && - (_write_buffers_valid[idx].data() + _write_buffers_valid[idx].size()) == (_write_buffers[idx]->data() + _write_buffers[idx]->size())) - { - return {nullptr, nullptr}; // full - } - return {&_write_buffers[idx], &_write_buffers_valid[idx]}; - } #undef LLFIO_OPENSSL_DISPATCH #define LLFIO_OPENSSL_DISPATCH(functp, functt, ...) \ @@ -524,6 +670,11 @@ class openssl_socket_handle final : public tls_socket_handle protected: virtual size_t _do_max_buffers() const noexcept override { return 1; /* There is no atomicity of scatter gather i/o at all! */ } + virtual result<registered_buffer_type> _do_allocate_registered_buffer(size_t &bytes) noexcept override + { + LLFIO_LOG_FUNCTION_CALL(this); + return LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, _do_allocate_registered_buffer, (bytes)); + } virtual io_result<buffers_type> _do_read(io_request<buffers_type> reqs, deadline d) noexcept override { LLFIO_LOG_FUNCTION_CALL(this); @@ -532,10 +683,21 @@ protected: return errc::not_supported; } _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); + if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0) + { + return errc::not_connected; + } + LLFIO_DEADLINE_TO_SLEEP_INIT(d); if(d) { - LLFIO_DEADLINE_TO_SLEEP_INIT(d); _read_deadline_began_steady = began_steady; _read_deadline = d; } @@ -546,7 +708,9 @@ protected: for(size_t n = 0; n < reqs.buffers.size(); n++) { size_t read = 0; - auto res = BIO_read_ex(_ssl_bio, reqs.buffers[n].data(), reqs.buffers[n].size(), &read); + // OpenSSL early outs if buf is ever null + byte dummy{}, *buf = reqs.buffers[n].empty() ? &dummy : reqs.buffers[n].data(); + auto res = BIO_read_ex(_ssl_bio, buf, reqs.buffers[n].size(), &read); if(res <= 0) { auto errcode = ERR_get_error(); @@ -560,11 +724,16 @@ protected: { return errc::operation_would_block; } - return openssl_error(this).as_failure(); + return openssl_error(this, errcode).as_failure(); } if(read < reqs.buffers[n].size()) { reqs.buffers[n] = {reqs.buffers[n].data(), read}; + if(n == 0 && read == 0) + { + reqs.buffers = {reqs.buffers.data(), n}; + return std::move(reqs.buffers); + } reqs.buffers = {reqs.buffers.data(), n + 1}; break; } @@ -579,7 +748,18 @@ protected: return errc::not_supported; } _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); + if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0) + { + return errc::not_connected; + } LLFIO_DEADLINE_TO_SLEEP_INIT(d); if(d) { @@ -590,6 +770,29 @@ protected: { _write_deadline = {}; } + // OpenSSL will accept new writes forever, so we need to emulate write backpressure + if(_write_socket_full) + { + // Write nothing new to OpenSSL, should cause _bwrite() to get called which will check + // if the socket's write buffers have drained + size_t written = 0; + auto res = BIO_write_ex(_ssl_bio, nullptr, 0, &written); + if(res <= 0) + { + auto errcode = ERR_get_error(); + if(BIO_should_retry(_ssl_bio)) + { + return errc::operation_would_block; + } + return openssl_error(this, errcode).as_failure(); + } + if(_write_socket_full) + { + // Return no buffers written + reqs.buffers = {reqs.buffers.data(), size_t(0)}; + return std::move(reqs.buffers); + } + } for(size_t n = 0; n < reqs.buffers.size(); n++) { size_t written = 0; @@ -607,11 +810,16 @@ protected: { return errc::operation_would_block; } - return openssl_error(this).as_failure(); + return openssl_error(this, errcode).as_failure(); } if(written < reqs.buffers[n].size()) { reqs.buffers[n] = {reqs.buffers[n].data(), written}; + if(n == 0 && written == 0) + { + reqs.buffers = {reqs.buffers.data(), n}; + return std::move(reqs.buffers); + } reqs.buffers = {reqs.buffers.data(), n + 1}; break; } @@ -633,47 +841,64 @@ protected: return errc::not_supported; } _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); - LLFIO_DEADLINE_TO_SLEEP_INIT(d); + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); if(_ssl_bio == nullptr) { - OUTCOME_TRY(_init(true, _authentication_certificates_path && !_authentication_certificates_path->empty())); + OUTCOME_TRY(_init(true, _authentication_certificates_path)); } - if(!(_v.behaviour & native_handle_type::disposition::_is_connected)) + if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0) { + LLFIO_DEADLINE_TO_SLEEP_INIT(d); OUTCOME_TRY(LLFIO_OPENSSL_DISPATCH(connect, _do_connect, (addr, d))); - if(d) - { - _read_deadline_began_steady = began_steady; - _write_deadline_began_steady = began_steady; - _read_deadline = d; - _write_deadline = d; - } - else - { - _read_deadline = {}; - _write_deadline = {}; - } - auto res = BIO_do_connect(_ssl_bio); - if(res != 1) + for(;;) { - if(BIO_should_retry(_ssl_bio)) + deadline nd; + LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d); + if(nd) { - return errc::operation_in_progress; + _read_deadline_began_steady = began_steady; + _write_deadline_began_steady = began_steady; + _read_deadline = d; + _write_deadline = d; } - return openssl_error(this).as_failure(); - } - - res = BIO_do_handshake(_ssl_bio); - if(res != 1) - { - if(BIO_should_retry(_ssl_bio)) + else { - return errc::operation_in_progress; + _read_deadline = {}; + _write_deadline = {}; } - return openssl_error(this).as_failure(); + if(_still_connecting < 1) + { + auto res = BIO_do_connect(_ssl_bio); + _still_connecting = 1; + if(res != 1) + { + if(BIO_should_retry(_ssl_bio)) + { + return errc::operation_in_progress; + } + return openssl_error(this).as_failure(); + } + } + { + auto res = BIO_do_handshake(_ssl_bio); + if(res != 1) + { + if(BIO_should_retry(_ssl_bio)) + { + return errc::operation_in_progress; + } + return openssl_error(this).as_failure(); + } + } + break; } - if(!_connect_hostname_port.empty()) { SSL *ssl{nullptr}; @@ -691,13 +916,14 @@ protected: { return openssl_error(this).as_failure(); } - res = SSL_get_verify_result(ssl); + auto res = SSL_get_verify_result(ssl); if(X509_V_OK != res) { - return openssl_error(this).as_failure(); + return x509_error(res).as_failure(); } } _v.behaviour |= native_handle_type::disposition::_is_connected; + _still_connecting = 0; } return success(); } @@ -714,9 +940,11 @@ public: this->_v.behaviour = (sock->native_handle().behaviour & ~(native_handle_type::disposition::kernel_handle)) | native_handle_type::disposition::is_pointer; } - result<void> _init(bool is_client, bool verify_peer) noexcept + result<void> _init(bool is_client, const optional<filesystem::path> &certpath) noexcept { LLFIO_LOG_FUNCTION_CALL(this); + const bool verify_peer = + (is_client && (!certpath.has_value() || (certpath.has_value() && !certpath->empty()))) || (!is_client && (!certpath.has_value() || !certpath->empty())); OUTCOME_TRY(detail::openssl_default_ctxs.init()); assert(_ssl_bio == nullptr); _ssl_bio = BIO_new_ssl(verify_peer ? detail::openssl_default_ctxs.verified : detail::openssl_default_ctxs.unverified, is_client); @@ -724,6 +952,23 @@ public: { return openssl_error(this).as_failure(); } + if(certpath.has_value() && !certpath->empty()) + { + SSL *ssl{nullptr}; + BIO_get_ssl(_ssl_bio, &ssl); + if(ssl == nullptr) + { + return openssl_error(this).as_failure(); + } + if(SSL_use_certificate_file(ssl, certpath->string().c_str(), SSL_FILETYPE_PEM) <= 0) + { + return openssl_error(this).as_failure(); + } + if(SSL_use_PrivateKey_file(ssl, certpath->string().c_str(), SSL_FILETYPE_PEM) <= 0) + { + return openssl_error(this).as_failure(); + } + } _self_bio = BIO_new(detail::openssl_custom_bio.method); if(_self_bio == nullptr) { @@ -740,7 +985,14 @@ public: if(kind == shutdown_write || kind == shutdown_both) { _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); SSL *ssl{nullptr}; BIO_get_ssl(_ssl_bio, &ssl); if(ssl == nullptr) @@ -761,7 +1013,12 @@ public: { return errc::operation_in_progress; } - return openssl_error(this).as_failure(); + if(e == SSL_ERROR_SSL) + { + // Already shut down? + break; + } + return openssl_error(this, e).as_failure(); } // Shutdown is in progress if(this->is_nonblocking()) @@ -792,27 +1049,34 @@ public: { LLFIO_LOG_FUNCTION_CALL(this); _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); OUTCOME_TRY(_flush_towrite({})); + if(_ssl_bio != nullptr) + { + BIO_free_all(_ssl_bio); // also frees _self_bio + _ssl_bio = nullptr; + } if(_v.behaviour & native_handle_type::disposition::is_pointer) { tls_socket_handle::release(); } else { + _lock_holder.unlock(); OUTCOME_TRY(tls_socket_handle::close()); - } - if(_ssl_bio != nullptr) - { - BIO_free_all(_ssl_bio); // also frees _self_bio - _ssl_bio = nullptr; + _lock_holder.lock(); } for(size_t n = 0; n < BUFFERS_COUNT; n++) { _read_buffers_valid[n] = {}; - _write_buffers_valid[n] = {}; _read_buffers[n].reset(); - _write_buffers[n].reset(); } return success(); } @@ -881,8 +1145,15 @@ public: { LLFIO_LOG_FUNCTION_CALL(this); _lock_holder.lock(); - auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); }); - if(!_toread_source_empty() || !_towrite_source_empty()) + auto unlock = make_scope_exit( + [this]() noexcept + { + if(_lock_holder.owns_lock()) + { + _lock_holder.unlock(); + } + }); + if(!_toread_source_empty()) { return errc::device_or_resource_busy; } @@ -893,9 +1164,6 @@ public: auto _bytes = bytes; OUTCOME_TRY(_read_buffers[n], LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, allocate_registered_buffer, (_bytes))); _read_buffers_valid[n] = {_read_buffers[n]->data(), 0}; - _bytes = bytes; - OUTCOME_TRY(_write_buffers[n], LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, allocate_registered_buffer, (_bytes))); - _write_buffers_valid[n] = {_write_buffers[n]->data(), 0}; } } return success(); @@ -934,28 +1202,28 @@ public: _connect_hostname_port.assign(host.data(), host.size()); _connect_hostname_port.push_back(':'); _connect_hostname_port.append(std::to_string(port)); - auto res = BIO_set_conn_hostname(_ssl_bio, _connect_hostname_port.c_str()); - if(res != 1) - { - return openssl_error(this).as_failure(); - } if(_ctx == nullptr) { - OUTCOME_TRY(_init(true, true)); + OUTCOME_TRY(_init(true, _authentication_certificates_path)); } + auto res = BIO_set_conn_hostname(_ssl_bio, _connect_hostname_port.c_str()); + /* if(res != 1) + { + return openssl_error(this).as_failure(); + }*/ SSL *ssl{nullptr}; BIO_get_ssl(_ssl_bio, &ssl); if(ssl == nullptr) { return openssl_error(this).as_failure(); } - auto hostname = _connect_hostname_port.substr(0, _connect_hostname_port.rfind(':')); + std::string hostname(host); res = SSL_set_tlsext_host_name(ssl, hostname.c_str()); if(res != 1) { return openssl_error(this).as_failure(); } - return _connect_hostname_port; + return string_view(_connect_hostname_port).substr(host.size() + 1); } catch(...) { @@ -1025,17 +1293,15 @@ public: } } _lock_holder.unlock(); + assert(!requires_aligned_io()); + assert(_v.is_valid()); auto r = LLFIO_OPENSSL_DISPATCH(read, _do_read, (*s.first, {{&b, 1}, 0}, nd)); _lock_holder.lock(); if(!r) { - if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out) + // Return an error if we never read any data, otherwise sink the error + if(*read > 0) { - if(*read == 0) - { - BIO_set_retry_read(bio); - return 0; - } return 1; } _read_error = std::move(r).as_failure(); @@ -1058,7 +1324,7 @@ public: #if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING std::lock_guard<std::mutex> g(detail::openssl_printing_lock); auto s = _toread_source(); - std::cout << "_bread(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *read << " bytes read and " << s.second->size() + std::cout << this << " _bread(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *read << " bytes read and " << s.second->size() << " remaining in source buffer." << std::endl; #endif return ret; @@ -1073,135 +1339,33 @@ public: assert(_lock_holder.owns_lock()); *written = 0; BIO_clear_retry_flags(bio); -#if 0 - // Write any existing buffers first - if(!_towrite_source_empty()) - { - auto s = _towrite_source(); - auto b = *s.second; - auto &began_steady = _write_deadline_began_steady; - deadline nd; - if(this->is_nonblocking()) - { - LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline); - } - _lock_holder.unlock(); - auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, (*s.first, {{&b, 1}, 0}, nd)); - _lock_holder.lock(); - if(!r) - { - if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out) - { - if(*written == 0) - { - BIO_set_retry_write(bio); - return 0; - } - return 1; - } - _write_error = std::move(r).as_failure(); - LLFIO_OPENSSL_SET_RESULT_ERROR(2); - return 0; - } - *s.second = {s.second->data() + b.size(), s.second->size() - b.size()}; - if(s.second->empty()) - { - *s.second = {(*s.first)->data(), 0}; - _write_buffer_source_idx++; - } - else - { - break; - } - auto r2 = [&]() -> result<void> - { - LLFIO_DEADLINE_TO_TIMEOUT_LOOP(_write_deadline); - return success(); - }(); - if(!r2) - { - if(*written == 0) - { - BIO_set_retry_write(bio); - return 0; - } - return 1; - } - } -#endif - // Are existing buffers now empty and we can write this directly? - if(_towrite_source_empty()) + auto &began_steady = _write_deadline_began_steady; + deadline nd; + if(this->is_nonblocking()) { - auto &began_steady = _write_deadline_began_steady; - deadline nd; - if(this->is_nonblocking()) - { - LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline); - } - _lock_holder.unlock(); - const_buffer_type b((const byte *) buffer, bytes); - auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, ({{&b, 1}, 0}, nd)); - _lock_holder.lock(); - if(!r) - { - if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out) - { - if(*written == 0) - { - BIO_set_retry_write(bio); - return 0; - } - return 1; - } - _write_error = std::move(r).as_failure(); - LLFIO_OPENSSL_SET_RESULT_ERROR(2); - return 0; - } - buffer += b.size(); - bytes -= b.size(); - *written += b.size(); - if(0 == bytes) - { - return 1; - } + LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline); } -#if 0 - // Append into write buffers - while(bytes > 0) + _lock_holder.unlock(); + assert(!requires_aligned_io()); + assert(_v.is_valid()); + const_buffer_type b((const byte *) buffer, bytes); + auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, ({{&b, 1}, 0}, nd)); + _lock_holder.lock(); + if(!r) { - auto s = _towrite_sink(); - // Are we full? - if(s.second == nullptr) - { - if(*written == 0) - { - BIO_set_retry_write(bio); - return 0; - } - return 1; - } - auto remaining = (size_t) (((*s.first)->data() + (*s.first)->size()) - (s.second->data() + s.second->size())); - auto tocopy = std::min(remaining, bytes); - memcpy((byte *) s.second->data() + s.second->size(), buffer, tocopy); - buffer += tocopy; - bytes -= tocopy; - *written += tocopy; - *s.second = {s.second->data(), s.second->size() + tocopy}; - if(remaining == tocopy) - { - _write_buffer_sink_idx++; - } - if(0 == bytes) - { - return 1; - } + _write_error = std::move(r).as_failure(); + LLFIO_OPENSSL_SET_RESULT_ERROR(2); + return 0; } -#endif + _write_socket_full = (b.size() == 0); + buffer += b.size(); + bytes -= b.size(); + *written += b.size(); return 1; }(); #if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING std::lock_guard<std::mutex> g(detail::openssl_printing_lock); - std::cout << "_bwrite(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *written << " bytes written." << std::endl; + std::cout << this << "_bwrite(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *written << " bytes written." << std::endl; #endif return ret; } @@ -1214,52 +1378,11 @@ public: assert(_lock_holder.owns_lock()); if(_ssl_bio != nullptr) { - auto *m = LLFIO_OPENSSL_DISPATCH(multiplexer, multiplexer, ()); - do + auto res = BIO_flush(_ssl_bio); + if(res <= 0 && !BIO_should_retry(_ssl_bio)) { - auto res = BIO_flush(_ssl_bio); - if(res <= 0 && !BIO_should_retry(_ssl_bio)) - { - return openssl_error(this).as_failure(); - } - if(!_towrite_source_empty()) - { - if(m != nullptr) - { - deadline nd; - LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d); - OUTCOME_TRY(m->check_for_any_completed_io(nd)); - } - if(this->is_nonblocking()) - { - _write_deadline = std::chrono::seconds(0); - } - else - { - _write_deadline = {}; - } - size_t written = 0; - res = _bwrite(_self_bio, nullptr, 0, &written); - if(res <= 0 && !BIO_should_retry(_ssl_bio)) - { - return openssl_error(this).as_failure(); - } - if(this->is_nonblocking()) - { - _read_deadline = std::chrono::seconds(0); - } - else - { - _read_deadline = {}; - } - size_t read = 0; - res = _bread(_self_bio, nullptr, 0, &read); - if(res <= 0 && !BIO_should_retry(_ssl_bio)) - { - return openssl_error(this).as_failure(); - } - } - } while(!_towrite_source_empty()); + return openssl_error(this).as_failure(); + } } return success(); } @@ -1288,12 +1411,12 @@ class listening_openssl_socket_handle final : public listening_tls_socket_handle #undef LLFIO_OPENSSL_DISPATCH /* *this has a vptr whose functions point into this class, so what - we need is to bind the function implementation listening_socket_handle + we need is to bind the function implementation listening_byte_socket_handle using its vptr and call it. */ #define LLFIO_OPENSSL_DISPATCH(functp, functt, ...) \ - ((_v.behaviour & native_handle_type::disposition::is_pointer) ? (reinterpret_cast<listening_socket_handle *>(_v.ptr)->functp) __VA_ARGS__ : \ - (socket_cast<listening_socket_handle>(this)->listening_socket_handle::functt) __VA_ARGS__) + ((_v.behaviour & native_handle_type::disposition::is_pointer) ? (reinterpret_cast<listening_byte_socket_handle *>(_v.ptr)->functp) __VA_ARGS__ : \ + (listening_byte_socket_handle::functt) __VA_ARGS__) protected: virtual result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept override @@ -1303,8 +1426,8 @@ protected: { return errc::not_supported; } - listening_socket_handle::buffer_type b; - OUTCOME_TRY(auto &&read, _underlying_read<listening_socket_handle>({b}, d)); + listening_byte_socket_handle::buffer_type b; + OUTCOME_TRY(auto &&read, _underlying_read<listening_byte_socket_handle>({b}, d)); auto *p = new(std::nothrow) openssl_socket_handle(std::move(read.connected_socket().first)); if(p == nullptr) { @@ -1312,7 +1435,7 @@ protected: } req.buffers.connected_socket() = {tls_socket_handle_ptr(p), read.connected_socket().second}; OUTCOME_TRY(p->set_registered_buffer_chunk_size(_registered_buffer_chunk_size)); - OUTCOME_TRY(p->_init(false, !_authentication_certificates_path || !_authentication_certificates_path->empty())); + OUTCOME_TRY(p->_init(false, _authentication_certificates_path)); return {std::move(req.buffers)}; } @@ -1323,8 +1446,8 @@ protected: { return errc::not_supported; } - listening_socket_handle::buffer_type b; - OUTCOME_TRY(auto &&read, _underlying_read<listening_socket_handle>({b}, d)); + listening_byte_socket_handle::buffer_type b; + OUTCOME_TRY(auto &&read, _underlying_read<listening_byte_socket_handle>({b}, d)); auto *p = new(std::nothrow) openssl_socket_handle(std::move(read.connected_socket().first)); if(p == nullptr) { @@ -1332,16 +1455,16 @@ protected: } req.buffers.connected_socket() = {tls_socket_handle_ptr(p), read.connected_socket().second}; OUTCOME_TRY(p->set_registered_buffer_chunk_size(_registered_buffer_chunk_size)); - OUTCOME_TRY(p->_init(false, !_authentication_certificates_path || !_authentication_certificates_path->empty())); + OUTCOME_TRY(p->_init(false, _authentication_certificates_path)); return {std::move(req.buffers)}; } public: - explicit listening_openssl_socket_handle(listening_socket_handle &&sock) + explicit listening_openssl_socket_handle(listening_byte_socket_handle &&sock) : listening_tls_socket_handle(std::move(sock)) { } - explicit listening_openssl_socket_handle(listening_socket_handle *sock) + explicit listening_openssl_socket_handle(listening_byte_socket_handle *sock) : listening_tls_socket_handle(handle(), nullptr) { this->_v.ptr = sock; @@ -1482,7 +1605,7 @@ static struct openssl_socket_source_registration_t virtual result<listening_tls_socket_handle_ptr> listening_socket(ip::family family, byte_socket_handle::mode _mode, byte_socket_handle::caching _caching, byte_socket_handle::flag flags) noexcept override { - OUTCOME_TRY(auto &&sock, listening_socket_handle::listening_socket(family, _mode, _caching, flags)); + OUTCOME_TRY(auto &&sock, listening_byte_socket_handle::listening_byte_socket(family, _mode, _caching, flags)); auto *p = new(std::nothrow) listening_openssl_socket_handle(std::move(sock)); if(p == nullptr) { @@ -1504,7 +1627,7 @@ static struct openssl_socket_source_registration_t return {std::move(ret)}; } - virtual result<listening_tls_socket_handle_ptr> wrap(listening_socket_handle *listening) noexcept override + virtual result<listening_tls_socket_handle_ptr> wrap(listening_byte_socket_handle *listening) noexcept override { auto *p = new(std::nothrow) listening_openssl_socket_handle(listening); if(p == nullptr) |