diff options
author | Niall Douglas (s [underscore] sourceforge {at} nedprod [dot] com) <spamtrap@nedprod.com> | 2022-04-08 23:29:46 +0300 |
---|---|---|
committer | Niall Douglas (s [underscore] sourceforge {at} nedprod [dot] com) <spamtrap@nedprod.com> | 2022-04-13 14:09:28 +0300 |
commit | 5c6f7f3933de89fb4e4a9aa7df69da933a8f09aa (patch) | |
tree | 2e0af0bbe06983827c63961809c3c05fefdc1bf8 /test | |
parent | 366d4e210a6152026285bbabdf58e9ee95c896d3 (diff) |
tls_socket_handle: Get non-blocking working with OpenSSL backend.
Diffstat (limited to 'test')
-rw-r--r-- | test/tests/byte_socket_handle.cpp | 12 | ||||
-rw-r--r-- | test/tests/tls_socket_handle.cpp | 379 |
2 files changed, 238 insertions, 153 deletions
diff --git a/test/tests/byte_socket_handle.cpp b/test/tests/byte_socket_handle.cpp index 0ba8b2a2..785398fb 100644 --- a/test/tests/byte_socket_handle.cpp +++ b/test/tests/byte_socket_handle.cpp @@ -375,7 +375,7 @@ static inline void TestSocketResolve() static inline void TestBlockingSocketHandles() { namespace llfio = LLFIO_V2_NAMESPACE; - auto serversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::read).value(); + auto serversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::read).value(); BOOST_REQUIRE(serversocket.is_valid()); BOOST_CHECK(serversocket.is_socket()); BOOST_CHECK(serversocket.is_readable()); @@ -438,7 +438,7 @@ static inline void TestBlockingSocketHandles() static inline void TestNonBlockingSocketHandles() { namespace llfio = LLFIO_V2_NAMESPACE; - auto serversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::read, + auto serversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::read, llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable) .value(); BOOST_REQUIRE(serversocket.is_valid()); @@ -602,7 +602,7 @@ static inline void TestMultiplexedSocketHandles() } }; auto serversocket = - llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::write, + llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::write, llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable) .value(); serversocket.bind(llfio::ip::address_v4::loopback()).value(); @@ -727,7 +727,7 @@ static inline void TestCoroutinedSocketHandles() } }; auto serversocket = - llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_socket_handle::mode::write, llfio::byte_socket_handle::caching::all, + llfio::listening_byte_socket_handle::listening_socket(llfio::ip::family::v4, llfio::listening_byte_socket_handle::mode::write, llfio::byte_socket_handle::caching::all, llfio::byte_socket_handle::flag::multiplexable) .value(); serversocket.bind(llfio::ip::address_v4::loopback()).value(); @@ -824,12 +824,12 @@ static inline void TestPollingSocketHandles() { static constexpr size_t MAX_SOCKETS = 64; namespace llfio = LLFIO_V2_NAMESPACE; - std::vector<std::pair<llfio::listening_socket_handle, llfio::ip::address>> listening; + std::vector<std::pair<llfio::listening_byte_socket_handle, llfio::ip::address>> listening; std::vector<llfio::byte_socket_handle> sockets; std::vector<size_t> idxs; for(size_t n = 0; n < MAX_SOCKETS; n++) { - auto s = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value(); + auto s = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4).value(); s.bind(llfio::ip::address_v4::loopback()).value(); auto endpoint = s.local_endpoint().value(); if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr) diff --git a/test/tests/tls_socket_handle.cpp b/test/tests/tls_socket_handle.cpp index 591525da..cade8a91 100644 --- a/test/tests/tls_socket_handle.cpp +++ b/test/tests/tls_socket_handle.cpp @@ -42,8 +42,9 @@ static inline void TestBlockingTLSSocketHandles() BOOST_CHECK(serversocket->is_socket()); BOOST_CHECK(serversocket->is_readable()); BOOST_CHECK(serversocket->is_writable()); - // Disable server authentication + // Disable authentication serversocket->set_authentication_certificates_path({}).value(); + writer->set_authentication_certificates_path({}).value(); { auto desc = serversocket->algorithms_description(); std::cout << "Server socket will offer during handshake the ciphers (" << (1 + std::count(desc.begin(), desc.end(), ',')) << "): "; @@ -132,7 +133,7 @@ static inline void TestBlockingTLSSocketHandles() runtest(tls_socket_source->listening_socket(llfio::ip::family::v4).value(), tls_socket_source->connecting_socket(llfio::ip::family::v4).value()); std::cout << "\nWrapped TLS socket:\n" << std::endl; - auto rawserversocket = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value(); + auto rawserversocket = llfio::listening_byte_socket_handle::listening_byte_socket(llfio::ip::family::v4).value(); auto rawwriter = llfio::byte_socket_handle::byte_socket(llfio::ip::family::v4).value(); runtest(tls_socket_source->wrap(&rawserversocket).value(), tls_socket_source->wrap(&rawwriter).value()); } @@ -152,7 +153,7 @@ static inline void TestNonBlockingTLSSocketHandles() BOOST_CHECK(serversocket->is_socket()); BOOST_CHECK(serversocket->is_readable()); BOOST_CHECK(serversocket->is_writable()); - // Disable server authentication + // Disable authentication serversocket->set_authentication_certificates_path({}).value(); serversocket->bind(llfio::ip::address_v4::loopback()).value(); auto endpoint = serversocket->local_endpoint().value(); @@ -178,14 +179,38 @@ static inline void TestNonBlockingTLSSocketHandles() // Form the connection. llfio::tls_socket_handle_ptr writer = make_writer(); - writer->connect(endpoint).value(); + // Disable authentication + writer->set_authentication_certificates_path({}).value(); + { + auto r = writer->connect(endpoint, std::chrono::milliseconds(0)); + BOOST_REQUIRE(!r); + BOOST_REQUIRE(r.error() == llfio::errc::timed_out); // can't possibly connect immediately + } serversocket->read({reader}, std::chrono::seconds(1)).value(); std::cout << "Server socket sees incoming connection from " << reader.second << std::endl; + // The TLS negotiation needs to do various reading and writing until the connection goes up + bool connected = false; + for(size_t count=0;count<1024;count++) + { + auto r1 = writer->connect(endpoint, std::chrono::milliseconds(0)); + auto r2 = reader.first->read({{nullptr, 0}}, std::chrono::milliseconds(0)); + (void) r2; + if(r1) + { + connected = true; + break; + } + std::cout << "*" << std::flush; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + std::cout << std::endl; + BOOST_REQUIRE(connected); llfio::byte buffer[64]; { // no data, so non-blocking read should time out auto read = reader.first->read(0, {{buffer, 64}}, std::chrono::milliseconds(0)); BOOST_REQUIRE(read.has_error()); + std::cout << read.error().message() << std::endl; BOOST_REQUIRE(read.error() == llfio::errc::timed_out); } { // no data, so blocking read should time out @@ -195,6 +220,7 @@ static inline void TestNonBlockingTLSSocketHandles() std::cout << "Blocking read did not return error, instead returned " << read.value() << std::endl; } BOOST_REQUIRE(read.has_error()); + std::cout << read.error().message() << std::endl; BOOST_REQUIRE(read.error() == llfio::errc::timed_out); } auto written = writer->write(0, {{(const llfio::byte *) "hello", 5}}).value(); @@ -204,7 +230,22 @@ static inline void TestNonBlockingTLSSocketHandles() auto read = reader.first->read(0, {{buffer, 64}}, std::chrono::milliseconds(1)); BOOST_REQUIRE(read.value() == 5); BOOST_CHECK(0 == memcmp(buffer, "hello", 5)); - writer->shutdown_and_close().value(); // must not block nor fail + // The TLS shutdown needs to do various reading and writing until the connection goes down + for(size_t count = 0; count < 1024; count++) + { + auto r1 = writer->shutdown_and_close(std::chrono::milliseconds(0)); + auto r2 = reader.first->read({{nullptr, 0}}, std::chrono::milliseconds(0)); + (void) r2; + if(r1) + { + connected = false; + break; + } + std::cout << "*" << std::flush; + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + std::cout << std::endl; + BOOST_REQUIRE(!connected); writer->close().value(); reader.first->close().value(); }; @@ -213,7 +254,7 @@ static inline void TestNonBlockingTLSSocketHandles() [&] { return tls_socket_source->multiplexable_connecting_socket(llfio::ip::family::v4).value(); }); std::cout << "\nWrapped TLS socket:\n" << std::endl; - auto rawserversocket = llfio::listening_socket_handle::multiplexable_listening_socket(llfio::ip::family::v4).value(); + auto rawserversocket = llfio::listening_byte_socket_handle::multiplexable_listening_byte_socket(llfio::ip::family::v4).value(); auto rawwriter = llfio::byte_socket_handle::multiplexable_byte_socket(llfio::ip::family::v4).value(); runtest(tls_socket_source->wrap(&rawserversocket).value(), [&] { return tls_socket_source->wrap(&rawwriter).value(); }); } @@ -234,10 +275,11 @@ Host: github.com std::cout << "\nNOTE: This platform has no TLS socket sources in its registry, skipping this test." << std::endl; return; } - auto test_host_ip = llfio::ip::resolve(test_host, "https", llfio::ip::family::any, {}, llfio::ip::resolve_flag::blocking).value()->get().value(); - std::cout << "The IP address of " << test_host << " is " << test_host_ip.front() << std::endl; + auto test_host_ip = llfio::ip::resolve(test_host, "https", llfio::ip::family::any, {}, llfio::ip::resolve_flag::blocking).value()->get().value().front(); + std::cout << "The IP address of " << test_host << " is " << test_host_ip << std::endl; + BOOST_REQUIRE(test_host_ip.is_v4() || test_host_ip.is_v6()); auto tls_socket_source = llfio::tls_socket_source_registry::default_source().instantiate().value(); - auto sock = tls_socket_source->multiplexable_connecting_socket(llfio::ip::family::any).value(); + auto sock = tls_socket_source->multiplexable_connecting_socket(test_host_ip.family()).value(); { auto r = sock->connect(test_host, 443, std::chrono::seconds(5)); if(!r) @@ -592,156 +634,198 @@ static inline void TestCoroutinedTLSSocketHandles() } #endif #endif +#endif static inline void TestPollingTLSSocketHandles() { static constexpr size_t MAX_SOCKETS = 64; namespace llfio = LLFIO_V2_NAMESPACE; - std::vector<std::pair<llfio::listening_socket_handle, llfio::ip::address>> listening; - std::vector<llfio::byte_socket_handle> sockets; - std::vector<size_t> idxs; - for(size_t n = 0; n < MAX_SOCKETS; n++) + if(llfio::tls_socket_source_registry::empty()) { - auto s = llfio::listening_socket_handle::listening_socket(llfio::ip::family::v4).value(); - s.bind(llfio::ip::address_v4::loopback()).value(); - auto endpoint = s.local_endpoint().value(); - if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr) - { - std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test." - << std::endl; - return; - } - listening.emplace_back(std::move(s), endpoint); - sockets.push_back(llfio::byte_socket_handle::byte_socket(llfio::ip::family::v4).value()); - idxs.push_back(n); + std::cout << "\nNOTE: This platform has no TLS socket sources in its registry, skipping this test." << std::endl; + return; } - QUICKCPPLIB_NAMESPACE::algorithm::small_prng::random_shuffle(idxs.begin(), idxs.end()); - std::mutex lock; - std::atomic<size_t> currently_connecting{(size_t)-1}; - auto poll_listening_task = std::async(std::launch::async, [&] { - std::vector<llfio::pollable_handle *> handles; - std::vector<llfio::poll_what> what, out; - for(size_t n = 0; n < MAX_SOCKETS; n++) - { - handles.push_back(&listening[n].first); - what.push_back(llfio::poll_what::is_readable); - out.push_back(llfio::poll_what::none); - } - for(;;) - { - int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value(); - bool done = true; - for(size_t n = 0; n < MAX_SOCKETS; n++) - { - auto idx = idxs[n]; - if(handles[idx] != nullptr) - { - done = false; - if(out[idx] & llfio::poll_what::is_readable) - { - { - std::lock_guard<std::mutex> g(lock); - std::cout << "Poll listening sees readable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx << ". Currently connecting is " - << currently_connecting << std::endl; - } - BOOST_CHECK(currently_connecting == idx); - std::pair<llfio::byte_socket_handle, llfio::ip::address> s; - listening[idx].first.read({s}).value(); - handles[idx] = nullptr; - ret--; - } - out[idx] = llfio::poll_what::none; - } - } - BOOST_CHECK(ret == 0); - if(done) - { - std::lock_guard<std::mutex> g(lock); - std::cout << "Poll listening task exits." << std::endl; - break; - } - } - }); - auto poll_connecting_task = std::async(std::launch::async, [&] { - std::vector<llfio::pollable_handle *> handles; - std::vector<llfio::poll_what> what, out; - for(size_t n = 0; n < MAX_SOCKETS; n++) - { - handles.push_back(&sockets[n]); - what.push_back(llfio::poll_what::is_writable); - out.push_back(llfio::poll_what::none); - } - for(;;) - { - int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value(); - bool done = true, saw_closed = false; - size_t remaining = MAX_SOCKETS; - for(size_t n = 0; n < MAX_SOCKETS; n++) - { - auto idx = idxs[n]; - if(handles[idx] != nullptr) - { - done = false; - // On Linux, a new socket not yet connected MAY appear as both writable and hanged up, - // so filter out the closed. - if(!(out[idx] & llfio::poll_what::is_closed) || (remaining == 1 && currently_connecting == idx)) - { - if(out[idx] & llfio::poll_what::is_writable) - { - { - std::lock_guard<std::mutex> g(lock); - std::cout << "Poll connect sees writable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx << ". Currently connecting is " - << currently_connecting << std::endl; - } - BOOST_CHECK(currently_connecting == idx); - handles[idx] = nullptr; - ret--; - } - } - else - { - saw_closed = true; - } - out[idx] = llfio::poll_what::none; - } - else - { - remaining--; - } - } - if(!saw_closed) - { - BOOST_CHECK(ret == 0); - } - if(done) - { - std::lock_guard<std::mutex> g(lock); - std::cout << "Poll connect task exits." << std::endl; - break; - } - } - }); - auto connect_task = std::async(std::launch::async, [&] { + auto tls_socket_source = llfio::tls_socket_source_registry::default_source().instantiate().value(); + auto runtest = [](auto &&make_reader, auto &&make_writer) + { + std::vector<std::pair<llfio::listening_tls_socket_handle_ptr, llfio::ip::address>> listening; + std::vector<llfio::tls_socket_handle_ptr> sockets; + std::vector<size_t> idxs; for(size_t n = 0; n < MAX_SOCKETS; n++) { - auto idx = idxs[n]; + llfio::listening_tls_socket_handle_ptr s = make_reader(n); + // Disable authentication + s->set_authentication_certificates_path({}).value(); + s->bind(llfio::ip::address_v4::loopback()).value(); + auto endpoint = s->local_endpoint().value(); + if(endpoint.family() == llfio::ip::family::unknown && getenv("CI") != nullptr) { - std::lock_guard<std::mutex> g(lock); - std::cout << "Connecting " << idx << " ... " << std::endl; + std::cout << "\nNOTE: Currently on CI and couldn't bind a listening socket to loopback, assuming it is CI host restrictions and skipping this test." + << std::endl; + return; } - currently_connecting = idx; - sockets[idx].connect(listening[idx].second).value(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + listening.emplace_back(std::move(s), endpoint); + sockets.push_back(make_writer(n)); + // Disable authentication + sockets.back()->set_authentication_certificates_path({}).value(); + idxs.push_back(n); } - std::this_thread::sleep_for(std::chrono::seconds(1)); - std::lock_guard<std::mutex> g(lock); - std::cout << "Connecting task exits." << std::endl; - }); - connect_task.get(); - poll_listening_task.get(); - poll_connecting_task.get(); + QUICKCPPLIB_NAMESPACE::algorithm::small_prng::random_shuffle(idxs.begin(), idxs.end()); + std::mutex lock; + std::atomic<size_t> currently_connecting{(size_t) -1}; + auto poll_listening_task = std::async(std::launch::async, + [&] + { + std::vector<llfio::pollable_handle *> handles; + std::vector<llfio::poll_what> what, out; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + handles.push_back(listening[n].first.get()); + what.push_back(llfio::poll_what::is_readable); + out.push_back(llfio::poll_what::none); + } + for(;;) + { + int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value(); + bool done = true; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + auto idx = idxs[n]; + if(handles[idx] != nullptr) + { + done = false; + if(out[idx] & llfio::poll_what::is_readable) + { + { + std::lock_guard<std::mutex> g(lock); + std::cout << "Poll listening sees readable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx + << ". Currently connecting is " << currently_connecting << std::endl; + } + BOOST_CHECK(currently_connecting == idx); + std::pair<llfio::tls_socket_handle_ptr, llfio::ip::address> s; + listening[idx].first->read({s}).value(); + llfio::byte buf; + llfio::tls_socket_handle::buffer_type b(&buf, 1); + s.first->read({{b}}); + handles[idx] = nullptr; + ret--; + } + out[idx] = llfio::poll_what::none; + } + } + BOOST_CHECK(ret == 0); + if(done) + { + std::lock_guard<std::mutex> g(lock); + std::cout << "Poll listening task exits." << std::endl; + break; + } + } + }); + auto poll_connecting_task = std::async(std::launch::async, + [&] + { + std::vector<llfio::pollable_handle *> handles; + std::vector<llfio::poll_what> what, out; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + handles.push_back(sockets[n].get()); + what.push_back(llfio::poll_what::is_writable); + out.push_back(llfio::poll_what::none); + } + for(;;) + { + int ret = (int) llfio::poll(out, {handles}, what, std::chrono::seconds(30)).value(); + bool done = true, saw_closed = false; + size_t remaining = MAX_SOCKETS; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + auto idx = idxs[n]; + if(handles[idx] != nullptr) + { + done = false; + // On Linux, a new socket not yet connected MAY appear as both writable and hanged up, + // so filter out the closed. + if(!(out[idx] & llfio::poll_what::is_closed) || (remaining == 1 && currently_connecting == idx)) + { + if(out[idx] & llfio::poll_what::is_writable) + { + { + std::lock_guard<std::mutex> g(lock); + std::cout << "Poll connect sees writable (raw = " << (int) (uint8_t) out[idx] << ") on socket " << idx + << ". Currently connecting is " << currently_connecting << std::endl; + } + BOOST_CHECK(currently_connecting == idx); + handles[idx] = nullptr; + ret--; + } + } + else + { + saw_closed = true; + } + out[idx] = llfio::poll_what::none; + } + else + { + remaining--; + } + } + if(!saw_closed) + { + BOOST_CHECK(ret == 0); + } + if(done) + { + std::lock_guard<std::mutex> g(lock); + std::cout << "Poll connect task exits." << std::endl; + break; + } + } + }); + auto connect_task = std::async(std::launch::async, + [&] + { + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + auto idx = idxs[n]; + { + std::lock_guard<std::mutex> g(lock); + std::cout << "Connecting " << idx << " ... " << std::endl; + } + currently_connecting = idx; + sockets[idx]->connect(listening[idx].second).value(); + llfio::byte buf(llfio::to_byte(0)); + llfio::tls_socket_handle::const_buffer_type b(&buf, 1); + sockets[idx]->write({{&b, 1}}); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + std::lock_guard<std::mutex> g(lock); + std::cout << "Connecting task exits." << std::endl; + }); + connect_task.get(); + poll_listening_task.get(); + poll_connecting_task.get(); + }; + std::cout << "\nUnwrapped TLS socket:\n" << std::endl; + runtest([&](size_t) { return tls_socket_source->multiplexable_listening_socket(llfio::ip::family::v4).value(); }, + [&](size_t) { return tls_socket_source->multiplexable_connecting_socket(llfio::ip::family::v4).value(); }); + + std::cout << "\nWrapped TLS socket:\n" << std::endl; + std::vector<llfio::listening_byte_socket_handle> rawlisteners; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + rawlisteners.push_back(llfio::listening_byte_socket_handle::multiplexable_listening_byte_socket(llfio::ip::family::v4).value()); + } + std::vector<llfio::byte_socket_handle> rawwriters; + for(size_t n = 0; n < MAX_SOCKETS; n++) + { + rawwriters.push_back(llfio::byte_socket_handle::multiplexable_byte_socket(llfio::ip::family::v4).value()); + } + runtest([&](size_t idx) { return tls_socket_source->wrap(&rawlisteners[idx]).value(); }, + [&](size_t idx) { return tls_socket_source->wrap(&rawwriters[idx]).value(); }); } -#endif KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, blocking, "Tests that blocking llfio::tls_byte_socket_handle works as expected", TestBlockingTLSSocketHandles()) @@ -759,5 +843,6 @@ KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, coroutined, "Tests TestCoroutinedTLSSocketHandles()) #endif #endif -KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, poll, "Tests that polling llfio::tls_byte_socket_handle works as expected", TestPollingTLSSocketHandles()) #endif +KERNELTEST_TEST_KERNEL(integration, llfio, tls_socket_handle, poll, "Tests that polling llfio::tls_byte_socket_handle works as expected", + TestPollingTLSSocketHandles()) |