Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/ned14/llfio.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNiall Douglas <s_github@nedprod.com>2022-04-16 14:07:08 +0300
committerGitHub <noreply@github.com>2022-04-16 14:07:08 +0300
commitac897294b2e9a0383a754527efbbbcc30a582b21 (patch)
treebd01bc968f15944ea98a87fb8cb84b71506a2952
parente99466d0c63b7e73450340ed2ebb1380eafa198b (diff)
parentcc81a5f6a78b87fd48ff829980ba681e0637dba2 (diff)
Merge pull request #89 from ned14/networking
Networking
-rw-r--r--include/llfio/revision.hpp6
-rw-r--r--include/llfio/v2.0/byte_io_multiplexer.hpp22
-rw-r--r--include/llfio/v2.0/byte_socket_handle.hpp378
-rw-r--r--include/llfio/v2.0/detail/impl/posix/byte_io_handle.ipp6
-rw-r--r--include/llfio/v2.0/detail/impl/posix/byte_socket_handle.ipp13
-rw-r--r--include/llfio/v2.0/detail/impl/posix/file_handle.ipp2
-rw-r--r--include/llfio/v2.0/detail/impl/posix/utils.ipp7
-rw-r--r--include/llfio/v2.0/detail/impl/test/null_multiplexer.ipp2
-rw-r--r--include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp671
-rw-r--r--include/llfio/v2.0/detail/impl/windows/byte_io_handle.ipp19
-rw-r--r--include/llfio/v2.0/detail/impl/windows/byte_socket_handle.ipp42
-rw-r--r--include/llfio/v2.0/detail/impl/windows/test/iocp_multiplexer.ipp2
-rw-r--r--include/llfio/v2.0/tls_socket_handle.hpp56
-rw-r--r--test/tests/byte_socket_handle.cpp12
-rw-r--r--test/tests/tls_socket_handle.cpp550
15 files changed, 1068 insertions, 720 deletions
diff --git a/include/llfio/revision.hpp b/include/llfio/revision.hpp
index d0299231..aac0b15c 100644
--- a/include/llfio/revision.hpp
+++ b/include/llfio/revision.hpp
@@ -1,4 +1,4 @@
// Note the second line of this file must ALWAYS be the git SHA, third line ALWAYS the git SHA update time
-#define LLFIO_PREVIOUS_COMMIT_REF 870a8c572f9f0dd81a8ed8d63e3eecd2f395dbe6
-#define LLFIO_PREVIOUS_COMMIT_DATE "2022-04-15 20:14:41 +00:00"
-#define LLFIO_PREVIOUS_COMMIT_UNIQUE 870a8c57
+#define LLFIO_PREVIOUS_COMMIT_REF 5c6f7f3933de89fb4e4a9aa7df69da933a8f09aa
+#define LLFIO_PREVIOUS_COMMIT_DATE "2022-04-15 23:08:23 +00:00"
+#define LLFIO_PREVIOUS_COMMIT_UNIQUE 5c6f7f39
diff --git a/include/llfio/v2.0/byte_io_multiplexer.hpp b/include/llfio/v2.0/byte_io_multiplexer.hpp
index bf4d58cf..96fd9749 100644
--- a/include/llfio/v2.0/byte_io_multiplexer.hpp
+++ b/include/llfio/v2.0/byte_io_multiplexer.hpp
@@ -41,7 +41,7 @@ LLFIO_V2_NAMESPACE_EXPORT_BEGIN
class byte_io_handle;
class byte_socket_handle;
-class listening_socket_handle;
+class listening_byte_socket_handle;
namespace ip
{
class address;
@@ -497,13 +497,13 @@ public:
uint16_t file_handle : 1; //!< This i/o multiplexer can register plain kernel `file_handle`.
uint16_t pipe_handle : 1; //!< This i/o multiplexer can register plain kernel `pipe_handle`.
uint16_t byte_socket_handle : 1; //!< This i/o multiplexer can register plain kernel `byte_socket_handle`.
- uint16_t listening_socket_handle : 1; //!< This i/o multiplexer can register plain kernel `listening_socket_handle`.
+ uint16_t listening_byte_socket_handle : 1; //!< This i/o multiplexer can register plain kernel `listening_byte_socket_handle`.
constexpr kernel_t()
: file_handle(false)
, pipe_handle(false)
, byte_socket_handle(false)
- , listening_socket_handle(false)
+ , listening_byte_socket_handle(false)
{
}
} kernel;
@@ -530,11 +530,11 @@ public:
virtual result<uint8_t> do_byte_io_handle_register(byte_io_handle * /*unused*/) noexcept { return (uint8_t) 0; }
//! Implements `byte_io_handle` deregistration
virtual result<void> do_byte_io_handle_deregister(byte_io_handle * /*unused*/) noexcept { return success(); }
- //! Implements `listening_socket_handle` registration. The bottom two bits of the returned value are set into `_v.behaviour`'s `_multiplexer_state_bit0` and
+ //! Implements `listening_byte_socket_handle` registration. The bottom two bits of the returned value are set into `_v.behaviour`'s `_multiplexer_state_bit0` and
//! `_multiplexer_state_bit`
- virtual result<uint8_t> do_byte_io_handle_register(listening_socket_handle * /*unused*/) noexcept { return errc::operation_not_supported; }
- //! Implements `listening_socket_handle` deregistration
- virtual result<void> do_byte_io_handle_deregister(listening_socket_handle * /*unused*/) noexcept { return errc::operation_not_supported; }
+ virtual result<uint8_t> do_byte_io_handle_register(listening_byte_socket_handle * /*unused*/) noexcept { return errc::operation_not_supported; }
+ //! Implements `listening_byte_socket_handle` deregistration
+ virtual result<void> do_byte_io_handle_deregister(listening_byte_socket_handle * /*unused*/) noexcept { return errc::operation_not_supported; }
//! Implements `byte_io_handle::max_buffers()`
LLFIO_HEADERS_ONLY_VIRTUAL_SPEC size_t do_byte_io_handle_max_buffers(const byte_io_handle *h) const noexcept;
//! Implements `byte_io_handle::allocate_registered_buffer()`
@@ -1185,7 +1185,7 @@ public:
{
friend class byte_io_handle;
friend class byte_socket_handle;
- friend class listening_socket_handle;
+ friend class listening_byte_socket_handle;
static constexpr size_t _state_storage_bytes = _awaitable_size - sizeof(void *) - sizeof(io_operation_state *)
#if LLFIO_ENABLE_COROUTINES
- sizeof(coroutine_handle<>)
@@ -1399,10 +1399,10 @@ public:
}
/*! \brief Constructs either a `unsynchronised_io_operation_state` or a `synchronised_io_operation_state`
- for a `listening_socket_handle` read operation into the storage provided. The i/o is not initiated. The storage must
+ for a `listening_byte_socket_handle` read operation into the storage provided. The i/o is not initiated. The storage must
meet the requirements from `state_requirements()`.
*/
- virtual io_operation_state *construct(span<byte> storage, listening_socket_handle *_h, io_operation_state_visitor *_visitor, deadline d,
+ virtual io_operation_state *construct(span<byte> storage, listening_byte_socket_handle *_h, io_operation_state_visitor *_visitor, deadline d,
std::pair<byte_socket_handle, ip::address> & /*unused*/) noexcept
{
(void) storage;
@@ -1473,7 +1473,7 @@ public:
/*! \brief Combines `.construct()` with `.init_io_operation()` in a single call for improved efficiency.
*/
- virtual io_operation_state *construct_and_init_io_operation(span<byte> storage, listening_socket_handle *_h, io_operation_state_visitor *_visitor, deadline d,
+ virtual io_operation_state *construct_and_init_io_operation(span<byte> storage, listening_byte_socket_handle *_h, io_operation_state_visitor *_visitor, deadline d,
std::pair<byte_socket_handle, ip::address> &req) noexcept
{
io_operation_state *state = construct(storage, _h, _visitor, d, req);
diff --git a/include/llfio/v2.0/byte_socket_handle.hpp b/include/llfio/v2.0/byte_socket_handle.hpp
index 1776b863..3615c5b4 100644
--- a/include/llfio/v2.0/byte_socket_handle.hpp
+++ b/include/llfio/v2.0/byte_socket_handle.hpp
@@ -44,7 +44,7 @@ struct sockaddr_in6;
LLFIO_V2_NAMESPACE_EXPORT_BEGIN
class byte_socket_handle;
-class listening_socket_handle;
+class listening_byte_socket_handle;
namespace ip
{
class address;
@@ -85,7 +85,7 @@ namespace ip
class LLFIO_DECL address
{
friend class LLFIO_V2_NAMESPACE::byte_socket_handle;
- friend class LLFIO_V2_NAMESPACE::listening_socket_handle;
+ friend class LLFIO_V2_NAMESPACE::listening_byte_socket_handle;
friend LLFIO_HEADERS_ONLY_MEMFUNC_SPEC std::ostream &operator<<(std::ostream &s, const address &v);
protected:
@@ -363,7 +363,7 @@ no longer blocks. However it will then block in `read()` or `write()`,
unless its deadline is zero.
If you want to create a socket which awaits connections, you need
-to instance a `listening_socket_handle`. Reads from that handle yield
+to instance a `listening_byte_socket_handle`. Reads from that handle yield
new `byte_socket_handle` instances.
### `caching::safety_barriers`
@@ -599,8 +599,11 @@ public:
flag flags = flag::none) noexcept;
//! \brief Convenience function defaulting `flag::multiplexable` set.
LLFIO_MAKE_FREE_FUNCTION
- static LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<byte_socket_handle>
- multiplexable_byte_socket(ip::family family, mode _mode = mode::write, caching _caching = caching::all, flag flags = flag::multiplexable) noexcept;
+ static result<byte_socket_handle> multiplexable_byte_socket(ip::family family, mode _mode = mode::write, caching _caching = caching::all,
+ flag flags = flag::multiplexable) noexcept
+ {
+ return byte_socket(family, _mode, _caching, flags);
+ }
LLFIO_HEADERS_ONLY_VIRTUAL_SPEC ~byte_socket_handle() override
{
@@ -618,22 +621,36 @@ public:
LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<void> close() noexcept override;
/*! \brief Convenience function to shut down the outbound connection and wait for the other side to shut down our
- inbound connection by throwing away any bytes read, then closing the socket.
+ inbound connection by throwing away any bytes read, then closing the socket. Note that if the deadline passes
+ and we are still reading data, the socket is forced closed.
*/
result<void> shutdown_and_close(deadline d = {}) noexcept
{
LLFIO_DEADLINE_TO_SLEEP_INIT(d);
- OUTCOME_TRY(shutdown());
+ auto r = shutdown();
+ if(!r && r.assume_error() != errc::operation_in_progress)
+ {
+ OUTCOME_TRY(std::move(r));
+ }
byte buffer[4096];
for(;;)
{
deadline nd;
LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
- OUTCOME_TRY(auto readed, read(0, {{buffer}}, nd));
- if(readed == 0)
+ auto r2 = read(0, {{buffer}}, nd);
+ if(r2 && r2.assume_value() == 0)
{
break;
}
+ if(!r2)
+ {
+ if(r2.assume_error() == errc::connection_aborted || r2.assume_error() == errc::connection_reset || r2.assume_error() == errc::not_connected ||
+ r2.assume_error() == errc::timed_out)
+ {
+ break;
+ }
+ OUTCOME_TRY(std::move(r2));
+ }
}
return close();
}
@@ -683,29 +700,23 @@ template <> struct construct<byte_socket_handle>
result<byte_socket_handle> operator()() const noexcept { return byte_socket_handle::byte_socket(family, _mode, _caching, flags); }
};
-/* \class listening_socket_handle_impl
-\brief A handle to a socket-like entity able to receive incoming connections.
+/*! \class listening_socket_handle_buffer_types_injector
+\brief Injects buffer types for a particular kind of listening socket read.
*/
-template <class SocketType> class listening_socket_handle_impl : public handle, public pollable_handle
+template <class Base, class SocketType> struct listening_socket_handle_buffer_types_injector : public Base
{
- template <class ST> friend class listening_socket_handle_impl;
- LLFIO_HEADERS_ONLY_VIRTUAL_SPEC const handle &_get_handle() const noexcept final { return *this; }
-
-protected:
- byte_io_multiplexer *_ctx{nullptr}; // +4 or +8 bytes
-
- template <class Impl>
- result<typename Impl::buffers_type> _underlying_read(typename Impl::template io_request<typename Impl::buffers_type> req, deadline d) noexcept
+ using Base::Base;
+ constexpr listening_socket_handle_buffer_types_injector(Base &&o) noexcept(std::is_nothrow_move_constructible<Base>::value)
+ : Base(static_cast<Base &&>(o))
{
- if(_v.behaviour & native_handle_type::disposition::is_pointer)
- {
- return reinterpret_cast<Impl *>(_v.ptr)->read(std::move(req), d);
- }
- auto *sock = static_cast<Impl *>(static_cast<handle *>(this));
- return (_ctx == nullptr) ? sock->Impl::_do_read(std::move(req), d) : sock->Impl::_do_multiplexer_read(std::move(req), d);
}
+ listening_socket_handle_buffer_types_injector() = default;
+ listening_socket_handle_buffer_types_injector(const listening_socket_handle_buffer_types_injector &) = default;
+ listening_socket_handle_buffer_types_injector(listening_socket_handle_buffer_types_injector &&) = default;
+ listening_socket_handle_buffer_types_injector &operator=(const listening_socket_handle_buffer_types_injector &) = default;
+ listening_socket_handle_buffer_types_injector &operator=(listening_socket_handle_buffer_types_injector &&) = default;
+ ~listening_socket_handle_buffer_types_injector() = default;
-public:
//! The buffer type used by this handle, which is a pair of `SocketType` and `ip::address`
using buffer_type = std::pair<SocketType, ip::address>;
//! The const buffer type used by this handle, which is a pair of `SocketType` and `ip::address`
@@ -809,21 +820,68 @@ public:
};
template <class T> using io_result = result<T>;
template <class T> using awaitable = byte_io_multiplexer::awaitable<T>;
+};
+
+/* \class listening_byte_socket_handle
+\brief A handle to a socket-like entity able to receive incoming connections.
+*/
+class LLFIO_DECL listening_byte_socket_handle : public listening_socket_handle_buffer_types_injector<handle, byte_socket_handle>, public pollable_handle
+{
+ using _base = listening_socket_handle_buffer_types_injector<handle, byte_socket_handle>;
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC const handle &_get_handle() const noexcept final { return *this; }
+
+protected:
+ byte_io_multiplexer *_ctx{nullptr}; // +4 or +8 bytes
+
+ template <class Impl>
+ result<typename Impl::buffers_type> _underlying_read(typename Impl::template io_request<typename Impl::buffers_type> req, deadline d) noexcept
+ {
+ if(_v.behaviour & native_handle_type::disposition::is_pointer)
+ {
+ return reinterpret_cast<Impl *>(_v.ptr)->read(std::move(req), d);
+ }
+ auto *sock = static_cast<Impl *>(static_cast<handle *>(this));
+ return (_ctx == nullptr) ? sock->Impl::_do_read(std::move(req), d) : sock->Impl::_do_multiplexer_read(std::move(req), d);
+ }
+
+public:
// Used by byte_socket_source
virtual void _deleter() { delete this; }
protected:
- virtual result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept = 0;
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept;
- virtual io_result<buffers_type> _do_multiplexer_read(io_request<buffers_type> reqs, deadline d) noexcept = 0;
+ virtual io_result<buffers_type> _do_multiplexer_read(io_request<buffers_type> reqs, deadline d) noexcept
+ {
+ LLFIO_DEADLINE_TO_SLEEP_INIT(d);
+ const auto state_reqs = _ctx->io_state_requirements();
+ auto *storage = (byte *) alloca(state_reqs.first + state_reqs.second);
+ const auto diff = (uintptr_t) storage & (state_reqs.second - 1);
+ storage += state_reqs.second - diff;
+ auto *state = _ctx->construct_and_init_io_operation({storage, state_reqs.first}, this, nullptr, d, reqs.buffers.connected_socket());
+ if(state == nullptr)
+ {
+ return errc::resource_unavailable_try_again;
+ }
+ OUTCOME_TRY(_ctx->flush_inited_io_operations());
+ while(!is_finished(_ctx->check_io_operation(state)))
+ {
+ deadline nd;
+ LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
+ OUTCOME_TRY(_ctx->check_for_any_completed_io(nd));
+ }
+ OUTCOME_TRY(std::move(*state).get_completed_read());
+ state->~io_operation_state();
+ return {std::move(reqs.buffers)};
+ }
-protected:
+public:
//! Default constructor
- constexpr listening_socket_handle_impl() {} // NOLINT
+ constexpr listening_byte_socket_handle() {} // NOLINT
//! Construct a handle from a supplied native handle
- constexpr listening_socket_handle_impl(native_handle_type h, flag flags, byte_io_multiplexer *ctx)
- : handle(std::move(h), flags)
+ constexpr listening_byte_socket_handle(native_handle_type h, flag flags, byte_io_multiplexer *ctx)
+ : _base(std::move(h), flags)
, _ctx(ctx)
{
#ifdef _WIN32
@@ -834,12 +892,12 @@ protected:
#endif
}
//! No copy construction (use clone())
- listening_socket_handle_impl(const listening_socket_handle_impl &) = delete;
+ listening_byte_socket_handle(const listening_byte_socket_handle &) = delete;
//! No copy assignment
- listening_socket_handle_impl &operator=(const listening_socket_handle_impl &) = delete;
+ listening_byte_socket_handle &operator=(const listening_byte_socket_handle &) = delete;
//! Implicit move construction of listening socket handle permitted
- constexpr listening_socket_handle_impl(listening_socket_handle_impl &&o) noexcept
- : handle(std::move(o))
+ constexpr listening_byte_socket_handle(listening_byte_socket_handle &&o) noexcept
+ : _base(std::move(o))
, _ctx(o._ctx)
{
#ifdef _WIN32
@@ -851,8 +909,8 @@ protected:
#endif
}
//! Explicit conversion from handle permitted
- explicit constexpr listening_socket_handle_impl(handle &&o, byte_io_multiplexer *ctx) noexcept
- : handle(std::move(o))
+ explicit constexpr listening_byte_socket_handle(handle &&o, byte_io_multiplexer *ctx) noexcept
+ : _base(std::move(o))
, _ctx(ctx)
{
#ifdef _WIN32
@@ -862,137 +920,8 @@ protected:
}
#endif
}
-
-public:
- /*! \brief The i/o multiplexer this handle will use to multiplex i/o. If this returns null,
- then this handle has not been registered with an i/o multiplexer yet.
- */
- byte_io_multiplexer *multiplexer() const noexcept { return _ctx; }
-
- /*! \brief Sets the i/o multiplexer this handle will use to implement `read()`, `write()` and `barrier()`.
-
- Note that this call deregisters this handle from any existing i/o multiplexer, and registers
- it with the new i/o multiplexer. You must therefore not call it if any i/o is currently
- outstanding on this handle. You should also be aware that multiple dynamic memory
- allocations and deallocations may occur, as well as multiple syscalls (i.e. this is
- an expensive call, try to do it from cold code).
-
- If the handle was not created as multiplexable, this call always fails.
-
- \mallocs Multiple dynamic memory allocations and deallocations.
- */
- virtual result<void> set_multiplexer(byte_io_multiplexer *c = this_thread::multiplexer()) noexcept = 0;
-
- //! Returns the IP family of this socket instance
- ip::family family() const noexcept { return (this->_v.behaviour & native_handle_type::disposition::is_alternate) ? ip::family::v6 : ip::family::v4; }
-
- //! Returns the local endpoint of this socket instance
- virtual result<ip::address> local_endpoint() const noexcept = 0;
-
- /*! \brief Binds a socket to a local endpoint and sets the socket to listen for new connections.
- \param addr The local endpoint to which to bind the socket.
- \param _creation Whether to apply `SO_REUSEADDR` before binding.
- \param backlog The maximum queue length of pending connections. `-1` chooses `SOMAXCONN`.
-
- You should set any socket options etc that you need on `native_handle()` before binding
- the socket to its local endpoint.
-
- \errors Any of the values `bind()` and `listen()` can return.
- */
- virtual result<void> bind(const ip::address &addr, creation _creation = creation::only_if_not_exist, int backlog = -1) noexcept = 0;
-
- /*! Read the contents of the listening socket for newly connected byte sockets.
-
- \return Returns the buffers filled, with its socket handle and address set to the newly connected socket.
- \param req A buffer to fill with a newly connected socket.
- \param d An optional deadline by which to time out.
-
- \errors Any of the errors which `accept()` or `WSAAccept()` might return.
- */
- LLFIO_MAKE_FREE_FUNCTION
- result<buffers_type> read(io_request<buffers_type> req, deadline d = {}) noexcept
- {
- return (_ctx == nullptr) ? _do_read(std::move(req), d) : _do_multiplexer_read(std::move(req), d);
- }
-
- /*! \brief A coroutinised equivalent to `.read()` which suspends the coroutine until
- a new incoming connection occurs. **Blocks execution** i.e is equivalent to `.read()` if no i/o multiplexer
- has been set on this handle!
-
- The awaitable returned is **eager** i.e. it immediately begins the i/o. If the i/o completes
- and finishes immediately, no coroutine suspension occurs.
- */
- LLFIO_MAKE_FREE_FUNCTION
- awaitable<io_result<buffers_type>> co_read(io_request<buffers_type> reqs, deadline d = {}) noexcept;
-#if 0 // TODO
- {
- if(_ctx == nullptr)
- {
- return awaitable<io_result<buffers_type>>(read(std::move(reqs), d));
- }
- awaitable<io_result<buffers_type>> ret;
- ret.set_state(_ctx->construct(ret._state_storage, this, nullptr, d, reqs.buffers.connected_socket()));
- return ret;
- }
-#endif
-};
-
-LLFIO_TEMPLATE(class T, class U)
-LLFIO_TREQUIRES(LLFIO_TEXPR(static_cast<T *>(static_cast<handle *>(nullptr))))
-T *socket_cast(listening_socket_handle_impl<U> *v)
-{
- return static_cast<T *>(static_cast<handle *>(v));
-}
-LLFIO_TEMPLATE(class T, class U)
-LLFIO_TREQUIRES(LLFIO_TEXPR(static_cast<const T *>(static_cast<const handle *>(nullptr))))
-const T *socket_cast(const listening_socket_handle_impl<U> *v)
-{
- return static_cast<const T *>(static_cast<const handle *>(v));
-}
-
-/* \class listening_socket_handle
-\brief A handle to a socket-like entity able to receive incoming connections.
-*/
-class LLFIO_DECL listening_socket_handle : public listening_socket_handle_impl<byte_socket_handle>
-{
- template <class ST> friend class LLFIO_V2_NAMESPACE::listening_socket_handle_impl;
- using _base = listening_socket_handle_impl<byte_socket_handle>;
-
-protected:
- LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept override;
-
- LLFIO_HEADERS_ONLY_VIRTUAL_SPEC io_result<buffers_type> _do_multiplexer_read(io_request<buffers_type> reqs, deadline d) noexcept override
- {
- LLFIO_DEADLINE_TO_SLEEP_INIT(d);
- const auto state_reqs = _ctx->io_state_requirements();
- auto *storage = (byte *) alloca(state_reqs.first + state_reqs.second);
- const auto diff = (uintptr_t) storage & (state_reqs.second - 1);
- storage += state_reqs.second - diff;
- auto *state = _ctx->construct_and_init_io_operation({storage, state_reqs.first}, this, nullptr, d, reqs.buffers.connected_socket());
- if(state == nullptr)
- {
- return errc::resource_unavailable_try_again;
- }
- OUTCOME_TRY(_ctx->flush_inited_io_operations());
- while(!is_finished(_ctx->check_io_operation(state)))
- {
- deadline nd;
- LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
- OUTCOME_TRY(_ctx->check_for_any_completed_io(nd));
- }
- OUTCOME_TRY(std::move(*state).get_completed_read());
- state->~io_operation_state();
- return {std::move(reqs.buffers)};
- }
-
-public:
- constexpr listening_socket_handle() {}
- using _base::_base;
- listening_socket_handle(const listening_socket_handle &) = delete;
- listening_socket_handle(listening_socket_handle &&) = default;
- listening_socket_handle &operator=(const listening_socket_handle &) = delete;
- //! Move assignment of listening_socket_handle_impl permitted
- listening_socket_handle_impl &operator=(listening_socket_handle &&o) noexcept
+ //! Move assignment of listening_byte_socket_handle permitted
+ listening_byte_socket_handle &operator=(listening_byte_socket_handle &&o) noexcept
{
if(this == &o)
{
@@ -1004,25 +933,23 @@ public:
detail::unregister_socket_handle_instance(this);
}
#endif
- this->~listening_socket_handle();
- new(this) listening_socket_handle(std::move(o));
+ this->~listening_byte_socket_handle();
+ new(this) listening_byte_socket_handle(std::move(o));
return *this;
}
-
//! Swap with another instance
LLFIO_MAKE_FREE_FUNCTION
- void swap(listening_socket_handle &o) noexcept
+ void swap(listening_byte_socket_handle &o) noexcept
{
- listening_socket_handle temp(std::move(*this));
+ listening_byte_socket_handle temp(std::move(*this));
*this = std::move(o);
o = std::move(temp);
}
-
- virtual ~listening_socket_handle() override
+ virtual ~listening_byte_socket_handle() override
{
if(_v)
{
- (void) listening_socket_handle_impl::close();
+ (void) listening_byte_socket_handle::close();
}
}
virtual result<void> close() noexcept override
@@ -1049,7 +976,25 @@ public:
return ret;
}
- virtual result<void> set_multiplexer(byte_io_multiplexer *c = this_thread::multiplexer()) noexcept override
+public:
+ /*! \brief The i/o multiplexer this handle will use to multiplex i/o. If this returns null,
+ then this handle has not been registered with an i/o multiplexer yet.
+ */
+ byte_io_multiplexer *multiplexer() const noexcept { return _ctx; }
+
+ /*! \brief Sets the i/o multiplexer this handle will use to implement `read()`, `write()` and `barrier()`.
+
+ Note that this call deregisters this handle from any existing i/o multiplexer, and registers
+ it with the new i/o multiplexer. You must therefore not call it if any i/o is currently
+ outstanding on this handle. You should also be aware that multiple dynamic memory
+ allocations and deallocations may occur, as well as multiple syscalls (i.e. this is
+ an expensive call, try to do it from cold code).
+
+ If the handle was not created as multiplexable, this call always fails.
+
+ \mallocs Multiple dynamic memory allocations and deallocations.
+ */
+ virtual result<void> set_multiplexer(byte_io_multiplexer *c = this_thread::multiplexer()) noexcept
{
if(!is_multiplexable())
{
@@ -1081,11 +1026,6 @@ public:
return success();
}
- LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<ip::address> local_endpoint() const noexcept override;
-
- LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<void> bind(const ip::address &addr, creation _creation = creation::only_if_not_exist,
- int backlog = -1) noexcept override;
-
/*! Create a listening socket handle.
\param _family Which IP family to create the socket in.
\param _mode How to open the socket. If this is `mode::append`, the read side of the socket
@@ -1098,25 +1038,79 @@ public:
\errors Any of the values POSIX `socket()` or `WSASocket()` can return.
*/
LLFIO_MAKE_FREE_FUNCTION
- static LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle> listening_socket(ip::family _family, mode _mode = mode::write,
- caching _caching = caching::all, flag flags = flag::none) noexcept;
+ static LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_byte_socket_handle>
+ listening_byte_socket(ip::family _family, mode _mode = mode::write, caching _caching = caching::all, flag flags = flag::none) noexcept;
//! \brief Convenience function defaulting `flag::multiplexable` set.
LLFIO_MAKE_FREE_FUNCTION
- static LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle>
- multiplexable_listening_socket(ip::family _family, mode _mode = mode::write, caching _caching = caching::all, flag flags = flag::multiplexable) noexcept
+ static result<listening_byte_socket_handle> multiplexable_listening_byte_socket(ip::family _family, mode _mode = mode::write, caching _caching = caching::all,
+ flag flags = flag::multiplexable) noexcept
+ {
+ return listening_byte_socket(_family, _mode, _caching, flags);
+ }
+
+
+ //! Returns the IP family of this socket instance
+ ip::family family() const noexcept { return (this->_v.behaviour & native_handle_type::disposition::is_alternate) ? ip::family::v6 : ip::family::v4; }
+
+ //! Returns the local endpoint of this socket instance
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<ip::address> local_endpoint() const noexcept;
+
+ /*! \brief Binds a socket to a local endpoint and sets the socket to listen for new connections.
+ \param addr The local endpoint to which to bind the socket.
+ \param _creation Whether to apply `SO_REUSEADDR` before binding.
+ \param backlog The maximum queue length of pending connections. `-1` chooses `SOMAXCONN`.
+
+ You should set any socket options etc that you need on `native_handle()` before binding
+ the socket to its local endpoint.
+
+ \errors Any of the values `bind()` and `listen()` can return.
+ */
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<void> bind(const ip::address &addr, creation _creation = creation::only_if_not_exist, int backlog = -1) noexcept;
+
+ /*! Read the contents of the listening socket for newly connected byte sockets.
+
+ \return Returns the buffers filled, with its socket handle and address set to the newly connected socket.
+ \param req A buffer to fill with a newly connected socket.
+ \param d An optional deadline by which to time out.
+
+ \errors Any of the errors which `accept()` or `WSAAccept()` might return.
+ */
+ LLFIO_MAKE_FREE_FUNCTION
+ result<buffers_type> read(io_request<buffers_type> req, deadline d = {}) noexcept
{
- return listening_socket(_family, _mode, _caching, flags);
+ return (_ctx == nullptr) ? _do_read(std::move(req), d) : _do_multiplexer_read(std::move(req), d);
}
+
+ /*! \brief A coroutinised equivalent to `.read()` which suspends the coroutine until
+ a new incoming connection occurs. **Blocks execution** i.e is equivalent to `.read()` if no i/o multiplexer
+ has been set on this handle!
+
+ The awaitable returned is **eager** i.e. it immediately begins the i/o. If the i/o completes
+ and finishes immediately, no coroutine suspension occurs.
+ */
+ LLFIO_MAKE_FREE_FUNCTION
+ awaitable<io_result<buffers_type>> co_read(io_request<buffers_type> reqs, deadline d = {}) noexcept;
+#if 0 // TODO
+ {
+ if(_ctx == nullptr)
+ {
+ return awaitable<io_result<buffers_type>>(read(std::move(reqs), d));
+ }
+ awaitable<io_result<buffers_type>> ret;
+ ret.set_state(_ctx->construct(ret._state_storage, this, nullptr, d, reqs.buffers.connected_socket()));
+ return ret;
+ }
+#endif
};
-//! \brief Constructor for `listening_socket_handle`
-template <> struct construct<listening_socket_handle>
+//! \brief Constructor for `listening_byte_socket_handle`
+template <> struct construct<listening_byte_socket_handle>
{
ip::family family;
byte_socket_handle::mode _mode = byte_socket_handle::mode::write;
byte_socket_handle::caching _caching = byte_socket_handle::caching::all;
byte_socket_handle::flag flags = byte_socket_handle::flag::none;
- result<listening_socket_handle> operator()() const noexcept { return listening_socket_handle::listening_socket(family, _mode, _caching, flags); }
+ result<listening_byte_socket_handle> operator()() const noexcept { return listening_byte_socket_handle::listening_byte_socket(family, _mode, _caching, flags); }
};
// BEGIN make_free_functions.py
diff --git a/include/llfio/v2.0/detail/impl/posix/byte_io_handle.ipp b/include/llfio/v2.0/detail/impl/posix/byte_io_handle.ipp
index 43e53cbb..026bded2 100644
--- a/include/llfio/v2.0/detail/impl/posix/byte_io_handle.ipp
+++ b/include/llfio/v2.0/detail/impl/posix/byte_io_handle.ipp
@@ -375,7 +375,8 @@ result<size_t> poll(span<poll_what> out, span<pollable_handle *> handles, span<c
{
if(handles[n] != nullptr)
{
- auto &h = handles[n]->_get_handle();
+ auto &h_ = handles[n]->_get_handle();
+ auto &h = h_.native_handle().is_third_party_pointer() ? *(handle *) h_.native_handle().ptr : h_;
if(h.is_kernel_handle())
{
fds[fdscount].fd = h.native_handle().fd;
@@ -431,7 +432,8 @@ result<size_t> poll(span<poll_what> out, span<pollable_handle *> handles, span<c
{
if(handles[n] != nullptr)
{
- auto &h = handles[n]->_get_handle();
+ auto &h_ = handles[n]->_get_handle();
+ auto &h = h_.native_handle().is_third_party_pointer() ? *(handle *) h_.native_handle().ptr : h_;
if(h.is_kernel_handle())
{
if(fds[fdscount].revents != 0)
diff --git a/include/llfio/v2.0/detail/impl/posix/byte_socket_handle.ipp b/include/llfio/v2.0/detail/impl/posix/byte_socket_handle.ipp
index f33d94ae..d0cd604f 100644
--- a/include/llfio/v2.0/detail/impl/posix/byte_socket_handle.ipp
+++ b/include/llfio/v2.0/detail/impl/posix/byte_socket_handle.ipp
@@ -339,7 +339,8 @@ namespace ip
#if LLFIO_IP_ADDRESS_RESOLVER_USE_ASYNC_GETADDRINFO
retcode = self->get();
#else
- retcode = self->task.get();
+ if(self->task.valid())
+ retcode = self->task.get();
#endif
break;
case 1: // deadline passed
@@ -690,7 +691,7 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> byte_socket_handle::close() noexcep
/*******************************************************************************************************************/
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_socket_handle::local_endpoint() const noexcept
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_byte_socket_handle::local_endpoint() const noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
ip::address ret;
@@ -702,7 +703,7 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_socket_handle::loc
return ret;
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_socket_handle::bind(const ip::address &addr, creation _creation, int backlog) noexcept
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_byte_socket_handle::bind(const ip::address &addr, creation _creation, int backlog) noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
if(_creation != creation::only_if_not_exist)
@@ -724,16 +725,16 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_socket_handle::bind(const
return success();
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle> listening_socket_handle::listening_socket(ip::family family, mode _mode, caching _caching,
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_byte_socket_handle> listening_byte_socket_handle::listening_byte_socket(ip::family family, mode _mode, caching _caching,
flag flags) noexcept
{
- result<listening_socket_handle> ret(listening_socket_handle(native_handle_type(), flags, nullptr));
+ result<listening_byte_socket_handle> ret(listening_byte_socket_handle(native_handle_type(), flags, nullptr));
native_handle_type &nativeh = ret.value()._v;
OUTCOME_TRY(detail::create_socket(nativeh, family, _mode, _caching, flags));
return ret;
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle::buffers_type> listening_socket_handle::_do_read(io_request<buffers_type> req,
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_byte_socket_handle::buffers_type> listening_byte_socket_handle::_do_read(io_request<buffers_type> req,
deadline d) noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
diff --git a/include/llfio/v2.0/detail/impl/posix/file_handle.ipp b/include/llfio/v2.0/detail/impl/posix/file_handle.ipp
index e9e0e27b..e25c2e47 100644
--- a/include/llfio/v2.0/detail/impl/posix/file_handle.ipp
+++ b/include/llfio/v2.0/detail/impl/posix/file_handle.ipp
@@ -807,7 +807,7 @@ result<file_handle::extent_pair> file_handle::clone_extents_to(file_handle::exte
{
retry_clone:
bool done = false;
- const auto thisblock = std::min(blocksize, item.src.length - thisoffset);
+ const auto thisblock = (size_type) std::min(blocksize, item.src.length - thisoffset);
if(duplicate_extents && item.op == workitem::clone_extents)
{
off_t off_in = item.src.offset + thisoffset, off_out = item.src.offset + thisoffset + destoffsetdiff;
diff --git a/include/llfio/v2.0/detail/impl/posix/utils.ipp b/include/llfio/v2.0/detail/impl/posix/utils.ipp
index fe3aa5d2..72dcf774 100644
--- a/include/llfio/v2.0/detail/impl/posix/utils.ipp
+++ b/include/llfio/v2.0/detail/impl/posix/utils.ipp
@@ -24,6 +24,7 @@ Distributed under the Boost Software License, Version 1.0.
#include "../../../utils.hpp"
+#include <cinttypes> // for SCNu64
#include <mutex> // for lock_guard
#include <sys/mman.h>
@@ -636,13 +637,15 @@ namespace utils
static const uint64_t ts_multiplier = 1000000000ULL / sysconf(_SC_CLK_TCK);
OUTCOME_TRY(fill_buffer(buffer1, "/proc/self/stat"));
OUTCOME_TRY(fill_buffer(buffer2, "/proc/stat"));
- if(sscanf(buffer1.data(), "%*d %*s %*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %lu %lu", &ret.process_ns_in_user_mode, &ret.process_ns_in_kernel_mode) <
+ if(sscanf(buffer1.data(), "%*d %*s %*c %*d %*d %*d %*d %*d %*u %*u %*u %*u %*u %" SCNu64 " %" SCNu64, &ret.process_ns_in_user_mode,
+ &ret.process_ns_in_kernel_mode) <
2)
{
return errc::protocol_error;
}
uint64_t user_nice;
- if(sscanf(buffer2.data(), "cpu %lu %lu %lu %lu", &ret.system_ns_in_user_mode, &user_nice, &ret.system_ns_in_kernel_mode, &ret.system_ns_in_idle_mode) < 4)
+ if(sscanf(buffer2.data(), "cpu %" SCNu64 " %" SCNu64 " %" SCNu64 " %" SCNu64, &ret.system_ns_in_user_mode, &user_nice, &ret.system_ns_in_kernel_mode,
+ &ret.system_ns_in_idle_mode) < 4)
{
return errc::protocol_error;
}
diff --git a/include/llfio/v2.0/detail/impl/test/null_multiplexer.ipp b/include/llfio/v2.0/detail/impl/test/null_multiplexer.ipp
index f0f6b410..0c31d90a 100644
--- a/include/llfio/v2.0/detail/impl/test/null_multiplexer.ipp
+++ b/include/llfio/v2.0/detail/impl/test/null_multiplexer.ipp
@@ -182,7 +182,7 @@ namespace test
ret.multiplexes.kernel.file_handle = true;
ret.multiplexes.kernel.pipe_handle = true;
ret.multiplexes.kernel.byte_socket_handle = true;
- ret.multiplexes.kernel.listening_socket_handle = false;
+ ret.multiplexes.kernel.listening_byte_socket_handle = false;
return ret;
}();
return v;
diff --git a/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp b/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp
index e9a86f98..ffb72949 100644
--- a/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp
+++ b/include/llfio/v2.0/detail/impl/tls_socket_sources/openssl.ipp
@@ -1,5 +1,5 @@
/* A TLS socket source based on OpenSSL
-(C) 2021-2021 Niall Douglas <http://www.nedproductions.biz/> (20 commits)
+(C) 2021-2022 Niall Douglas <http://www.nedproductions.biz/> (20 commits)
File Created: Dec 2021
@@ -29,11 +29,17 @@ Distributed under the Boost Software License, Version 1.0.
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
+#include <openssl/x509.h>
#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING
#include <iostream>
#endif
+#ifdef _WIN32
+#include <cryptuiapi.h>
+#pragma comment(lib, "cryptui.lib")
+#endif
+
LLFIO_V2_NAMESPACE_BEGIN
namespace detail
@@ -101,7 +107,7 @@ namespace detail
PSK-AES128-CCM
*/
static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA:@SECLEVEL=0";
- //static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:!aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA";
+ // static const char *openssl_unverified_cipher_list = "CHACHA20:HIGH:!aNULL:!EXPORT:!LOW:!eNULL:!SSLv2:!SSLv3:!TLSv1.0:!CAMELLIA:!ARIA:!SHA:!kRSA";
/* This is the list my OpenSSL v1.1 offers in HELLO for this spec (48):
TLS_AES_256_GCM_SHA384
@@ -246,6 +252,64 @@ namespace detail
constexpr openssl_error_domain openssl_error_domain_inst;
inline constexpr const openssl_error_domain &openssl_error_domain::get() { return openssl_error_domain_inst; }
using openssl_code = OUTCOME_V2_NAMESPACE::experimental::status_code<openssl_error_domain>;
+
+ struct x509_error_domain final : public OUTCOME_V2_NAMESPACE::experimental::status_code_domain
+ {
+ using value_type = int;
+
+ constexpr x509_error_domain()
+ : OUTCOME_V2_NAMESPACE::experimental::status_code_domain("{bea47e79-6787-6009-7ac7-ba8616575312}")
+ {
+ }
+
+ virtual string_ref name() const noexcept override { return string_ref("x509"); }
+ virtual payload_info_t payload_info() const noexcept override
+ {
+ return {sizeof(value_type), sizeof(status_code_domain *) + sizeof(value_type),
+ (alignof(value_type) > alignof(status_code_domain *)) ? alignof(value_type) : alignof(status_code_domain *)};
+ }
+ static inline constexpr const x509_error_domain &get();
+
+ virtual bool _do_failure(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const noexcept override
+ {
+ auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code);
+ return c.value() != 0;
+ }
+ virtual bool _do_equivalent(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code1,
+ const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code2) const noexcept override
+ {
+ assert(code1.domain() == *this); // NOLINT
+ const auto &c1 = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code1); // NOLINT
+ if(code2.domain() == *this)
+ {
+ const auto &c2 = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code2); // NOLINT
+ return c1.value() == c2.value();
+ }
+ return false;
+ }
+ virtual OUTCOME_V2_NAMESPACE::experimental::generic_code
+ _generic_code(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &) const noexcept override
+ {
+ return errc::unknown;
+ }
+ virtual string_ref _do_message(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const noexcept override
+ {
+ auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code);
+ if(c.value() == 0)
+ {
+ return string_ref("not an error");
+ }
+ return string_ref(X509_verify_cert_error_string(c.value()));
+ }
+ SYSTEM_ERROR2_NORETURN virtual void _do_throw_exception(const OUTCOME_V2_NAMESPACE::experimental::status_code<void> &code) const override
+ {
+ auto &c = static_cast<const OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain> &>(code);
+ throw OUTCOME_V2_NAMESPACE::experimental::status_error<x509_error_domain>(c);
+ }
+ };
+ constexpr x509_error_domain x509_error_domain_inst;
+ inline constexpr const x509_error_domain &x509_error_domain::get() { return x509_error_domain_inst; }
+ using x509_code = OUTCOME_V2_NAMESPACE::experimental::status_code<x509_error_domain>;
} // namespace detail
template <class T> inline result<void> openssl_error(T *inst, unsigned long errcode = ERR_get_error())
{
@@ -284,6 +348,16 @@ inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_ge
assert(ret.failure());
return ret;
}
+inline result<void> x509_error(int errcode)
+{
+ detail::x509_code ret(errcode);
+#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING
+ std::lock_guard<std::mutex> g(detail::openssl_printing_lock);
+ std::cerr << "X509 error: " << ret.message().c_str() << std::endl;
+#endif
+ assert(ret.failure());
+ return ret;
+}
#else
namespace detail
{
@@ -321,6 +395,17 @@ namespace detail
return {buffer};
}
};
+ struct x509_error_category final : public std::error_category
+ {
+ virtual const char *name() const noexcept override { return "x509"; }
+ virtual std::error_condition default_error_condition(int code) const noexcept override { return {code, *this}; }
+ virtual std::string message(int code) const override
+ {
+ if(code == 0)
+ return "not an error";
+ return X509_verify_cert_error_string(code);
+ }
+ };
} // namespace detail
template <class T> inline result<void> openssl_error(T *inst, unsigned long errcode = ERR_get_error())
{
@@ -333,7 +418,12 @@ template <class T> inline result<void> openssl_error(T *inst, unsigned long errc
return (ERR_GET_REASON(errcode) == 2) ? std::move(inst->_write_error) : std::move(inst->_read_error);
}
static detail::openssl_error_category cat;
- return error_info(std::error_code((int) errcode, cat));
+ error_info ret(std::error_code((int) errcode, cat));
+#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING
+ std::lock_guard<std::mutex> g(detail::openssl_printing_lock);
+ std::cout << "ERROR: " << ret.message() << std::endl;
+#endif
+ return ret;
}
inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_get_error())
{
@@ -344,6 +434,11 @@ inline result<void> openssl_error(std::nullptr_t, unsigned long errcode = ERR_ge
static detail::openssl_error_category cat;
return error_info(std::error_code((int) errcode, cat));
}
+inline result<void> x509_error(int errcode)
+{
+ static detail::x509_error_category cat;
+ return error_info(std::error_code((int) errcode, cat));
+}
#endif
namespace detail
@@ -396,8 +491,14 @@ namespace detail
static struct openssl_default_ctxs_t
{
SSL_CTX *unverified{nullptr}, *verified{nullptr};
+ X509_STORE *certstore{nullptr};
~openssl_default_ctxs_t()
{
+ if(certstore != nullptr)
+ {
+ X509_STORE_free(certstore);
+ certstore = nullptr;
+ }
if(verified != nullptr)
{
SSL_CTX_free(verified);
@@ -417,7 +518,48 @@ namespace detail
QUICKCPPLIB_NAMESPACE::configurable_spinlock::lock_guard<QUICKCPPLIB_NAMESPACE::configurable_spinlock::spinlock<unsigned>> g(lock);
if(verified == nullptr)
{
- auto make_ctx = [](bool verify_peer) -> result<SSL_CTX *>
+#ifdef _WIN32
+ // Create an OpenSSL certificate store made up of the certs from the Windows certificate store
+ HCERTSTORE winstore = CertOpenSystemStoreW(NULL, L"ROOT");
+ if(!winstore)
+ {
+ return win32_error();
+ }
+ auto unwinstore = make_scope_exit([&]() noexcept { CertCloseStore(winstore, 0); });
+ certstore = X509_STORE_new();
+ if(!certstore)
+ {
+ return openssl_error(nullptr);
+ }
+ PCCERT_CONTEXT context = nullptr;
+ auto uncontext = make_scope_exit(
+ [&]() noexcept
+ {
+ if(context != nullptr)
+ {
+ CertFreeCertificateContext(context);
+ }
+ });
+ while((context = CertEnumCertificatesInStore(winstore, context)) != nullptr)
+ {
+ const unsigned char *in = (const unsigned char *) context->pbCertEncoded;
+ X509 *x509 = d2i_X509(nullptr, &in, context->cbCertEncoded);
+ if(!x509)
+ {
+ return openssl_error(nullptr);
+ }
+ auto unx509 = make_scope_exit([&]() noexcept { X509_free(x509); });
+ //{
+ // X509_NAME_print_ex_fp(stdout, X509_get_issuer_name(x509), 3, 0);
+ // printf("\n");
+ //}
+ if(X509_STORE_add_cert(certstore, x509) <= 0)
+ {
+ return openssl_error(nullptr);
+ }
+ }
+#endif
+ auto make_ctx = [certstore = certstore](bool verify_peer) -> result<SSL_CTX *>
{
SSL_CTX *_ctx = SSL_CTX_new(TLS_method());
if(_ctx == nullptr)
@@ -434,13 +576,17 @@ namespace detail
}
else
{
+ if(certstore != nullptr)
+ {
+ SSL_CTX_set1_cert_store(_ctx, certstore);
+ }
if(!SSL_CTX_set_cipher_list(_ctx, openssl_verified_cipher_list))
{
return openssl_error(nullptr).as_failure();
}
SSL_CTX_set_verify(_ctx, SSL_VERIFY_PEER, nullptr);
SSL_CTX_set_verify_depth(_ctx, 4);
- if(!SSL_CTX_set_default_verify_paths(_ctx))
+ if(SSL_CTX_set_default_verify_paths(_ctx) <= 0)
{
return openssl_error(nullptr).as_failure();
}
@@ -471,16 +617,33 @@ class openssl_socket_handle final : public tls_socket_handle
optional<filesystem::path> _authentication_certificates_path;
std::string _connect_hostname_port;
+ /* We use a registered buffer from the underlying transport for reads only, but not writes.
+ The reason why not is that from my best reading of the implementation source code,
+ OpenSSL doesn't seem to allow fixed size write buffers (as according to the documentation
+ for BIO_set_mem_buf), plus OpenSSL seems to treat failure to allocate memory as an abort
+ situation rather than a retry situation, so there seems to me that is no way of backpressuring
+ a fixed size buffer in OpenSSL.
+
+ Furthermore, the documentation for BUF_MEM suggests that the buffer must be a single
+ contiguous region of memory, so we can't use a list of multiple registered
+ buffers either. If your TLS implementation were better designed, it should be possible to
+ use registered buffers for both reads and writes, and that then reduces CPU cache loading
+ on a very busy server.
+
+ Be aware that due to this design flaw in OpenSSL, it is possible to buffer writes to
+ infinity i.e. it never returns a partial write, it always accepts fully every write. We
+ therefore have to emulate backpressure below.
+ */
std::mutex _lock;
std::unique_lock<std::mutex> _lock_holder{_lock, std::defer_lock};
- uint16_t _read_buffer_source_idx{0}, _write_buffer_source_idx{0};
- uint16_t _read_buffer_sink_idx{0}, _write_buffer_sink_idx{0};
+ uint16_t _read_buffer_source_idx{0}, _read_buffer_sink_idx{0};
+ bool _write_socket_full{false};
+ uint8_t _still_connecting{0};
deadline _read_deadline, _write_deadline;
result<void> _read_error{success()}, _write_error{success()};
std::chrono::steady_clock::time_point _read_deadline_began_steady, _write_deadline_began_steady;
- byte_socket_handle::registered_buffer_type _read_buffers[BUFFERS_COUNT]{}, _write_buffers[BUFFERS_COUNT]{};
+ byte_socket_handle::registered_buffer_type _read_buffers[BUFFERS_COUNT]{};
byte_socket_handle::buffer_type _read_buffers_valid[BUFFERS_COUNT]{};
- byte_socket_handle::const_buffer_type _write_buffers_valid[BUFFERS_COUNT]{};
// Front of the queue
std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::buffer_type *> _toread_source() noexcept
@@ -499,23 +662,6 @@ class openssl_socket_handle final : public tls_socket_handle
}
return {&_read_buffers[idx], &_read_buffers_valid[idx]};
}
- // Front of the queue
- std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::const_buffer_type *> _towrite_source() noexcept
- {
- return {&_write_buffers[_write_buffer_source_idx % BUFFERS_COUNT], &_write_buffers_valid[_write_buffer_source_idx % BUFFERS_COUNT]};
- }
- bool _towrite_source_empty() const noexcept { return _write_buffers_valid[_write_buffer_source_idx % BUFFERS_COUNT].empty(); }
- // Back of the queue. Can return "full"
- std::pair<byte_socket_handle::registered_buffer_type *, byte_socket_handle::const_buffer_type *> _towrite_sink() noexcept
- {
- const auto idx = _write_buffer_sink_idx % BUFFERS_COUNT;
- if(idx == (_write_buffer_source_idx % BUFFERS_COUNT) &&
- (_write_buffers_valid[idx].data() + _write_buffers_valid[idx].size()) == (_write_buffers[idx]->data() + _write_buffers[idx]->size()))
- {
- return {nullptr, nullptr}; // full
- }
- return {&_write_buffers[idx], &_write_buffers_valid[idx]};
- }
#undef LLFIO_OPENSSL_DISPATCH
#define LLFIO_OPENSSL_DISPATCH(functp, functt, ...) \
@@ -524,6 +670,11 @@ class openssl_socket_handle final : public tls_socket_handle
protected:
virtual size_t _do_max_buffers() const noexcept override { return 1; /* There is no atomicity of scatter gather i/o at all! */ }
+ virtual result<registered_buffer_type> _do_allocate_registered_buffer(size_t &bytes) noexcept override
+ {
+ LLFIO_LOG_FUNCTION_CALL(this);
+ return LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, _do_allocate_registered_buffer, (bytes));
+ }
virtual io_result<buffers_type> _do_read(io_request<buffers_type> reqs, deadline d) noexcept override
{
LLFIO_LOG_FUNCTION_CALL(this);
@@ -532,10 +683,21 @@ protected:
return errc::not_supported;
}
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
+ if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0)
+ {
+ return errc::not_connected;
+ }
+ LLFIO_DEADLINE_TO_SLEEP_INIT(d);
if(d)
{
- LLFIO_DEADLINE_TO_SLEEP_INIT(d);
_read_deadline_began_steady = began_steady;
_read_deadline = d;
}
@@ -546,7 +708,9 @@ protected:
for(size_t n = 0; n < reqs.buffers.size(); n++)
{
size_t read = 0;
- auto res = BIO_read_ex(_ssl_bio, reqs.buffers[n].data(), reqs.buffers[n].size(), &read);
+ // OpenSSL early outs if buf is ever null
+ byte dummy{}, *buf = reqs.buffers[n].empty() ? &dummy : reqs.buffers[n].data();
+ auto res = BIO_read_ex(_ssl_bio, buf, reqs.buffers[n].size(), &read);
if(res <= 0)
{
auto errcode = ERR_get_error();
@@ -560,11 +724,16 @@ protected:
{
return errc::operation_would_block;
}
- return openssl_error(this).as_failure();
+ return openssl_error(this, errcode).as_failure();
}
if(read < reqs.buffers[n].size())
{
reqs.buffers[n] = {reqs.buffers[n].data(), read};
+ if(n == 0 && read == 0)
+ {
+ reqs.buffers = {reqs.buffers.data(), n};
+ return std::move(reqs.buffers);
+ }
reqs.buffers = {reqs.buffers.data(), n + 1};
break;
}
@@ -579,7 +748,18 @@ protected:
return errc::not_supported;
}
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
+ if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0)
+ {
+ return errc::not_connected;
+ }
LLFIO_DEADLINE_TO_SLEEP_INIT(d);
if(d)
{
@@ -590,6 +770,29 @@ protected:
{
_write_deadline = {};
}
+ // OpenSSL will accept new writes forever, so we need to emulate write backpressure
+ if(_write_socket_full)
+ {
+ // Write nothing new to OpenSSL, should cause _bwrite() to get called which will check
+ // if the socket's write buffers have drained
+ size_t written = 0;
+ auto res = BIO_write_ex(_ssl_bio, nullptr, 0, &written);
+ if(res <= 0)
+ {
+ auto errcode = ERR_get_error();
+ if(BIO_should_retry(_ssl_bio))
+ {
+ return errc::operation_would_block;
+ }
+ return openssl_error(this, errcode).as_failure();
+ }
+ if(_write_socket_full)
+ {
+ // Return no buffers written
+ reqs.buffers = {reqs.buffers.data(), size_t(0)};
+ return std::move(reqs.buffers);
+ }
+ }
for(size_t n = 0; n < reqs.buffers.size(); n++)
{
size_t written = 0;
@@ -607,11 +810,16 @@ protected:
{
return errc::operation_would_block;
}
- return openssl_error(this).as_failure();
+ return openssl_error(this, errcode).as_failure();
}
if(written < reqs.buffers[n].size())
{
reqs.buffers[n] = {reqs.buffers[n].data(), written};
+ if(n == 0 && written == 0)
+ {
+ reqs.buffers = {reqs.buffers.data(), n};
+ return std::move(reqs.buffers);
+ }
reqs.buffers = {reqs.buffers.data(), n + 1};
break;
}
@@ -633,47 +841,64 @@ protected:
return errc::not_supported;
}
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
- LLFIO_DEADLINE_TO_SLEEP_INIT(d);
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
if(_ssl_bio == nullptr)
{
- OUTCOME_TRY(_init(true, _authentication_certificates_path && !_authentication_certificates_path->empty()));
+ OUTCOME_TRY(_init(true, _authentication_certificates_path));
}
- if(!(_v.behaviour & native_handle_type::disposition::_is_connected))
+ if(!(_v.behaviour & native_handle_type::disposition::_is_connected) || _still_connecting > 0)
{
+ LLFIO_DEADLINE_TO_SLEEP_INIT(d);
OUTCOME_TRY(LLFIO_OPENSSL_DISPATCH(connect, _do_connect, (addr, d)));
- if(d)
- {
- _read_deadline_began_steady = began_steady;
- _write_deadline_began_steady = began_steady;
- _read_deadline = d;
- _write_deadline = d;
- }
- else
- {
- _read_deadline = {};
- _write_deadline = {};
- }
- auto res = BIO_do_connect(_ssl_bio);
- if(res != 1)
+ for(;;)
{
- if(BIO_should_retry(_ssl_bio))
+ deadline nd;
+ LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
+ if(nd)
{
- return errc::operation_in_progress;
+ _read_deadline_began_steady = began_steady;
+ _write_deadline_began_steady = began_steady;
+ _read_deadline = d;
+ _write_deadline = d;
}
- return openssl_error(this).as_failure();
- }
-
- res = BIO_do_handshake(_ssl_bio);
- if(res != 1)
- {
- if(BIO_should_retry(_ssl_bio))
+ else
{
- return errc::operation_in_progress;
+ _read_deadline = {};
+ _write_deadline = {};
}
- return openssl_error(this).as_failure();
+ if(_still_connecting < 1)
+ {
+ auto res = BIO_do_connect(_ssl_bio);
+ _still_connecting = 1;
+ if(res != 1)
+ {
+ if(BIO_should_retry(_ssl_bio))
+ {
+ return errc::operation_in_progress;
+ }
+ return openssl_error(this).as_failure();
+ }
+ }
+ {
+ auto res = BIO_do_handshake(_ssl_bio);
+ if(res != 1)
+ {
+ if(BIO_should_retry(_ssl_bio))
+ {
+ return errc::operation_in_progress;
+ }
+ return openssl_error(this).as_failure();
+ }
+ }
+ break;
}
-
if(!_connect_hostname_port.empty())
{
SSL *ssl{nullptr};
@@ -691,13 +916,14 @@ protected:
{
return openssl_error(this).as_failure();
}
- res = SSL_get_verify_result(ssl);
+ auto res = SSL_get_verify_result(ssl);
if(X509_V_OK != res)
{
- return openssl_error(this).as_failure();
+ return x509_error(res).as_failure();
}
}
_v.behaviour |= native_handle_type::disposition::_is_connected;
+ _still_connecting = 0;
}
return success();
}
@@ -714,9 +940,11 @@ public:
this->_v.behaviour = (sock->native_handle().behaviour & ~(native_handle_type::disposition::kernel_handle)) | native_handle_type::disposition::is_pointer;
}
- result<void> _init(bool is_client, bool verify_peer) noexcept
+ result<void> _init(bool is_client, const optional<filesystem::path> &certpath) noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
+ const bool verify_peer =
+ (is_client && (!certpath.has_value() || (certpath.has_value() && !certpath->empty()))) || (!is_client && (!certpath.has_value() || !certpath->empty()));
OUTCOME_TRY(detail::openssl_default_ctxs.init());
assert(_ssl_bio == nullptr);
_ssl_bio = BIO_new_ssl(verify_peer ? detail::openssl_default_ctxs.verified : detail::openssl_default_ctxs.unverified, is_client);
@@ -724,6 +952,23 @@ public:
{
return openssl_error(this).as_failure();
}
+ if(certpath.has_value() && !certpath->empty())
+ {
+ SSL *ssl{nullptr};
+ BIO_get_ssl(_ssl_bio, &ssl);
+ if(ssl == nullptr)
+ {
+ return openssl_error(this).as_failure();
+ }
+ if(SSL_use_certificate_file(ssl, certpath->string().c_str(), SSL_FILETYPE_PEM) <= 0)
+ {
+ return openssl_error(this).as_failure();
+ }
+ if(SSL_use_PrivateKey_file(ssl, certpath->string().c_str(), SSL_FILETYPE_PEM) <= 0)
+ {
+ return openssl_error(this).as_failure();
+ }
+ }
_self_bio = BIO_new(detail::openssl_custom_bio.method);
if(_self_bio == nullptr)
{
@@ -740,7 +985,14 @@ public:
if(kind == shutdown_write || kind == shutdown_both)
{
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
SSL *ssl{nullptr};
BIO_get_ssl(_ssl_bio, &ssl);
if(ssl == nullptr)
@@ -761,7 +1013,12 @@ public:
{
return errc::operation_in_progress;
}
- return openssl_error(this).as_failure();
+ if(e == SSL_ERROR_SSL)
+ {
+ // Already shut down?
+ break;
+ }
+ return openssl_error(this, e).as_failure();
}
// Shutdown is in progress
if(this->is_nonblocking())
@@ -792,27 +1049,34 @@ public:
{
LLFIO_LOG_FUNCTION_CALL(this);
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
OUTCOME_TRY(_flush_towrite({}));
+ if(_ssl_bio != nullptr)
+ {
+ BIO_free_all(_ssl_bio); // also frees _self_bio
+ _ssl_bio = nullptr;
+ }
if(_v.behaviour & native_handle_type::disposition::is_pointer)
{
tls_socket_handle::release();
}
else
{
+ _lock_holder.unlock();
OUTCOME_TRY(tls_socket_handle::close());
- }
- if(_ssl_bio != nullptr)
- {
- BIO_free_all(_ssl_bio); // also frees _self_bio
- _ssl_bio = nullptr;
+ _lock_holder.lock();
}
for(size_t n = 0; n < BUFFERS_COUNT; n++)
{
_read_buffers_valid[n] = {};
- _write_buffers_valid[n] = {};
_read_buffers[n].reset();
- _write_buffers[n].reset();
}
return success();
}
@@ -881,8 +1145,15 @@ public:
{
LLFIO_LOG_FUNCTION_CALL(this);
_lock_holder.lock();
- auto unlock = make_scope_exit([this]() noexcept { _lock_holder.unlock(); });
- if(!_toread_source_empty() || !_towrite_source_empty())
+ auto unlock = make_scope_exit(
+ [this]() noexcept
+ {
+ if(_lock_holder.owns_lock())
+ {
+ _lock_holder.unlock();
+ }
+ });
+ if(!_toread_source_empty())
{
return errc::device_or_resource_busy;
}
@@ -893,9 +1164,6 @@ public:
auto _bytes = bytes;
OUTCOME_TRY(_read_buffers[n], LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, allocate_registered_buffer, (_bytes)));
_read_buffers_valid[n] = {_read_buffers[n]->data(), 0};
- _bytes = bytes;
- OUTCOME_TRY(_write_buffers[n], LLFIO_OPENSSL_DISPATCH(allocate_registered_buffer, allocate_registered_buffer, (_bytes)));
- _write_buffers_valid[n] = {_write_buffers[n]->data(), 0};
}
}
return success();
@@ -934,28 +1202,28 @@ public:
_connect_hostname_port.assign(host.data(), host.size());
_connect_hostname_port.push_back(':');
_connect_hostname_port.append(std::to_string(port));
- auto res = BIO_set_conn_hostname(_ssl_bio, _connect_hostname_port.c_str());
- if(res != 1)
- {
- return openssl_error(this).as_failure();
- }
if(_ctx == nullptr)
{
- OUTCOME_TRY(_init(true, true));
+ OUTCOME_TRY(_init(true, _authentication_certificates_path));
}
+ auto res = BIO_set_conn_hostname(_ssl_bio, _connect_hostname_port.c_str());
+ /* if(res != 1)
+ {
+ return openssl_error(this).as_failure();
+ }*/
SSL *ssl{nullptr};
BIO_get_ssl(_ssl_bio, &ssl);
if(ssl == nullptr)
{
return openssl_error(this).as_failure();
}
- auto hostname = _connect_hostname_port.substr(0, _connect_hostname_port.rfind(':'));
+ std::string hostname(host);
res = SSL_set_tlsext_host_name(ssl, hostname.c_str());
if(res != 1)
{
return openssl_error(this).as_failure();
}
- return _connect_hostname_port;
+ return string_view(_connect_hostname_port).substr(host.size() + 1);
}
catch(...)
{
@@ -1025,17 +1293,15 @@ public:
}
}
_lock_holder.unlock();
+ assert(!requires_aligned_io());
+ assert(_v.is_valid());
auto r = LLFIO_OPENSSL_DISPATCH(read, _do_read, (*s.first, {{&b, 1}, 0}, nd));
_lock_holder.lock();
if(!r)
{
- if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out)
+ // Return an error if we never read any data, otherwise sink the error
+ if(*read > 0)
{
- if(*read == 0)
- {
- BIO_set_retry_read(bio);
- return 0;
- }
return 1;
}
_read_error = std::move(r).as_failure();
@@ -1058,7 +1324,7 @@ public:
#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING
std::lock_guard<std::mutex> g(detail::openssl_printing_lock);
auto s = _toread_source();
- std::cout << "_bread(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *read << " bytes read and " << s.second->size()
+ std::cout << this << " _bread(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *read << " bytes read and " << s.second->size()
<< " remaining in source buffer." << std::endl;
#endif
return ret;
@@ -1073,135 +1339,33 @@ public:
assert(_lock_holder.owns_lock());
*written = 0;
BIO_clear_retry_flags(bio);
-#if 0
- // Write any existing buffers first
- if(!_towrite_source_empty())
- {
- auto s = _towrite_source();
- auto b = *s.second;
- auto &began_steady = _write_deadline_began_steady;
- deadline nd;
- if(this->is_nonblocking())
- {
- LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline);
- }
- _lock_holder.unlock();
- auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, (*s.first, {{&b, 1}, 0}, nd));
- _lock_holder.lock();
- if(!r)
- {
- if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out)
- {
- if(*written == 0)
- {
- BIO_set_retry_write(bio);
- return 0;
- }
- return 1;
- }
- _write_error = std::move(r).as_failure();
- LLFIO_OPENSSL_SET_RESULT_ERROR(2);
- return 0;
- }
- *s.second = {s.second->data() + b.size(), s.second->size() - b.size()};
- if(s.second->empty())
- {
- *s.second = {(*s.first)->data(), 0};
- _write_buffer_source_idx++;
- }
- else
- {
- break;
- }
- auto r2 = [&]() -> result<void>
- {
- LLFIO_DEADLINE_TO_TIMEOUT_LOOP(_write_deadline);
- return success();
- }();
- if(!r2)
- {
- if(*written == 0)
- {
- BIO_set_retry_write(bio);
- return 0;
- }
- return 1;
- }
- }
-#endif
- // Are existing buffers now empty and we can write this directly?
- if(_towrite_source_empty())
+ auto &began_steady = _write_deadline_began_steady;
+ deadline nd;
+ if(this->is_nonblocking())
{
- auto &began_steady = _write_deadline_began_steady;
- deadline nd;
- if(this->is_nonblocking())
- {
- LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline);
- }
- _lock_holder.unlock();
- const_buffer_type b((const byte *) buffer, bytes);
- auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, ({{&b, 1}, 0}, nd));
- _lock_holder.lock();
- if(!r)
- {
- if(r.error() == errc::operation_would_block || r.error() == errc::resource_unavailable_try_again || r.error() == errc::timed_out)
- {
- if(*written == 0)
- {
- BIO_set_retry_write(bio);
- return 0;
- }
- return 1;
- }
- _write_error = std::move(r).as_failure();
- LLFIO_OPENSSL_SET_RESULT_ERROR(2);
- return 0;
- }
- buffer += b.size();
- bytes -= b.size();
- *written += b.size();
- if(0 == bytes)
- {
- return 1;
- }
+ LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, _write_deadline);
}
-#if 0
- // Append into write buffers
- while(bytes > 0)
+ _lock_holder.unlock();
+ assert(!requires_aligned_io());
+ assert(_v.is_valid());
+ const_buffer_type b((const byte *) buffer, bytes);
+ auto r = LLFIO_OPENSSL_DISPATCH(write, _do_write, ({{&b, 1}, 0}, nd));
+ _lock_holder.lock();
+ if(!r)
{
- auto s = _towrite_sink();
- // Are we full?
- if(s.second == nullptr)
- {
- if(*written == 0)
- {
- BIO_set_retry_write(bio);
- return 0;
- }
- return 1;
- }
- auto remaining = (size_t) (((*s.first)->data() + (*s.first)->size()) - (s.second->data() + s.second->size()));
- auto tocopy = std::min(remaining, bytes);
- memcpy((byte *) s.second->data() + s.second->size(), buffer, tocopy);
- buffer += tocopy;
- bytes -= tocopy;
- *written += tocopy;
- *s.second = {s.second->data(), s.second->size() + tocopy};
- if(remaining == tocopy)
- {
- _write_buffer_sink_idx++;
- }
- if(0 == bytes)
- {
- return 1;
- }
+ _write_error = std::move(r).as_failure();
+ LLFIO_OPENSSL_SET_RESULT_ERROR(2);
+ return 0;
}
-#endif
+ _write_socket_full = (b.size() == 0);
+ buffer += b.size();
+ bytes -= b.size();
+ *written += b.size();
return 1;
}();
#if LLFIO_OPENSSL_ENABLE_DEBUG_PRINTING
std::lock_guard<std::mutex> g(detail::openssl_printing_lock);
- std::cout << "_bwrite(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *written << " bytes written." << std::endl;
+ std::cout << this << "_bwrite(" << (void *) buffer << ", " << bytes << ") returns " << ret << " with " << *written << " bytes written." << std::endl;
#endif
return ret;
}
@@ -1214,52 +1378,11 @@ public:
assert(_lock_holder.owns_lock());
if(_ssl_bio != nullptr)
{
- auto *m = LLFIO_OPENSSL_DISPATCH(multiplexer, multiplexer, ());
- do
+ auto res = BIO_flush(_ssl_bio);
+ if(res <= 0 && !BIO_should_retry(_ssl_bio))
{
- auto res = BIO_flush(_ssl_bio);
- if(res <= 0 && !BIO_should_retry(_ssl_bio))
- {
- return openssl_error(this).as_failure();
- }
- if(!_towrite_source_empty())
- {
- if(m != nullptr)
- {
- deadline nd;
- LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
- OUTCOME_TRY(m->check_for_any_completed_io(nd));
- }
- if(this->is_nonblocking())
- {
- _write_deadline = std::chrono::seconds(0);
- }
- else
- {
- _write_deadline = {};
- }
- size_t written = 0;
- res = _bwrite(_self_bio, nullptr, 0, &written);
- if(res <= 0 && !BIO_should_retry(_ssl_bio))
- {
- return openssl_error(this).as_failure();
- }
- if(this->is_nonblocking())
- {
- _read_deadline = std::chrono::seconds(0);
- }
- else
- {
- _read_deadline = {};
- }
- size_t read = 0;
- res = _bread(_self_bio, nullptr, 0, &read);
- if(res <= 0 && !BIO_should_retry(_ssl_bio))
- {
- return openssl_error(this).as_failure();
- }
- }
- } while(!_towrite_source_empty());
+ return openssl_error(this).as_failure();
+ }
}
return success();
}
@@ -1288,12 +1411,12 @@ class listening_openssl_socket_handle final : public listening_tls_socket_handle
#undef LLFIO_OPENSSL_DISPATCH
/* *this has a vptr whose functions point into this class, so what
- we need is to bind the function implementation listening_socket_handle
+ we need is to bind the function implementation listening_byte_socket_handle
using its vptr and call it.
*/
#define LLFIO_OPENSSL_DISPATCH(functp, functt, ...) \
- ((_v.behaviour & native_handle_type::disposition::is_pointer) ? (reinterpret_cast<listening_socket_handle *>(_v.ptr)->functp) __VA_ARGS__ : \
- (socket_cast<listening_socket_handle>(this)->listening_socket_handle::functt) __VA_ARGS__)
+ ((_v.behaviour & native_handle_type::disposition::is_pointer) ? (reinterpret_cast<listening_byte_socket_handle *>(_v.ptr)->functp) __VA_ARGS__ : \
+ (listening_byte_socket_handle::functt) __VA_ARGS__)
protected:
virtual result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept override
@@ -1303,8 +1426,8 @@ protected:
{
return errc::not_supported;
}
- listening_socket_handle::buffer_type b;
- OUTCOME_TRY(auto &&read, _underlying_read<listening_socket_handle>({b}, d));
+ listening_byte_socket_handle::buffer_type b;
+ OUTCOME_TRY(auto &&read, _underlying_read<listening_byte_socket_handle>({b}, d));
auto *p = new(std::nothrow) openssl_socket_handle(std::move(read.connected_socket().first));
if(p == nullptr)
{
@@ -1312,7 +1435,7 @@ protected:
}
req.buffers.connected_socket() = {tls_socket_handle_ptr(p), read.connected_socket().second};
OUTCOME_TRY(p->set_registered_buffer_chunk_size(_registered_buffer_chunk_size));
- OUTCOME_TRY(p->_init(false, !_authentication_certificates_path || !_authentication_certificates_path->empty()));
+ OUTCOME_TRY(p->_init(false, _authentication_certificates_path));
return {std::move(req.buffers)};
}
@@ -1323,8 +1446,8 @@ protected:
{
return errc::not_supported;
}
- listening_socket_handle::buffer_type b;
- OUTCOME_TRY(auto &&read, _underlying_read<listening_socket_handle>({b}, d));
+ listening_byte_socket_handle::buffer_type b;
+ OUTCOME_TRY(auto &&read, _underlying_read<listening_byte_socket_handle>({b}, d));
auto *p = new(std::nothrow) openssl_socket_handle(std::move(read.connected_socket().first));
if(p == nullptr)
{
@@ -1332,16 +1455,16 @@ protected:
}
req.buffers.connected_socket() = {tls_socket_handle_ptr(p), read.connected_socket().second};
OUTCOME_TRY(p->set_registered_buffer_chunk_size(_registered_buffer_chunk_size));
- OUTCOME_TRY(p->_init(false, !_authentication_certificates_path || !_authentication_certificates_path->empty()));
+ OUTCOME_TRY(p->_init(false, _authentication_certificates_path));
return {std::move(req.buffers)};
}
public:
- explicit listening_openssl_socket_handle(listening_socket_handle &&sock)
+ explicit listening_openssl_socket_handle(listening_byte_socket_handle &&sock)
: listening_tls_socket_handle(std::move(sock))
{
}
- explicit listening_openssl_socket_handle(listening_socket_handle *sock)
+ explicit listening_openssl_socket_handle(listening_byte_socket_handle *sock)
: listening_tls_socket_handle(handle(), nullptr)
{
this->_v.ptr = sock;
@@ -1482,7 +1605,7 @@ static struct openssl_socket_source_registration_t
virtual result<listening_tls_socket_handle_ptr> listening_socket(ip::family family, byte_socket_handle::mode _mode, byte_socket_handle::caching _caching,
byte_socket_handle::flag flags) noexcept override
{
- OUTCOME_TRY(auto &&sock, listening_socket_handle::listening_socket(family, _mode, _caching, flags));
+ OUTCOME_TRY(auto &&sock, listening_byte_socket_handle::listening_byte_socket(family, _mode, _caching, flags));
auto *p = new(std::nothrow) listening_openssl_socket_handle(std::move(sock));
if(p == nullptr)
{
@@ -1504,7 +1627,7 @@ static struct openssl_socket_source_registration_t
return {std::move(ret)};
}
- virtual result<listening_tls_socket_handle_ptr> wrap(listening_socket_handle *listening) noexcept override
+ virtual result<listening_tls_socket_handle_ptr> wrap(listening_byte_socket_handle *listening) noexcept override
{
auto *p = new(std::nothrow) listening_openssl_socket_handle(listening);
if(p == nullptr)
diff --git a/include/llfio/v2.0/detail/impl/windows/byte_io_handle.ipp b/include/llfio/v2.0/detail/impl/windows/byte_io_handle.ipp
index 888249a7..d4307c6e 100644
--- a/include/llfio/v2.0/detail/impl/windows/byte_io_handle.ipp
+++ b/include/llfio/v2.0/detail/impl/windows/byte_io_handle.ipp
@@ -91,7 +91,9 @@ inline bool do_read_write(byte_io_handle::io_result<BuffersType> &ret, Syscall &
EIOSB &ol = *ol_it++;
ol.Status = -1;
}
- auto cancel_io = make_scope_exit([&]() noexcept {
+ auto cancel_io = make_scope_exit(
+ [&]() noexcept
+ {
if(nativeh.is_nonblocking())
{
if(ol_it != ols.begin() + 1)
@@ -147,7 +149,8 @@ inline bool do_read_write(byte_io_handle::io_result<BuffersType> &ret, Syscall &
if(STATUS_TIMEOUT == ntwait(nativeh.h, ol, nd))
{
// ntwait cancels the i/o, undoer will cancel all the other i/o
- auto r = [&]() -> result<void> {
+ auto r = [&]() -> result<void>
+ {
LLFIO_WIN_DEADLINE_TO_TIMEOUT_LOOP(d);
return success();
}();
@@ -233,7 +236,8 @@ retry:
if(STATUS_TIMEOUT == ntwait(nativeh.h, ol, nd))
{
// ntwait cancels the i/o, undoer will cancel all the other i/o
- auto r = [&]() -> result<void> {
+ auto r = [&]() -> result<void>
+ {
LLFIO_DEADLINE_TO_TIMEOUT_LOOP(d);
return success();
}();
@@ -285,7 +289,8 @@ retry:
ret = win32_error(WSAGetLastError());
return true;
}
- auto r = [&]() -> result<void> {
+ auto r = [&]() -> result<void>
+ {
LLFIO_DEADLINE_TO_TIMEOUT_LOOP(d);
return success();
}();
@@ -351,7 +356,8 @@ result<size_t> poll(span<poll_what> out, span<pollable_handle *> handles, span<c
{
if(handles[n] != nullptr)
{
- auto &h = handles[n]->_get_handle();
+ auto &h_ = handles[n]->_get_handle();
+ auto &h = h_.native_handle().is_third_party_pointer() ? *(handle *) h_.native_handle().ptr : h_;
if(h.is_kernel_handle())
{
fds[fdscount].fd = h.native_handle().sock;
@@ -416,7 +422,8 @@ result<size_t> poll(span<poll_what> out, span<pollable_handle *> handles, span<c
{
if(handles[n] != nullptr)
{
- auto &h = handles[n]->_get_handle();
+ auto &h_ = handles[n]->_get_handle();
+ auto &h = h_.native_handle().is_third_party_pointer() ? *(handle *) h_.native_handle().ptr : h_;
if(h.is_kernel_handle())
{
if(fds[fdscount].revents != 0)
diff --git a/include/llfio/v2.0/detail/impl/windows/byte_socket_handle.ipp b/include/llfio/v2.0/detail/impl/windows/byte_socket_handle.ipp
index e0eb3f7a..ca73bdad 100644
--- a/include/llfio/v2.0/detail/impl/windows/byte_socket_handle.ipp
+++ b/include/llfio/v2.0/detail/impl/windows/byte_socket_handle.ipp
@@ -88,6 +88,7 @@ namespace ip
OVERLAPPED ol;
::ADDRINFOEXW *res{nullptr};
HANDLE ophandle{nullptr};
+ bool done{false};
resolver_impl() { clear(); }
@@ -107,12 +108,13 @@ namespace ip
res = nullptr;
}
ophandle = nullptr;
+ done = false;
}
// Returns 0 for not ready yet, -1 for already processed, +1 for just processed
int check(DWORD millis)
{
- if(0 != WaitForSingleObject(ol.hEvent, millis))
+ if(!done && 0 != WaitForSingleObject(ol.hEvent, millis))
{
return 0;
}
@@ -120,6 +122,7 @@ namespace ip
{
return -1;
}
+ done = true;
auto unaddrinfo = make_scope_exit(
[&]() noexcept
{
@@ -188,7 +191,7 @@ namespace ip
void resolver_deleter::operator()(resolver *_p) const
{
auto *p = static_cast<resolver_impl *>(_p);
- if(0 != WaitForSingleObject(p->ol.hEvent, 0))
+ if(!p->done && 0 != WaitForSingleObject(p->ol.hEvent, 0))
{
GetAddrInfoExCancel(&p->ophandle);
WaitForSingleObject(p->ol.hEvent, INFINITE);
@@ -220,7 +223,7 @@ namespace ip
bool resolver::incomplete() const noexcept
{
auto *self = static_cast<const detail::resolver_impl *>(this);
- return 0 != WaitForSingleObject(self->ol.hEvent, 0);
+ return !self->done && 0 != WaitForSingleObject(self->ol.hEvent, 0);
}
result<span<address>> resolver::get() noexcept
{
@@ -326,9 +329,15 @@ namespace ip
_timeout.tv_usec = (long) (diff % 1000000);
}
timeout = &_timeout;
+ // Can't combine blocking and timeouts
+ flags &= ~resolve_flag::blocking;
}
p->hints.ai_socktype = SOCK_STREAM;
- p->hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED;
+ p->hints.ai_flags = AI_ADDRCONFIG;
+ if(_family != family::v4)
+ {
+ p->hints.ai_flags |= AI_V4MAPPED;
+ }
if(flags & resolve_flag::passive)
{
p->hints.ai_flags |= AI_PASSIVE;
@@ -343,24 +352,29 @@ namespace ip
return ntkernel_error(ntstat);
}
std::wstring ret;
- ret.resize(written);
+ ret.resize(written / sizeof(wchar_t));
written = 0;
// Do the conversion UTF-8 to UTF-16
- ntstat = RtlUTF8ToUnicodeN(const_cast<wchar_t *>(ret.data()), static_cast<ULONG>(ret.size()), &written, str.c_str(), static_cast<ULONG>(str.size()));
+ ntstat = RtlUTF8ToUnicodeN(const_cast<wchar_t *>(ret.data()), static_cast<ULONG>(ret.size() * sizeof(wchar_t)), &written, str.c_str(),
+ static_cast<ULONG>(str.size()));
if(ntstat < 0)
{
return ntkernel_error(ntstat);
}
- ret.resize(written);
+ ret.resize(written / sizeof(wchar_t));
return ret;
};
OUTCOME_TRY(auto &&_name, to_wstring(p->name));
OUTCOME_TRY(auto &&_service, to_wstring(p->service));
auto errcode = GetAddrInfoExW(_name.c_str(), _service.c_str(), NS_ALL, nullptr, &p->hints, &p->res, timeout,
(flags & resolve_flag::blocking) ? nullptr : &p->ol, nullptr, (flags & resolve_flag::blocking) ? nullptr : &p->ophandle);
- if(NO_ERROR != errcode && WSA_IO_PENDING != errcode)
+ if(NO_ERROR == errcode)
+ {
+ p->done = true;
+ }
+ else if(WSA_IO_PENDING != errcode)
{
- SetEvent(p->ol.hEvent);
+ p->done = true;
return win32_error(errcode);
}
p->check(0);
@@ -601,7 +615,7 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> byte_socket_handle::close() noexcep
/*******************************************************************************************************************/
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_socket_handle::local_endpoint() const noexcept
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_byte_socket_handle::local_endpoint() const noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
ip::address ret;
@@ -613,7 +627,7 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<ip::address> listening_socket_handle::loc
return ret;
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_socket_handle::bind(const ip::address &addr, creation _creation, int backlog) noexcept
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_byte_socket_handle::bind(const ip::address &addr, creation _creation, int backlog) noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
if(_creation != creation::only_if_not_exist)
@@ -635,16 +649,16 @@ LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<void> listening_socket_handle::bind(const
return success();
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle> listening_socket_handle::listening_socket(ip::family family, mode _mode, caching _caching,
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_byte_socket_handle> listening_byte_socket_handle::listening_byte_socket(ip::family family, mode _mode, caching _caching,
flag flags) noexcept
{
- result<listening_socket_handle> ret(listening_socket_handle(native_handle_type(), flags, nullptr));
+ result<listening_byte_socket_handle> ret(listening_byte_socket_handle(native_handle_type(), flags, nullptr));
native_handle_type &nativeh = ret.value()._v;
OUTCOME_TRY(detail::create_socket(&ret.value(), nativeh, family, _mode, _caching, flags));
return ret;
}
-LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_socket_handle::buffers_type> listening_socket_handle::_do_read(io_request<buffers_type> req,
+LLFIO_HEADERS_ONLY_MEMFUNC_SPEC result<listening_byte_socket_handle::buffers_type> listening_byte_socket_handle::_do_read(io_request<buffers_type> req,
deadline d) noexcept
{
LLFIO_LOG_FUNCTION_CALL(this);
diff --git a/include/llfio/v2.0/detail/impl/windows/test/iocp_multiplexer.ipp b/include/llfio/v2.0/detail/impl/windows/test/iocp_multiplexer.ipp
index ae8f5c94..ea05dd89 100644
--- a/include/llfio/v2.0/detail/impl/windows/test/iocp_multiplexer.ipp
+++ b/include/llfio/v2.0/detail/impl/windows/test/iocp_multiplexer.ipp
@@ -107,7 +107,7 @@ namespace test
ret.multiplexes.kernel.file_handle = true;
ret.multiplexes.kernel.pipe_handle = true;
ret.multiplexes.kernel.byte_socket_handle = true;
- ret.multiplexes.kernel.listening_socket_handle = false;
+ ret.multiplexes.kernel.listening_byte_socket_handle = false;
return ret;
}();
return v;
diff --git a/include/llfio/v2.0/tls_socket_handle.hpp b/include/llfio/v2.0/tls_socket_handle.hpp
index c5121e2e..474b9aec 100644
--- a/include/llfio/v2.0/tls_socket_handle.hpp
+++ b/include/llfio/v2.0/tls_socket_handle.hpp
@@ -147,7 +147,7 @@ public:
{
deadline nd;
LLFIO_DEADLINE_TO_PARTIAL_DEADLINE(nd, d);
- lasterror = byte_socket_handle::connect(address, nd);
+ lasterror = this->connect(address, nd);
if(lasterror)
{
return lasterror;
@@ -196,19 +196,30 @@ using tls_socket_handle_ptr = std::unique_ptr<tls_socket_handle, detail::tls_soc
As you cannot create one of these on your own, one generally acquires one of these
from a `tls_socket_source`.
*/
-class LLFIO_DECL listening_tls_socket_handle : public listening_socket_handle_impl<tls_socket_handle_ptr>
+class LLFIO_DECL listening_tls_socket_handle : public listening_socket_handle_buffer_types_injector<listening_byte_socket_handle, tls_socket_handle_ptr>
{
- using _base = listening_socket_handle_impl<tls_socket_handle_ptr>;
+ using _base = listening_socket_handle_buffer_types_injector<listening_byte_socket_handle, tls_socket_handle_ptr>;
protected:
constexpr listening_tls_socket_handle() {}
- explicit listening_tls_socket_handle(listening_socket_handle &&sock)
+ explicit listening_tls_socket_handle(listening_byte_socket_handle &&sock)
: _base(sock.release(), sock.flags(), sock.multiplexer())
{
this->_v.behaviour |= native_handle_type::disposition::tls_socket;
}
using _base::_base;
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#endif
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC result<buffers_type> _do_read(io_request<buffers_type> req, deadline d) noexcept = 0;
+
+ LLFIO_HEADERS_ONLY_VIRTUAL_SPEC io_result<buffers_type> _do_multiplexer_read(io_request<buffers_type> reqs, deadline d) noexcept = 0;
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+
public:
/*! \brief Returns an implementation defined string describing the algorithms
to be chosen during connection. Can be an empty string if the implementation
@@ -240,6 +251,41 @@ public:
defined identifier.
*/
virtual result<void> set_authentication_certificates_path(path_view identifier) noexcept = 0;
+
+ /*! Read the contents of the listening socket for newly connected byte sockets.
+
+ \return Returns the buffers filled, with its socket handle and address set to the newly connected socket.
+ \param req A buffer to fill with a newly connected socket.
+ \param d An optional deadline by which to time out.
+
+ \errors Any of the errors which `accept()` or `WSAAccept()` might return.
+ */
+ LLFIO_MAKE_FREE_FUNCTION
+ result<buffers_type> read(io_request<buffers_type> req, deadline d = {}) noexcept
+ {
+ return (_ctx == nullptr) ? _do_read(std::move(req), d) : _do_multiplexer_read(std::move(req), d);
+ }
+
+ /*! \brief A coroutinised equivalent to `.read()` which suspends the coroutine until
+ a new incoming connection occurs. **Blocks execution** i.e is equivalent to `.read()` if no i/o multiplexer
+ has been set on this handle!
+
+ The awaitable returned is **eager** i.e. it immediately begins the i/o. If the i/o completes
+ and finishes immediately, no coroutine suspension occurs.
+ */
+ LLFIO_MAKE_FREE_FUNCTION
+ awaitable<io_result<buffers_type>> co_read(io_request<buffers_type> reqs, deadline d = {}) noexcept;
+#if 0 // TODO
+ {
+ if(_ctx == nullptr)
+ {
+ return awaitable<io_result<buffers_type>>(read(std::move(reqs), d));
+ }
+ awaitable<io_result<buffers_type>> ret;
+ ret.set_state(_ctx->construct(ret._state_storage, this, nullptr, d, reqs.buffers.connected_socket()));
+ return ret;
+ }
+#endif
};
namespace detail
@@ -348,7 +394,7 @@ public:
//! Returns a pointer to a new `listening_tls_socket_handle` instance, which will wrap `listening`. `listening` must NOT change address until the
//! `listening_tls_socket_handle` is closed.
- virtual result<listening_tls_socket_handle_ptr> wrap(listening_socket_handle *listening) noexcept = 0;
+ virtual result<listening_tls_socket_handle_ptr> wrap(listening_byte_socket_handle *listening) noexcept = 0;
};
namespace detail
diff --git a/test/tests/byte_socket_handle.cpp b/test/tests/byte_socket_handle.cpp
index 0ba8b2a2..785398fb 100644
--- a/test/tests/byte_socket_handle.cpp
+++ b/test/tests/byte_socket_handle.cpp
@@ -375,7 +375,7 @@ static inline void TestSocketResolve()
static inline void TestBlockingSocketHandles()
{
namespace llfio = LLFIO_V2_NAMESPACE;
- auto serversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::read).value();
+ auto serversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::read).value();
BOOST_REQUIRE(serversocket.is_valid());
BOOST_CHECK(serversocket.is_socket());
BOOST_CHECK(serversocket.is_readable());
@@ -438,7 +438,7 @@ static inline void TestBlockingSocketHandles()
static inline void TestNonBlockingSocketHandles()
{
namespace llfio = LLFIO_V2_NAMESPACE;
- auto serversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::read,
+ auto serversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::read,
llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable)
.value();
BOOST_REQUIRE(serversocket.is_valid());
@@ -602,7 +602,7 @@ static inline void TestMultiplexedSocketHandles()
}
};
auto serversocket =
- llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::write,
+ llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::write,
llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable)
.value();
serversocket.bind(llfio::ip::address_v4::loopback()).value();
@@ -727,7 +727,7 @@ static inline void TestCoroutinedSocketHandles()
}
};
auto serversocket =
- llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::write, llfio::byte_socket_handle::caching::all,
+ llfio::listening_byte_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::write, llfio::byte_socket_handle::caching::all,
llfio::byte_socket_handle::flag::multiplexable)
.value();
serversocket.bind(llfio::ip::address_v4::loopback()).value();
@@ -824,12 +824,12 @@ static inline void TestPollingSocketHandles()
{
static constexpr size_t MAX_SOCKETS = 64;
namespace llfio = LLFIO_V2_NAMESPACE;
- std::vector<std::pair<llfio::listening_socket_handle, llfio::ip::address>> listening;
+ std::vector<std::pair<llfio::listening_byte_socket_handle, llfio::ip::address>> listening;
std::vector<llfio::byte_socket_handle> sockets;
std::vector<size_t> idxs;
for(size_t n = 0; n < MAX_SOCKETS; n++)
{
- auto s = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value();
+ auto s = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4).value();
s.bind(llfio::ip::address_v4::loopback()).value();
auto endpoint = s.local_endpoint().value();
if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr)
diff --git a/test/tests/tls_socket_handle.cpp b/test/tests/tls_socket_handle.cpp
index ee78320c..cade8a91 100644
--- a/test/tests/tls_socket_handle.cpp
+++ b/test/tests/tls_socket_handle.cpp
@@ -42,8 +42,9 @@ static inline void TestBlockingTLSSocketHandles()
BOOST_CHECK(serversocket->is_socket());
BOOST_CHECK(serversocket->is_readable());
BOOST_CHECK(serversocket->is_writable());
- // Disable server authentication
+ // Disable authentication
serversocket->set_authentication_certificates_path({}).value();
+ writer->set_authentication_certificates_path({}).value();
{
auto desc = serversocket->algorithms_description();
std::cout << "Server socket will offer during handshake the ciphers (" << (1 + std::count(desc.begin(), desc.end(), ',')) << "): ";
@@ -89,7 +90,7 @@ static inline void TestBlockingTLSSocketHandles()
g.unlock();
serversocket->close().value();
llfio::byte buffer[64];
- auto read = s.first->read(0, {{buffer, 64}}).value();
+ auto read = s.first->read({{buffer, 64}}).value();
g.lock();
std::cout << "\nThe inbound server socket negotiated the cipher " << s.first->algorithms_description() << std::endl;
BOOST_REQUIRE(read == 5);
@@ -123,7 +124,7 @@ static inline void TestBlockingTLSSocketHandles()
BOOST_CHECK(writer->is_writable());
std::cout << "\nThe connecting socket negotiated the cipher " << writer->algorithms_description() << std::endl;
g.unlock();
- auto written = writer->write(0, {{(const llfio::byte *) "hello", 5}}).value();
+ auto written = writer->write({{(const llfio::byte *) "hello", 5}}).value();
BOOST_REQUIRE(written == 5);
writer->shutdown_and_close().value();
readerthread.get();
@@ -132,79 +133,190 @@ static inline void TestBlockingTLSSocketHandles()
runtest(tls_socket_source->listening_socket(llfio::ip::family::v4).value(), tls_socket_source->connecting_socket(llfio::ip::family::v4).value());
std::cout << "\nWrapped TLS socket:\n" << std::endl;
- auto rawserversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value();
+ auto rawserversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4).value();
auto rawwriter = llfio::byte_socket_handle::byte_socket(llfio::ip::family::v4).value();
runtest(tls_socket_source->wrap(&rawserversocket).value(), tls_socket_source->wrap(&rawwriter).value());
}
-#if 0
static inline void TestNonBlockingTLSSocketHandles()
{
namespace llfio = LLFIO_V2_NAMESPACE;
- auto serversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::read,
- llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable)
- .value();
- BOOST_REQUIRE(serversocket.is_valid());
- BOOST_CHECK(serversocket.is_socket());
- BOOST_CHECK(serversocket.is_readable());
- BOOST_CHECK(!serversocket.is_writable());
- serversocket.bind(llfio::ip::address_v4::loopback()).value();
- auto endpoint = serversocket.local_endpoint().value();
- std::cout << "Server socket is listening on " << endpoint << std::endl;
- if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr)
+ if(llfio::tls_socket_source_registry::empty())
{
- std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test."
- << std::endl;
+ std::cout << "\nNOTE: This platform has no TLS socket sources in its registry, skipping this test." << std::endl;
return;
}
+ auto tls_socket_source = llfio::tls_socket_source_registry::default_source().instantiate().value();
+ auto runtest = [](llfio::listening_tls_socket_handle_ptr serversocket, auto &&make_writer)
+ {
+ BOOST_REQUIRE(serversocket->is_valid());
+ BOOST_CHECK(serversocket->is_socket());
+ BOOST_CHECK(serversocket->is_readable());
+ BOOST_CHECK(serversocket->is_writable());
+ // Disable authentication
+ serversocket->set_authentication_certificates_path({}).value();
+ serversocket->bind(llfio::ip::address_v4::loopback()).value();
+ auto endpoint = serversocket->local_endpoint().value();
+ std::cout << "Server socket is listening on " << endpoint << std::endl;
+ if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr)
+ {
+ std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test."
+ << std::endl;
+ return;
+ }
- std::pair<llfio::byte_socket_handle, llfio::ip::address> reader;
- { // no incoming, so non-blocking read should time out
- auto read = serversocket.read({reader}, std::chrono::milliseconds(0));
- BOOST_REQUIRE(read.has_error());
- BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
- }
- { // no incoming, so blocking read should time out
- auto read = serversocket.read({reader}, std::chrono::seconds(1));
- BOOST_REQUIRE(read.has_error());
- BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
- }
+ std::pair<llfio::tls_socket_handle_ptr, llfio::ip::address> reader;
+ { // no incoming, so non-blocking read should time out
+ auto read = serversocket->read({reader}, std::chrono::milliseconds(0));
+ BOOST_REQUIRE(read.has_error());
+ BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
+ }
+ { // no incoming, so blocking read should time out
+ auto read = serversocket->read({reader}, std::chrono::seconds(1));
+ BOOST_REQUIRE(read.has_error());
+ BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
+ }
- // Form the connection.
- auto writer = llfio::byte_socket_handle::byte_socket(llfio::ip::family::v4, llfio::byte_socket_handle::mode::append,
- llfio::byte_socket_handle::caching::reads, llfio::byte_socket_handle::flag::multiplexable)
- .value();
- writer.connect(endpoint).value();
- serversocket.read({reader}, std::chrono::seconds(1)).value();
- std::cout << "Server socket sees incoming connection from " << reader.second << std::endl;
+ // Form the connection.
+ llfio::tls_socket_handle_ptr writer = make_writer();
+ // Disable authentication
+ writer->set_authentication_certificates_path({}).value();
+ {
+ auto r = writer->connect(endpoint, std::chrono::milliseconds(0));
+ BOOST_REQUIRE(!r);
+ BOOST_REQUIRE(r.error() == llfio::errc::timed_out); // can't possibly connect immediately
+ }
+ serversocket->read({reader}, std::chrono::seconds(1)).value();
+ std::cout << "Server socket sees incoming connection from " << reader.second << std::endl;
+ // The TLS negotiation needs to do various reading and writing until the connection goes up
+ bool connected = false;
+ for(size_t count=0;count<1024;count++)
+ {
+ auto r1 = writer->connect(endpoint, std::chrono::milliseconds(0));
+ auto r2 = reader.first->read({{nullptr, 0}}, std::chrono::milliseconds(0));
+ (void) r2;
+ if(r1)
+ {
+ connected = true;
+ break;
+ }
+ std::cout << "*" << std::flush;
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ std::cout << std::endl;
+ BOOST_REQUIRE(connected);
- llfio::byte buffer[64];
- { // no data, so non-blocking read should time out
- auto read = reader.first.read(0, {{buffer, 64}}, std::chrono::milliseconds(0));
- BOOST_REQUIRE(read.has_error());
- BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
+ llfio::byte buffer[64];
+ { // no data, so non-blocking read should time out
+ auto read = reader.first->read(0, {{buffer, 64}}, std::chrono::milliseconds(0));
+ BOOST_REQUIRE(read.has_error());
+ std::cout << read.error().message() << std::endl;
+ BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
+ }
+ { // no data, so blocking read should time out
+ auto read = reader.first->read(0, {{buffer, 64}}, std::chrono::seconds(1));
+ if(!read.has_error())
+ {
+ std::cout << "Blocking read did not return error, instead returned " << read.value() << std::endl;
+ }
+ BOOST_REQUIRE(read.has_error());
+ std::cout << read.error().message() << std::endl;
+ BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
+ }
+ auto written = writer->write(0, {{(const llfio::byte *) "hello", 5}}).value();
+ BOOST_REQUIRE(written == 5);
+ // writer.shutdown_and_close().value(); // would block until socket drained by reader
+ // writer.close().value(); // would cause all further reads to fail due to socket broken
+ auto read = reader.first->read(0, {{buffer, 64}}, std::chrono::milliseconds(1));
+ BOOST_REQUIRE(read.value() == 5);
+ BOOST_CHECK(0 == memcmp(buffer, "hello", 5));
+ // The TLS shutdown needs to do various reading and writing until the connection goes down
+ for(size_t count = 0; count < 1024; count++)
+ {
+ auto r1 = writer->shutdown_and_close(std::chrono::milliseconds(0));
+ auto r2 = reader.first->read({{nullptr, 0}}, std::chrono::milliseconds(0));
+ (void) r2;
+ if(r1)
+ {
+ connected = false;
+ break;
+ }
+ std::cout << "*" << std::flush;
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ }
+ std::cout << std::endl;
+ BOOST_REQUIRE(!connected);
+ writer->close().value();
+ reader.first->close().value();
+ };
+ std::cout << "\nUnwrapped TLS socket:\n" << std::endl;
+ runtest(tls_socket_source->multiplexable_listening_socket(llfio::ip::family::v4).value(),
+ [&] { return tls_socket_source->multiplexable_connecting_socket(llfio::ip::family::v4).value(); });
+
+ std::cout << "\nWrapped TLS socket:\n" << std::endl;
+ auto rawserversocket = llfio::listening_byte_socket_handle::multiplexable_listening_byte_socket(llfio::ip::family::v4).value();
+ auto rawwriter = llfio::byte_socket_handle::multiplexable_byte_socket(llfio::ip::family::v4).value();
+ runtest(tls_socket_source->wrap(&rawserversocket).value(), [&] { return tls_socket_source->wrap(&rawwriter).value(); });
+}
+
+/* This test makes the assumption that the host OS is able to validate github.com's
+TLS certificate.
+*/
+static inline void TestAuthenticatingTLSSocketHandles()
+{
+ static constexpr const char *test_host = "github.com";
+ static constexpr const char *get_request = R"(GET / HTTP/1.0
+Host: github.com
+
+)";
+ namespace llfio = LLFIO_V2_NAMESPACE;
+ if(llfio::tls_socket_source_registry::empty())
+ {
+ std::cout << "\nNOTE: This platform has no TLS socket sources in its registry, skipping this test." << std::endl;
+ return;
+ }
+ auto test_host_ip = llfio::ip::resolve(test_host, "https", llfio::ip::family::any, {}, llfio::ip::resolve_flag::blocking).value()->get().value().front();
+ std::cout << "The IP address of " << test_host << " is " << test_host_ip << std::endl;
+ BOOST_REQUIRE(test_host_ip.is_v4() || test_host_ip.is_v6());
+ auto tls_socket_source = llfio::tls_socket_source_registry::default_source().instantiate().value();
+ auto sock = tls_socket_source->multiplexable_connecting_socket(test_host_ip.family()).value();
+ {
+ auto r = sock->connect(test_host, 443, std::chrono::seconds(5));
+ if(!r)
+ {
+ if(r.error() == llfio::errc::timed_out || r.error() == llfio::errc::host_unreachable || r.error() == llfio::errc::network_unreachable)
+ {
+ std::cout << "\nNOTE: Failed to connect to " << test_host
+ << " within five seconds, assuming there is no internet connection and skipping this test. Error was: " << r.error().message() << std::endl;
+ return;
+ }
+ r.value();
+ }
}
- { // no data, so blocking read should time out
- auto read = reader.first.read(0, {{buffer, 64}}, std::chrono::seconds(1));
- if(!read.has_error())
+ // Get the front page
+ std::cout << "\nThe socket which connected to " << test_host << " negotiated the cipher " << sock->algorithms_description() << std::endl;
+ auto written = sock->write({{(const llfio::byte *) get_request, strlen(get_request)}}).value();
+ BOOST_REQUIRE(written == strlen(get_request));
+ // Fetch the front page. The connection will close once all data is sent.
+ std::vector<char> buffer(4096);
+ size_t offset = 0;
+ for(size_t readed = 0; (readed = sock->read({{(llfio::byte *) buffer.data() + offset, buffer.size() - offset}}, std::chrono::seconds(3)).value()) > 0;)
+ {
+ offset += readed;
+ if(buffer.size() - offset < 1024)
{
- std::cout << "Blocking read did not return error, instead returned " << read.value() << std::endl;
+ buffer.resize(buffer.size() + 4096);
}
- BOOST_REQUIRE(read.has_error());
- BOOST_REQUIRE(read.error() == llfio::errc::timed_out);
}
- auto written = writer.write(0, {{(const llfio::byte *) "hello", 5}}).value();
- BOOST_REQUIRE(written == 5);
- // writer.shutdown_and_close().value(); // would block until socket drained by reader
- // writer.close().value(); // would cause all further reads to fail due to socket broken
- auto read = reader.first.read(0, {{buffer, 64}}, std::chrono::milliseconds(1));
- BOOST_REQUIRE(read.value() == 5);
- BOOST_CHECK(0 == memcmp(buffer, "hello", 5));
- writer.shutdown_and_close().value(); // must not block nor fail
- writer.close().value();
- reader.first.close().value();
+ buffer.resize(offset);
+ std::cout << "\nRead from " << test_host << " " << offset << " bytes. The first 1024 bytes are:\n\n"
+ << llfio::string_view(buffer.data(), offset).substr(0, 1024) << "\n"
+ << std::endl;
+ // Make sure this doesn't hang because the socket is closed
+ sock->shutdown_and_close().value();
}
+#if 0
#if LLFIO_ENABLE_TEST_IO_MULTIPLEXERS
static inline void TestMultiplexedTLSSocketHandles()
{
@@ -522,162 +634,207 @@ static inline void TestCoroutinedTLSSocketHandles()
}
#endif
#endif
+#endif
static inline void TestPollingTLSSocketHandles()
{
static constexpr size_t MAX_SOCKETS = 64;
namespace llfio = LLFIO_V2_NAMESPACE;
- std::vector<std::pair<llfio::listening_socket_handle, llfio::ip::address>> listening;
- std::vector<llfio::byte_socket_handle> sockets;
- std::vector<size_t> idxs;
- for(size_t n = 0; n < MAX_SOCKETS; n++)
+ if(llfio::tls_socket_source_registry::empty())
{
- auto s = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value();
- s.bind(llfio::ip::address_v4::loopback()).value();
- auto endpoint = s.local_endpoint().value();
- if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr)
- {
- std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test."
- << std::endl;
- return;
- }
- listening.emplace_back(std::move(s), endpoint);
- sockets.push_back(llfio::byte_socket_handle::byte_socket(llfio::ip::family::v4).value());
- idxs.push_back(n);
+ std::cout << "\nNOTE: This platform has no TLS socket sources in its registry, skipping this test." << std::endl;
+ return;
}
- QUICKCPPLIB_NAMESPACE::algorithm::small_prng::random_shuffle(idxs.begin(), idxs.end());
- std::mutex lock;
- std::atomic<size_t> currently_connecting{(size_t)-1};
- auto poll_listening_task = std::async(std::launch::async, [&] {
- std::vector<llfio::pollable_handle *> handles;
- std::vector<llfio::poll_what> what, out;
- for(size_t n = 0; n < MAX_SOCKETS; n++)
- {
- handles.push_back(&listening[n].first);
- what.push_back(llfio::poll_what::is_readable);
- out.push_back(llfio::poll_what::none);
- }
- for(;;)
- {
- int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value();
- bool done = true;
- for(size_t n = 0; n < MAX_SOCKETS; n++)
- {
- auto idx = idxs[n];
- if(handles[idx] != nullptr)
- {
- done = false;
- if(out[idx] & llfio::poll_what::is_readable)
- {
- {
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Poll listening sees readable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx << ". Currently connecting is "
- << currently_connecting << std::endl;
- }
- BOOST_CHECK(currently_connecting == idx);
- std::pair<llfio::byte_socket_handle, llfio::ip::address> s;
- listening[idx].first.read({s}).value();
- handles[idx] = nullptr;
- ret--;
- }
- out[idx] = llfio::poll_what::none;
- }
- }
- BOOST_CHECK(ret == 0);
- if(done)
- {
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Poll listening task exits." << std::endl;
- break;
- }
- }
- });
- auto poll_connecting_task = std::async(std::launch::async, [&] {
- std::vector<llfio::pollable_handle *> handles;
- std::vector<llfio::poll_what> what, out;
- for(size_t n = 0; n < MAX_SOCKETS; n++)
- {
- handles.push_back(&sockets[n]);
- what.push_back(llfio::poll_what::is_writable);
- out.push_back(llfio::poll_what::none);
- }
- for(;;)
- {
- int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value();
- bool done = true, saw_closed = false;
- size_t remaining = MAX_SOCKETS;
- for(size_t n = 0; n < MAX_SOCKETS; n++)
- {
- auto idx = idxs[n];
- if(handles[idx] != nullptr)
- {
- done = false;
- // On Linux, a new socket not yet connected MAY appear as both writable and hanged up,
- // so filter out the closed.
- if(!(out[idx] & llfio::poll_what::is_closed) || (remaining == 1 && currently_connecting == idx))
- {
- if(out[idx] & llfio::poll_what::is_writable)
- {
- {
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Poll connect sees writable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx << ". Currently connecting is "
- << currently_connecting << std::endl;
- }
- BOOST_CHECK(currently_connecting == idx);
- handles[idx] = nullptr;
- ret--;
- }
- }
- else
- {
- saw_closed = true;
- }
- out[idx] = llfio::poll_what::none;
- }
- else
- {
- remaining--;
- }
- }
- if(!saw_closed)
- {
- BOOST_CHECK(ret == 0);
- }
- if(done)
- {
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Poll connect task exits." << std::endl;
- break;
- }
- }
- });
- auto connect_task = std::async(std::launch::async, [&] {
+ auto tls_socket_source = llfio::tls_socket_source_registry::default_source().instantiate().value();
+ auto runtest = [](auto &&make_reader, auto &&make_writer)
+ {
+ std::vector<std::pair<llfio::listening_tls_socket_handle_ptr, llfio::ip::address>> listening;
+ std::vector<llfio::tls_socket_handle_ptr> sockets;
+ std::vector<size_t> idxs;
for(size_t n = 0; n < MAX_SOCKETS; n++)
{
- auto idx = idxs[n];
+ llfio::listening_tls_socket_handle_ptr s = make_reader(n);
+ // Disable authentication
+ s->set_authentication_certificates_path({}).value();
+ s->bind(llfio::ip::address_v4::loopback()).value();
+ auto endpoint = s->local_endpoint().value();
+ if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr)
{
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Connecting " << idx << " ... " << std::endl;
+ std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test."
+ << std::endl;
+ return;
}
- currently_connecting = idx;
- sockets[idx].connect(listening[idx].second).value();
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ listening.emplace_back(std::move(s), endpoint);
+ sockets.push_back(make_writer(n));
+ // Disable authentication
+ sockets.back()->set_authentication_certificates_path({}).value();
+ idxs.push_back(n);
}
- std::this_thread::sleep_for(std::chrono::seconds(1));
- std::lock_guard<std::mutex> g(lock);
- std::cout << "Connecting task exits." << std::endl;
- });
- connect_task.get();
- poll_listening_task.get();
- poll_connecting_task.get();
+ QUICKCPPLIB_NAMESPACE::algorithm::small_prng::random_shuffle(idxs.begin(), idxs.end());
+ std::mutex lock;
+ std::atomic<size_t> currently_connecting{(size_t) -1};
+ auto poll_listening_task = std::async(std::launch::async,
+ [&]
+ {
+ std::vector<llfio::pollable_handle *> handles;
+ std::vector<llfio::poll_what> what, out;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ handles.push_back(listening[n].first.get());
+ what.push_back(llfio::poll_what::is_readable);
+ out.push_back(llfio::poll_what::none);
+ }
+ for(;;)
+ {
+ int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value();
+ bool done = true;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ auto idx = idxs[n];
+ if(handles[idx] != nullptr)
+ {
+ done = false;
+ if(out[idx] & llfio::poll_what::is_readable)
+ {
+ {
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Poll listening sees readable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx
+ << ". Currently connecting is " << currently_connecting << std::endl;
+ }
+ BOOST_CHECK(currently_connecting == idx);
+ std::pair<llfio::tls_socket_handle_ptr, llfio::ip::address> s;
+ listening[idx].first->read({s}).value();
+ llfio::byte buf;
+ llfio::tls_socket_handle::buffer_type b(&buf, 1);
+ s.first->read({{b}});
+ handles[idx] = nullptr;
+ ret--;
+ }
+ out[idx] = llfio::poll_what::none;
+ }
+ }
+ BOOST_CHECK(ret == 0);
+ if(done)
+ {
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Poll listening task exits." << std::endl;
+ break;
+ }
+ }
+ });
+ auto poll_connecting_task = std::async(std::launch::async,
+ [&]
+ {
+ std::vector<llfio::pollable_handle *> handles;
+ std::vector<llfio::poll_what> what, out;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ handles.push_back(sockets[n].get());
+ what.push_back(llfio::poll_what::is_writable);
+ out.push_back(llfio::poll_what::none);
+ }
+ for(;;)
+ {
+ int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value();
+ bool done = true, saw_closed = false;
+ size_t remaining = MAX_SOCKETS;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ auto idx = idxs[n];
+ if(handles[idx] != nullptr)
+ {
+ done = false;
+ // On Linux, a new socket not yet connected MAY appear as both writable and hanged up,
+ // so filter out the closed.
+ if(!(out[idx] & llfio::poll_what::is_closed) || (remaining == 1 && currently_connecting == idx))
+ {
+ if(out[idx] & llfio::poll_what::is_writable)
+ {
+ {
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Poll connect sees writable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx
+ << ". Currently connecting is " << currently_connecting << std::endl;
+ }
+ BOOST_CHECK(currently_connecting == idx);
+ handles[idx] = nullptr;
+ ret--;
+ }
+ }
+ else
+ {
+ saw_closed = true;
+ }
+ out[idx] = llfio::poll_what::none;
+ }
+ else
+ {
+ remaining--;
+ }
+ }
+ if(!saw_closed)
+ {
+ BOOST_CHECK(ret == 0);
+ }
+ if(done)
+ {
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Poll connect task exits." << std::endl;
+ break;
+ }
+ }
+ });
+ auto connect_task = std::async(std::launch::async,
+ [&]
+ {
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ auto idx = idxs[n];
+ {
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Connecting " << idx << " ... " << std::endl;
+ }
+ currently_connecting = idx;
+ sockets[idx]->connect(listening[idx].second).value();
+ llfio::byte buf(llfio::to_byte(0));
+ llfio::tls_socket_handle::const_buffer_type b(&buf, 1);
+ sockets[idx]->write({{&b, 1}});
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ std::lock_guard<std::mutex> g(lock);
+ std::cout << "Connecting task exits." << std::endl;
+ });
+ connect_task.get();
+ poll_listening_task.get();
+ poll_connecting_task.get();
+ };
+ std::cout << "\nUnwrapped TLS socket:\n" << std::endl;
+ runtest([&](size_t) { return tls_socket_source->multiplexable_listening_socket(llfio::ip::family::v4).value(); },
+ [&](size_t) { return tls_socket_source->multiplexable_connecting_socket(llfio::ip::family::v4).value(); });
+
+ std::cout << "\nWrapped TLS socket:\n" << std::endl;
+ std::vector<llfio::listening_byte_socket_handle> rawlisteners;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ rawlisteners.push_back(llfio::listening_byte_socket_handle::multiplexable_listening_byte_socket(llfio::ip::family::v4).value());
+ }
+ std::vector<llfio::byte_socket_handle> rawwriters;
+ for(size_t n = 0; n < MAX_SOCKETS; n++)
+ {
+ rawwriters.push_back(llfio::byte_socket_handle::multiplexable_byte_socket(llfio::ip::family::v4).value());
+ }
+ runtest([&](size_t idx) { return tls_socket_source->wrap(&rawlisteners[idx]).value(); },
+ [&](size_t idx) { return tls_socket_source->wrap(&rawwriters[idx]).value(); });
}
-#endif
KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, blocking, "Tests that blocking llfio::tls_byte_socket_handle works as expected",
TestBlockingTLSSocketHandles())
-#if 0
KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, nonblocking, "Tests that nonblocking llfio::tls_byte_socket_handle works as expected",
TestNonBlockingTLSSocketHandles())
+KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, authenticating,
+ "Tests that connecting to an authenticating server using llfio::tls_byte_socket_handle works as expected",
+ TestAuthenticatingTLSSocketHandles())
+#if 0
#if LLFIO_ENABLE_TEST_IO_MULTIPLEXERS
KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, multiplexed, "Tests that multiplexed llfio::tls_byte_socket_handle works as expected",
TestMultiplexedTLSSocketHandles())
@@ -686,5 +843,6 @@ KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, coroutined, "Tests
TestCoroutinedTLSSocketHandles())
#endif
#endif
-KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, poll, "Tests that polling llfio::tls_byte_socket_handle works as expected", TestPollingTLSSocketHandles())
#endif
+KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, poll, "Tests that polling llfio::tls_byte_socket_handle works as expected",
+ TestPollingTLSSocketHandles())