diff options
Diffstat (limited to 'client')
-rw-r--r-- | client/dial.go | 24 | ||||
-rw-r--r-- | client/dial_test.go | 214 |
2 files changed, 232 insertions, 6 deletions
diff --git a/client/dial.go b/client/dial.go index 4fce2ac5b..d60138268 100644 --- a/client/dial.go +++ b/client/dial.go @@ -8,6 +8,8 @@ import ( "time" gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509" + grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" + grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" @@ -74,12 +76,22 @@ func DialContext(ctx context.Context, rawAddress string, connOpts []grpc.DialOpt ) } - // grpc.KeepaliveParams must be specified at least as large as what is allowed by the - // server-side grpc.KeepaliveEnforcementPolicy - connOpts = append(connOpts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 20 * time.Second, - PermitWithoutStream: true, - })) + connOpts = append(connOpts, + // grpc.KeepaliveParams must be specified at least as large as what is allowed by the + // server-side grpc.KeepaliveEnforcementPolicy + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 20 * time.Second, + PermitWithoutStream: true, + }), + grpc.WithChainUnaryInterceptor( + grpctracing.UnaryClientTracingInterceptor(), + grpccorrelation.UnaryClientCorrelationInterceptor(), + ), + grpc.WithChainStreamInterceptor( + grpctracing.StreamClientTracingInterceptor(), + grpccorrelation.StreamClientCorrelationInterceptor(), + ), + ) conn, err := grpc.DialContext(ctx, canonicalAddress, connOpts...) if err != nil { diff --git a/client/dial_test.go b/client/dial_test.go index 05015df62..4df297918 100644 --- a/client/dial_test.go +++ b/client/dial_test.go @@ -9,9 +9,16 @@ import ( "strings" "testing" + "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/uber/jaeger-client-go" + proxytestdata "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata" "gitlab.com/gitlab-org/gitaly/internal/testhelper" gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509" + "gitlab.com/gitlab-org/labkit/correlation" + grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" + grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -126,6 +133,213 @@ func TestDial(t *testing.T) { } } +type testSvc struct { + proxytestdata.TestServiceServer + PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) + PingStreamMethod func(stream proxytestdata.TestService_PingStreamServer) error +} + +func (ts *testSvc) Ping(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + if ts.PingMethod != nil { + return ts.PingMethod(ctx, r) + } + + return &proxytestdata.PingResponse{}, nil +} + +func (ts *testSvc) PingStream(stream proxytestdata.TestService_PingStreamServer) error { + if ts.PingStreamMethod != nil { + return ts.PingStreamMethod(stream) + } + + return nil +} + +func TestDial_Correlation(t *testing.T) { + t.Run("unary", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName() + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpccorrelation.UnaryServerCorrelationInterceptor())) + svc := &testSvc{ + PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + cid := correlation.ExtractFromContext(ctx) + assert.Equal(t, "correlation-id-1", cid) + return &proxytestdata.PingResponse{}, nil + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + + defer grpcServer.Stop() + + ctx, cancel := testhelper.Context() + defer cancel() + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1") + _, err = client.Ping(ctx, &proxytestdata.PingRequest{}) + require.NoError(t, err) + }) + + t.Run("stream", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName() + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpccorrelation.StreamServerCorrelationInterceptor())) + svc := &testSvc{ + PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error { + cid := correlation.ExtractFromContext(stream.Context()) + assert.Equal(t, "correlation-id-1", cid) + return stream.Send(&proxytestdata.PingResponse{}) + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + defer grpcServer.Stop() + + ctx, cancel := testhelper.Context() + defer cancel() + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1") + stream, err := client.PingStream(ctx) + require.NoError(t, err) + + require.NoError(t, stream.Send(&proxytestdata.PingRequest{})) + require.NoError(t, stream.CloseSend()) + + _, err = stream.Recv() + require.NoError(t, err) + }) +} + +func TestDial_Tracing(t *testing.T) { + t.Run("unary", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName() + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpctracing.UnaryServerTracingInterceptor())) + svc := &testSvc{ + PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) { + span, _ := opentracing.StartSpanFromContext(ctx, "health") + defer span.Finish() + span.LogKV("was", "called") + return &proxytestdata.PingResponse{}, nil + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + defer grpcServer.Stop() + + reporter := jaeger.NewInMemoryReporter() + tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter) + defer closer.Close() + + defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer()) + opentracing.SetGlobalTracer(tracer) + + span := tracer.StartSpan("unary-check") + span = span.SetBaggageItem("service", "stub") + + ctx, cancel := testhelper.Context() + defer cancel() + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = opentracing.ContextWithSpan(ctx, span) + _, err = client.Ping(ctx, &proxytestdata.PingRequest{}) + require.NoError(t, err) + + span.Finish() + + spans := reporter.GetSpans() + require.Len(t, spans, 3) + require.Equal(t, "stub", spans[1].BaggageItem("service")) + require.Equal(t, "stub", spans[2].BaggageItem("service")) + }) + + t.Run("stream", func(t *testing.T) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName() + + listener, err := net.Listen("unix", serverSocketPath) + require.NoError(t, err) + + grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpctracing.StreamServerTracingInterceptor())) + svc := &testSvc{ + PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error { + span, _ := opentracing.StartSpanFromContext(stream.Context(), "health") + defer span.Finish() + span.LogKV("was", "called") + return stream.Send(&proxytestdata.PingResponse{}) + }, + } + proxytestdata.RegisterTestServiceServer(grpcServer, svc) + + go func() { assert.NoError(t, grpcServer.Serve(listener)) }() + defer grpcServer.Stop() + + reporter := jaeger.NewInMemoryReporter() + tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter) + defer closer.Close() + + defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer()) + opentracing.SetGlobalTracer(tracer) + + span := tracer.StartSpan("stream-check") + span = span.SetBaggageItem("service", "stub") + + ctx, cancel := testhelper.Context() + defer cancel() + + cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil) + require.NoError(t, err) + defer cc.Close() + + client := proxytestdata.NewTestServiceClient(cc) + + ctx = opentracing.ContextWithSpan(ctx, span) + stream, err := client.PingStream(ctx) + require.NoError(t, err) + + require.NoError(t, stream.Send(&proxytestdata.PingRequest{})) + require.NoError(t, stream.CloseSend()) + + _, err = stream.Recv() + require.NoError(t, err) + + span.Finish() + + spans := reporter.GetSpans() + require.Len(t, spans, 2) + require.Equal(t, "", spans[0].BaggageItem("service")) + require.Equal(t, "stub", spans[1].BaggageItem("service")) + }) +} + // healthServer provide a basic GRPC health service endpoint for testing purposes type healthServer struct { } |