Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2021-09-15 16:35:34 +0300
committerJacob Vosmaer <jacob@gitlab.com>2021-09-30 18:27:04 +0300
commita12b72cf06b17416b36102f4b7b096df1405b896 (patch)
treef731823e851c269fb9c93a1028aa4689432f54fc /internal
parent098ee38438822b534723e76bad0bc3a5ba8baadb (diff)
sidechannel: add proxy middleware
This adds gRPC client interceptors that can proxy sidechannels. This is meant to be used in Praefect. Changelog: other
Diffstat (limited to 'internal')
-rw-r--r--internal/sidechannel/proxy.go126
-rw-r--r--internal/sidechannel/proxy_test.go193
-rw-r--r--internal/sidechannel/registry.go3
-rw-r--r--internal/sidechannel/sidechannel_test.go9
4 files changed, 324 insertions, 7 deletions
diff --git a/internal/sidechannel/proxy.go b/internal/sidechannel/proxy.go
new file mode 100644
index 000000000..2302d44eb
--- /dev/null
+++ b/internal/sidechannel/proxy.go
@@ -0,0 +1,126 @@
+package sidechannel
+
+import (
+ "context"
+ "fmt"
+ "io"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
+ "google.golang.org/grpc"
+ grpcMetadata "google.golang.org/grpc/metadata"
+)
+
+// NewUnaryProxy creates a gRPC client middleware that proxies sidechannels.
+func NewUnaryProxy(registry *Registry) grpc.UnaryClientInterceptor {
+ return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+ if !hasSidechannelMetadata(ctx) {
+ return invoker(ctx, method, req, reply, cc, opts...)
+ }
+
+ ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx))
+ defer waiter.Close()
+
+ if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
+ return err
+ }
+ if err := waiter.Close(); err != nil && err != ErrCallbackDidNotRun {
+ return fmt.Errorf("sidechannel: proxy callback: %w", err)
+ }
+ return nil
+ }
+}
+
+// NewStreamProxy creates a gRPC client middleware that proxies sidechannels.
+func NewStreamProxy(registry *Registry) grpc.StreamClientInterceptor {
+ return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+ if !hasSidechannelMetadata(ctx) {
+ return streamer(ctx, desc, cc, method, opts...)
+ }
+
+ ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx))
+ go func() {
+ <-ctx.Done()
+ // The Close() error is checked and bubbled up in
+ // streamWrapper.RecvMsg(). This call is just for cleanup.
+ _ = waiter.Close()
+ }()
+
+ cs, err := streamer(ctx, desc, cc, method, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ return &streamWrapper{ClientStream: cs, waiter: waiter}, nil
+ }
+}
+
+type streamWrapper struct {
+ grpc.ClientStream
+ waiter *Waiter
+}
+
+func (sw *streamWrapper) RecvMsg(m interface{}) error {
+ if err := sw.ClientStream.RecvMsg(m); err != io.EOF {
+ return err
+ }
+
+ if err := sw.waiter.Close(); err != nil && err != ErrCallbackDidNotRun {
+ return fmt.Errorf("sidechannel: proxy callback: %w", err)
+ }
+
+ return io.EOF
+}
+
+func hasSidechannelMetadata(ctx context.Context) bool {
+ md, ok := grpcMetadata.FromOutgoingContext(ctx)
+ return ok && len(md.Get(sidechannelMetadataKey)) > 0
+}
+
+func proxy(ctx context.Context) func(*ClientConn) error {
+ return func(upstream *ClientConn) error {
+ downstream, err := OpenSidechannel(metadata.OutgoingToIncoming(ctx))
+ if err != nil {
+ return err
+ }
+ defer downstream.Close()
+
+ const nStreams = 2
+ errC := make(chan error, nStreams)
+
+ go func() {
+ errC <- func() error {
+ if _, err := io.Copy(upstream, downstream); err != nil {
+ return err
+ }
+
+ // Downstream.Read() has returned EOF. That means we are done proxying
+ // the request body from downstream to upstream. Propagate this EOF to
+ // upstream by calling CloseWrite(). Use CloseWrite(), not Close(),
+ // because we still want to read the response body from upstream in the
+ // other goroutine.
+ return upstream.CloseWrite()
+ }()
+ }()
+
+ go func() {
+ errC <- func() error {
+ if _, err := io.Copy(downstream, upstream); err != nil {
+ return err
+ }
+
+ // Upstream is now closed for both reads and writes. Propagate this state
+ // to downstream. This also happens via defer, but this way we can log
+ // the Close error if there is one.
+ return downstream.Close()
+ }()
+ }()
+
+ for i := 0; i < nStreams; i++ {
+ if err := <-errC; err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+}
diff --git a/internal/sidechannel/proxy_test.go b/internal/sidechannel/proxy_test.go
new file mode 100644
index 000000000..1501a02d3
--- /dev/null
+++ b/internal/sidechannel/proxy_test.go
@@ -0,0 +1,193 @@
+package sidechannel
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ healthpb "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func testProxyServer(ctx context.Context) (err error) {
+ conn, err := OpenSidechannel(ctx)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ buf, err := io.ReadAll(conn)
+ if err != nil {
+ return fmt.Errorf("server read: %w", err)
+ }
+ if string(buf) != "hello" {
+ return fmt.Errorf("server: unexpected request: %q", buf)
+ }
+
+ if _, err := io.WriteString(conn, "world"); err != nil {
+ return fmt.Errorf("server write: %w", err)
+ }
+ if err := conn.Close(); err != nil {
+ return fmt.Errorf("server close: %w", err)
+ }
+
+ return nil
+}
+
+func testProxyClient(conn *ClientConn) (err error) {
+ if _, err := io.WriteString(conn, "hello"); err != nil {
+ return fmt.Errorf("client write: %w", err)
+ }
+ if err := conn.CloseWrite(); err != nil {
+ return err
+ }
+
+ buf, err := io.ReadAll(conn)
+ if err != nil {
+ return fmt.Errorf("client read: %w", err)
+ }
+ if string(buf) != "world" {
+ return fmt.Errorf("client: unexpected response: %q", buf)
+ }
+
+ return nil
+}
+
+func TestUnaryProxy(t *testing.T) {
+ upstreamAddr := startServer(
+ t,
+ func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
+ if err := testProxyServer(ctx); err != nil {
+ return nil, err
+ }
+ return &healthpb.HealthCheckResponse{}, nil
+ },
+ )
+
+ proxyAddr := startServer(
+ t,
+ func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
+ conn, err := dialProxy(upstreamAddr)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ ctxOut := metadata.IncomingToOutgoing(ctx)
+ return healthpb.NewHealthClient(conn).Check(ctxOut, request)
+ },
+ )
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ conn, registry := dial(t, proxyAddr)
+ require.NoError(t, call(ctx, conn, registry, testProxyClient))
+}
+
+func newLogger() *logrus.Entry { return logrus.NewEntry(logrus.New()) }
+
+func dialProxy(upstreamAddr string) (*grpc.ClientConn, error) {
+ registry := NewRegistry()
+ factory := func() backchannel.Server {
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(NewServerHandshaker(registry))
+ return grpc.NewServer(grpc.Creds(lm))
+ }
+
+ clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory)
+ dialOpts := []grpc.DialOption{
+ grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())),
+ grpc.WithUnaryInterceptor(NewUnaryProxy(registry)),
+ grpc.WithStreamInterceptor(NewStreamProxy(registry)),
+ }
+
+ return grpc.Dial(upstreamAddr, dialOpts...)
+}
+
+func TestStreamProxy(t *testing.T) {
+ upstreamAddr := startStreamServer(
+ t,
+ func(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ return testProxyServer(stream.Context())
+ },
+ )
+
+ proxyAddr := startStreamServer(
+ t,
+ func(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ conn, err := dialProxy(upstreamAddr)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ ctxOut := metadata.IncomingToOutgoing(stream.Context())
+ client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctxOut)
+ if err != nil {
+ return err
+ }
+
+ if _, err := client.Recv(); err != io.EOF {
+ return fmt.Errorf("grpc proxy recv: %w", err)
+ }
+
+ return nil
+ },
+ )
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ conn, registry := dial(t, proxyAddr)
+ ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient)
+ defer waiter.Close()
+
+ client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctx)
+ require.NoError(t, err)
+
+ _, err = client.Recv()
+ require.Equal(t, io.EOF, err)
+
+ require.NoError(t, waiter.Close())
+}
+
+type mockSSHService struct {
+ sshUploadPackFunc func(gitalypb.SSHService_SSHUploadPackServer) error
+ gitalypb.UnimplementedSSHServiceServer
+}
+
+func (m mockSSHService) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ return m.sshUploadPackFunc(stream)
+}
+
+func startStreamServer(t *testing.T, handler func(gitalypb.SSHService_SSHUploadPackServer) error) string {
+ t.Helper()
+
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(
+ newLogger(), backchannel.NewRegistry(), nil,
+ ))
+
+ srv := grpc.NewServer(grpc.Creds(lm))
+ gitalypb.RegisterSSHServiceServer(srv, &mockSSHService{
+ sshUploadPackFunc: handler,
+ })
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+
+ t.Cleanup(srv.Stop)
+ go srv.Serve(ln)
+ return ln.Addr().String()
+}
diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go
index 9a5145f1c..3148cb05d 100644
--- a/internal/sidechannel/registry.go
+++ b/internal/sidechannel/registry.go
@@ -105,6 +105,9 @@ func (s *Registry) waiting() int {
return len(s.waiters)
}
+// ErrCallbackDidNotRun indicates that a sidechannel callback was
+// de-registered without having run. This can happen if the server chose
+// not to use the sidechannel.
var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run")
func (w *Waiter) run() error {
diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go
index 3ad925e61..bcbd0ce7c 100644
--- a/internal/sidechannel/sidechannel_test.go
+++ b/internal/sidechannel/sidechannel_test.go
@@ -9,7 +9,6 @@ import (
"sync"
"testing"
- "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
@@ -145,10 +144,8 @@ func TestSidechannelConcurrency(t *testing.T) {
func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string {
t.Helper()
- logger := logrus.NewEntry(logrus.New())
-
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
+ lm.Register(backchannel.NewServerHandshaker(newLogger(), backchannel.NewRegistry(), nil))
opts = append(opts, grpc.Creds(lm))
@@ -169,15 +166,13 @@ func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string
func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) {
registry := NewRegistry()
- logger := logrus.NewEntry(logrus.New())
-
factory := func() backchannel.Server {
lm := listenmux.New(insecure.NewCredentials())
lm.Register(NewServerHandshaker(registry))
return grpc.NewServer(grpc.Creds(lm))
}
- clientHandshaker := backchannel.NewClientHandshaker(logger, factory)
+ clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory)
dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials()))
conn, err := grpc.Dial(addr, dialOpt)