Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/Simple-WebSocket-Server.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client_ws.hpp49
-rw-r--r--client_wss.hpp17
-rw-r--r--server_ws.hpp51
-rw-r--r--server_wss.hpp8
-rw-r--r--tests/parse_test.cpp4
5 files changed, 106 insertions, 23 deletions
diff --git a/client_ws.hpp b/client_ws.hpp
index bb17beb..15593ec 100644
--- a/client_ws.hpp
+++ b/client_ws.hpp
@@ -65,7 +65,10 @@ namespace SimpleWeb {
private:
template <typename... Args>
- Connection(Args &&... args) noexcept : socket(new socket_type(std::forward<Args>(args)...)), strand(socket->get_io_service()), closed(false) {}
+ Connection(std::shared_ptr<ScopeRunner> handler_runner, Args &&... args) noexcept
+ : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward<Args>(args)...)), strand(socket->get_io_service()), closed(false) {}
+
+ std::shared_ptr<ScopeRunner> handler_runner;
std::unique_ptr<socket_type> socket; // Socket must be unique_ptr since asio::ssl::stream<asio::ip::tcp::socket> is not movable
std::mutex socket_close_mutex;
@@ -95,6 +98,9 @@ namespace SimpleWeb {
auto self = this->shared_from_this();
strand.post([self]() {
asio::async_write(*self->socket, self->send_queue.begin()->send_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = self->handler_runner->continue_lock();
+ if(!lock)
+ return;
auto send_queued = self->send_queue.begin();
if(send_queued->callback)
send_queued->callback(ec);
@@ -251,7 +257,10 @@ namespace SimpleWeb {
io_service->stop();
}
- virtual ~SocketClientBase() noexcept {}
+ virtual ~SocketClientBase() noexcept {
+ handler_runner->stop();
+ stop();
+ }
/// If you have your own asio::io_service, store its pointer here before running start().
std::shared_ptr<asio::io_service> io_service;
@@ -266,7 +275,9 @@ namespace SimpleWeb {
std::shared_ptr<Connection> connection;
std::mutex connection_mutex;
- SocketClientBase(const std::string &host_port_path, unsigned short default_port) noexcept {
+ std::shared_ptr<ScopeRunner> handler_runner;
+
+ SocketClientBase(const std::string &host_port_path, unsigned short default_port) noexcept : handler_runner(new ScopeRunner()) {
size_t host_end = host_port_path.find(':');
size_t host_port_end = host_port_path.find('/');
if(host_end == std::string::npos) {
@@ -320,8 +331,14 @@ namespace SimpleWeb {
connection->message = std::shared_ptr<Message>(new Message());
asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::async_read_until(*connection->socket, connection->message->streambuf, "\r\n\r\n", [this, connection, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
if(!ResponseMessage::parse(*connection->message, connection->http_version, connection->status_code, connection->header) ||
connection->status_code != "101 Web Socket Protocol Handshake") {
@@ -351,6 +368,9 @@ namespace SimpleWeb {
void read_message(const std::shared_ptr<Connection> &connection) {
asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t bytes_transferred) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
if(bytes_transferred == 0) { // TODO: This might happen on server at least, might also happen here
this->read_message(connection);
@@ -376,6 +396,9 @@ namespace SimpleWeb {
if(length == 126) {
// 2 next bytes is the size of content
asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::vector<unsigned char> length_bytes;
length_bytes.resize(2);
@@ -396,6 +419,9 @@ namespace SimpleWeb {
else if(length == 127) {
// 8 next bytes is the size of content
asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(8), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::vector<unsigned char> length_bytes;
length_bytes.resize(8);
@@ -425,6 +451,9 @@ namespace SimpleWeb {
void read_message_content(const std::shared_ptr<Connection> &connection) {
asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(connection->message->length), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
// If connection close
if((connection->message->fin_rsv_opcode & 0x0f) == 8) {
@@ -474,14 +503,20 @@ namespace SimpleWeb {
protected:
void connect() override {
+ std::unique_lock<std::mutex> lock(connection_mutex);
+ auto connection = this->connection = std::shared_ptr<Connection>(new Connection(this->handler_runner, *io_service));
+ lock.unlock();
asio::ip::tcp::resolver::query query(host, std::to_string(port));
auto resolver = std::make_shared<asio::ip::tcp::resolver>(*io_service);
- resolver->async_resolve(query, [this, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) {
- std::unique_lock<std::mutex> lock(connection_mutex);
- auto connection = this->connection = std::shared_ptr<Connection>(new Connection(*io_service));
- lock.unlock();
+ resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::async_connect(*connection->socket, it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::ip::tcp::no_delay option(true);
connection->socket->set_option(option);
diff --git a/client_wss.hpp b/client_wss.hpp
index dd7ea43..a6b057e 100644
--- a/client_wss.hpp
+++ b/client_wss.hpp
@@ -42,19 +42,28 @@ namespace SimpleWeb {
asio::ssl::context context;
void connect() override {
+ std::unique_lock<std::mutex> connection_lock(connection_mutex);
+ auto connection = this->connection = std::shared_ptr<Connection>(new Connection(this->handler_runner, *io_service, context));
+ connection_lock.unlock();
asio::ip::tcp::resolver::query query(host, std::to_string(port));
auto resolver = std::make_shared<asio::ip::tcp::resolver>(*io_service);
- resolver->async_resolve(query, [this, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) {
- std::unique_lock<std::mutex> lock(connection_mutex);
- auto connection = this->connection = std::shared_ptr<Connection>(new Connection(*io_service, context));
- lock.unlock();
+ resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::async_connect(connection->socket->lowest_layer(), it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::ip::tcp::no_delay option(true);
connection->socket->lowest_layer().set_option(option);
connection->socket->async_handshake(asio::ssl::stream_base::client, [this, connection](const error_code &ec) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec)
handshake(connection);
else if(on_error)
diff --git a/server_ws.hpp b/server_ws.hpp
index 39d4dbe..7be8629 100644
--- a/server_ws.hpp
+++ b/server_ws.hpp
@@ -76,7 +76,10 @@ namespace SimpleWeb {
private:
template <typename... Args>
- Connection(long timeout_idle, Args &&... args) noexcept : socket(new socket_type(std::forward<Args>(args)...)), timeout_idle(timeout_idle), strand(socket->get_io_service()), closed(false) {}
+ Connection(std::shared_ptr<ScopeRunner> handler_runner, long timeout_idle, Args &&... args) noexcept
+ : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward<Args>(args)...)), timeout_idle(timeout_idle), strand(socket->get_io_service()), closed(false) {}
+
+ std::shared_ptr<ScopeRunner> handler_runner;
std::unique_ptr<socket_type> socket; // Socket must be unique_ptr since asio::ssl::stream<asio::ip::tcp::socket> is not movable
std::mutex socket_close_mutex;
@@ -166,8 +169,14 @@ namespace SimpleWeb {
auto self = this->shared_from_this();
strand.post([self]() {
asio::async_write(*self->socket, self->send_queue.begin()->header_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = self->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
asio::async_write(*self->socket, self->send_queue.begin()->message_stream->streambuf, self->strand.wrap([self](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = self->handler_runner->continue_lock();
+ if(!lock)
+ return;
auto send_queued = self->send_queue.begin();
if(send_queued->callback)
send_queued->callback(ec);
@@ -298,7 +307,7 @@ namespace SimpleWeb {
std::function<void(std::shared_ptr<Connection>, const error_code &)> on_error;
std::unordered_set<std::shared_ptr<Connection>> get_connections() noexcept {
- std::lock_guard<std::mutex> lock(connections_mutex);
+ std::unique_lock<std::mutex> lock(connections_mutex);
auto copy = connections;
return copy;
}
@@ -393,7 +402,7 @@ namespace SimpleWeb {
acceptor->close(ec);
for(auto &pair : endpoint) {
- std::lock_guard<std::mutex> lock(pair.second.connections_mutex);
+ std::unique_lock<std::mutex> lock(pair.second.connections_mutex);
for(auto &connection : pair.second.connections)
connection->close();
pair.second.connections.clear();
@@ -409,7 +418,7 @@ namespace SimpleWeb {
std::unordered_set<std::shared_ptr<Connection>> get_connections() noexcept {
std::unordered_set<std::shared_ptr<Connection>> all_connections;
for(auto &e : endpoint) {
- std::lock_guard<std::mutex> lock(e.second.connections_mutex);
+ std::unique_lock<std::mutex> lock(e.second.connections_mutex);
all_connections.insert(e.second.connections.begin(), e.second.connections.end());
}
return all_connections;
@@ -435,6 +444,7 @@ namespace SimpleWeb {
* }
*/
void upgrade(const std::shared_ptr<Connection> &connection) {
+ connection->handler_runner = handler_runner;
connection->timeout_idle = config.timeout_idle;
write_handshake(connection);
}
@@ -448,7 +458,9 @@ namespace SimpleWeb {
std::unique_ptr<asio::ip::tcp::acceptor> acceptor;
std::vector<std::thread> threads;
- SocketServerBase(unsigned short port) noexcept : config(port) {}
+ std::shared_ptr<ScopeRunner> handler_runner;
+
+ SocketServerBase(unsigned short port) noexcept : config(port), handler_runner(new ScopeRunner()) {}
virtual void accept() = 0;
@@ -458,6 +470,9 @@ namespace SimpleWeb {
connection->set_timeout(config.timeout_request);
asio::async_read_until(*connection->socket, connection->read_buffer, "\r\n\r\n", [this, connection](const error_code &ec, size_t /*bytes_transferred*/) {
connection->cancel_timeout();
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::istream stream(&connection->read_buffer);
if(RequestMessage::parse(stream, connection->method, connection->path, connection->query_string, connection->http_version, connection->header))
@@ -477,6 +492,9 @@ namespace SimpleWeb {
connection->set_timeout(config.timeout_request);
asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, &regex_endpoint](const error_code &ec, size_t /*bytes_transferred*/) {
connection->cancel_timeout();
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
connection_open(connection, regex_endpoint.second);
read_message(connection, regex_endpoint.second);
@@ -492,6 +510,9 @@ namespace SimpleWeb {
void read_message(const std::shared_ptr<Connection> &connection, Endpoint &endpoint) const {
asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint](const error_code &ec, size_t bytes_transferred) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
if(bytes_transferred == 0) { // TODO: why does this happen sometimes?
read_message(connection, endpoint);
@@ -518,6 +539,9 @@ namespace SimpleWeb {
if(length == 126) {
// 2 next bytes is the size of content
asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::istream stream(&connection->read_buffer);
@@ -539,6 +563,9 @@ namespace SimpleWeb {
else if(length == 127) {
// 8 next bytes is the size of content
asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(8), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::istream stream(&connection->read_buffer);
@@ -567,6 +594,9 @@ namespace SimpleWeb {
void read_message_content(const std::shared_ptr<Connection> &connection, size_t length, Endpoint &endpoint, unsigned char fin_rsv_opcode) const {
asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(4 + length), [this, connection, length, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
if(!ec) {
std::istream raw_message_data(&connection->read_buffer);
@@ -625,7 +655,7 @@ namespace SimpleWeb {
connection->set_timeout();
{
- std::lock_guard<std::mutex> lock(endpoint.connections_mutex);
+ std::unique_lock<std::mutex> lock(endpoint.connections_mutex);
endpoint.connections.insert(connection);
}
@@ -638,7 +668,7 @@ namespace SimpleWeb {
connection->set_timeout();
{
- std::lock_guard<std::mutex> lock(endpoint.connections_mutex);
+ std::unique_lock<std::mutex> lock(endpoint.connections_mutex);
endpoint.connections.erase(connection);
}
@@ -651,7 +681,7 @@ namespace SimpleWeb {
connection->set_timeout();
{
- std::lock_guard<std::mutex> lock(endpoint.connections_mutex);
+ std::unique_lock<std::mutex> lock(endpoint.connections_mutex);
endpoint.connections.erase(connection);
}
@@ -672,9 +702,12 @@ namespace SimpleWeb {
protected:
void accept() override {
- std::shared_ptr<Connection> connection(new Connection(config.timeout_idle, *io_service));
+ std::shared_ptr<Connection> connection(new Connection(handler_runner, config.timeout_idle, *io_service));
acceptor->async_accept(*connection->socket, [this, connection](const error_code &ec) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
// Immediately start accepting a new connection (if io_service hasn't been stopped)
if(ec != asio::error::operation_aborted)
accept();
diff --git a/server_wss.hpp b/server_wss.hpp
index 8a2e6e8..23d9463 100644
--- a/server_wss.hpp
+++ b/server_wss.hpp
@@ -49,9 +49,12 @@ namespace SimpleWeb {
asio::ssl::context context;
void accept() override {
- std::shared_ptr<Connection> connection(new Connection(config.timeout_idle, *io_service, context));
+ std::shared_ptr<Connection> connection(new Connection(handler_runner, config.timeout_idle, *io_service, context));
acceptor->async_accept(connection->socket->lowest_layer(), [this, connection](const error_code &ec) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
// Immediately start accepting a new connection (if io_service hasn't been stopped)
if(ec != asio::error::operation_aborted)
accept();
@@ -62,6 +65,9 @@ namespace SimpleWeb {
connection->set_timeout(config.timeout_request);
connection->socket->async_handshake(asio::ssl::stream_base::server, [this, connection](const error_code &ec) {
+ auto lock = connection->handler_runner->continue_lock();
+ if(!lock)
+ return;
connection->cancel_timeout();
if(!ec)
read_handshake(connection);
diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp
index 6689773..9295ed1 100644
--- a/tests/parse_test.cpp
+++ b/tests/parse_test.cpp
@@ -13,7 +13,7 @@ public:
void accept() {}
void parse_request_test() {
- std::shared_ptr<Connection> connection(new Connection(0, *io_service));
+ std::shared_ptr<Connection> connection(new Connection(handler_runner, 0, *io_service));
ostream ss(&connection->read_buffer);
ss << "GET /test/ HTTP/1.1\r\n";
@@ -82,7 +82,7 @@ public:
}
void parse_response_header_test() {
- auto connection = std::shared_ptr<Connection>(new Connection(*io_service));
+ auto connection = std::shared_ptr<Connection>(new Connection(handler_runner, *io_service));
connection->message = std::shared_ptr<Message>(new Message());
ostream stream(&connection->message->streambuf);