Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/microsoft/vscode.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorConnor Peet <connor@peet.io>2022-11-13 01:40:34 +0300
committerGitHub <noreply@github.com>2022-11-13 01:40:34 +0300
commitaa60836c16abb03839962354138aed04a9344968 (patch)
treee1386246117ff69efb5e10cc22eda82733f27460
parent3b81dd8d48797b9d9caf20ff3a089df90008b0d2 (diff)
parent0f9c86ec840195600b1c50adf6738bb7b0d1d7c2 (diff)
Merge pull request #166139 from microsoft/connor4312/cli-server-message-compression
cli: enable server message compression
-rw-r--r--cli/src/constants.rs8
-rw-r--r--cli/src/tunnels.rs1
-rw-r--r--cli/src/tunnels/control_server.rs72
-rw-r--r--cli/src/tunnels/protocol.rs3
-rw-r--r--cli/src/tunnels/server_bridge_unix.rs31
-rw-r--r--cli/src/tunnels/server_bridge_windows.rs32
-rw-r--r--cli/src/tunnels/socket_signal.rs244
7 files changed, 318 insertions, 73 deletions
diff --git a/cli/src/constants.rs b/cli/src/constants.rs
index 1541a6a43a4..ab412fb860f 100644
--- a/cli/src/constants.rs
+++ b/cli/src/constants.rs
@@ -10,7 +10,13 @@ use lazy_static::lazy_static;
use crate::options::Quality;
pub const CONTROL_PORT: u16 = 31545;
-pub const PROTOCOL_VERSION: u32 = 1;
+
+/// Protocol version sent to clients. This can be used to indiciate new or
+/// changed capabilities that clients may wish to leverage.
+/// 1 - Initial protocol version
+/// 2 - Addition of `serve.compressed` property to control whether servermsg's
+/// are compressed bidirectionally.
+pub const PROTOCOL_VERSION: u32 = 2;
pub const VSCODE_CLI_VERSION: Option<&'static str> = option_env!("VSCODE_CLI_VERSION");
pub const VSCODE_CLI_AI_KEY: Option<&'static str> = option_env!("VSCODE_CLI_AI_KEY");
diff --git a/cli/src/tunnels.rs b/cli/src/tunnels.rs
index d94e47addf3..638432b040d 100644
--- a/cli/src/tunnels.rs
+++ b/cli/src/tunnels.rs
@@ -8,6 +8,7 @@ pub mod dev_tunnels;
pub mod legal;
pub mod paths;
+mod socket_signal;
mod control_server;
mod name_generator;
mod port_forwarder;
diff --git a/cli/src/tunnels/control_server.rs b/cli/src/tunnels/control_server.rs
index c04257f715f..5d62cd57d65 100644
--- a/cli/src/tunnels/control_server.rs
+++ b/cli/src/tunnels/control_server.rs
@@ -7,6 +7,8 @@ use crate::constants::{CONTROL_PORT, PROTOCOL_VERSION, VSCODE_CLI_VERSION};
use crate::log;
use crate::self_update::SelfUpdate;
use crate::state::LauncherPaths;
+use crate::tunnels::protocol::HttpRequestParams;
+use crate::tunnels::socket_signal::CloseReason;
use crate::update_service::{Platform, UpdateService};
use crate::util::errors::{
wrap, AnyError, MismatchedLaunchModeError, NoAttachedServerError, ServerWriteError,
@@ -18,7 +20,6 @@ use crate::util::io::SilentCopyProgress;
use crate::util::sync::{new_barrier, Barrier};
use opentelemetry::trace::SpanKind;
use opentelemetry::KeyValue;
-use serde::Serialize;
use std::collections::HashMap;
use std::convert::Infallible;
use std::env;
@@ -38,12 +39,12 @@ use super::paths::prune_stopped_servers;
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
CallServerHttpParams, CallServerHttpResult, ClientRequestMethod, EmptyResult, ErrorResponse,
- ForwardParams, ForwardResult, GetHostnameResponse, HttpRequestParams, RefServerMessageParams,
- ResponseError, ServeParams, ServerLog, ServerMessageParams, ServerRequestMethod,
- SuccessResponse, ToClientRequest, ToServerRequest, UnforwardParams, UpdateParams, UpdateResult,
- VersionParams,
+ ForwardParams, ForwardResult, GetHostnameResponse, ResponseError, ServeParams, ServerLog,
+ ServerMessageParams, ServerRequestMethod, SuccessResponse, ToClientRequest, ToServerRequest,
+ UnforwardParams, UpdateParams, UpdateResult, VersionParams,
};
-use super::server_bridge::{get_socket_rw_stream, FromServerMessage, ServerBridge};
+use super::server_bridge::{get_socket_rw_stream, ServerBridge};
+use super::socket_signal::{ClientMessageDecoder, ServerMessageSink, SocketSignal};
type ServerBridgeList = Option<Vec<(u16, ServerBridge)>>;
type ServerBridgeListLock = Arc<Mutex<ServerBridgeList>>;
@@ -122,39 +123,6 @@ enum ServerSignal {
Respawn,
}
-struct CloseReason(String);
-
-enum SocketSignal {
- /// Signals bytes to send to the socket.
- Send(Vec<u8>),
- /// Closes the socket (e.g. as a result of an error)
- CloseWith(CloseReason),
- /// Disposes ServerBridge corresponding to an ID
- CloseServerBridge(u16),
-}
-
-impl SocketSignal {
- fn from_message<T>(msg: &T) -> Self
- where
- T: Serialize + ?Sized,
- {
- SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
- }
-}
-
-impl FromServerMessage for SocketSignal {
- fn from_server_message(i: u16, body: &[u8]) -> Self {
- SocketSignal::from_message(&ToClientRequest {
- id: None,
- params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }),
- })
- }
-
- fn from_closed_server_bridge(i: u16) -> Self {
- SocketSignal::CloseServerBridge(i)
- }
-}
-
pub struct ServerTermination {
/// Whether the server should be respawned in a new binary (see ServerSignal.Respawn).
pub respawn: bool,
@@ -719,7 +687,15 @@ async fn handle_serve(
}
};
- attach_server_bridge(&log, server, socket_tx, server_bridges, params.socket_id).await?;
+ attach_server_bridge(
+ &log,
+ server,
+ socket_tx,
+ server_bridges,
+ params.socket_id,
+ params.compress,
+ )
+ .await?;
Ok(EmptyResult {})
}
@@ -729,8 +705,22 @@ async fn attach_server_bridge(
socket_tx: mpsc::Sender<SocketSignal>,
server_bridges: ServerBridgeListLock,
socket_id: u16,
+ compress: bool,
) -> Result<u16, AnyError> {
- let attached_fut = ServerBridge::new(&code_server.socket, socket_id, &socket_tx).await;
+ let (server_messages, decoder) = if compress {
+ (
+ ServerMessageSink::new_compressed(socket_tx),
+ ClientMessageDecoder::new_compressed(),
+ )
+ } else {
+ (
+ ServerMessageSink::new_plain(socket_tx),
+ ClientMessageDecoder::new_plain(),
+ )
+ };
+
+ let attached_fut =
+ ServerBridge::new(&code_server.socket, socket_id, server_messages, decoder).await;
match attached_fut {
Ok(a) => {
diff --git a/cli/src/tunnels/protocol.rs b/cli/src/tunnels/protocol.rs
index 649b15fed39..5f8d904cbcd 100644
--- a/cli/src/tunnels/protocol.rs
+++ b/cli/src/tunnels/protocol.rs
@@ -91,6 +91,9 @@ pub struct ServeParams {
pub extensions: Vec<String>,
#[serde(default)]
pub use_local_download: bool,
+ /// If true, the client and server should gzip servermsg's sent in either direction.
+ #[serde(default)]
+ pub compress: bool,
}
#[derive(Deserialize, Serialize, Debug)]
diff --git a/cli/src/tunnels/server_bridge_unix.rs b/cli/src/tunnels/server_bridge_unix.rs
index f584ddfddc9..9f06223ccbb 100644
--- a/cli/src/tunnels/server_bridge_unix.rs
+++ b/cli/src/tunnels/server_bridge_unix.rs
@@ -7,18 +7,15 @@ use std::path::Path;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{unix::OwnedWriteHalf, UnixStream},
- sync::mpsc::Sender,
};
use crate::util::errors::{wrap, AnyError};
+use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
+
pub struct ServerBridge {
write: OwnedWriteHalf,
-}
-
-pub trait FromServerMessage {
- fn from_server_message(index: u16, message: &[u8]) -> Self;
- fn from_closed_server_bridge(i: u16) -> Self;
+ decoder: ClientMessageDecoder,
}
pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
@@ -38,25 +35,26 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result<UnixStream, AnyError> {
const BUFFER_SIZE: usize = 65536;
impl ServerBridge {
- pub async fn new<T>(path: &Path, index: u16, target: &Sender<T>) -> Result<Self, AnyError>
- where
- T: 'static + FromServerMessage + Send,
- {
+ pub async fn new(
+ path: &Path,
+ index: u16,
+ mut target: ServerMessageSink,
+ decoder: ClientMessageDecoder,
+ ) -> Result<Self, AnyError> {
let stream = get_socket_rw_stream(path).await?;
let (mut read, write) = stream.into_split();
- let tx = target.clone();
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];
loop {
match read.read(&mut read_buf).await {
Err(_) => return,
Ok(0) => {
- let _ = tx.send(T::from_closed_server_bridge(index)).await;
+ let _ = target.closed_server_bridge(index).await;
return; // EOF
}
Ok(s) => {
- let send = tx.send(T::from_server_message(index, &read_buf[..s])).await;
+ let send = target.server_message(index, &read_buf[..s]).await;
if send.is_err() {
return;
}
@@ -65,11 +63,14 @@ impl ServerBridge {
}
});
- Ok(ServerBridge { write })
+ Ok(ServerBridge { write, decoder })
}
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
- self.write.write_all(&b).await?;
+ let dec = self.decoder.decode(&b)?;
+ if !dec.is_empty() {
+ self.write.write_all(dec).await?;
+ }
Ok(())
}
diff --git a/cli/src/tunnels/server_bridge_windows.rs b/cli/src/tunnels/server_bridge_windows.rs
index fb4b2b321f0..c7ac242fa6c 100644
--- a/cli/src/tunnels/server_bridge_windows.rs
+++ b/cli/src/tunnels/server_bridge_windows.rs
@@ -14,13 +14,11 @@ use tokio::{
use crate::util::errors::{wrap, AnyError};
+use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
+
pub struct ServerBridge {
write_tx: mpsc::Sender<Vec<u8>>,
-}
-
-pub trait FromServerMessage {
- fn from_server_message(index: u16, message: &[u8]) -> Self;
- fn from_closed_server_bridge(i: u16) -> Self;
+ decoder: ClientMessageDecoder,
}
const BUFFER_SIZE: usize = 65536;
@@ -49,13 +47,14 @@ pub async fn get_socket_rw_stream(path: &Path) -> Result<NamedPipeClient, AnyErr
}
impl ServerBridge {
- pub async fn new<T>(path: &Path, index: u16, target: &mpsc::Sender<T>) -> Result<Self, AnyError>
- where
- T: 'static + FromServerMessage + Send,
- {
+ pub async fn new(
+ path: &Path,
+ index: u16,
+ mut target: ServerMessageSink,
+ decoder: ClientMessageDecoder,
+ ) -> Result<Self, AnyError> {
let client = get_socket_rw_stream(path).await?;
let (write_tx, mut write_rx) = mpsc::channel(4);
- let read_tx = target.clone();
tokio::spawn(async move {
let mut read_buf = vec![0; BUFFER_SIZE];
let mut pending_recv: Option<Vec<u8>> = None;
@@ -89,9 +88,7 @@ impl ServerBridge {
match client.try_read(&mut read_buf) {
Ok(0) => return, // EOF
Ok(s) => {
- let send = read_tx
- .send(T::from_server_message(index, &read_buf[..s]))
- .await;
+ let send = target.server_message(index, &read_buf[..s]).await;
if send.is_err() {
return;
}
@@ -118,11 +115,14 @@ impl ServerBridge {
}
});
- Ok(ServerBridge { write_tx })
+ Ok(ServerBridge { write_tx, decoder })
}
- pub async fn write(&self, b: Vec<u8>) -> std::io::Result<()> {
- self.write_tx.send(b).await.ok();
+ pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
+ let dec = self.decoder.decode(&b)?;
+ if !dec.is_empty() {
+ self.write_tx.send(dec.to_vec()).await.ok();
+ }
Ok(())
}
diff --git a/cli/src/tunnels/socket_signal.rs b/cli/src/tunnels/socket_signal.rs
new file mode 100644
index 00000000000..95ed0bc3e0e
--- /dev/null
+++ b/cli/src/tunnels/socket_signal.rs
@@ -0,0 +1,244 @@
+/*---------------------------------------------------------------------------------------------
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License. See License.txt in the project root for license information.
+ *--------------------------------------------------------------------------------------------*/
+
+use serde::Serialize;
+use tokio::sync::mpsc;
+
+use super::protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest};
+
+pub struct CloseReason(pub String);
+
+pub enum SocketSignal {
+ /// Signals bytes to send to the socket.
+ Send(Vec<u8>),
+ /// Closes the socket (e.g. as a result of an error)
+ CloseWith(CloseReason),
+ /// Disposes ServerBridge corresponding to an ID
+ CloseServerBridge(u16),
+}
+
+impl SocketSignal {
+ pub fn from_message<T>(msg: &T) -> Self
+ where
+ T: Serialize + ?Sized,
+ {
+ SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
+ }
+}
+
+/// Struct that handling sending or closing a connected server socket.
+pub struct ServerMessageSink {
+ tx: mpsc::Sender<SocketSignal>,
+ flate: Option<FlateStream<CompressFlateAlgorithm>>,
+}
+
+impl ServerMessageSink {
+ pub fn new_plain(tx: mpsc::Sender<SocketSignal>) -> Self {
+ Self { tx, flate: None }
+ }
+
+ pub fn new_compressed(tx: mpsc::Sender<SocketSignal>) -> Self {
+ Self {
+ tx,
+ flate: Some(FlateStream::new(CompressFlateAlgorithm(
+ flate2::Compress::new(flate2::Compression::new(2), false),
+ ))),
+ }
+ }
+
+ pub async fn server_message(
+ &mut self,
+ i: u16,
+ body: &[u8],
+ ) -> Result<(), mpsc::error::SendError<SocketSignal>> {
+ let msg = {
+ let body = self.get_server_msg_content(body);
+ SocketSignal::from_message(&ToClientRequest {
+ id: None,
+ params: ClientRequestMethod::servermsg(RefServerMessageParams { i, body }),
+ })
+ };
+
+ self.tx.send(msg).await
+ }
+
+ pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] {
+ if let Some(flate) = &mut self.flate {
+ if let Ok(compressed) = flate.process(body) {
+ return compressed;
+ }
+ }
+
+ body
+ }
+
+ #[allow(dead_code)]
+ pub async fn closed_server_bridge(
+ &mut self,
+ i: u16,
+ ) -> Result<(), mpsc::error::SendError<SocketSignal>> {
+ self.tx.send(SocketSignal::CloseServerBridge(i)).await
+ }
+}
+
+pub struct ClientMessageDecoder {
+ dec: Option<FlateStream<DecompressFlateAlgorithm>>,
+}
+
+impl ClientMessageDecoder {
+ pub fn new_plain() -> Self {
+ ClientMessageDecoder { dec: None }
+ }
+
+ pub fn new_compressed() -> Self {
+ ClientMessageDecoder {
+ dec: Some(FlateStream::new(DecompressFlateAlgorithm(
+ flate2::Decompress::new(false),
+ ))),
+ }
+ }
+
+ pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
+ match &mut self.dec {
+ Some(d) => d.process(message),
+ None => Ok(message),
+ }
+ }
+}
+
+trait FlateAlgorithm {
+ fn total_in(&self) -> u64;
+ fn total_out(&self) -> u64;
+ fn process(
+ &mut self,
+ contents: &[u8],
+ output: &mut [u8],
+ ) -> Result<flate2::Status, std::io::Error>;
+}
+
+struct DecompressFlateAlgorithm(flate2::Decompress);
+
+impl FlateAlgorithm for DecompressFlateAlgorithm {
+ fn total_in(&self) -> u64 {
+ self.0.total_in()
+ }
+
+ fn total_out(&self) -> u64 {
+ self.0.total_out()
+ }
+
+ fn process(
+ &mut self,
+ contents: &[u8],
+ output: &mut [u8],
+ ) -> Result<flate2::Status, std::io::Error> {
+ self.0
+ .decompress(contents, output, flate2::FlushDecompress::None)
+ .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
+ }
+}
+
+struct CompressFlateAlgorithm(flate2::Compress);
+
+impl FlateAlgorithm for CompressFlateAlgorithm {
+ fn total_in(&self) -> u64 {
+ self.0.total_in()
+ }
+
+ fn total_out(&self) -> u64 {
+ self.0.total_out()
+ }
+
+ fn process(
+ &mut self,
+ contents: &[u8],
+ output: &mut [u8],
+ ) -> Result<flate2::Status, std::io::Error> {
+ self.0
+ .compress(contents, output, flate2::FlushCompress::Sync)
+ .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
+ }
+}
+
+struct FlateStream<A>
+where
+ A: FlateAlgorithm,
+{
+ flate: A,
+ output: Vec<u8>,
+}
+
+impl<A> FlateStream<A>
+where
+ A: FlateAlgorithm,
+{
+ pub fn new(alg: A) -> Self {
+ Self {
+ flate: alg,
+ output: vec![0; 4096],
+ }
+ }
+
+ pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> {
+ let mut out_offset = 0;
+ let mut in_offset = 0;
+ loop {
+ let in_before = self.flate.total_in();
+ let out_before = self.flate.total_out();
+
+ match self
+ .flate
+ .process(&contents[in_offset..], &mut self.output[out_offset..])
+ {
+ Ok(flate2::Status::Ok | flate2::Status::BufError) => {
+ let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
+ let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
+ if processed_len < contents.len() {
+ // If we filled the output buffer but there's more data to compress,
+ // extend the output buffer and keep compressing.
+ out_offset = output_len;
+ in_offset = processed_len;
+ if output_len == self.output.len() {
+ self.output.resize(self.output.len() * 2, 0);
+ }
+ continue;
+ }
+
+ return Ok(&self.output[..output_len]);
+ }
+ Ok(flate2::Status::StreamEnd) => {
+ return Err(std::io::Error::new(
+ std::io::ErrorKind::UnexpectedEof,
+ "unexpected stream end",
+ ))
+ }
+ Err(e) => return Err(e),
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ // Note this useful idiom: importing names from outer (for mod tests) scope.
+ use super::*;
+
+ #[test]
+ fn test_round_trips_compression() {
+ let (tx, _) = mpsc::channel(1);
+ let mut sink = ServerMessageSink::new_compressed(tx);
+ let mut decompress = ClientMessageDecoder::new_compressed();
+
+ // 3000 and 30000 test resizing the buffer
+ for msg_len in [3, 30, 300, 3000, 30000] {
+ let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
+ let compressed = sink.get_server_msg_content(&vals);
+ assert_ne!(compressed, vals);
+ let decompressed = decompress.decode(compressed).unwrap();
+ assert_eq!(decompressed.len(), vals.len());
+ assert_eq!(decompressed, vals);
+ }
+ }
+}