diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2022-01-12 19:55:43 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2022-01-17 16:12:05 +0300 |
commit | 11fa95f708b9cb9775e934fb48db267505728a96 (patch) | |
tree | 33df77692d89c754eab203ef9840b04b617ca0a9 | |
parent | 4b4a2e960636d518c5c223f241d5780624740e2c (diff) |
sidechannel: proxy: allow early upstream return
Before this change, it was a proxy error if a sidechannel upstream
closed its connection without reading from downstream until EOF first.
But it turns out we need to support that use case.
This commit changes the proxy control flow so that it is not an error
if the upstream closes its connection before the downstream has been
fully consumed.
Changelog: fixed
-rw-r--r-- | internal/sidechannel/proxy.go | 51 | ||||
-rw-r--r-- | internal/sidechannel/proxy_test.go | 66 |
2 files changed, 74 insertions, 43 deletions
diff --git a/internal/sidechannel/proxy.go b/internal/sidechannel/proxy.go index 2302d44eb..db184ff88 100644 --- a/internal/sidechannel/proxy.go +++ b/internal/sidechannel/proxy.go @@ -84,43 +84,54 @@ func proxy(ctx context.Context) func(*ClientConn) error { } defer downstream.Close() - const nStreams = 2 - errC := make(chan error, nStreams) - + fromDownstream := make(chan error, 1) go func() { - errC <- func() error { + fromDownstream <- func() error { if _, err := io.Copy(upstream, downstream); err != nil { - return err + return fmt.Errorf("copy to upstream: %w", err) + } + + if err := upstream.CloseWrite(); err != nil { + return fmt.Errorf("closewrite upstream: %w", err) } - // Downstream.Read() has returned EOF. That means we are done proxying - // the request body from downstream to upstream. Propagate this EOF to - // upstream by calling CloseWrite(). Use CloseWrite(), not Close(), - // because we still want to read the response body from upstream in the - // other goroutine. - return upstream.CloseWrite() + return nil }() }() + fromUpstream := make(chan error, 1) go func() { - errC <- func() error { + fromUpstream <- func() error { if _, err := io.Copy(downstream, upstream); err != nil { - return err + return fmt.Errorf("copy to downstream: %w", err) } - // Upstream is now closed for both reads and writes. Propagate this state - // to downstream. This also happens via defer, but this way we can log - // the Close error if there is one. - return downstream.Close() + return nil }() }() - for i := 0; i < nStreams; i++ { - if err := <-errC; err != nil { - return err + waitForUpstream: + for { + select { + case err := <-fromUpstream: + if err != nil { + return err + } + + break waitForUpstream + case err := <-fromDownstream: + if err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() } } + if err := downstream.Close(); err != nil { + return fmt.Errorf("close downstream: %w", err) + } + return nil } } diff --git a/internal/sidechannel/proxy_test.go b/internal/sidechannel/proxy_test.go index 1501a02d3..103ecfd63 100644 --- a/internal/sidechannel/proxy_test.go +++ b/internal/sidechannel/proxy_test.go @@ -19,17 +19,25 @@ import ( healthpb "google.golang.org/grpc/health/grpc_health_v1" ) -func testProxyServer(ctx context.Context) (err error) { +func testProxyServer(ctx context.Context, expectEOF bool) (err error) { conn, err := OpenSidechannel(ctx) if err != nil { return err } defer conn.Close() - buf, err := io.ReadAll(conn) + var buf []byte + if expectEOF { + buf, err = io.ReadAll(conn) + } else { + buf = make([]byte, 5) + _, err = conn.Read(buf) + } + if err != nil { return fmt.Errorf("server read: %w", err) } + if string(buf) != "hello" { return fmt.Errorf("server: unexpected request: %q", buf) } @@ -44,30 +52,38 @@ func testProxyServer(ctx context.Context) (err error) { return nil } -func testProxyClient(conn *ClientConn) (err error) { - if _, err := io.WriteString(conn, "hello"); err != nil { - return fmt.Errorf("client write: %w", err) - } - if err := conn.CloseWrite(); err != nil { - return err - } +func testProxyClient(closeWrite bool) func(*ClientConn) error { + return func(conn *ClientConn) (err error) { + if _, err := io.WriteString(conn, "hello"); err != nil { + return fmt.Errorf("client write: %w", err) + } + if closeWrite { + if err := conn.CloseWrite(); err != nil { + return err + } + } - buf, err := io.ReadAll(conn) - if err != nil { - return fmt.Errorf("client read: %w", err) - } - if string(buf) != "world" { - return fmt.Errorf("client: unexpected response: %q", buf) - } + buf, err := io.ReadAll(conn) + if err != nil { + return fmt.Errorf("client read: %w", err) + } + if string(buf) != "world" { + return fmt.Errorf("client: unexpected response: %q", buf) + } - return nil + return nil + } } -func TestUnaryProxy(t *testing.T) { +func TestUnaryProxy(t *testing.T) { testUnaryProxy(t, true) } + +func TestUnaryProxy_withoutCloseWrite(t *testing.T) { testUnaryProxy(t, false) } + +func testUnaryProxy(t *testing.T, closeWrite bool) { upstreamAddr := startServer( t, func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { - if err := testProxyServer(ctx); err != nil { + if err := testProxyServer(ctx, closeWrite); err != nil { return nil, err } return &healthpb.HealthCheckResponse{}, nil @@ -92,7 +108,7 @@ func TestUnaryProxy(t *testing.T) { defer cancel() conn, registry := dial(t, proxyAddr) - require.NoError(t, call(ctx, conn, registry, testProxyClient)) + require.NoError(t, call(ctx, conn, registry, testProxyClient(closeWrite))) } func newLogger() *logrus.Entry { return logrus.NewEntry(logrus.New()) } @@ -115,11 +131,15 @@ func dialProxy(upstreamAddr string) (*grpc.ClientConn, error) { return grpc.Dial(upstreamAddr, dialOpts...) } -func TestStreamProxy(t *testing.T) { +func TestStreamProxy(t *testing.T) { testStreamProxy(t, true) } + +func TestStreamProxy_noCloseWrite(t *testing.T) { testStreamProxy(t, false) } + +func testStreamProxy(t *testing.T, closeWrite bool) { upstreamAddr := startStreamServer( t, func(stream gitalypb.SSHService_SSHUploadPackServer) error { - return testProxyServer(stream.Context()) + return testProxyServer(stream.Context(), closeWrite) }, ) @@ -150,7 +170,7 @@ func TestStreamProxy(t *testing.T) { defer cancel() conn, registry := dial(t, proxyAddr) - ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient) + ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient(closeWrite)) defer waiter.Close() client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctx) |