diff options
-rw-r--r-- | client_ws.hpp | 49 | ||||
-rw-r--r-- | client_wss.hpp | 17 | ||||
-rw-r--r-- | server_ws.hpp | 51 | ||||
-rw-r--r-- | server_wss.hpp | 8 | ||||
-rw-r--r-- | tests/parse_test.cpp | 4 |
5 files changed, 106 insertions, 23 deletions
diff --git a/client_ws.hpp b/client_ws.hpp index bb17beb..15593ec 100644 --- a/client_ws.hpp +++ b/client_ws.hpp @@ -65,7 +65,10 @@ namespace SimpleWeb { private: template <typename... Args> - Connection(Args &&... args) noexcept : socket(new socket_type(std::forward<Args>(args)...)), strand(socket->get_io_service()), closed(false) {} + Connection(std::shared_ptr<ScopeRunner> handler_runner, Args &&... args) noexcept + : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward<Args>(args)...)), strand(socket->get_io_service()), closed(false) {} + + std::shared_ptr<ScopeRunner> handler_runner; std::unique_ptr<socket_type> socket; // Socket must be unique_ptr since asio::ssl::stream<asio::ip::tcp::socket> is not movable std::mutex socket_close_mutex; @@ -95,6 +98,9 @@ namespace SimpleWeb { auto self = this->shared_from_this(); strand.post([self]() { asio::async_write(*self->socket, self->send_queue.begin()->send_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = self->handler_runner->continue_lock(); + if(!lock) + return; auto send_queued = self->send_queue.begin(); if(send_queued->callback) send_queued->callback(ec); @@ -251,7 +257,10 @@ namespace SimpleWeb { io_service->stop(); } - virtual ~SocketClientBase() noexcept {} + virtual ~SocketClientBase() noexcept { + handler_runner->stop(); + stop(); + } /// If you have your own asio::io_service, store its pointer here before running start(). std::shared_ptr<asio::io_service> io_service; @@ -266,7 +275,9 @@ namespace SimpleWeb { std::shared_ptr<Connection> connection; std::mutex connection_mutex; - SocketClientBase(const std::string &host_port_path, unsigned short default_port) noexcept { + std::shared_ptr<ScopeRunner> handler_runner; + + SocketClientBase(const std::string &host_port_path, unsigned short default_port) noexcept : handler_runner(new ScopeRunner()) { size_t host_end = host_port_path.find(':'); size_t host_port_end = host_port_path.find('/'); if(host_end == std::string::npos) { @@ -320,8 +331,14 @@ namespace SimpleWeb { connection->message = std::shared_ptr<Message>(new Message()); asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::async_read_until(*connection->socket, connection->message->streambuf, "\r\n\r\n", [this, connection, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { if(!ResponseMessage::parse(*connection->message, connection->http_version, connection->status_code, connection->header) || connection->status_code != "101 Web Socket Protocol Handshake") { @@ -351,6 +368,9 @@ namespace SimpleWeb { void read_message(const std::shared_ptr<Connection> &connection) { asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t bytes_transferred) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { if(bytes_transferred == 0) { // TODO: This might happen on server at least, might also happen here this->read_message(connection); @@ -376,6 +396,9 @@ namespace SimpleWeb { if(length == 126) { // 2 next bytes is the size of content asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::vector<unsigned char> length_bytes; length_bytes.resize(2); @@ -396,6 +419,9 @@ namespace SimpleWeb { else if(length == 127) { // 8 next bytes is the size of content asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(8), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::vector<unsigned char> length_bytes; length_bytes.resize(8); @@ -425,6 +451,9 @@ namespace SimpleWeb { void read_message_content(const std::shared_ptr<Connection> &connection) { asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(connection->message->length), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { // If connection close if((connection->message->fin_rsv_opcode & 0x0f) == 8) { @@ -474,14 +503,20 @@ namespace SimpleWeb { protected: void connect() override { + std::unique_lock<std::mutex> lock(connection_mutex); + auto connection = this->connection = std::shared_ptr<Connection>(new Connection(this->handler_runner, *io_service)); + lock.unlock(); asio::ip::tcp::resolver::query query(host, std::to_string(port)); auto resolver = std::make_shared<asio::ip::tcp::resolver>(*io_service); - resolver->async_resolve(query, [this, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { - std::unique_lock<std::mutex> lock(connection_mutex); - auto connection = this->connection = std::shared_ptr<Connection>(new Connection(*io_service)); - lock.unlock(); + resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::async_connect(*connection->socket, it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::ip::tcp::no_delay option(true); connection->socket->set_option(option); diff --git a/client_wss.hpp b/client_wss.hpp index dd7ea43..a6b057e 100644 --- a/client_wss.hpp +++ b/client_wss.hpp @@ -42,19 +42,28 @@ namespace SimpleWeb { asio::ssl::context context; void connect() override { + std::unique_lock<std::mutex> connection_lock(connection_mutex); + auto connection = this->connection = std::shared_ptr<Connection>(new Connection(this->handler_runner, *io_service, context)); + connection_lock.unlock(); asio::ip::tcp::resolver::query query(host, std::to_string(port)); auto resolver = std::make_shared<asio::ip::tcp::resolver>(*io_service); - resolver->async_resolve(query, [this, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { - std::unique_lock<std::mutex> lock(connection_mutex); - auto connection = this->connection = std::shared_ptr<Connection>(new Connection(*io_service, context)); - lock.unlock(); + resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::async_connect(connection->socket->lowest_layer(), it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::ip::tcp::no_delay option(true); connection->socket->lowest_layer().set_option(option); connection->socket->async_handshake(asio::ssl::stream_base::client, [this, connection](const error_code &ec) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) handshake(connection); else if(on_error) diff --git a/server_ws.hpp b/server_ws.hpp index 39d4dbe..7be8629 100644 --- a/server_ws.hpp +++ b/server_ws.hpp @@ -76,7 +76,10 @@ namespace SimpleWeb { private: template <typename... Args> - Connection(long timeout_idle, Args &&... args) noexcept : socket(new socket_type(std::forward<Args>(args)...)), timeout_idle(timeout_idle), strand(socket->get_io_service()), closed(false) {} + Connection(std::shared_ptr<ScopeRunner> handler_runner, long timeout_idle, Args &&... args) noexcept + : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward<Args>(args)...)), timeout_idle(timeout_idle), strand(socket->get_io_service()), closed(false) {} + + std::shared_ptr<ScopeRunner> handler_runner; std::unique_ptr<socket_type> socket; // Socket must be unique_ptr since asio::ssl::stream<asio::ip::tcp::socket> is not movable std::mutex socket_close_mutex; @@ -166,8 +169,14 @@ namespace SimpleWeb { auto self = this->shared_from_this(); strand.post([self]() { asio::async_write(*self->socket, self->send_queue.begin()->header_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = self->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { asio::async_write(*self->socket, self->send_queue.begin()->message_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = self->handler_runner->continue_lock(); + if(!lock) + return; auto send_queued = self->send_queue.begin(); if(send_queued->callback) send_queued->callback(ec); @@ -298,7 +307,7 @@ namespace SimpleWeb { std::function<void(std::shared_ptr<Connection>, const error_code &)> on_error; std::unordered_set<std::shared_ptr<Connection>> get_connections() noexcept { - std::lock_guard<std::mutex> lock(connections_mutex); + std::unique_lock<std::mutex> lock(connections_mutex); auto copy = connections; return copy; } @@ -393,7 +402,7 @@ namespace SimpleWeb { acceptor->close(ec); for(auto &pair : endpoint) { - std::lock_guard<std::mutex> lock(pair.second.connections_mutex); + std::unique_lock<std::mutex> lock(pair.second.connections_mutex); for(auto &connection : pair.second.connections) connection->close(); pair.second.connections.clear(); @@ -409,7 +418,7 @@ namespace SimpleWeb { std::unordered_set<std::shared_ptr<Connection>> get_connections() noexcept { std::unordered_set<std::shared_ptr<Connection>> all_connections; for(auto &e : endpoint) { - std::lock_guard<std::mutex> lock(e.second.connections_mutex); + std::unique_lock<std::mutex> lock(e.second.connections_mutex); all_connections.insert(e.second.connections.begin(), e.second.connections.end()); } return all_connections; @@ -435,6 +444,7 @@ namespace SimpleWeb { * } */ void upgrade(const std::shared_ptr<Connection> &connection) { + connection->handler_runner = handler_runner; connection->timeout_idle = config.timeout_idle; write_handshake(connection); } @@ -448,7 +458,9 @@ namespace SimpleWeb { std::unique_ptr<asio::ip::tcp::acceptor> acceptor; std::vector<std::thread> threads; - SocketServerBase(unsigned short port) noexcept : config(port) {} + std::shared_ptr<ScopeRunner> handler_runner; + + SocketServerBase(unsigned short port) noexcept : config(port), handler_runner(new ScopeRunner()) {} virtual void accept() = 0; @@ -458,6 +470,9 @@ namespace SimpleWeb { connection->set_timeout(config.timeout_request); asio::async_read_until(*connection->socket, connection->read_buffer, "\r\n\r\n", [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { connection->cancel_timeout(); + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::istream stream(&connection->read_buffer); if(RequestMessage::parse(stream, connection->method, connection->path, connection->query_string, connection->http_version, connection->header)) @@ -477,6 +492,9 @@ namespace SimpleWeb { connection->set_timeout(config.timeout_request); asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, ®ex_endpoint](const error_code &ec, size_t /*bytes_transferred*/) { connection->cancel_timeout(); + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { connection_open(connection, regex_endpoint.second); read_message(connection, regex_endpoint.second); @@ -492,6 +510,9 @@ namespace SimpleWeb { void read_message(const std::shared_ptr<Connection> &connection, Endpoint &endpoint) const { asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint](const error_code &ec, size_t bytes_transferred) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { if(bytes_transferred == 0) { // TODO: why does this happen sometimes? read_message(connection, endpoint); @@ -518,6 +539,9 @@ namespace SimpleWeb { if(length == 126) { // 2 next bytes is the size of content asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::istream stream(&connection->read_buffer); @@ -539,6 +563,9 @@ namespace SimpleWeb { else if(length == 127) { // 8 next bytes is the size of content asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(8), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::istream stream(&connection->read_buffer); @@ -567,6 +594,9 @@ namespace SimpleWeb { void read_message_content(const std::shared_ptr<Connection> &connection, size_t length, Endpoint &endpoint, unsigned char fin_rsv_opcode) const { asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(4 + length), [this, connection, length, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; if(!ec) { std::istream raw_message_data(&connection->read_buffer); @@ -625,7 +655,7 @@ namespace SimpleWeb { connection->set_timeout(); { - std::lock_guard<std::mutex> lock(endpoint.connections_mutex); + std::unique_lock<std::mutex> lock(endpoint.connections_mutex); endpoint.connections.insert(connection); } @@ -638,7 +668,7 @@ namespace SimpleWeb { connection->set_timeout(); { - std::lock_guard<std::mutex> lock(endpoint.connections_mutex); + std::unique_lock<std::mutex> lock(endpoint.connections_mutex); endpoint.connections.erase(connection); } @@ -651,7 +681,7 @@ namespace SimpleWeb { connection->set_timeout(); { - std::lock_guard<std::mutex> lock(endpoint.connections_mutex); + std::unique_lock<std::mutex> lock(endpoint.connections_mutex); endpoint.connections.erase(connection); } @@ -672,9 +702,12 @@ namespace SimpleWeb { protected: void accept() override { - std::shared_ptr<Connection> connection(new Connection(config.timeout_idle, *io_service)); + std::shared_ptr<Connection> connection(new Connection(handler_runner, config.timeout_idle, *io_service)); acceptor->async_accept(*connection->socket, [this, connection](const error_code &ec) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; // Immediately start accepting a new connection (if io_service hasn't been stopped) if(ec != asio::error::operation_aborted) accept(); diff --git a/server_wss.hpp b/server_wss.hpp index 8a2e6e8..23d9463 100644 --- a/server_wss.hpp +++ b/server_wss.hpp @@ -49,9 +49,12 @@ namespace SimpleWeb { asio::ssl::context context; void accept() override { - std::shared_ptr<Connection> connection(new Connection(config.timeout_idle, *io_service, context)); + std::shared_ptr<Connection> connection(new Connection(handler_runner, config.timeout_idle, *io_service, context)); acceptor->async_accept(connection->socket->lowest_layer(), [this, connection](const error_code &ec) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; // Immediately start accepting a new connection (if io_service hasn't been stopped) if(ec != asio::error::operation_aborted) accept(); @@ -62,6 +65,9 @@ namespace SimpleWeb { connection->set_timeout(config.timeout_request); connection->socket->async_handshake(asio::ssl::stream_base::server, [this, connection](const error_code &ec) { + auto lock = connection->handler_runner->continue_lock(); + if(!lock) + return; connection->cancel_timeout(); if(!ec) read_handshake(connection); diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp index 6689773..9295ed1 100644 --- a/tests/parse_test.cpp +++ b/tests/parse_test.cpp @@ -13,7 +13,7 @@ public: void accept() {} void parse_request_test() { - std::shared_ptr<Connection> connection(new Connection(0, *io_service)); + std::shared_ptr<Connection> connection(new Connection(handler_runner, 0, *io_service)); ostream ss(&connection->read_buffer); ss << "GET /test/ HTTP/1.1\r\n"; @@ -82,7 +82,7 @@ public: } void parse_response_header_test() { - auto connection = std::shared_ptr<Connection>(new Connection(*io_service)); + auto connection = std::shared_ptr<Connection>(new Connection(handler_runner, *io_service)); connection->message = std::shared_ptr<Message>(new Message()); ostream stream(&connection->message->streambuf); |