diff options
Diffstat (limited to 'v14/client/dial_test.go')
-rw-r--r-- | v14/client/dial_test.go | 630 |
1 files changed, 630 insertions, 0 deletions
diff --git a/v14/client/dial_test.go b/v14/client/dial_test.go new file mode 100644 index 000000000..6dd9ec839 --- /dev/null +++ b/v14/client/dial_test.go @@ -0,0 +1,630 @@ +package client + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/opentracing/opentracing-go" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/uber/jaeger-client-go" + gitalyauth "gitlab.com/gitlab-org/gitaly/v14/auth" + proxytestdata "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" + gitalyx509 "gitlab.com/gitlab-org/gitaly/internal/x509" + "gitlab.com/gitlab-org/labkit/correlation" + grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" + grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +var proxyEnvironmentKeys = []string{"http_proxy", "https_proxy", "no_proxy"} + +func TestDial(t *testing.T) { + if emitProxyWarning() { + t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution") + } + + stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server { + srv := grpc.NewServer(grpc.Creds(creds)) + healthpb.RegisterHealthServer(srv, &healthServer{}) + return srv + }) + defer stop() + + unixSocketAbsPath := connectionMap["unix"] + + tempDir := testhelper.TempDir(t) + + unixSocketPath := filepath.Join(tempDir, "gitaly.socket") + require.NoError(t, os.Symlink(unixSocketAbsPath, unixSocketPath)) + + tests := []struct { + name string + rawAddress string + envSSLCertFile string + dialOpts []grpc.DialOption + expectDialFailure bool + expectHealthFailure bool + }{ + { + name: "tcp localhost with prefix", + rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234" + expectDialFailure: false, + expectHealthFailure: false, + }, + { + name: "tls localhost", + rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234" + envSSLCertFile: "./testdata/gitalycert.pem", + expectDialFailure: false, + expectHealthFailure: false, + }, + { + name: "unix absolute", + rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket" + expectDialFailure: false, + expectHealthFailure: false, + }, + { + name: "unix relative", + rawAddress: "unix:" + unixSocketPath, // "unix:../../tmp/temp-socket" + expectDialFailure: false, + expectHealthFailure: false, + }, + { + name: "unix absolute does not exist", + rawAddress: "unix:" + unixSocketAbsPath + ".does_not_exist", // "unix:/tmp/temp-socket.does_not_exist" + expectDialFailure: false, + expectHealthFailure: true, + }, + { + name: "unix relative does not exist", + rawAddress: "unix:" + unixSocketPath + ".does_not_exist", // "unix:../../tmp/temp-socket.does_not_exist" + expectDialFailure: false, + expectHealthFailure: true, + }, + { + // Gitaly does not support connections that do not have a scheme. + name: "tcp localhost no prefix", + rawAddress: "localhost:" + connectionMap["tcp"], // "localhost:1234" + expectDialFailure: true, + }, + { + name: "invalid", + rawAddress: ".", + expectDialFailure: true, + }, + { + name: "empty", + rawAddress: "", + expectDialFailure: true, + }, + { + name: "dial fail if there is no listener on address", + rawAddress: "tcp://invalid.address", + dialOpts: FailOnNonTempDialError(), + expectDialFailure: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if emitProxyWarning() { + t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution") + } + + if tt.envSSLCertFile != "" { + testhelper.ModifyEnvironment(t, gitalyx509.SSLCertFile, tt.envSSLCertFile) + } + ctx := testhelper.Context(t) + + conn, err := Dial(tt.rawAddress, tt.dialOpts) + if tt.expectDialFailure { + require.Error(t, err) + return + } + require.NoError(t, err) + defer conn.Close() + + _, err = healthpb.NewHealthClient(conn).Check(ctx, &healthpb.HealthCheckRequest{}) + if tt.expectHealthFailure { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestDialSidechannel(t *testing.T) { + if emitProxyWarning() { + t.Log("WARNING. Proxy configuration detected from environment settings. This test failure may be related to proxy configuration. Please process with caution") + } + + stop, connectionMap := startListeners(t, func(creds credentials.TransportCredentials) *grpc.Server { + return grpc.NewServer(TestSidechannelServer(newLogger(t), creds, func( + _ interface{}, + stream grpc.ServerStream, + sidechannelConn io.ReadWriteCloser, + ) error { + if method, ok := grpc.Method(stream.Context()); !ok || method != "/grpc.health.v1.Health/Check" { + return fmt.Errorf("unexpected method: %s", method) + } + + var req healthpb.HealthCheckRequest + if err := stream.RecvMsg(&req); err != nil { + return fmt.Errorf("recv msg: %w", err) + } + + if _, err := io.Copy(sidechannelConn, sidechannelConn); err != nil { + return fmt.Errorf("copy: %w", err) + } + + if err := stream.SendMsg(&healthpb.HealthCheckResponse{}); err != nil { + return fmt.Errorf("send msg: %w", err) + } + + return nil + })...) + }) + defer stop() + + unixSocketAbsPath := connectionMap["unix"] + + tempDir := testhelper.TempDir(t) + + unixSocketPath := filepath.Join(tempDir, "gitaly.socket") + require.NoError(t, os.Symlink(unixSocketAbsPath, unixSocketPath)) + + registry := NewSidechannelRegistry(newLogger(t)) + + tests := []struct { + name string + rawAddress string + envSSLCertFile string + dialOpts []grpc.DialOption + }{ + { + name: "tcp sidechannel", + rawAddress: "tcp://localhost:" + connectionMap["tcp"], // "tcp://localhost:1234" + }, + { + name: "tls sidechannel", + rawAddress: "tls://localhost:" + connectionMap["tls"], // "tls://localhost:1234" + envSSLCertFile: "./testdata/gitalycert.pem", + }, + { + name: "unix sidechannel", + rawAddress: "unix:" + unixSocketAbsPath, // "unix:/tmp/temp-socket" + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envSSLCertFile != "" { + testhelper.ModifyEnvironment(t, gitalyx509.SSLCertFile, tt.envSSLCertFile) + } + ctx := testhelper.Context(t) + + conn, err := DialSidechannel(ctx, tt.rawAddress, registry, tt.dialOpts) + require.NoError(t, err) + defer conn.Close() + + ctx, scw := registry.Register(ctx, func(conn SidechannelConn) error { + const message = "hello world" + if _, err := io.WriteString(conn, message); err != nil { + return err + } + if err := conn.CloseWrite(); err != nil { + return err + } + buf, err := io.ReadAll(conn) + if err != nil { + return err + } + if string(buf) != message { + return fmt.Errorf("expected %q, got %q", message, buf) + } + + return nil + }) + defer scw.Close() + + req := &healthpb.HealthCheckRequest{Service: "test sidechannel"} + _, err = healthpb.NewHealthClient(conn).Check(ctx, req) + require.NoError(t, err) + require.NoError(t, scw.Close()) + }) + } +} + +type testSvc struct { + proxytestdata.UnimplementedTestServiceServer + PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) + PingStreamMethod func(stream proxytestdata.TestService_PingStreamServer) error +} + +func (ts *testSvc) Ping(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + if ts.PingMethod != nil { + return ts.PingMethod(ctx, r) + } + + return &proxytestdata.PingResponse{}, nil +} + +func (ts *testSvc) PingStream(stream proxytestdata.TestService_PingStreamServer) error { + if ts.PingStreamMethod != nil { + return ts.PingStreamMethod(stream) + } + + return nil +} + +func TestDial_Correlation(t *testing.T) { + t.Run("unary", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpccorrelation.UnaryServerCorrelationInterceptor())) + svc := &testSvc{ + PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + cid := correlation.ExtractFromContext(ctx) + assert.Equal(t, "correlation-id-1", cid) + return &proxytestdata.PingResponse{}, nil + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + + defer grpcServer.Stop() + ctx := testhelper.Context(t) + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1") + _, err = client.Ping(ctx, &proxytestdata.PingRequest{}) + require.NoError(t, err) + }) + + t.Run("stream", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpccorrelation.StreamServerCorrelationInterceptor())) + svc := &testSvc{ + PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error { + cid := correlation.ExtractFromContext(stream.Context()) + assert.Equal(t, "correlation-id-1", cid) + _, err := stream.Recv() + assert.NoError(t, err) + return stream.Send(&proxytestdata.PingResponse{}) + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + defer grpcServer.Stop() + ctx := testhelper.Context(t) + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1") + stream, err := client.PingStream(ctx) + require.NoError(t, err) + + require.NoError(t, stream.Send(&proxytestdata.PingRequest{})) + require.NoError(t, stream.CloseSend()) + + _, err = stream.Recv() + require.NoError(t, err) + }) +} + +func TestDial_Tracing(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + clientSendClosed := make(chan struct{}) + + // This is our test service. All it does is to create additional spans + // which should in the end be visible when collecting all registered + // spans. + grpcServer := grpc.NewServer( + grpc.UnaryInterceptor(grpctracing.UnaryServerTracingInterceptor()), + grpc.StreamInterceptor(grpctracing.StreamServerTracingInterceptor()), + ) + svc := &testSvc{ + PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + span, _ := opentracing.StartSpanFromContext(ctx, "nested-span") + defer span.Finish() + span.LogKV("was", "called") + return &proxytestdata.PingResponse{}, nil + }, + PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error { + // synchronize the client has returned from CloseSend as the client span finishing + // races with sending the stream close to the server + select { + case <-clientSendClosed: + case <-stream.Context().Done(): + return stream.Context().Err() + } + + span, _ := opentracing.StartSpanFromContext(stream.Context(), "nested-span") + defer span.Finish() + span.LogKV("was", "called") + return nil + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { require.NoError(t, grpcServer.Serve(listener)) }() + defer grpcServer.Stop() + ctx := testhelper.Context(t) + + t.Run("unary", func(t *testing.T) { + reporter := jaeger.NewInMemoryReporter() + tracer, tracerCloser := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter) + defer tracerCloser.Close() + + defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer()) + opentracing.SetGlobalTracer(tracer) + + // This needs to be run after setting up the global tracer as it will cause us to + // create the span when executing the RPC call further down below. + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + // We set up a "main" span here, which is going to be what the + // other spans inherit from. In order to check whether baggage + // works correctly, we also set up a "stub" baggage item which + // should be inherited to child contexts. + span := tracer.StartSpan("unary-check") + span = span.SetBaggageItem("service", "stub") + ctx := opentracing.ContextWithSpan(ctx, span) + + // We're now invoking the unary RPC with the span injected into + // the context. This should create a span that's nested into + // the "stream-check" span. + _, err = proxytestdata.NewTestServiceClient(cc).Ping(ctx, &proxytestdata.PingRequest{}) + require.NoError(t, err) + + span.Finish() + + spans := reporter.GetSpans() + require.Len(t, spans, 3) + + for i, expectedSpan := range []struct { + baggage string + operation string + }{ + // This is the first span we expect, which is the + // "health" span which we've manually created inside of + // PingMethod. + {baggage: "", operation: "nested-span"}, + // This span is the RPC call to TestService/Ping. It + // inherits the "unary-check" we set up and thus has + // baggage. + {baggage: "stub", operation: "/mwitkow.testproto.TestService/Ping"}, + // And this finally is the outermost span which we + // manually set up before the RPC call. + {baggage: "stub", operation: "unary-check"}, + } { + assert.IsType(t, spans[i], &jaeger.Span{}) + span := spans[i].(*jaeger.Span) + + assert.Equal(t, expectedSpan.baggage, span.BaggageItem("service"), "wrong baggage item for span %d", i) + assert.Equal(t, expectedSpan.operation, span.OperationName(), "wrong operation name for span %d", i) + } + }) + + t.Run("stream", func(t *testing.T) { + reporter := jaeger.NewInMemoryReporter() + tracer, tracerCloser := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter) + defer tracerCloser.Close() + + defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer()) + opentracing.SetGlobalTracer(tracer) + + // This needs to be run after setting up the global tracer as it will cause us to + // create the span when executing the RPC call further down below. + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + // We set up a "main" span here, which is going to be what the other spans inherit + // from. In order to check whether baggage works correctly, we also set up a "stub" + // baggage item which should be inherited to child contexts. + span := tracer.StartSpan("stream-check") + span = span.SetBaggageItem("service", "stub") + ctx := opentracing.ContextWithSpan(ctx, span) + + // We're now invoking the streaming RPC with the span injected into the context. + // This should create a span that's nested into the "stream-check" span. + stream, err := proxytestdata.NewTestServiceClient(cc).PingStream(ctx) + require.NoError(t, err) + require.NoError(t, stream.CloseSend()) + close(clientSendClosed) + + // wait for the server to finish its spans and close the stream + resp, err := stream.Recv() + require.Equal(t, err, io.EOF) + require.Nil(t, resp) + + span.Finish() + + spans := reporter.GetSpans() + require.Len(t, spans, 3) + + for i, expectedSpan := range []struct { + baggage string + operation string + }{ + // This span is the RPC call to TestService/Ping. + {baggage: "stub", operation: "/mwitkow.testproto.TestService/PingStream"}, + // This is the second span we expect, which is the "nested-span" span which + // we've manually created inside of PingMethod. This is different than for + // unary RPCs: given that one can send multiple messages to the RPC, we may + // see multiple such "nested-span"s being created. And the PingStream span + // will only be finalized last. + {baggage: "", operation: "nested-span"}, + // And this finally is the outermost span which we + // manually set up before the RPC call. + {baggage: "stub", operation: "stream-check"}, + } { + if !assert.IsType(t, spans[i], &jaeger.Span{}) { + continue + } + + span := spans[i].(*jaeger.Span) + assert.Equal(t, expectedSpan.baggage, span.BaggageItem("service"), "wrong baggage item for span %d", i) + assert.Equal(t, expectedSpan.operation, span.OperationName(), "wrong operation name for span %d", i) + } + }) +} + +// healthServer provide a basic GRPC health service endpoint for testing purposes +type healthServer struct { + healthpb.UnimplementedHealthServer +} + +func (*healthServer) Check(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{}, nil +} + +// startTCPListener will start a insecure TCP listener on a random unused port +func startTCPListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + tcpPort := listener.Addr().(*net.TCPAddr).Port + address := fmt.Sprintf("%d", tcpPort) + + grpcServer := factory(insecure.NewCredentials()) + go grpcServer.Serve(listener) + + return func() { + grpcServer.Stop() + }, address +} + +// startUnixListener will start a unix socket listener using a temporary file +func startUnixListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := factory(insecure.NewCredentials()) + go grpcServer.Serve(listener) + + return func() { + grpcServer.Stop() + }, serverSocketPath +} + +// startTLSListener will start a secure TLS listener on a random unused port +//go:generate openssl req -newkey rsa:4096 -new -nodes -x509 -days 3650 -out testdata/gitalycert.pem -keyout testdata/gitalykey.pem -subj "/C=US/ST=California/L=San Francisco/O=GitLab/OU=GitLab-Shell/CN=localhost" -addext "subjectAltName = IP:127.0.0.1, DNS:localhost" +func startTLSListener(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), string) { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + tcpPort := listener.Addr().(*net.TCPAddr).Port + address := fmt.Sprintf("%d", tcpPort) + + cert, err := tls.LoadX509KeyPair("testdata/gitalycert.pem", "testdata/gitalykey.pem") + require.NoError(t, err) + + grpcServer := factory( + credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }), + ) + go grpcServer.Serve(listener) + + return func() { + grpcServer.Stop() + }, address +} + +var listeners = map[string]func(testing.TB, func(credentials.TransportCredentials) *grpc.Server) (func(), string){ + "tcp": startTCPListener, + "unix": startUnixListener, + "tls": startTLSListener, +} + +// startListeners will start all the different listeners used in this test +func startListeners(t testing.TB, factory func(credentials.TransportCredentials) *grpc.Server) (func(), map[string]string) { + var closers []func() + connectionMap := map[string]string{} + for k, v := range listeners { + closer, address := v(t, factory) + closers = append(closers, closer) + connectionMap[k] = address + } + + return func() { + for _, v := range closers { + v() + } + }, connectionMap +} + +func emitProxyWarning() bool { + for _, key := range proxyEnvironmentKeys { + value := os.Getenv(key) + if value != "" { + return true + } + value = os.Getenv(strings.ToUpper(key)) + if value != "" { + return true + } + } + return false +} + +func TestHealthCheckDialer(t *testing.T) { + _, addr, cleanup := runServer(t, "token") + defer cleanup() + ctx := testhelper.Context(t) + + _, err := HealthCheckDialer(DialContext)(ctx, addr, nil) + testhelper.RequireGrpcError(t, status.Error(codes.Unauthenticated, "authentication required"), err) + + cc, err := HealthCheckDialer(DialContext)(ctx, addr, []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2("token"))}) + require.NoError(t, err) + require.NoError(t, cc.Close()) +} + +func newLogger(t testing.TB) *logrus.Entry { return logrus.NewEntry(testhelper.NewDiscardingLogger(t)) } |