Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2022-01-12 19:55:43 +0300
committerJacob Vosmaer <jacob@gitlab.com>2022-01-17 16:12:05 +0300
commit11fa95f708b9cb9775e934fb48db267505728a96 (patch)
tree33df77692d89c754eab203ef9840b04b617ca0a9
parent4b4a2e960636d518c5c223f241d5780624740e2c (diff)
sidechannel: proxy: allow early upstream return
Before this change, it was a proxy error if a sidechannel upstream closed its connection without reading from downstream until EOF first. But it turns out we need to support that use case. This commit changes the proxy control flow so that it is not an error if the upstream closes its connection before the downstream has been fully consumed. Changelog: fixed
-rw-r--r--internal/sidechannel/proxy.go51
-rw-r--r--internal/sidechannel/proxy_test.go66
2 files changed, 74 insertions, 43 deletions
diff --git a/internal/sidechannel/proxy.go b/internal/sidechannel/proxy.go
index 2302d44eb..db184ff88 100644
--- a/internal/sidechannel/proxy.go
+++ b/internal/sidechannel/proxy.go
@@ -84,43 +84,54 @@ func proxy(ctx context.Context) func(*ClientConn) error {
}
defer downstream.Close()
- const nStreams = 2
- errC := make(chan error, nStreams)
-
+ fromDownstream := make(chan error, 1)
go func() {
- errC <- func() error {
+ fromDownstream <- func() error {
if _, err := io.Copy(upstream, downstream); err != nil {
- return err
+ return fmt.Errorf("copy to upstream: %w", err)
+ }
+
+ if err := upstream.CloseWrite(); err != nil {
+ return fmt.Errorf("closewrite upstream: %w", err)
}
- // Downstream.Read() has returned EOF. That means we are done proxying
- // the request body from downstream to upstream. Propagate this EOF to
- // upstream by calling CloseWrite(). Use CloseWrite(), not Close(),
- // because we still want to read the response body from upstream in the
- // other goroutine.
- return upstream.CloseWrite()
+ return nil
}()
}()
+ fromUpstream := make(chan error, 1)
go func() {
- errC <- func() error {
+ fromUpstream <- func() error {
if _, err := io.Copy(downstream, upstream); err != nil {
- return err
+ return fmt.Errorf("copy to downstream: %w", err)
}
- // Upstream is now closed for both reads and writes. Propagate this state
- // to downstream. This also happens via defer, but this way we can log
- // the Close error if there is one.
- return downstream.Close()
+ return nil
}()
}()
- for i := 0; i < nStreams; i++ {
- if err := <-errC; err != nil {
- return err
+ waitForUpstream:
+ for {
+ select {
+ case err := <-fromUpstream:
+ if err != nil {
+ return err
+ }
+
+ break waitForUpstream
+ case err := <-fromDownstream:
+ if err != nil {
+ return err
+ }
+ case <-ctx.Done():
+ return ctx.Err()
}
}
+ if err := downstream.Close(); err != nil {
+ return fmt.Errorf("close downstream: %w", err)
+ }
+
return nil
}
}
diff --git a/internal/sidechannel/proxy_test.go b/internal/sidechannel/proxy_test.go
index 1501a02d3..103ecfd63 100644
--- a/internal/sidechannel/proxy_test.go
+++ b/internal/sidechannel/proxy_test.go
@@ -19,17 +19,25 @@ import (
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
-func testProxyServer(ctx context.Context) (err error) {
+func testProxyServer(ctx context.Context, expectEOF bool) (err error) {
conn, err := OpenSidechannel(ctx)
if err != nil {
return err
}
defer conn.Close()
- buf, err := io.ReadAll(conn)
+ var buf []byte
+ if expectEOF {
+ buf, err = io.ReadAll(conn)
+ } else {
+ buf = make([]byte, 5)
+ _, err = conn.Read(buf)
+ }
+
if err != nil {
return fmt.Errorf("server read: %w", err)
}
+
if string(buf) != "hello" {
return fmt.Errorf("server: unexpected request: %q", buf)
}
@@ -44,30 +52,38 @@ func testProxyServer(ctx context.Context) (err error) {
return nil
}
-func testProxyClient(conn *ClientConn) (err error) {
- if _, err := io.WriteString(conn, "hello"); err != nil {
- return fmt.Errorf("client write: %w", err)
- }
- if err := conn.CloseWrite(); err != nil {
- return err
- }
+func testProxyClient(closeWrite bool) func(*ClientConn) error {
+ return func(conn *ClientConn) (err error) {
+ if _, err := io.WriteString(conn, "hello"); err != nil {
+ return fmt.Errorf("client write: %w", err)
+ }
+ if closeWrite {
+ if err := conn.CloseWrite(); err != nil {
+ return err
+ }
+ }
- buf, err := io.ReadAll(conn)
- if err != nil {
- return fmt.Errorf("client read: %w", err)
- }
- if string(buf) != "world" {
- return fmt.Errorf("client: unexpected response: %q", buf)
- }
+ buf, err := io.ReadAll(conn)
+ if err != nil {
+ return fmt.Errorf("client read: %w", err)
+ }
+ if string(buf) != "world" {
+ return fmt.Errorf("client: unexpected response: %q", buf)
+ }
- return nil
+ return nil
+ }
}
-func TestUnaryProxy(t *testing.T) {
+func TestUnaryProxy(t *testing.T) { testUnaryProxy(t, true) }
+
+func TestUnaryProxy_withoutCloseWrite(t *testing.T) { testUnaryProxy(t, false) }
+
+func testUnaryProxy(t *testing.T, closeWrite bool) {
upstreamAddr := startServer(
t,
func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
- if err := testProxyServer(ctx); err != nil {
+ if err := testProxyServer(ctx, closeWrite); err != nil {
return nil, err
}
return &healthpb.HealthCheckResponse{}, nil
@@ -92,7 +108,7 @@ func TestUnaryProxy(t *testing.T) {
defer cancel()
conn, registry := dial(t, proxyAddr)
- require.NoError(t, call(ctx, conn, registry, testProxyClient))
+ require.NoError(t, call(ctx, conn, registry, testProxyClient(closeWrite)))
}
func newLogger() *logrus.Entry { return logrus.NewEntry(logrus.New()) }
@@ -115,11 +131,15 @@ func dialProxy(upstreamAddr string) (*grpc.ClientConn, error) {
return grpc.Dial(upstreamAddr, dialOpts...)
}
-func TestStreamProxy(t *testing.T) {
+func TestStreamProxy(t *testing.T) { testStreamProxy(t, true) }
+
+func TestStreamProxy_noCloseWrite(t *testing.T) { testStreamProxy(t, false) }
+
+func testStreamProxy(t *testing.T, closeWrite bool) {
upstreamAddr := startStreamServer(
t,
func(stream gitalypb.SSHService_SSHUploadPackServer) error {
- return testProxyServer(stream.Context())
+ return testProxyServer(stream.Context(), closeWrite)
},
)
@@ -150,7 +170,7 @@ func TestStreamProxy(t *testing.T) {
defer cancel()
conn, registry := dial(t, proxyAddr)
- ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient)
+ ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient(closeWrite))
defer waiter.Close()
client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctx)