diff options
author | eidheim <eidheim@gmail.com> | 2018-07-12 13:21:41 +0300 |
---|---|---|
committer | eidheim <eidheim@gmail.com> | 2018-07-12 13:21:48 +0300 |
commit | 40b19395793c43fae6c962b3ba236d6b154dc36b (patch) | |
tree | 4a364937d4f0b846feb4a990e1a86869735cfb52 /client_ws.hpp | |
parent | 611f17c7ceb9ae784cb5a7d5ddb1c2634b056560 (diff) |
Breaking change: SendStream renamed to OutMessage, and Message renamed to InMessage. Also added convencience functions for Connection::send that takes strings as arguments."
Diffstat (limited to 'client_ws.hpp')
-rw-r--r-- | client_ws.hpp | 192 |
1 files changed, 100 insertions, 92 deletions
diff --git a/client_ws.hpp b/client_ws.hpp index c8c569f..719ab14 100644 --- a/client_ws.hpp +++ b/client_ws.hpp @@ -38,7 +38,7 @@ namespace SimpleWeb { template <class socket_type> class SocketClientBase { public: - class Message : public std::istream { + class InMessage : public std::istream { friend class SocketClientBase<socket_type>; friend class Connection; @@ -63,20 +63,20 @@ namespace SimpleWeb { } private: - Message() noexcept : std::istream(&streambuf), length(0) {} - Message(unsigned char fin_rsv_opcode, std::size_t length) noexcept : std::istream(&streambuf), fin_rsv_opcode(fin_rsv_opcode), length(length) {} + InMessage() noexcept : std::istream(&streambuf), length(0) {} + InMessage(unsigned char fin_rsv_opcode, std::size_t length) noexcept : std::istream(&streambuf), fin_rsv_opcode(fin_rsv_opcode), length(length) {} std::size_t length; asio::streambuf streambuf; }; /// The buffer is consumed during send operations. - class SendStream : public std::iostream { + class OutMessage : public std::iostream { friend class SocketClientBase<socket_type>; asio::streambuf streambuf; public: - SendStream() noexcept : std::iostream(&streambuf) {} + OutMessage() noexcept : std::iostream(&streambuf) {} /// Returns the size of the buffer std::size_t size() const noexcept { @@ -117,8 +117,8 @@ namespace SimpleWeb { std::unique_ptr<socket_type> socket; // Socket must be unique_ptr since asio::ssl::stream<asio::ip::tcp::socket> is not movable std::mutex socket_close_mutex; - std::shared_ptr<Message> message; - std::shared_ptr<Message> fragmented_message; + std::shared_ptr<InMessage> in_message; + std::shared_ptr<InMessage> fragmented_in_message; long timeout_idle; std::unique_ptr<asio::steady_timer> timer; @@ -170,20 +170,20 @@ namespace SimpleWeb { asio::io_service::strand strand; - class SendData { + class OutData { public: - SendData(std::shared_ptr<SendStream> send_stream_, std::function<void(const error_code)> &&callback_) noexcept - : send_stream(std::move(send_stream_)), callback(std::move(callback_)) {} - std::shared_ptr<SendStream> send_stream; + OutData(std::shared_ptr<OutMessage> out_message_, std::function<void(const error_code)> &&callback_) noexcept + : out_message(std::move(out_message_)), callback(std::move(callback_)) {} + std::shared_ptr<OutMessage> out_message; std::function<void(const error_code)> callback; }; - std::list<SendData> send_queue; + std::list<OutData> send_queue; void send_from_queue() { auto self = this->shared_from_this(); strand.post([self]() { - asio::async_write(*self->socket, self->send_queue.begin()->send_stream->streambuf, self->strand.wrap([self](const error_code &ec, std::size_t /*bytes_transferred*/) { + asio::async_write(*self->socket, self->send_queue.begin()->out_message->streambuf, self->strand.wrap([self](const error_code &ec, std::size_t /*bytes_transferred*/) { auto lock = self->handler_runner->continue_lock(); if(!lock) return; @@ -197,9 +197,9 @@ namespace SimpleWeb { } else { // All handlers in the queue is called with ec: - for(auto &send_data : self->send_queue) { - if(send_data.callback) - send_data.callback(ec); + for(auto &out_data : self->send_queue) { + if(out_data.callback) + out_data.callback(ec); } self->send_queue.clear(); } @@ -220,9 +220,8 @@ namespace SimpleWeb { public: /// fin_rsv_opcode: 129=one fragment, text, 130=one fragment, binary, 136=close connection. - /// See http://tools.ietf.org/html/rfc6455#section-5.2 for more information - void send(const std::shared_ptr<SendStream> &send_stream, const std::function<void(const error_code &)> &callback = nullptr, - unsigned char fin_rsv_opcode = 129) { + /// See http://tools.ietf.org/html/rfc6455#section-5.2 for more information. + void send(const std::shared_ptr<OutMessage> &out_message, const std::function<void(const error_code &)> &callback = nullptr, unsigned char fin_rsv_opcode = 129) { cancel_timeout(); set_timeout(); @@ -233,58 +232,67 @@ namespace SimpleWeb { for(std::size_t c = 0; c < 4; c++) mask[c] = static_cast<unsigned char>(dist(rd)); - auto message_stream = std::make_shared<SendStream>(); + auto out_header_and_message = std::make_shared<OutMessage>(); - std::size_t length = send_stream->size(); + std::size_t length = out_message->size(); - message_stream->put(static_cast<char>(fin_rsv_opcode)); + out_header_and_message->put(static_cast<char>(fin_rsv_opcode)); // Masked (first length byte>=128) if(length >= 126) { std::size_t num_bytes; if(length > 0xffff) { num_bytes = 8; - message_stream->put(static_cast<char>(127 + 128)); + out_header_and_message->put(static_cast<char>(127 + 128)); } else { num_bytes = 2; - message_stream->put(static_cast<char>(126 + 128)); + out_header_and_message->put(static_cast<char>(126 + 128)); } for(std::size_t c = num_bytes - 1; c != static_cast<std::size_t>(-1); c--) - message_stream->put((static_cast<unsigned long long>(length) >> (8 * c)) % 256); + out_header_and_message->put((static_cast<unsigned long long>(length) >> (8 * c)) % 256); } else - message_stream->put(static_cast<char>(length + 128)); + out_header_and_message->put(static_cast<char>(length + 128)); for(std::size_t c = 0; c < 4; c++) - message_stream->put(static_cast<char>(mask[c])); + out_header_and_message->put(static_cast<char>(mask[c])); for(std::size_t c = 0; c < length; c++) - message_stream->put(send_stream->get() ^ mask[c % 4]); + out_header_and_message->put(out_message->get() ^ mask[c % 4]); auto self = this->shared_from_this(); - strand.post([self, message_stream, callback]() { - self->send_queue.emplace_back(message_stream, callback); + strand.post([self, out_header_and_message, callback]() { + self->send_queue.emplace_back(out_header_and_message, callback); if(self->send_queue.size() == 1) self->send_from_queue(); }); } + /// Convenience function for sending a string. + /// fin_rsv_opcode: 129=one fragment, text, 130=one fragment, binary, 136=close connection. + /// See http://tools.ietf.org/html/rfc6455#section-5.2 for more information. + void send(string_view out_message_str, const std::function<void(const error_code &)> &callback = nullptr, unsigned char fin_rsv_opcode = 129) { + auto out_message = std::make_shared<OutMessage>(); + out_message->write(out_message_str.data(), static_cast<std::streamsize>(out_message_str.size())); + send(out_message, callback, fin_rsv_opcode); + } + void send_close(int status, const std::string &reason = "", const std::function<void(const error_code &)> &callback = nullptr) { // Send close only once (in case close is initiated by client) if(closed) return; closed = true; - auto send_stream = std::make_shared<SendStream>(); + auto out_message = std::make_shared<OutMessage>(); - send_stream->put(status >> 8); - send_stream->put(status % 256); + out_message->put(status >> 8); + out_message->put(status % 256); - *send_stream << reason; + *out_message << reason; // fin_rsv_opcode=136: message close - send(send_stream, callback, 136); + send(out_message, callback, 136); } }; @@ -310,7 +318,7 @@ namespace SimpleWeb { Config config; std::function<void(std::shared_ptr<Connection>)> on_open; - std::function<void(std::shared_ptr<Connection>, std::shared_ptr<Message>)> on_message; + std::function<void(std::shared_ptr<Connection>, std::shared_ptr<InMessage>)> on_message; std::function<void(std::shared_ptr<Connection>, int, const std::string &)> on_close; std::function<void(std::shared_ptr<Connection>, const error_code &)> on_error; std::function<void(std::shared_ptr<Connection>)> on_ping; @@ -415,7 +423,7 @@ namespace SimpleWeb { request << header_field.first << ": " << header_field.second << "\r\n"; request << "\r\n"; - connection->message = std::shared_ptr<Message>(new Message()); + connection->in_message = std::shared_ptr<InMessage>(new InMessage()); connection->set_timeout(config.timeout_request); asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, nonce_base64](const error_code &ec, std::size_t /*bytes_transferred*/) { @@ -425,19 +433,19 @@ namespace SimpleWeb { return; if(!ec) { connection->set_timeout(this->config.timeout_request); - asio::async_read_until(*connection->socket, connection->message->streambuf, "\r\n\r\n", [this, connection, nonce_base64](const error_code &ec, std::size_t bytes_transferred) { + asio::async_read_until(*connection->socket, connection->in_message->streambuf, "\r\n\r\n", [this, connection, nonce_base64](const error_code &ec, std::size_t bytes_transferred) { connection->cancel_timeout(); auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - // connection->message->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs: + // connection->in_message->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs: // "After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter" // The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the // streambuf (maybe some bytes of a message) is appended to in the next async_read-function - std::size_t num_additional_bytes = connection->message->streambuf.size() - bytes_transferred; + std::size_t num_additional_bytes = connection->in_message->streambuf.size() - bytes_transferred; - if(!ResponseMessage::parse(*connection->message, connection->http_version, connection->status_code, connection->header) || + if(!ResponseMessage::parse(*connection->in_message, connection->http_version, connection->status_code, connection->header) || connection->status_code.empty() || connection->status_code.compare(0, 4, "101 ") != 0) { this->connection_error(connection, make_error_code::make_error_code(errc::protocol_error)); return; @@ -462,21 +470,21 @@ namespace SimpleWeb { } void read_message(const std::shared_ptr<Connection> &connection, std::size_t num_additional_bytes) { - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(num_additional_bytes > 2 ? 0 : 2 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { + asio::async_read(*connection->socket, connection->in_message->streambuf, asio::transfer_exactly(num_additional_bytes > 2 ? 0 : 2 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - if(bytes_transferred == 0 && connection->message->streambuf.size() == 0) { // TODO: This might happen on server at least, might also happen here + if(bytes_transferred == 0 && connection->in_message->streambuf.size() == 0) { // TODO: This might happen on server at least, might also happen here this->read_message(connection, 0); return; } - std::size_t num_additional_bytes = connection->message->streambuf.size() - bytes_transferred; + std::size_t num_additional_bytes = connection->in_message->streambuf.size() - bytes_transferred; std::array<unsigned char, 2> first_bytes; - connection->message->read(reinterpret_cast<char *>(&first_bytes[0]), 2); + connection->in_message->read(reinterpret_cast<char *>(&first_bytes[0]), 2); - connection->message->fin_rsv_opcode = first_bytes[0]; + connection->in_message->fin_rsv_opcode = first_bytes[0]; // Close connection if masked message from server (protocol error) if(first_bytes[1] >= 128) { @@ -490,22 +498,22 @@ namespace SimpleWeb { if(length == 126) { // 2 next bytes is the size of content - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(num_additional_bytes > 2 ? 0 : 2 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { + asio::async_read(*connection->socket, connection->in_message->streambuf, asio::transfer_exactly(num_additional_bytes > 2 ? 0 : 2 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - std::size_t num_additional_bytes = connection->message->streambuf.size() - bytes_transferred; + std::size_t num_additional_bytes = connection->in_message->streambuf.size() - bytes_transferred; std::array<unsigned char, 2> length_bytes; - connection->message->read(reinterpret_cast<char *>(&length_bytes[0]), 2); + connection->in_message->read(reinterpret_cast<char *>(&length_bytes[0]), 2); std::size_t length = 0; std::size_t num_bytes = 2; for(std::size_t c = 0; c < num_bytes; c++) length += static_cast<std::size_t>(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - connection->message->length = length; + connection->in_message->length = length; this->read_message_content(connection, num_additional_bytes); } else @@ -514,22 +522,22 @@ namespace SimpleWeb { } else if(length == 127) { // 8 next bytes is the size of content - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(num_additional_bytes > 8 ? 0 : 8 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { + asio::async_read(*connection->socket, connection->in_message->streambuf, asio::transfer_exactly(num_additional_bytes > 8 ? 0 : 8 - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - std::size_t num_additional_bytes = connection->message->streambuf.size() - bytes_transferred; + std::size_t num_additional_bytes = connection->in_message->streambuf.size() - bytes_transferred; std::array<unsigned char, 8> length_bytes; - connection->message->read(reinterpret_cast<char *>(&length_bytes[0]), 8); + connection->in_message->read(reinterpret_cast<char *>(&length_bytes[0]), 8); std::size_t length = 0; std::size_t num_bytes = 8; for(std::size_t c = 0; c < num_bytes; c++) length += static_cast<std::size_t>(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - connection->message->length = length; + connection->in_message->length = length; this->read_message_content(connection, num_additional_bytes); } else @@ -537,7 +545,7 @@ namespace SimpleWeb { }); } else { - connection->message->length = length; + connection->in_message->length = length; this->read_message_content(connection, num_additional_bytes); } } @@ -547,7 +555,7 @@ namespace SimpleWeb { } void read_message_content(const std::shared_ptr<Connection> &connection, std::size_t num_additional_bytes) { - if(connection->message->length + (connection->fragmented_message ? connection->fragmented_message->length : 0) > config.max_message_size) { + if(connection->in_message->length + (connection->fragmented_in_message ? connection->fragmented_in_message->length : 0) > config.max_message_size) { connection_error(connection, make_error_code::make_error_code(errc::message_size)); const int status = 1009; const std::string reason = "message too big"; @@ -555,57 +563,57 @@ namespace SimpleWeb { connection_close(connection, status, reason); return; } - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(num_additional_bytes > connection->message->length ? 0 : connection->message->length - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { + asio::async_read(*connection->socket, connection->in_message->streambuf, asio::transfer_exactly(num_additional_bytes > connection->in_message->length ? 0 : connection->in_message->length - num_additional_bytes), [this, connection](const error_code &ec, std::size_t bytes_transferred) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - std::size_t num_additional_bytes = connection->message->streambuf.size() - bytes_transferred; - std::shared_ptr<Message> next_message; + std::size_t num_additional_bytes = connection->in_message->streambuf.size() - bytes_transferred; + std::shared_ptr<InMessage> next_in_message; if(num_additional_bytes > 0) { // Extract bytes that are not extra bytes in buffer (only happen when several messages are sent in handshake response) - next_message = connection->message; - connection->message = std::shared_ptr<Message>(new Message(next_message->fin_rsv_opcode, next_message->length)); - std::ostream ostream(&connection->message->streambuf); - for(std::size_t c = 0; c < next_message->length; ++c) - ostream.put(next_message->get()); + next_in_message = connection->in_message; + connection->in_message = std::shared_ptr<InMessage>(new InMessage(next_in_message->fin_rsv_opcode, next_in_message->length)); + std::ostream ostream(&connection->in_message->streambuf); + for(std::size_t c = 0; c < next_in_message->length; ++c) + ostream.put(next_in_message->get()); } else - next_message = std::shared_ptr<Message>(new Message()); + next_in_message = std::shared_ptr<InMessage>(new InMessage()); // If connection close - if((connection->message->fin_rsv_opcode & 0x0f) == 8) { + if((connection->in_message->fin_rsv_opcode & 0x0f) == 8) { connection->cancel_timeout(); connection->set_timeout(); int status = 0; - if(connection->message->length >= 2) { - unsigned char byte1 = connection->message->get(); - unsigned char byte2 = connection->message->get(); + if(connection->in_message->length >= 2) { + unsigned char byte1 = connection->in_message->get(); + unsigned char byte2 = connection->in_message->get(); status = (static_cast<int>(byte1) << 8) + byte2; } - auto reason = connection->message->string(); + auto reason = connection->in_message->string(); connection->send_close(status, reason); this->connection_close(connection, status, reason); } // If ping - else if((connection->message->fin_rsv_opcode & 0x0f) == 9) { + else if((connection->in_message->fin_rsv_opcode & 0x0f) == 9) { connection->cancel_timeout(); connection->set_timeout(); // Send pong - auto empty_send_stream = std::make_shared<SendStream>(); - connection->send(empty_send_stream, nullptr, connection->message->fin_rsv_opcode + 1); + auto empty_out_message = std::make_shared<OutMessage>(); + connection->send(empty_out_message, nullptr, connection->in_message->fin_rsv_opcode + 1); if(this->on_ping) this->on_ping(connection); // Next message - connection->message = next_message; + connection->in_message = next_in_message; this->read_message(connection, num_additional_bytes); } // If pong - else if((connection->message->fin_rsv_opcode & 0x0f) == 10) { + else if((connection->in_message->fin_rsv_opcode & 0x0f) == 10) { connection->cancel_timeout(); connection->set_timeout(); @@ -613,23 +621,23 @@ namespace SimpleWeb { this->on_pong(connection); // Next message - connection->message = next_message; + connection->in_message = next_in_message; this->read_message(connection, num_additional_bytes); } // If fragmented message and not final fragment - else if((connection->message->fin_rsv_opcode & 0x80) == 0) { - if(!connection->fragmented_message) { - connection->fragmented_message = connection->message; - connection->fragmented_message->fin_rsv_opcode |= 0x80; + else if((connection->in_message->fin_rsv_opcode & 0x80) == 0) { + if(!connection->fragmented_in_message) { + connection->fragmented_in_message = connection->in_message; + connection->fragmented_in_message->fin_rsv_opcode |= 0x80; } else { - connection->fragmented_message->length += connection->message->length; - std::ostream ostream(&connection->fragmented_message->streambuf); - ostream << connection->message->rdbuf(); + connection->fragmented_in_message->length += connection->in_message->length; + std::ostream ostream(&connection->fragmented_in_message->streambuf); + ostream << connection->in_message->rdbuf(); } // Next message - connection->message = next_message; + connection->in_message = next_in_message; this->read_message(connection, num_additional_bytes); } else { @@ -637,21 +645,21 @@ namespace SimpleWeb { connection->set_timeout(); if(this->on_message) { - if(connection->fragmented_message) { - connection->fragmented_message->length += connection->message->length; - std::ostream ostream(&connection->fragmented_message->streambuf); - ostream << connection->message->rdbuf(); + if(connection->fragmented_in_message) { + connection->fragmented_in_message->length += connection->in_message->length; + std::ostream ostream(&connection->fragmented_in_message->streambuf); + ostream << connection->in_message->rdbuf(); - this->on_message(connection, connection->fragmented_message); + this->on_message(connection, connection->fragmented_in_message); } else - this->on_message(connection, connection->message); + this->on_message(connection, connection->in_message); } // Next message - connection->message = next_message; + connection->in_message = next_in_message; // Only reset fragmented_message for non-control frames (control frames can be in between a fragmented message) - connection->fragmented_message = nullptr; + connection->fragmented_in_message = nullptr; this->read_message(connection, num_additional_bytes); } } |