Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavlo Strokov <pstrokov@gitlab.com>2020-09-08 13:07:22 +0300
committerZeger-Jan van de Weg <git@zjvandeweg.nl>2020-09-08 13:07:22 +0300
commitd87f0757c880a28b20338fd9533495688eeece95 (patch)
tree35f26f53712adee29f8c92c36a57e00c0b249bcc
parent785bebf7fb05c6c4457c75c011a01b1bede3f3a9 (diff)
Pass correlation_id over to gitaly-ssh
When gitaly-ssh is called it doesn't receive correlation_id. That is why it is hard to understand what to what chain of calls this call is related. correlation_id passed as env var 'CORRELATION_ID' to gitaly-ssh process. So it would be picked by it and used in outgoing requests. In order to cover other missing parts where correlation_id should be passed the client.DialContext method includes interceptors for it by default. If correlation_id is present in the context.Context used to invoke the method it will be passed to the remote. Part of: https://gitlab.com/gitlab-org/gitaly/-/issues/3047
-rw-r--r--changelogs/unreleased/ps-gitaly-ssh-correlation-id.yml5
-rw-r--r--client/dial.go24
-rw-r--r--client/dial_test.go214
-rw-r--r--cmd/gitaly-ssh/main.go16
-rw-r--r--go.mod2
-rw-r--r--internal/gitalyssh/gitalyssh.go25
-rw-r--r--internal/gitalyssh/gitalyssh_test.go23
-rw-r--r--internal/praefect/nodes/manager.go13
-rw-r--r--internal/testhelper/testserver.go16
9 files changed, 300 insertions, 38 deletions
diff --git a/changelogs/unreleased/ps-gitaly-ssh-correlation-id.yml b/changelogs/unreleased/ps-gitaly-ssh-correlation-id.yml
new file mode 100644
index 000000000..49f7d7744
--- /dev/null
+++ b/changelogs/unreleased/ps-gitaly-ssh-correlation-id.yml
@@ -0,0 +1,5 @@
+---
+title: Pass correlation_id over to gitaly-ssh
+merge_request: 2530
+author:
+type: fixed
diff --git a/client/dial.go b/client/dial.go
index 4fce2ac5b..d60138268 100644
--- a/client/dial.go
+++ b/client/dial.go
@@ -8,6 +8,8 @@ import (
"time"
gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
@@ -74,12 +76,22 @@ func DialContext(ctx context.Context, rawAddress string, connOpts []grpc.DialOpt
)
}
- // grpc.KeepaliveParams must be specified at least as large as what is allowed by the
- // server-side grpc.KeepaliveEnforcementPolicy
- connOpts = append(connOpts, grpc.WithKeepaliveParams(keepalive.ClientParameters{
- Time: 20 * time.Second,
- PermitWithoutStream: true,
- }))
+ connOpts = append(connOpts,
+ // grpc.KeepaliveParams must be specified at least as large as what is allowed by the
+ // server-side grpc.KeepaliveEnforcementPolicy
+ grpc.WithKeepaliveParams(keepalive.ClientParameters{
+ Time: 20 * time.Second,
+ PermitWithoutStream: true,
+ }),
+ grpc.WithChainUnaryInterceptor(
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpccorrelation.UnaryClientCorrelationInterceptor(),
+ ),
+ grpc.WithChainStreamInterceptor(
+ grpctracing.StreamClientTracingInterceptor(),
+ grpccorrelation.StreamClientCorrelationInterceptor(),
+ ),
+ )
conn, err := grpc.DialContext(ctx, canonicalAddress, connOpts...)
if err != nil {
diff --git a/client/dial_test.go b/client/dial_test.go
index 05015df62..4df297918 100644
--- a/client/dial_test.go
+++ b/client/dial_test.go
@@ -9,9 +9,16 @@ import (
"strings"
"testing"
+ "github.com/opentracing/opentracing-go"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/uber/jaeger-client-go"
+ proxytestdata "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -126,6 +133,213 @@ func TestDial(t *testing.T) {
}
}
+type testSvc struct {
+ proxytestdata.TestServiceServer
+ PingMethod func(context.Context, *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error)
+ PingStreamMethod func(stream proxytestdata.TestService_PingStreamServer) error
+}
+
+func (ts *testSvc) Ping(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ if ts.PingMethod != nil {
+ return ts.PingMethod(ctx, r)
+ }
+
+ return &proxytestdata.PingResponse{}, nil
+}
+
+func (ts *testSvc) PingStream(stream proxytestdata.TestService_PingStreamServer) error {
+ if ts.PingStreamMethod != nil {
+ return ts.PingStreamMethod(stream)
+ }
+
+ return nil
+}
+
+func TestDial_Correlation(t *testing.T) {
+ t.Run("unary", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpccorrelation.UnaryServerCorrelationInterceptor()))
+ svc := &testSvc{
+ PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ cid := correlation.ExtractFromContext(ctx)
+ assert.Equal(t, "correlation-id-1", cid)
+ return &proxytestdata.PingResponse{}, nil
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+
+ defer grpcServer.Stop()
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1")
+ _, err = client.Ping(ctx, &proxytestdata.PingRequest{})
+ require.NoError(t, err)
+ })
+
+ t.Run("stream", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpccorrelation.StreamServerCorrelationInterceptor()))
+ svc := &testSvc{
+ PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error {
+ cid := correlation.ExtractFromContext(stream.Context())
+ assert.Equal(t, "correlation-id-1", cid)
+ return stream.Send(&proxytestdata.PingResponse{})
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1")
+ stream, err := client.PingStream(ctx)
+ require.NoError(t, err)
+
+ require.NoError(t, stream.Send(&proxytestdata.PingRequest{}))
+ require.NoError(t, stream.CloseSend())
+
+ _, err = stream.Recv()
+ require.NoError(t, err)
+ })
+}
+
+func TestDial_Tracing(t *testing.T) {
+ t.Run("unary", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.UnaryInterceptor(grpctracing.UnaryServerTracingInterceptor()))
+ svc := &testSvc{
+ PingMethod: func(ctx context.Context, r *proxytestdata.PingRequest) (*proxytestdata.PingResponse, error) {
+ span, _ := opentracing.StartSpanFromContext(ctx, "health")
+ defer span.Finish()
+ span.LogKV("was", "called")
+ return &proxytestdata.PingResponse{}, nil
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ reporter := jaeger.NewInMemoryReporter()
+ tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter)
+ defer closer.Close()
+
+ defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer())
+ opentracing.SetGlobalTracer(tracer)
+
+ span := tracer.StartSpan("unary-check")
+ span = span.SetBaggageItem("service", "stub")
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = opentracing.ContextWithSpan(ctx, span)
+ _, err = client.Ping(ctx, &proxytestdata.PingRequest{})
+ require.NoError(t, err)
+
+ span.Finish()
+
+ spans := reporter.GetSpans()
+ require.Len(t, spans, 3)
+ require.Equal(t, "stub", spans[1].BaggageItem("service"))
+ require.Equal(t, "stub", spans[2].BaggageItem("service"))
+ })
+
+ t.Run("stream", func(t *testing.T) {
+ serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ listener, err := net.Listen("unix", serverSocketPath)
+ require.NoError(t, err)
+
+ grpcServer := grpc.NewServer(grpc.StreamInterceptor(grpctracing.StreamServerTracingInterceptor()))
+ svc := &testSvc{
+ PingStreamMethod: func(stream proxytestdata.TestService_PingStreamServer) error {
+ span, _ := opentracing.StartSpanFromContext(stream.Context(), "health")
+ defer span.Finish()
+ span.LogKV("was", "called")
+ return stream.Send(&proxytestdata.PingResponse{})
+ },
+ }
+ proxytestdata.RegisterTestServiceServer(grpcServer, svc)
+
+ go func() { assert.NoError(t, grpcServer.Serve(listener)) }()
+ defer grpcServer.Stop()
+
+ reporter := jaeger.NewInMemoryReporter()
+ tracer, closer := jaeger.NewTracer("", jaeger.NewConstSampler(true), reporter)
+ defer closer.Close()
+
+ defer func(old opentracing.Tracer) { opentracing.SetGlobalTracer(old) }(opentracing.GlobalTracer())
+ opentracing.SetGlobalTracer(tracer)
+
+ span := tracer.StartSpan("stream-check")
+ span = span.SetBaggageItem("service", "stub")
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ cc, err := DialContext(ctx, "unix://"+serverSocketPath, nil)
+ require.NoError(t, err)
+ defer cc.Close()
+
+ client := proxytestdata.NewTestServiceClient(cc)
+
+ ctx = opentracing.ContextWithSpan(ctx, span)
+ stream, err := client.PingStream(ctx)
+ require.NoError(t, err)
+
+ require.NoError(t, stream.Send(&proxytestdata.PingRequest{}))
+ require.NoError(t, stream.CloseSend())
+
+ _, err = stream.Recv()
+ require.NoError(t, err)
+
+ span.Finish()
+
+ spans := reporter.GetSpans()
+ require.Len(t, spans, 2)
+ require.Equal(t, "", spans[0].BaggageItem("service"))
+ require.Equal(t, "stub", spans[1].BaggageItem("service"))
+ })
+}
+
// healthServer provide a basic GRPC health service endpoint for testing purposes
type healthServer struct {
}
diff --git a/cmd/gitaly-ssh/main.go b/cmd/gitaly-ssh/main.go
index d611da5c7..1a99ae849 100644
--- a/cmd/gitaly-ssh/main.go
+++ b/cmd/gitaly-ssh/main.go
@@ -7,13 +7,10 @@ import (
"os"
"strings"
- grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
gitalyauth "gitlab.com/gitlab-org/gitaly/auth"
"gitlab.com/gitlab-org/gitaly/client"
"gitlab.com/gitlab-org/gitaly/internal/metadata/featureflag"
- grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
"gitlab.com/gitlab-org/labkit/tracing"
- grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
)
@@ -128,18 +125,5 @@ func dialOpts() []grpc.DialOption {
connOpts = append(connOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token)))
}
- // Add grpc client interceptors
- connOpts = append(connOpts, grpc.WithStreamInterceptor(
- grpc_middleware.ChainStreamClient(
- grpctracing.StreamClientTracingInterceptor(), // Tracing
- grpccorrelation.StreamClientCorrelationInterceptor(), // Correlation
- )),
-
- grpc.WithUnaryInterceptor(
- grpc_middleware.ChainUnaryClient(
- grpctracing.UnaryClientTracingInterceptor(), // Tracing
- grpccorrelation.UnaryClientCorrelationInterceptor(), // Correlation
- )))
-
return connOpts
}
diff --git a/go.mod b/go.mod
index 7c6a3a8f3..e7d51f7da 100644
--- a/go.mod
+++ b/go.mod
@@ -13,11 +13,13 @@ require (
github.com/lib/pq v1.2.0
github.com/libgit2/git2go/v30 v30.0.5
github.com/olekukonko/tablewriter v0.0.2
+ github.com/opentracing/opentracing-go v1.0.2
github.com/prometheus/client_golang v1.0.0
github.com/prometheus/procfs v0.0.3 // indirect
github.com/rubenv/sql-migrate v0.0.0-20191213152630-06338513c237
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
+ github.com/uber/jaeger-client-go v2.15.0+incompatible
gitlab.com/gitlab-org/gitlab-shell v0.0.0-20200821152636-82ec8144fb2a
gitlab.com/gitlab-org/labkit v0.0.0-20200507062444-0149780c759d
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
diff --git a/internal/gitalyssh/gitalyssh.go b/internal/gitalyssh/gitalyssh.go
index 0ffda1e69..0300100ac 100644
--- a/internal/gitalyssh/gitalyssh.go
+++ b/internal/gitalyssh/gitalyssh.go
@@ -3,9 +3,11 @@ package gitalyssh
import (
"context"
"fmt"
+ "math/rand"
"os"
"path/filepath"
"strings"
+ "time"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
@@ -14,12 +16,16 @@ import (
"gitlab.com/gitlab-org/gitaly/internal/metadata/featureflag"
gitaly_x509 "gitlab.com/gitlab-org/gitaly/internal/x509"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/tracing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
-var envInjector = tracing.NewEnvInjector()
+var (
+ envInjector = tracing.NewEnvInjector()
+ correlationIDRand = rand.New(rand.NewSource(time.Now().UnixNano()))
+)
func UploadPackEnv(ctx context.Context, req *gitalypb.SSHUploadPackRequest) ([]string, error) {
env, err := commandEnv(ctx, req.Repository.StorageName, "upload-pack", req)
@@ -61,6 +67,7 @@ func commandEnv(ctx context.Context, storageName, command string, message proto.
fmt.Sprintf("GITALY_ADDRESS=%s", address),
fmt.Sprintf("GITALY_TOKEN=%s", token),
fmt.Sprintf("GITALY_FEATUREFLAGS=%s", strings.Join(featureFlagPairs, ",")),
+ fmt.Sprintf("CORRELATION_ID=%s", getCorrelationID(ctx)),
// Pass through the SSL_CERT_* variables that indicate which
// system certs to trust
fmt.Sprintf("%s=%s", gitaly_x509.SSLCertDir, os.Getenv(gitaly_x509.SSLCertDir)),
@@ -71,3 +78,19 @@ func commandEnv(ctx context.Context, storageName, command string, message proto.
func gitalySSHPath() string {
return filepath.Join(config.Config.BinDir, "gitaly-ssh")
}
+
+func getCorrelationID(ctx context.Context) string {
+ correlationID := correlation.ExtractFromContext(ctx)
+ if correlationID != "" {
+ return correlationID
+ }
+
+ correlationID, _ = correlation.RandomID()
+ if correlationID == "" {
+ source := []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+ correlationIDRand.Shuffle(len(source), func(i, j int) { source[i], source[j] = source[j], source[i] })
+ return correlationID[:32]
+ }
+
+ return correlationID
+}
diff --git a/internal/gitalyssh/gitalyssh_test.go b/internal/gitalyssh/gitalyssh_test.go
index 5c7cb93c4..e4e926b9a 100644
--- a/internal/gitalyssh/gitalyssh_test.go
+++ b/internal/gitalyssh/gitalyssh_test.go
@@ -1,6 +1,7 @@
package gitalyssh
import (
+ "context"
"encoding/base64"
"fmt"
"path/filepath"
@@ -11,6 +12,7 @@ import (
"gitlab.com/gitlab-org/gitaly/internal/config"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/labkit/correlation"
"google.golang.org/grpc/metadata"
)
@@ -23,6 +25,7 @@ func TestUploadPackEnv(t *testing.T) {
md := metadata.Pairs("gitaly-servers", base64.StdEncoding.EncodeToString([]byte(`{"default":{"address":"unix:///tmp/sock","token":"hunter1"}}`)))
ctx = metadata.NewIncomingContext(ctx, md)
+ ctx = correlation.ContextWithCorrelation(ctx, "correlation-id-1")
req := gitalypb.SSHUploadPackRequest{
Repository: testRepo,
@@ -38,5 +41,25 @@ func TestUploadPackEnv(t *testing.T) {
require.Subset(t, env, []string{
fmt.Sprintf("GIT_SSH_COMMAND=%s upload-pack", filepath.Join(config.Config.BinDir, "gitaly-ssh")),
fmt.Sprintf("GITALY_PAYLOAD=%s", expectedPayload),
+ "CORRELATION_ID=correlation-id-1",
+ })
+}
+
+func TestGetCorrelationID(t *testing.T) {
+ t.Run("not provided in context", func(t *testing.T) {
+ ctx := context.Background()
+ cid1 := getCorrelationID(ctx)
+ require.NotEmpty(t, cid1)
+
+ cid2 := getCorrelationID(ctx)
+ require.NotEqual(t, cid1, cid2, "it should return a new correlation_id each time as it is not injected into the context")
+ })
+
+ t.Run("provided in context", func(t *testing.T) {
+ const cid = "1-2-3-4"
+ ctx := correlation.ContextWithCorrelation(context.Background(), cid)
+
+ require.Equal(t, cid, getCorrelationID(ctx))
+ require.Equal(t, cid, getCorrelationID(ctx))
})
}
diff --git a/internal/praefect/nodes/manager.go b/internal/praefect/nodes/manager.go
index 6b3ab19fb..20c50201c 100644
--- a/internal/praefect/nodes/manager.go
+++ b/internal/praefect/nodes/manager.go
@@ -9,7 +9,6 @@ import (
"sync"
"time"
- grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/sirupsen/logrus"
@@ -24,8 +23,6 @@ import (
"gitlab.com/gitlab-org/gitaly/internal/praefect/nodes/tracker"
"gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry"
prommetrics "gitlab.com/gitlab-org/gitaly/internal/prometheus/metrics"
- correlation "gitlab.com/gitlab-org/labkit/correlation/grpc"
- grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
@@ -137,8 +134,6 @@ func NewManager(
for _, node := range virtualStorage.Nodes {
streamInterceptors := []grpc.StreamClientInterceptor{
grpc_prometheus.StreamClientInterceptor,
- grpctracing.StreamClientTracingInterceptor(),
- correlation.StreamClientCorrelationInterceptor(),
}
if c.Failover.Enabled && errorTracker != nil {
@@ -150,12 +145,8 @@ func NewManager(
[]grpc.DialOption{
grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())),
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(node.Token)),
- grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(streamInterceptors...)),
- grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
- grpc_prometheus.UnaryClientInterceptor,
- grpctracing.UnaryClientTracingInterceptor(),
- correlation.UnaryClientCorrelationInterceptor(),
- )),
+ grpc.WithChainStreamInterceptor(streamInterceptors...),
+ grpc.WithChainUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
}, dialOpts...),
)
if err != nil {
diff --git a/internal/testhelper/testserver.go b/internal/testhelper/testserver.go
index 8ccbcb6eb..c4c112e18 100644
--- a/internal/testhelper/testserver.go
+++ b/internal/testhelper/testserver.go
@@ -35,6 +35,7 @@ import (
praefectconfig "gitlab.com/gitlab-org/gitaly/internal/praefect/config"
serverauth "gitlab.com/gitlab-org/gitaly/internal/server/auth"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
@@ -298,11 +299,18 @@ func NewServer(tb testing.TB, streamInterceptors []grpc.StreamServerInterceptor,
logrusEntry := log.NewEntry(logger).WithField("test", tb.Name())
ctxTagger := grpc_ctxtags.WithFieldExtractorForInitialReq(fieldextractors.FieldExtractor)
- ctxStreamTagger := grpc_ctxtags.StreamServerInterceptor(ctxTagger)
- ctxUnaryTagger := grpc_ctxtags.UnaryServerInterceptor(ctxTagger)
- streamInterceptors = append([]grpc.StreamServerInterceptor{ctxStreamTagger, grpc_logrus.StreamServerInterceptor(logrusEntry)}, streamInterceptors...)
- unaryInterceptors = append([]grpc.UnaryServerInterceptor{ctxUnaryTagger, grpc_logrus.UnaryServerInterceptor(logrusEntry)}, unaryInterceptors...)
+ streamInterceptors = append([]grpc.StreamServerInterceptor{
+ grpc_ctxtags.StreamServerInterceptor(ctxTagger),
+ grpccorrelation.StreamServerCorrelationInterceptor(),
+ grpc_logrus.StreamServerInterceptor(logrusEntry),
+ }, streamInterceptors...)
+
+ unaryInterceptors = append([]grpc.UnaryServerInterceptor{
+ grpc_ctxtags.UnaryServerInterceptor(ctxTagger),
+ grpccorrelation.UnaryServerCorrelationInterceptor(),
+ grpc_logrus.UnaryServerInterceptor(logrusEntry),
+ }, unaryInterceptors...)
return NewTestServer(
grpc.NewServer(