diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2021-09-15 16:35:34 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2021-09-30 18:27:04 +0300 |
commit | a12b72cf06b17416b36102f4b7b096df1405b896 (patch) | |
tree | f731823e851c269fb9c93a1028aa4689432f54fc /internal | |
parent | 098ee38438822b534723e76bad0bc3a5ba8baadb (diff) |
sidechannel: add proxy middleware
This adds gRPC client interceptors that can proxy sidechannels. This
is meant to be used in Praefect.
Changelog: other
Diffstat (limited to 'internal')
-rw-r--r-- | internal/sidechannel/proxy.go | 126 | ||||
-rw-r--r-- | internal/sidechannel/proxy_test.go | 193 | ||||
-rw-r--r-- | internal/sidechannel/registry.go | 3 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel_test.go | 9 |
4 files changed, 324 insertions, 7 deletions
diff --git a/internal/sidechannel/proxy.go b/internal/sidechannel/proxy.go new file mode 100644 index 000000000..2302d44eb --- /dev/null +++ b/internal/sidechannel/proxy.go @@ -0,0 +1,126 @@ +package sidechannel + +import ( + "context" + "fmt" + "io" + + "gitlab.com/gitlab-org/gitaly/v14/internal/metadata" + "google.golang.org/grpc" + grpcMetadata "google.golang.org/grpc/metadata" +) + +// NewUnaryProxy creates a gRPC client middleware that proxies sidechannels. +func NewUnaryProxy(registry *Registry) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if !hasSidechannelMetadata(ctx) { + return invoker(ctx, method, req, reply, cc, opts...) + } + + ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx)) + defer waiter.Close() + + if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { + return err + } + if err := waiter.Close(); err != nil && err != ErrCallbackDidNotRun { + return fmt.Errorf("sidechannel: proxy callback: %w", err) + } + return nil + } +} + +// NewStreamProxy creates a gRPC client middleware that proxies sidechannels. +func NewStreamProxy(registry *Registry) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !hasSidechannelMetadata(ctx) { + return streamer(ctx, desc, cc, method, opts...) + } + + ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx)) + go func() { + <-ctx.Done() + // The Close() error is checked and bubbled up in + // streamWrapper.RecvMsg(). This call is just for cleanup. + _ = waiter.Close() + }() + + cs, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + } + + return &streamWrapper{ClientStream: cs, waiter: waiter}, nil + } +} + +type streamWrapper struct { + grpc.ClientStream + waiter *Waiter +} + +func (sw *streamWrapper) RecvMsg(m interface{}) error { + if err := sw.ClientStream.RecvMsg(m); err != io.EOF { + return err + } + + if err := sw.waiter.Close(); err != nil && err != ErrCallbackDidNotRun { + return fmt.Errorf("sidechannel: proxy callback: %w", err) + } + + return io.EOF +} + +func hasSidechannelMetadata(ctx context.Context) bool { + md, ok := grpcMetadata.FromOutgoingContext(ctx) + return ok && len(md.Get(sidechannelMetadataKey)) > 0 +} + +func proxy(ctx context.Context) func(*ClientConn) error { + return func(upstream *ClientConn) error { + downstream, err := OpenSidechannel(metadata.OutgoingToIncoming(ctx)) + if err != nil { + return err + } + defer downstream.Close() + + const nStreams = 2 + errC := make(chan error, nStreams) + + go func() { + errC <- func() error { + if _, err := io.Copy(upstream, downstream); err != nil { + return err + } + + // Downstream.Read() has returned EOF. That means we are done proxying + // the request body from downstream to upstream. Propagate this EOF to + // upstream by calling CloseWrite(). Use CloseWrite(), not Close(), + // because we still want to read the response body from upstream in the + // other goroutine. + return upstream.CloseWrite() + }() + }() + + go func() { + errC <- func() error { + if _, err := io.Copy(downstream, upstream); err != nil { + return err + } + + // Upstream is now closed for both reads and writes. Propagate this state + // to downstream. This also happens via defer, but this way we can log + // the Close error if there is one. + return downstream.Close() + }() + }() + + for i := 0; i < nStreams; i++ { + if err := <-errC; err != nil { + return err + } + } + + return nil + } +} diff --git a/internal/sidechannel/proxy_test.go b/internal/sidechannel/proxy_test.go new file mode 100644 index 000000000..1501a02d3 --- /dev/null +++ b/internal/sidechannel/proxy_test.go @@ -0,0 +1,193 @@ +package sidechannel + +import ( + "context" + "fmt" + "io" + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/internal/metadata" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" + "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +func testProxyServer(ctx context.Context) (err error) { + conn, err := OpenSidechannel(ctx) + if err != nil { + return err + } + defer conn.Close() + + buf, err := io.ReadAll(conn) + if err != nil { + return fmt.Errorf("server read: %w", err) + } + if string(buf) != "hello" { + return fmt.Errorf("server: unexpected request: %q", buf) + } + + if _, err := io.WriteString(conn, "world"); err != nil { + return fmt.Errorf("server write: %w", err) + } + if err := conn.Close(); err != nil { + return fmt.Errorf("server close: %w", err) + } + + return nil +} + +func testProxyClient(conn *ClientConn) (err error) { + if _, err := io.WriteString(conn, "hello"); err != nil { + return fmt.Errorf("client write: %w", err) + } + if err := conn.CloseWrite(); err != nil { + return err + } + + buf, err := io.ReadAll(conn) + if err != nil { + return fmt.Errorf("client read: %w", err) + } + if string(buf) != "world" { + return fmt.Errorf("client: unexpected response: %q", buf) + } + + return nil +} + +func TestUnaryProxy(t *testing.T) { + upstreamAddr := startServer( + t, + func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + if err := testProxyServer(ctx); err != nil { + return nil, err + } + return &healthpb.HealthCheckResponse{}, nil + }, + ) + + proxyAddr := startServer( + t, + func(ctx context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + conn, err := dialProxy(upstreamAddr) + if err != nil { + return nil, err + } + defer conn.Close() + + ctxOut := metadata.IncomingToOutgoing(ctx) + return healthpb.NewHealthClient(conn).Check(ctxOut, request) + }, + ) + + ctx, cancel := testhelper.Context() + defer cancel() + + conn, registry := dial(t, proxyAddr) + require.NoError(t, call(ctx, conn, registry, testProxyClient)) +} + +func newLogger() *logrus.Entry { return logrus.NewEntry(logrus.New()) } + +func dialProxy(upstreamAddr string) (*grpc.ClientConn, error) { + registry := NewRegistry() + factory := func() backchannel.Server { + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(NewServerHandshaker(registry)) + return grpc.NewServer(grpc.Creds(lm)) + } + + clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory) + dialOpts := []grpc.DialOption{ + grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())), + grpc.WithUnaryInterceptor(NewUnaryProxy(registry)), + grpc.WithStreamInterceptor(NewStreamProxy(registry)), + } + + return grpc.Dial(upstreamAddr, dialOpts...) +} + +func TestStreamProxy(t *testing.T) { + upstreamAddr := startStreamServer( + t, + func(stream gitalypb.SSHService_SSHUploadPackServer) error { + return testProxyServer(stream.Context()) + }, + ) + + proxyAddr := startStreamServer( + t, + func(stream gitalypb.SSHService_SSHUploadPackServer) error { + conn, err := dialProxy(upstreamAddr) + if err != nil { + return err + } + defer conn.Close() + + ctxOut := metadata.IncomingToOutgoing(stream.Context()) + client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctxOut) + if err != nil { + return err + } + + if _, err := client.Recv(); err != io.EOF { + return fmt.Errorf("grpc proxy recv: %w", err) + } + + return nil + }, + ) + + ctx, cancel := testhelper.Context() + defer cancel() + + conn, registry := dial(t, proxyAddr) + ctx, waiter := RegisterSidechannel(ctx, registry, testProxyClient) + defer waiter.Close() + + client, err := gitalypb.NewSSHServiceClient(conn).SSHUploadPack(ctx) + require.NoError(t, err) + + _, err = client.Recv() + require.Equal(t, io.EOF, err) + + require.NoError(t, waiter.Close()) +} + +type mockSSHService struct { + sshUploadPackFunc func(gitalypb.SSHService_SSHUploadPackServer) error + gitalypb.UnimplementedSSHServiceServer +} + +func (m mockSSHService) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error { + return m.sshUploadPackFunc(stream) +} + +func startStreamServer(t *testing.T, handler func(gitalypb.SSHService_SSHUploadPackServer) error) string { + t.Helper() + + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker( + newLogger(), backchannel.NewRegistry(), nil, + )) + + srv := grpc.NewServer(grpc.Creds(lm)) + gitalypb.RegisterSSHServiceServer(srv, &mockSSHService{ + sshUploadPackFunc: handler, + }) + + ln, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + t.Cleanup(srv.Stop) + go srv.Serve(ln) + return ln.Addr().String() +} diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go index 9a5145f1c..3148cb05d 100644 --- a/internal/sidechannel/registry.go +++ b/internal/sidechannel/registry.go @@ -105,6 +105,9 @@ func (s *Registry) waiting() int { return len(s.waiters) } +// ErrCallbackDidNotRun indicates that a sidechannel callback was +// de-registered without having run. This can happen if the server chose +// not to use the sidechannel. var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run") func (w *Waiter) run() error { diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index 3ad925e61..bcbd0ce7c 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -9,7 +9,6 @@ import ( "sync" "testing" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" @@ -145,10 +144,8 @@ func TestSidechannelConcurrency(t *testing.T) { func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string { t.Helper() - logger := logrus.NewEntry(logrus.New()) - lm := listenmux.New(insecure.NewCredentials()) - lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + lm.Register(backchannel.NewServerHandshaker(newLogger(), backchannel.NewRegistry(), nil)) opts = append(opts, grpc.Creds(lm)) @@ -169,15 +166,13 @@ func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) { registry := NewRegistry() - logger := logrus.NewEntry(logrus.New()) - factory := func() backchannel.Server { lm := listenmux.New(insecure.NewCredentials()) lm.Register(NewServerHandshaker(registry)) return grpc.NewServer(grpc.Creds(lm)) } - clientHandshaker := backchannel.NewClientHandshaker(logger, factory) + clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory) dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())) conn, err := grpc.Dial(addr, dialOpt) |