Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/dial.go24
-rw-r--r--client/dial_test.go214
2 files changed, 232 insertions, 6 deletions
diff --git a/client/dial.go b/client/dial.go
index 4fce2ac5b..d60138268 100644
--- a/client/dial.go
+++ b/client/dial.go
@@ -8,6 +8,8 @@ import (
"time"
gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
@@ -74,12 +76,22 @@ func DialContext(ctx context.Context, rawAddress string, connOpts []grpc.DialOpt
)
}
- // grpc.KeepaliveParams must be specified at least as large as what is allowed by the
- // server-side grpc.KeepaliveEnforcementPolicy
- connOpts = append(connOpts, grpc.WithKeepaliveParams(keepalive.ClientParameters{
- Time: 20 * time.Second,
- PermitWithoutStream: true,
- }))
+ connOpts = append(connOpts,
+ // grpc.KeepaliveParams must be specified at least as large as what is allowed by the
+ // server-side grpc.KeepaliveEnforcementPolicy
+ grpc.WithKeepaliveParams(keepalive.ClientParameters{
+ Time: 20 * time.Second,
+ PermitWithoutStream: true,
+ }),
+ grpc.WithChainUnaryInterceptor(
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpccorrelation.UnaryClientCorrelationInterceptor(),
+ ),
+ grpc.WithChainStreamInterceptor(
+ grpctracing.StreamClientTracingInterceptor(),
+ grpccorrelation.StreamClientCorrelationInterceptor(),
+ ),
+ )
conn, err := grpc.DialContext(ctx, canonicalAddress, connOpts...)
if err != nil {
diff --git a/client/dial_test.go b/client/dial_test.go
index 05015df62..4df297918 100644
--- a/client/dial_test.go
+++ b/client/dial_test.go
@@ -9,9 +9,16 @@ import (
"strings"
"testing"
+ "github.com/opentracing/opentracing-go"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/uber/jaeger-client-go"
+ proxytestdata "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -126,6 +133,213 @@ func TestDial(t *testing.T) {
}
}
+type testSvc struct {
+ proxytestdata.TestServiceServer
+ PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error)
+ PingStreamMethod func(stream proxytestdata.TestService_PingStreamServer) error
+}
+
+func (ts *testSvc) Ping(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ if ts.PingMethod != nil {
+ return ts.PingMethod(ctx, r)
+ }
+
+ return &proxytestdata.PingResponse{}, nil
+}
+
+func (ts *testSvc) PingStream(stream proxytestdata.TestService_PingStreamServer) error {
+ if ts.PingStreamMethod != nil {
+ return ts.PingStreamMethod(stream)
+ }
+
+ return nil
+}
+
+func TestDial_Correlation(t *testing.T) {
+ t.Run("unary", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpccorrelation.UnaryServerCorrelationInterceptor()))
+ svc := &testSvc{
+ PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ cid := correlation.ExtractFromContext(ctx)
+ assert.Equal(t, "correlation-id-1", cid)
+ return &proxytestdata.PingResponse{}, nil
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+
+ defer grpcServer.Stop()
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1")
+ _, err = client.Ping(ctx, &proxytestdata.PingRequest{})
+ require.NoError(t, err)
+ })
+
+ t.Run("stream", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpccorrelation.StreamServerCorrelationInterceptor()))
+ svc := &testSvc{
+ PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error {
+ cid := correlation.ExtractFromContext(stream.Context())
+ assert.Equal(t, "correlation-id-1", cid)
+ return stream.Send(&proxytestdata.PingResponse{})
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1")
+ stream, err := client.PingStream(ctx)
+ require.NoError(t, err)
+
+ require.NoError(t, stream.Send(&proxytestdata.PingRequest{}))
+ require.NoError(t, stream.CloseSend())
+
+ _, err = stream.Recv()
+ require.NoError(t, err)
+ })
+}
+
+func TestDial_Tracing(t *testing.T) {
+ t.Run("unary", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpctracing.UnaryServerTracingInterceptor()))
+ svc := &testSvc{
+ PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ span, _ := opentracing.StartSpanFromContext(ctx, "health")
+ defer span.Finish()
+ span.LogKV("was", "called")
+ return &proxytestdata.PingResponse{}, nil
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ reporter := jaeger.NewInMemoryReporter()
+ tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter)
+ defer closer.Close()
+
+ defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer())
+ opentracing.SetGlobalTracer(tracer)
+
+ span := tracer.StartSpan("unary-check")
+ span = span.SetBaggageItem("service", "stub")
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = opentracing.ContextWithSpan(ctx, span)
+ _, err = client.Ping(ctx, &proxytestdata.PingRequest{})
+ require.NoError(t, err)
+
+ span.Finish()
+
+ spans := reporter.GetSpans()
+ require.Len(t, spans, 3)
+ require.Equal(t, "stub", spans[1].BaggageItem("service"))
+ require.Equal(t, "stub", spans[2].BaggageItem("service"))
+ })
+
+ t.Run("stream", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpctracing.StreamServerTracingInterceptor()))
+ svc := &testSvc{
+ PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error {
+ span, _ := opentracing.StartSpanFromContext(stream.Context(), "health")
+ defer span.Finish()
+ span.LogKV("was", "called")
+ return stream.Send(&proxytestdata.PingResponse{})
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ reporter := jaeger.NewInMemoryReporter()
+ tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter)
+ defer closer.Close()
+
+ defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer())
+ opentracing.SetGlobalTracer(tracer)
+
+ span := tracer.StartSpan("stream-check")
+ span = span.SetBaggageItem("service", "stub")
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = opentracing.ContextWithSpan(ctx, span)
+ stream, err := client.PingStream(ctx)
+ require.NoError(t, err)
+
+ require.NoError(t, stream.Send(&proxytestdata.PingRequest{}))
+ require.NoError(t, stream.CloseSend())
+
+ _, err = stream.Recv()
+ require.NoError(t, err)
+
+ span.Finish()
+
+ spans := reporter.GetSpans()
+ require.Len(t, spans, 2)
+ require.Equal(t, "", spans[0].BaggageItem("service"))
+ require.Equal(t, "stub", spans[1].BaggageItem("service"))
+ })
+}
+
// healthServer provide a basic GRPC health service endpoint for testing purposes
type healthServer struct {
}