diff options
author | John Cai <jcai@gitlab.com> | 2021-10-05 19:53:29 +0300 |
---|---|---|
committer | John Cai <jcai@gitlab.com> | 2021-10-05 19:53:29 +0300 |
commit | 28f93e63a76c9e049dcf332f3546de86aac49181 (patch) | |
tree | 7dca3f005c3f96be4ec452930cda73fef1f3d2f5 | |
parent | c2cf30c9b0f88f03632ea9b12d4dff19f7f4850a (diff) | |
parent | e64a6c218dd5720789c82dd9deff3bc0f4212416 (diff) |
Merge branch 'jv-sidechannel-client' into 'master'
client: add sidechannel support
Closes gitlab-com/gl-infra/scalability#1303
See merge request gitlab-org/gitaly!3900
-rw-r--r-- | client/dial.go | 10 | ||||
-rw-r--r-- | client/dial_test.go | 153 | ||||
-rw-r--r-- | client/sidechannel.go | 90 | ||||
-rw-r--r-- | internal/gitaly/service/smarthttp/upload_pack_test.go | 12 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel.go | 18 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel_test.go | 8 |
6 files changed, 252 insertions, 39 deletions
diff --git a/client/dial.go b/client/dial.go index 77020bf6b..352003d09 100644 --- a/client/dial.go +++ b/client/dial.go @@ -4,6 +4,7 @@ import ( "context" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client" + "gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel" "google.golang.org/grpc" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) @@ -30,6 +31,15 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro return DialContext(context.Background(), rawAddress, connOpts) } +// DialSidechannel configures the dialer to establish a Gitaly +// backchannel connection instead of a regular gRPC connection. It also +// injects sr as a sidechannel registry, so that Gitaly can establish +// sidechannels back to the client. +func DialSidechannel(ctx context.Context, rawAddress string, sr *SidechannelRegistry, connOpts []grpc.DialOption) (*grpc.ClientConn, error) { + clientHandshaker := sidechannel.NewClientHandshaker(sr.logger, sr.registry) + return client.Dial(ctx, rawAddress, connOpts, clientHandshaker) +} + // FailOnNonTempDialError helps to identify if remote listener is ready to accept new connections. func FailOnNonTempDialError() []grpc.DialOption { return []grpc.DialOption{ diff --git a/client/dial_test.go b/client/dial_test.go index 5e448c78d..ca62bcfc2 100644 --- a/client/dial_test.go +++ b/client/dial_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/opentracing/opentracing-go" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/jaeger-client-go" @@ -26,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" ) @@ -37,7 +39,11 @@ func TestDial(t *testing.T) { t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution") } - stop, connectionMap := startListeners(t) + stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server { + srv := grpc.NewServer(grpc.Creds(creds)) + healthpb.RegisterHealthServer(srv, &healthServer{}) + return srv + }) defer stop() unixSocketAbsPath := connectionMap["unix"] @@ -147,6 +153,110 @@ func TestDial(t *testing.T) { } } +func TestDialSidechannel(t *testing.T) { + if emitProxyWarning() { + t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution") + } + + stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server { + return grpc.NewServer(TestSidechannelServer(newLogger(t), creds, func( + _ interface{}, + stream grpc.ServerStream, + sidechannelConn io.ReadWriteCloser, + ) error { + if method, ok := grpc.Method(stream.Context()); !ok || method != "/grpc.health.v1.Health/Check" { + return fmt.Errorf("unexpected method: %s", method) + } + + var req healthpb.HealthCheckRequest + if err := stream.RecvMsg(&req); err != nil { + return fmt.Errorf("recv msg: %w", err) + } + + if _, err := io.Copy(sidechannelConn, sidechannelConn); err != nil { + return fmt.Errorf("copy: %w", err) + } + + if err := stream.SendMsg(&healthpb.HealthCheckResponse{}); err != nil { + return fmt.Errorf("send msg: %w", err) + } + + return nil + })...) + }) + defer stop() + + unixSocketAbsPath := connectionMap["unix"] + + tempDir := testhelper.TempDir(t) + + unixSocketPath := filepath.Join(tempDir, "gitaly.socket") + require.NoError(t, os.Symlink(unixSocketAbsPath, unixSocketPath)) + + registry := NewSidechannelRegistry(newLogger(t)) + + tests := []struct { + name string + rawAddress string + envSSLCertFile string + dialOpts []grpc.DialOption + }{ + { + name: "tcp sidechannel", + rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234" + }, + { + name: "tls sidechannel", + rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234" + envSSLCertFile: "./testdata/gitalycert.pem", + }, + { + name: "unix sidechannel", + rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket" + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envSSLCertFile != "" { + defer testhelper.ModifyEnvironment(t, gitalyx509.SSLCertFile, tt.envSSLCertFile)() + } + + ctx, cancel := testhelper.Context() + defer cancel() + + conn, err := DialSidechannel(ctx, tt.rawAddress, registry, tt.dialOpts) + require.NoError(t, err) + defer conn.Close() + + ctx, scw := registry.Register(ctx, func(conn SidechannelConn) error { + const message = "hello world" + if _, err := io.WriteString(conn, message); err != nil { + return err + } + if err := conn.CloseWrite(); err != nil { + return err + } + buf, err := io.ReadAll(conn) + if err != nil { + return err + } + if string(buf) != message { + return fmt.Errorf("expected %q, got %q", message, buf) + } + + return nil + }) + defer scw.Close() + + req := &healthpb.HealthCheckRequest{Service: "test sidechannel"} + _, err = healthpb.NewHealthClient(conn).Check(ctx, req) + require.NoError(t, err) + require.NoError(t, scw.Close()) + }) + } +} + type testSvc struct { proxytestdata.UnimplementedTestServiceServer PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) @@ -414,26 +524,23 @@ func TestDial_Tracing(t *testing.T) { } // healthServer provide a basic GRPC health service endpoint for testing purposes -type healthServer struct{} - -func (*healthServer) Check(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { - return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil +type healthServer struct { + healthpb.UnimplementedHealthServer } -func (*healthServer) Watch(*healthpb.HealthCheckRequest, healthpb.Health_WatchServer) error { - return status.Errorf(codes.Unimplemented, "Not implemented") +func (*healthServer) Check(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{}, nil } // startTCPListener will start a insecure TCP listener on a random unused port -func startTCPListener(t testing.TB) (func(), string) { +func startTCPListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) tcpPort := listener.Addr().(*net.TCPAddr).Port address := fmt.Sprintf("%d", tcpPort) - grpcServer := grpc.NewServer() - healthpb.RegisterHealthServer(grpcServer, &healthServer{}) + grpcServer := factory(insecure.NewCredentials()) go grpcServer.Serve(listener) return func() { @@ -442,14 +549,13 @@ func startTCPListener(t testing.TB) (func(), string) { } // startUnixListener will start a unix socket listener using a temporary file -func startUnixListener(t testing.TB) (func(), string) { +func startUnixListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) listener, err := net.Listen("unix", serverSocketPath) require.NoError(t, err) - grpcServer := grpc.NewServer() - healthpb.RegisterHealthServer(grpcServer, &healthServer{}) + grpcServer := factory(insecure.NewCredentials()) go grpcServer.Serve(listener) return func() { @@ -459,7 +565,7 @@ func startUnixListener(t testing.TB) (func(), string) { // startTLSListener will start a secure TLS listener on a random unused port //go:generate openssl req -newkey rsa:4096 -new -nodes -x509 -days 3650 -out testdata/gitalycert.pem -keyout testdata/gitalykey.pem -subj "/C=US/ST=California/L=San Francisco/O=GitLab/OU=GitLab-Shell/CN=localhost" -addext "subjectAltName = IP:127.0.0.1, DNS:localhost" -func startTLSListener(t testing.TB) (func(), string) { +func startTLSListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { listener, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) @@ -469,11 +575,12 @@ func startTLSListener(t testing.TB) (func(), string) { cert, err := tls.LoadX509KeyPair("testdata/gitalycert.pem", "testdata/gitalykey.pem") require.NoError(t, err) - grpcServer := grpc.NewServer(grpc.Creds(credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - }))) - healthpb.RegisterHealthServer(grpcServer, &healthServer{}) + grpcServer := factory( + credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }), + ) go grpcServer.Serve(listener) return func() { @@ -481,18 +588,18 @@ func startTLSListener(t testing.TB) (func(), string) { }, address } -var listeners = map[string]func(testing.TB) (func(), string){ +var listeners = map[string]func(testing.TB, func(credentials.TransportCredentials) *grpc.Server) (func(), string){ "tcp": startTCPListener, "unix": startUnixListener, "tls": startTLSListener, } // startListeners will start all the different listeners used in this test -func startListeners(t testing.TB) (func(), map[string]string) { +func startListeners(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), map[string]string) { var closers []func() connectionMap := map[string]string{} for k, v := range listeners { - closer, address := v(t) + closer, address := v(t, factory) closers = append(closers, closer) connectionMap[k] = address } @@ -532,3 +639,5 @@ func TestHealthCheckDialer(t *testing.T) { require.NoError(t, err) require.NoError(t, cc.Close()) } + +func newLogger(t testing.TB) *logrus.Entry { return logrus.NewEntry(testhelper.NewTestLogger(t)) } diff --git a/client/sidechannel.go b/client/sidechannel.go new file mode 100644 index 000000000..26729c138 --- /dev/null +++ b/client/sidechannel.go @@ -0,0 +1,90 @@ +package client + +import ( + "context" + "io" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// SidechannelRegistry associates sidechannel callbacks with outbound +// gRPC calls. +type SidechannelRegistry struct { + registry *sidechannel.Registry + logger *logrus.Entry +} + +// NewSidechannelRegistry returns a new registry. +func NewSidechannelRegistry(logger *logrus.Entry) *SidechannelRegistry { + return &SidechannelRegistry{ + registry: sidechannel.NewRegistry(), + logger: logger, + } +} + +// Register registers a callback. It adds metadata to ctx and returns the +// new context. The caller must use the new context for the gRPC call. +// Caller must Close() the returned SidechannelWaiter to prevent resource +// leaks. +func (sr *SidechannelRegistry) Register( + ctx context.Context, + callback func(SidechannelConn) error, +) (context.Context, *SidechannelWaiter) { + ctx, waiter := sidechannel.RegisterSidechannel( + ctx, + sr.registry, + func(cc *sidechannel.ClientConn) error { return callback(cc) }, + ) + return ctx, &SidechannelWaiter{waiter: waiter} +} + +// SidechannelWaiter represents a pending sidechannel and its callback. +type SidechannelWaiter struct{ waiter *sidechannel.Waiter } + +// Close de-registers the sidechannel callback. If the callback is still +// running, Close blocks until it is done and returns the error return +// value of the callback. If the callback has not been called yet, Close +// returns an error immediately. +func (w *SidechannelWaiter) Close() error { return w.waiter.Close() } + +// SidechannelConn allows a client to read and write bytes with less +// overhead than doing so via gRPC messages. +type SidechannelConn interface { + io.ReadWriter + + // CloseWrite tells the server we won't write any more data. We can still + // read data from the server after CloseWrite(). A typical use case is in + // an RPC where the byte stream has a request/response pattern: the + // client then uses CloseWrite() to signal the end of the request body. + // When the client calls CloseWrite(), the server receives EOF. + CloseWrite() error +} + +// TestSidechannelServer allows downstream consumers of this package to +// create mock sidechannel gRPC servers. +func TestSidechannelServer( + logger *logrus.Entry, + creds credentials.TransportCredentials, + handler func(interface{}, grpc.ServerStream, io.ReadWriteCloser) error, +) []grpc.ServerOption { + lm := listenmux.New(creds) + lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + + return []grpc.ServerOption{ + grpc.Creds(lm), + grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { + conn, err := sidechannel.OpenSidechannel(stream.Context()) + if err != nil { + return err + } + defer conn.Close() + + return handler(srv, stream, conn) + }), + } +} diff --git a/internal/gitaly/service/smarthttp/upload_pack_test.go b/internal/gitaly/service/smarthttp/upload_pack_test.go index 506b1e34f..cd4fbbd3a 100644 --- a/internal/gitaly/service/smarthttp/upload_pack_test.go +++ b/internal/gitaly/service/smarthttp/upload_pack_test.go @@ -11,16 +11,13 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" gitalyauth "gitlab.com/gitlab-org/gitaly/v14/auth" - "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" "gitlab.com/gitlab-org/gitaly/v14/internal/git" "gitlab.com/gitlab-org/gitaly/v14/internal/git/gittest" "gitlab.com/gitlab-org/gitaly/v14/internal/git/pktline" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" - "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/internal/sidechannel" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg" @@ -471,14 +468,9 @@ func makePostUploadPackRequest(ctx context.Context, t *testing.T, serverSocketPa } func dialSmartHTTPServerWithSidechannel(t *testing.T, serverSocketPath, token string, registry *sidechannel.Registry) *grpc.ClientConn { - logger := logrus.NewEntry(logrus.New()) + t.Helper() - factory := func() backchannel.Server { - lm := listenmux.New(insecure.NewCredentials()) - lm.Register(sidechannel.NewServerHandshaker(registry)) - return grpc.NewServer(grpc.Creds(lm)) - } - clientHandshaker := backchannel.NewClientHandshaker(logger, factory) + clientHandshaker := sidechannel.NewClientHandshaker(testhelper.DiscardTestEntry(t), registry) connOpts := []grpc.DialOption{ grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())), grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token)), diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go index 4886603e7..74c7d0ab8 100644 --- a/internal/sidechannel/sidechannel.go +++ b/internal/sidechannel/sidechannel.go @@ -9,8 +9,13 @@ import ( "strconv" "time" + "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" ) @@ -124,3 +129,16 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf func NewServerHandshaker(registry *Registry) *ServerHandshaker { return &ServerHandshaker{registry: registry} } + +// NewClientHandshaker is used to enable sidechannel support on outbound +// gRPC connections. +func NewClientHandshaker(logger *logrus.Entry, registry *Registry) client.Handshaker { + return backchannel.NewClientHandshaker( + logger, + func() backchannel.Server { + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(NewServerHandshaker(registry)) + return grpc.NewServer(grpc.Creds(lm)) + }, + ) +} diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index bcbd0ce7c..285765b51 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -166,13 +166,7 @@ func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) { registry := NewRegistry() - factory := func() backchannel.Server { - lm := listenmux.New(insecure.NewCredentials()) - lm.Register(NewServerHandshaker(registry)) - return grpc.NewServer(grpc.Creds(lm)) - } - - clientHandshaker := backchannel.NewClientHandshaker(newLogger(), factory) + clientHandshaker := NewClientHandshaker(newLogger(), registry) dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())) conn, err := grpc.Dial(addr, dialOpt) |