diff options
-rw-r--r-- | client/dial.go | 19 | ||||
-rw-r--r-- | client/dial_test.go | 104 | ||||
-rw-r--r-- | cmd/gitaly/main.go | 5 | ||||
-rw-r--r-- | internal/gitaly/service/conflicts/server.go | 5 | ||||
-rw-r--r-- | internal/gitaly/service/remote/server.go | 5 | ||||
-rw-r--r-- | internal/gitaly/service/repository/server.go | 9 | ||||
-rw-r--r-- | internal/gitaly/service/smarthttp/receive_pack_test.go | 3 |
7 files changed, 105 insertions, 45 deletions
diff --git a/client/dial.go b/client/dial.go index 60be802cf..341f1cf27 100644 --- a/client/dial.go +++ b/client/dial.go @@ -12,6 +12,7 @@ import ( grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" ) @@ -130,3 +131,21 @@ func FailOnNonTempDialError() []grpc.DialOption { grpc.FailOnNonTempDialError(true), } } + +// HealthCheckDialer uses provided dialer as an actual dialer, but issues a health check request to the remote +// to verify the connection was set properly and could be used with no issues. +func HealthCheckDialer(base Dialer) Dialer { + return func(ctx context.Context, address string, dialOptions []grpc.DialOption) (*grpc.ClientConn, error) { + cc, err := base(ctx, address, dialOptions) + if err != nil { + return nil, err + } + + if _, err := healthpb.NewHealthClient(cc).Check(ctx, &healthpb.HealthCheckRequest{}); err != nil { + cc.Close() + return nil, err + } + + return cc, nil + } +} diff --git a/client/dial_test.go b/client/dial_test.go index fe496f9d6..9951c7da2 100644 --- a/client/dial_test.go +++ b/client/dial_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber/jaeger-client-go" + gitalyauth "gitlab.com/gitlab-org/gitaly/auth" proxytestdata "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata" "gitlab.com/gitlab-org/gitaly/internal/testhelper" gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509" @@ -48,64 +49,71 @@ func TestDial(t *testing.T) { require.NoError(t, os.Symlink(unixSocketAbsPath, unixSocketPath)) tests := []struct { - name string - rawAddress string - envSSLCertFile string - dialOpts []grpc.DialOption - expectFailure bool + name string + rawAddress string + envSSLCertFile string + dialOpts []grpc.DialOption + expectDialFailure bool + expectHealthFailure bool }{ { - name: "tcp localhost with prefix", - rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234" - expectFailure: false, + name: "tcp localhost with prefix", + rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234" + expectDialFailure: false, + expectHealthFailure: false, }, { - name: "tls localhost", - rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234" - envSSLCertFile: "./testdata/gitalycert.pem", - expectFailure: false, + name: "tls localhost", + rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234" + envSSLCertFile: "./testdata/gitalycert.pem", + expectDialFailure: false, + expectHealthFailure: false, }, { - name: "unix absolute", - rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket" - expectFailure: false, + name: "unix absolute", + rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket" + expectDialFailure: false, + expectHealthFailure: false, }, { - name: "unix relative", - rawAddress: "unix:" + unixSocketPath, // "unix:../../tmp/temp-socket" - expectFailure: false, + name: "unix relative", + rawAddress: "unix:" + unixSocketPath, // "unix:../../tmp/temp-socket" + expectDialFailure: false, + expectHealthFailure: false, }, { - name: "unix absolute does not exist", - rawAddress: "unix:" + unixSocketAbsPath + ".does_not_exist", // "unix:/tmp/temp-socket.does_not_exist" - expectFailure: true, + name: "unix absolute does not exist", + rawAddress: "unix:" + unixSocketAbsPath + ".does_not_exist", // "unix:/tmp/temp-socket.does_not_exist" + expectDialFailure: false, + expectHealthFailure: true, }, { - name: "unix relative does not exist", - rawAddress: "unix:" + unixSocketPath + ".does_not_exist", // "unix:../../tmp/temp-socket.does_not_exist" - expectFailure: true, + name: "unix relative does not exist", + rawAddress: "unix:" + unixSocketPath + ".does_not_exist", // "unix:../../tmp/temp-socket.does_not_exist" + expectDialFailure: false, + expectHealthFailure: true, }, { // Gitaly does not support connections that do not have a scheme. - name: "tcp localhost no prefix", - rawAddress: "localhost:" + connectionMap["tcp"], // "localhost:1234" - expectFailure: true, + name: "tcp localhost no prefix", + rawAddress: "localhost:" + connectionMap["tcp"], // "localhost:1234" + expectDialFailure: true, }, { - name: "invalid", - rawAddress: ".", - expectFailure: true, + name: "invalid", + rawAddress: ".", + expectDialFailure: true, }, { - name: "empty", - rawAddress: "", - expectFailure: true, + name: "empty", + rawAddress: "", + expectDialFailure: true, }, { - name: "dial fail if there is no listener on address", - rawAddress: "tcp://invalid.address", - dialOpts: FailOnNonTempDialError(), - expectFailure: true, + name: "dial fail if there is no listener on address", + rawAddress: "tcp://invalid.address", + dialOpts: FailOnNonTempDialError(), + expectDialFailure: true, }, } @@ -123,15 +131,18 @@ func TestDial(t *testing.T) { defer cancel() conn, err := Dial(tt.rawAddress, tt.dialOpts) - if tt.expectFailure { + if tt.expectDialFailure { require.Error(t, err) return } - require.NoError(t, err) defer conn.Close() _, err = healthpb.NewHealthClient(conn).Check(ctx, &healthpb.HealthCheckRequest{}) + if tt.expectHealthFailure { + require.Error(t, err) + return + } require.NoError(t, err) }) } @@ -459,3 +470,18 @@ func emitProxyWarning() bool { } return false } + +func TestHealthCheckDialer(t *testing.T) { + _, addr, cleanup := runServer(t, "token") + defer cleanup() + + ctx, cancel := testhelper.Context() + defer cancel() + + _, err := HealthCheckDialer(DialContext)(ctx, addr, nil) + require.Equal(t, status.Error(codes.Unauthenticated, "authentication required"), err, "should fail without token configured") + + cc, err := HealthCheckDialer(DialContext)(ctx, addr, []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2("token"))}) + require.NoError(t, err) + cc.Close() +} diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go index bc7c27847..9d4cc45c0 100644 --- a/cmd/gitaly/main.go +++ b/cmd/gitaly/main.go @@ -128,7 +128,10 @@ func run(b *bootstrap.Bootstrap) error { hookManager := hook.NewManager(config.NewLocator(config.Config), gitlabAPI, config.Config) prometheus.MustRegister(hookManager) - conns := client.NewPoolWithOptions(client.WithDialOptions(client.FailOnNonTempDialError()...)) + conns := client.NewPoolWithOptions( + client.WithDialer(client.HealthCheckDialer(client.DialContext)), + client.WithDialOptions(client.FailOnNonTempDialError()...), + ) defer conns.Close() servers := server.NewGitalyServerFactory(hookManager, conns) diff --git a/internal/gitaly/service/conflicts/server.go b/internal/gitaly/service/conflicts/server.go index 2eb29c8b6..16f00ca24 100644 --- a/internal/gitaly/service/conflicts/server.go +++ b/internal/gitaly/service/conflicts/server.go @@ -21,6 +21,9 @@ func NewServer(rs *rubyserver.Server, cfg config.Cfg, locator storage.Locator) g ruby: rs, cfg: cfg, locator: locator, - pool: client.NewPoolWithOptions(client.WithDialOptions(client.FailOnNonTempDialError()...)), + pool: client.NewPoolWithOptions( + client.WithDialer(client.HealthCheckDialer(client.DialContext)), + client.WithDialOptions(client.FailOnNonTempDialError()...), + ), } } diff --git a/internal/gitaly/service/remote/server.go b/internal/gitaly/service/remote/server.go index 9420dda02..7da72b02d 100644 --- a/internal/gitaly/service/remote/server.go +++ b/internal/gitaly/service/remote/server.go @@ -19,6 +19,9 @@ func NewServer(rs *rubyserver.Server, locator storage.Locator) gitalypb.RemoteSe return &server{ ruby: rs, locator: locator, - conns: client.NewPoolWithOptions(client.WithDialOptions(client.FailOnNonTempDialError()...)), + conns: client.NewPoolWithOptions( + client.WithDialer(client.HealthCheckDialer(client.DialContext)), + client.WithDialOptions(client.FailOnNonTempDialError()...), + ), } } diff --git a/internal/gitaly/service/repository/server.go b/internal/gitaly/service/repository/server.go index 01d363361..4a4045c85 100644 --- a/internal/gitaly/service/repository/server.go +++ b/internal/gitaly/service/repository/server.go @@ -20,9 +20,12 @@ type server struct { // NewServer creates a new instance of a gRPC repo server func NewServer(cfg config.Cfg, rs *rubyserver.Server, locator storage.Locator) gitalypb.RepositoryServiceServer { return &server{ - ruby: rs, - locator: locator, - conns: client.NewPoolWithOptions(client.WithDialOptions(client.FailOnNonTempDialError()...)), + ruby: rs, + locator: locator, + conns: client.NewPoolWithOptions( + client.WithDialer(client.HealthCheckDialer(client.DialContext)), + client.WithDialOptions(client.FailOnNonTempDialError()...), + ), cfg: cfg, binDir: cfg.BinDir, loggingCfg: cfg.Logging, diff --git a/internal/gitaly/service/smarthttp/receive_pack_test.go b/internal/gitaly/service/smarthttp/receive_pack_test.go index 529a9cb3b..3a126f70b 100644 --- a/internal/gitaly/service/smarthttp/receive_pack_test.go +++ b/internal/gitaly/service/smarthttp/receive_pack_test.go @@ -28,6 +28,8 @@ import ( "gitlab.com/gitlab-org/gitaly/streamio" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" ) @@ -559,6 +561,7 @@ func TestPostReceiveWithReferenceTransactionHook(t *testing.T) { gitalypb.RegisterSmartHTTPServiceServer(gitalyServer, NewServer(locator)) gitalypb.RegisterHookServiceServer(gitalyServer, hook.NewServer(config.Config, gitalyhook.NewManager(locator, gitalyhook.GitlabAPIStub, config.Config))) gitalypb.RegisterRefTransactionServer(gitalyServer, refTransactionServer) + healthpb.RegisterHealthServer(gitalyServer, health.NewServer()) reflection.Register(gitalyServer) gitalySocketPath := testhelper.GetTemporaryGitalySocketFileName() |