From ebfa4b0c3cf1d3cb87ba6e9a8d625eedf7b2ab47 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Fri, 11 Nov 2022 11:48:16 -0800 Subject: cli: enable server message compression This is the CLI side of enabling compression of servermsg's sent over the socket. It is feature-detected by the CLI sending protocolVersion=2. If present, the consumer can request compression by passing `compress:true` when setting up the server. In this mode, servermsg's are an inflate/deflate stream. Not a ton of code here, but was lots of fun tweaking to get it right :) Fixes https://github.com/microsoft/vscode/issues/163688 --- cli/src/constants.rs | 8 +- cli/src/tunnels.rs | 1 + cli/src/tunnels/control_server.rs | 72 ++++----- cli/src/tunnels/protocol.rs | 3 + cli/src/tunnels/server_bridge_unix.rs | 31 ++-- cli/src/tunnels/server_bridge_windows.rs | 30 ++-- cli/src/tunnels/socket_signal.rs | 244 +++++++++++++++++++++++++++++++ 7 files changed, 317 insertions(+), 72 deletions(-) create mode 100644 cli/src/tunnels/socket_signal.rs diff --git a/cli/src/constants.rs b/cli/src/constants.rs index 1541a6a43a4..ab412fb860f 100644 --- a/cli/src/constants.rs +++ b/cli/src/constants.rs @@ -10,7 +10,13 @@ use lazy_static::lazy_static; use crate::options::Quality; pub const CONTROL_PORT: u16 = 31545; -pub const PROTOCOL_VERSION: u32 = 1; + +/// Protocol version sent to clients. This can be used to indiciate new or +/// changed capabilities that clients may wish to leverage. +/// 1 - Initial protocol version +/// 2 - Addition of `serve.compressed` property to control whether servermsg's +/// are compressed bidirectionally. +pub const PROTOCOL_VERSION: u32 = 2; pub const VSCODE_CLI_VERSION: Option<&'static str> = option_env!("VSCODE_CLI_VERSION"); pub const VSCODE_CLI_AI_KEY: Option<&'static str> = option_env!("VSCODE_CLI_AI_KEY"); diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs index d94e47addf3..638432b040d 100644 --- a/cli/src/tunnels.rs +++ b/cli/src/tunnels.rs @@ -8,6 +8,7 @@ pub mod dev_tunnels; pub mod legal; pub mod paths; +mod socket_signal; mod control_server; mod name_generator; mod port_forwarder; diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs index c04257f715f..5d62cd57d65 100644 --- a/cli/src/tunnels/control_server.rs +++ b/cli/src/tunnels/control_server.rs @@ -7,6 +7,8 @@ use crate::constants::{CONTROL_PORT, PROTOCOL_VERSION, VSCODE_CLI_VERSION}; use crate::log; use crate::self_update::SelfUpdate; use crate::state::LauncherPaths; +use crate::tunnels::protocol::HttpRequestParams; +use crate::tunnels::socket_signal::CloseReason; use crate::update_service::{Platform, UpdateService}; use crate::util::errors::{ wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError, @@ -18,7 +20,6 @@ use crate::util::io::SilentCopyProgress; use crate::util::sync::{new_barrier, Barrier}; use opentelemetry::trace::SpanKind; use opentelemetry::KeyValue; -use serde::Serialize; use std::collections::HashMap; use std::convert::Infallible; use std::env; @@ -38,12 +39,12 @@ use super::paths::prune_stopped_servers; use super::port_forwarder::{PortForwarding, PortForwardingProcessor}; use super::protocol::{ CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse, - ForwardParams, ForwardResult, GetHostnameResponse, HttpRequestParams, RefServerMessageParams, - ResponseError, ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod, - SuccessResponse, ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult, - VersionParams, + ForwardParams, ForwardResult, GetHostnameResponse, ResponseError, ServeParams, ServerLog, + ServerMessageParams, ServerRequestMethod, SuccessResponse, ToClientRequest, ToServerRequest, + UnforwardParams, UpdateParams, UpdateResult, VersionParams, }; -use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge}; +use super::server_bridge::{get_socket_rw_stream, ServerBridge}; +use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal}; type ServerBridgeList = Option>; type ServerBridgeListLock = Arc>; @@ -122,39 +123,6 @@ enum ServerSignal { Respawn, } -struct CloseReason(String); - -enum SocketSignal { - /// Signals bytes to send to the socket. - Send(Vec), - /// Closes the socket (e.g. as a result of an error) - CloseWith(CloseReason), - /// Disposes ServerBridge corresponding to an ID - CloseServerBridge(u16), -} - -impl SocketSignal { - fn from_message(msg: &T) -> Self - where - T: Serialize + ?Sized, - { - SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap()) - } -} - -impl FromServerMessage for SocketSignal { - fn from_server_message(i: u16, body: &[u8]) -> Self { - SocketSignal::from_message(&ToClientRequest { - id: None, - params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }), - }) - } - - fn from_closed_server_bridge(i: u16) -> Self { - SocketSignal::CloseServerBridge(i) - } -} - pub struct ServerTermination { /// Whether the server should be respawned in a new binary (see ServerSignal.Respawn). pub respawn: bool, @@ -719,7 +687,15 @@ async fn handle_serve( } }; - attach_server_bridge(&log, server, socket_tx, server_bridges, params.socket_id).await?; + attach_server_bridge( + &log, + server, + socket_tx, + server_bridges, + params.socket_id, + params.compress, + ) + .await?; Ok(EmptyResult {}) } @@ -729,8 +705,22 @@ async fn attach_server_bridge( socket_tx: mpsc::Sender, server_bridges: ServerBridgeListLock, socket_id: u16, + compress: bool, ) -> Result { - let attached_fut = ServerBridge::new(&code_server.socket, socket_id, &socket_tx).await; + let (server_messages, decoder) = if compress { + ( + ServerMessageSink::new_compressed(socket_tx), + ClientMessageDecoder::new_compressed(), + ) + } else { + ( + ServerMessageSink::new_plain(socket_tx), + ClientMessageDecoder::new_plain(), + ) + }; + + let attached_fut = + ServerBridge::new(&code_server.socket, socket_id, server_messages, decoder).await; match attached_fut { Ok(a) => { diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs index 649b15fed39..5f8d904cbcd 100644 --- a/cli/src/tunnels/protocol.rs +++ b/cli/src/tunnels/protocol.rs @@ -91,6 +91,9 @@ pub struct ServeParams { pub extensions: Vec, #[serde(default)] pub use_local_download: bool, + /// If true, the client and server should gzip servermsg's sent in either direction. + #[serde(default)] + pub compress: bool, } #[derive(Deserialize, Serialize, Debug)] diff --git a/cli/src/tunnels/server_bridge_unix.rs b/cli/src/tunnels/server_bridge_unix.rs index f584ddfddc9..9f06223ccbb 100644 --- a/cli/src/tunnels/server_bridge_unix.rs +++ b/cli/src/tunnels/server_bridge_unix.rs @@ -7,18 +7,15 @@ use std::path::Path; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{unix::OwnedWriteHalf, UnixStream}, - sync::mpsc::Sender, }; use crate::util::errors::{wrap, AnyError}; +use super::socket_signal::{ClientMessageDecoder, ServerMessageSink}; + pub struct ServerBridge { write: OwnedWriteHalf, -} - -pub trait FromServerMessage { - fn from_server_message(index: u16, message: &[u8]) -> Self; - fn from_closed_server_bridge(i: u16) -> Self; + decoder: ClientMessageDecoder, } pub async fn get_socket_rw_stream(path: &Path) -> Result { @@ -38,25 +35,26 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result { const BUFFER_SIZE: usize = 65536; impl ServerBridge { - pub async fn new(path: &Path, index: u16, target: &Sender) -> Result - where - T: 'static + FromServerMessage + Send, - { + pub async fn new( + path: &Path, + index: u16, + mut target: ServerMessageSink, + decoder: ClientMessageDecoder, + ) -> Result { let stream = get_socket_rw_stream(path).await?; let (mut read, write) = stream.into_split(); - let tx = target.clone(); tokio::spawn(async move { let mut read_buf = vec![0; BUFFER_SIZE]; loop { match read.read(&mut read_buf).await { Err(_) => return, Ok(0) => { - let _ = tx.send(T::from_closed_server_bridge(index)).await; + let _ = target.closed_server_bridge(index).await; return; // EOF } Ok(s) => { - let send = tx.send(T::from_server_message(index, &read_buf[..s])).await; + let send = target.server_message(index, &read_buf[..s]).await; if send.is_err() { return; } @@ -65,11 +63,14 @@ impl ServerBridge { } }); - Ok(ServerBridge { write }) + Ok(ServerBridge { write, decoder }) } pub async fn write(&mut self, b: Vec) -> std::io::Result<()> { - self.write.write_all(&b).await?; + let dec = self.decoder.decode(&b)?; + if !dec.is_empty() { + self.write.write_all(dec).await?; + } Ok(()) } diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs index fb4b2b321f0..3a4d911d53e 100644 --- a/cli/src/tunnels/server_bridge_windows.rs +++ b/cli/src/tunnels/server_bridge_windows.rs @@ -14,13 +14,11 @@ use tokio::{ use crate::util::errors::{wrap, AnyError}; +use super::socket_signal::{ClientMessageDecoder, ServerMessageSink}; + pub struct ServerBridge { write_tx: mpsc::Sender>, -} - -pub trait FromServerMessage { - fn from_server_message(index: u16, message: &[u8]) -> Self; - fn from_closed_server_bridge(i: u16) -> Self; + decoder: ClientMessageDecoder, } const BUFFER_SIZE: usize = 65536; @@ -49,13 +47,14 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result(path: &Path, index: u16, target: &mpsc::Sender) -> Result - where - T: 'static + FromServerMessage + Send, - { + pub async fn new( + path: &Path, + index: u16, + mut target: ServerMessageSink, + decoder: ClientMessageDecoder, + ) -> Result { let client = get_socket_rw_stream(path).await?; let (write_tx, mut write_rx) = mpsc::channel(4); - let read_tx = target.clone(); tokio::spawn(async move { let mut read_buf = vec![0; BUFFER_SIZE]; let mut pending_recv: Option> = None; @@ -89,9 +88,7 @@ impl ServerBridge { match client.try_read(&mut read_buf) { Ok(0) => return, // EOF Ok(s) => { - let send = read_tx - .send(T::from_server_message(index, &read_buf[..s])) - .await; + let send = target.server_message(index, &read_buf[..s]).await; if send.is_err() { return; } @@ -118,11 +115,14 @@ impl ServerBridge { } }); - Ok(ServerBridge { write_tx }) + Ok(ServerBridge { write_tx, decoder }) } pub async fn write(&self, b: Vec) -> std::io::Result<()> { - self.write_tx.send(b).await.ok(); + let dec = self.decoder.decode(&b)?; + if !dec.is_empty() { + self.write_tx.send(dec.to_vec()).await.ok(); + } Ok(()) } diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs new file mode 100644 index 00000000000..8a0d2a4eaf6 --- /dev/null +++ b/cli/src/tunnels/socket_signal.rs @@ -0,0 +1,244 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +use serde::Serialize; +use tokio::sync::mpsc; + +use super::protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest}; + +pub struct CloseReason(pub String); + +pub enum SocketSignal { + /// Signals bytes to send to the socket. + Send(Vec), + /// Closes the socket (e.g. as a result of an error) + CloseWith(CloseReason), + /// Disposes ServerBridge corresponding to an ID + CloseServerBridge(u16), +} + +impl SocketSignal { + pub fn from_message(msg: &T) -> Self + where + T: Serialize + ?Sized, + { + SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap()) + } +} + +/// Struct that handling sending or closing a connected server socket. +pub struct ServerMessageSink { + tx: mpsc::Sender, + flate: Option>, +} + +impl ServerMessageSink { + pub fn new_plain(tx: mpsc::Sender) -> Self { + Self { tx, flate: None } + } + + pub fn new_compressed(tx: mpsc::Sender) -> Self { + Self { + tx, + flate: Some(FlateStream::new(CompressFlateAlgorithm( + flate2::Compress::new(flate2::Compression::new(2), false), + ))), + } + } + + pub async fn server_message( + &mut self, + i: u16, + body: &[u8], + ) -> Result<(), mpsc::error::SendError> { + let msg = { + let body = self.get_server_msg_content(body); + SocketSignal::from_message(&ToClientRequest { + id: None, + params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }), + }) + }; + + self.tx.send(msg).await + } + + pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] { + if let Some(flate) = &mut self.flate { + if let Ok(compressed) = flate.process(body) { + return compressed; + } + } + + body + } + + pub async fn closed_server_bridge( + &mut self, + i: u16, + ) -> Result<(), mpsc::error::SendError> { + self.tx.send(SocketSignal::CloseServerBridge(i)).await + } +} + +pub struct ClientMessageDecoder { + dec: Option>, +} + +impl ClientMessageDecoder { + pub fn new_plain() -> Self { + ClientMessageDecoder { dec: None } + } + + pub fn new_compressed() -> Self { + ClientMessageDecoder { + dec: Some(FlateStream::new(DecompressFlateAlgorithm( + flate2::Decompress::new(false), + ))), + } + } + + #[allow(dead_code)] + pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> { + match &mut self.dec { + Some(d) => d.process(message), + None => Ok(message), + } + } +} + +trait FlateAlgorithm { + fn total_in(&self) -> u64; + fn total_out(&self) -> u64; + fn process( + &mut self, + contents: &[u8], + output: &mut [u8], + ) -> Result; +} + +struct DecompressFlateAlgorithm(flate2::Decompress); + +impl FlateAlgorithm for DecompressFlateAlgorithm { + fn total_in(&self) -> u64 { + self.0.total_in() + } + + fn total_out(&self) -> u64 { + self.0.total_out() + } + + fn process( + &mut self, + contents: &[u8], + output: &mut [u8], + ) -> Result { + self.0 + .decompress(contents, output, flate2::FlushDecompress::None) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) + } +} + +struct CompressFlateAlgorithm(flate2::Compress); + +impl FlateAlgorithm for CompressFlateAlgorithm { + fn total_in(&self) -> u64 { + self.0.total_in() + } + + fn total_out(&self) -> u64 { + self.0.total_out() + } + + fn process( + &mut self, + contents: &[u8], + output: &mut [u8], + ) -> Result { + self.0 + .compress(contents, output, flate2::FlushCompress::Sync) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e)) + } +} + +struct FlateStream +where + A: FlateAlgorithm, +{ + flate: A, + output: Vec, +} + +impl FlateStream +where + A: FlateAlgorithm, +{ + pub fn new(alg: A) -> Self { + Self { + flate: alg, + output: vec![0; 4096], + } + } + + pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> { + let mut out_offset = 0; + let mut in_offset = 0; + loop { + let in_before = self.flate.total_in(); + let out_before = self.flate.total_out(); + + match self + .flate + .process(&contents[in_offset..], &mut self.output[out_offset..]) + { + Ok(flate2::Status::Ok | flate2::Status::BufError) => { + let processed_len = in_offset + (self.flate.total_in() - in_before) as usize; + let output_len = out_offset + (self.flate.total_out() - out_before) as usize; + if processed_len < contents.len() { + // If we filled the output buffer but there's more data to compress, + // extend the output buffer and keep compressing. + out_offset = output_len; + in_offset = processed_len; + if output_len == self.output.len() { + self.output.resize(self.output.len() * 2, 0); + } + continue; + } + + return Ok(&self.output[..output_len]); + } + Ok(flate2::Status::StreamEnd) => { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected stream end", + )) + } + Err(e) => return Err(e), + } + } + } +} + +#[cfg(test)] +mod tests { + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + #[test] + fn test_round_trips_compression() { + let (tx, _) = mpsc::channel(1); + let mut sink = ServerMessageSink::new_compressed(tx); + let mut decompress = ClientMessageDecoder::new_compressed(); + + // 3000 and 30000 test resizing the buffer + for msg_len in [3, 30, 300, 3000, 30000] { + let vals = (0..msg_len).map(|v| v as u8).collect::>(); + let compressed = sink.get_server_msg_content(&vals); + assert_ne!(compressed, vals); + let decompressed = decompress.decode(compressed).unwrap(); + assert_eq!(decompressed.len(), vals.len()); + assert_eq!(decompressed, vals); + } + } +} -- cgit v1.2.3 From 0f9c86ec840195600b1c50adf6738bb7b0d1d7c2 Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Fri, 11 Nov 2022 13:16:20 -0800 Subject: fixup! windows compilation --- cli/src/tunnels/server_bridge_windows.rs | 2 +- cli/src/tunnels/socket_signal.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs index 3a4d911d53e..c7ac242fa6c 100644 --- a/cli/src/tunnels/server_bridge_windows.rs +++ b/cli/src/tunnels/server_bridge_windows.rs @@ -118,7 +118,7 @@ impl ServerBridge { Ok(ServerBridge { write_tx, decoder }) } - pub async fn write(&self, b: Vec) -> std::io::Result<()> { + pub async fn write(&mut self, b: Vec) -> std::io::Result<()> { let dec = self.decoder.decode(&b)?; if !dec.is_empty() { self.write_tx.send(dec.to_vec()).await.ok(); diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs index 8a0d2a4eaf6..95ed0bc3e0e 100644 --- a/cli/src/tunnels/socket_signal.rs +++ b/cli/src/tunnels/socket_signal.rs @@ -74,6 +74,7 @@ impl ServerMessageSink { body } + #[allow(dead_code)] pub async fn closed_server_bridge( &mut self, i: u16, @@ -99,7 +100,6 @@ impl ClientMessageDecoder { } } - #[allow(dead_code)] pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> { match &mut self.dec { Some(d) => d.process(message), -- cgit v1.2.3