Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToon Claes <toon@gitlab.com>2021-11-16 15:53:30 +0300
committerToon Claes <toon@gitlab.com>2021-11-16 15:53:30 +0300
commitbb7fa728e9206c7b2f9ede313f3db7b616af9350 (patch)
treefea7680412f6fa1247f38fc1d11fc9dd709441df
parent724e844c814b91bb4a98f7eb3607756b7504657d (diff)
parentd56f6c64667e53107f6c40f3af4698e1f620290f (diff)
Merge branch 'ps-track-payload-size' into 'master'
Track payload bytes for RPC Closes #3867 See merge request gitlab-org/gitaly!4030
-rw-r--r--internal/gitaly/server/server.go21
-rw-r--r--internal/gitaly/server/server_factory_test.go25
-rw-r--r--internal/grpcstats/stats.go75
-rw-r--r--internal/grpcstats/stats_test.go77
-rw-r--r--internal/grpcstats/testhelper_test.go11
-rw-r--r--internal/log/log.go152
-rw-r--r--internal/log/log_test.go343
-rw-r--r--internal/middleware/commandstatshandler/commandstatshandler.go23
-rw-r--r--internal/middleware/commandstatshandler/commandstatshandler_test.go16
9 files changed, 720 insertions, 23 deletions
diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go
index ffdf55c10..9dcf07aa4 100644
--- a/internal/gitaly/server/server.go
+++ b/internal/gitaly/server/server.go
@@ -16,6 +16,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/server/auth"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/grpcstats"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper/fieldextractors"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
gitalylog "gitlab.com/gitlab-org/gitaly/v14/internal/log"
@@ -103,7 +104,19 @@ func New(
[]grpc.DialOption{client.UnaryInterceptor()},
))
+ logMsgProducer := grpcmwlogrus.WithMessageProducer(
+ gitalylog.MessageProducer(
+ gitalylog.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer),
+ commandstatshandler.FieldsProducer,
+ grpcstats.FieldsProducer,
+ ),
+ )
+
opts := []grpc.ServerOption{
+ grpc.StatsHandler(gitalylog.PerRPCLogHandler{
+ Underlying: &grpcstats.PayloadBytes{},
+ FieldProducers: []gitalylog.FieldsProducer{grpcstats.FieldsProducer},
+ }),
grpc.Creds(lm),
grpc.StreamInterceptor(grpcmw.ChainStreamServer(
grpcmwtags.StreamServerInterceptor(ctxTagOpts...),
@@ -113,7 +126,9 @@ func New(
commandstatshandler.StreamInterceptor,
grpcmwlogrus.StreamServerInterceptor(logrusEntry,
grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
+ logMsgProducer,
+ ),
+ gitalylog.StreamLogDataCatcherServerInterceptor(),
sentryhandler.StreamLogHandler,
cancelhandler.Stream, // Should be below LogHandler
auth.StreamServerInterceptor(cfg.Auth),
@@ -132,7 +147,9 @@ func New(
commandstatshandler.UnaryInterceptor,
grpcmwlogrus.UnaryServerInterceptor(logrusEntry,
grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
+ logMsgProducer,
+ ),
+ gitalylog.UnaryLogDataCatcherServerInterceptor(),
sentryhandler.UnaryLogHandler,
cancelhandler.Unary, // Should be below LogHandler
auth.UnaryServerInterceptor(cfg.Auth),
diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go
index 369cf3775..a9ad14696 100644
--- a/internal/gitaly/server/server_factory_test.go
+++ b/internal/gitaly/server/server_factory_test.go
@@ -9,6 +9,8 @@ import (
"os"
"testing"
+ "github.com/sirupsen/logrus"
+ "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/client"
@@ -125,6 +127,29 @@ func TestGitalyServerFactory(t *testing.T) {
_, socketErr := socketHealthClient.Check(ctx, &healthpb.HealthCheckRequest{})
require.Equal(t, codes.Unavailable, status.Code(socketErr))
})
+
+ t.Run("logging check", func(t *testing.T) {
+ cfg := testcfg.Build(t)
+ logger, hook := test.NewNullLogger()
+ sf := NewGitalyServerFactory(cfg, logger.WithContext(ctx), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
+
+ checkHealth(t, sf, starter.TCP, "localhost:0")
+
+ var entry *logrus.Entry
+ for _, e := range hook.AllEntries() {
+ if e.Message == "finished unary call with code OK" {
+ entry = e
+ break
+ }
+ }
+ require.NotNil(t, entry)
+ reqSize, found := entry.Data["grpc.request.payload_bytes"]
+ assert.EqualValues(t, 0, reqSize)
+ require.True(t, found)
+ respSize, found := entry.Data["grpc.response.payload_bytes"]
+ assert.EqualValues(t, 2, respSize)
+ require.True(t, found)
+ })
}
func TestGitalyServerFactory_closeOrder(t *testing.T) {
diff --git a/internal/grpcstats/stats.go b/internal/grpcstats/stats.go
new file mode 100644
index 000000000..f656aead5
--- /dev/null
+++ b/internal/grpcstats/stats.go
@@ -0,0 +1,75 @@
+package grpcstats
+
+import (
+ "context"
+
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/stats"
+)
+
+// PayloadBytes implements stats.Handler and tracks amount of bytes received and send by gRPC service
+// for each method call. The information about statistics is added into the context and can be
+// extracted with payloadBytesStatsFromContext.
+type PayloadBytes struct{}
+
+type payloadBytesStatsKey struct{}
+
+// HandleConn exists to satisfy gRPC stats.Handler.
+func (s *PayloadBytes) HandleConn(context.Context, stats.ConnStats) {}
+
+// TagConn exists to satisfy gRPC stats.Handler. We don't gather connection level stats
+// and are thus not currently using it.
+func (s *PayloadBytes) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
+ return ctx
+}
+
+// HandleRPC implements per-RPC tracing and stats instrumentation.
+func (s *PayloadBytes) HandleRPC(ctx context.Context, rs stats.RPCStats) {
+ switch st := rs.(type) {
+ case *stats.InPayload:
+ bytesStats := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats)
+ bytesStats.InPayloadBytes += int64(st.Length)
+ case *stats.OutPayload:
+ bytesStats := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats)
+ bytesStats.OutPayloadBytes += int64(st.Length)
+ }
+}
+
+// TagRPC initializes context with an RPC specific stats collector.
+// The returned context will be used in method invocation as is passed into HandleRPC.
+func (s *PayloadBytes) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
+ return context.WithValue(ctx, payloadBytesStatsKey{}, new(PayloadBytesStats))
+}
+
+// PayloadBytesStats contains info about bytes received and sent by the gRPC method.
+type PayloadBytesStats struct {
+ // InPayloadBytes amount of bytes received.
+ InPayloadBytes int64
+ // OutPayloadBytes amount of bytes sent.
+ OutPayloadBytes int64
+}
+
+// Fields returns logging info.
+func (s *PayloadBytesStats) Fields() logrus.Fields {
+ return logrus.Fields{
+ "grpc.request.payload_bytes": s.InPayloadBytes,
+ "grpc.response.payload_bytes": s.OutPayloadBytes,
+ }
+}
+
+// FieldsProducer extracts stats info from the context and returns it as a logging fields.
+func FieldsProducer(ctx context.Context) logrus.Fields {
+ payloadBytesStats := payloadBytesStatsFromContext(ctx)
+ if payloadBytesStats != nil {
+ return payloadBytesStats.Fields()
+ }
+ return nil
+}
+
+func payloadBytesStatsFromContext(ctx context.Context) *PayloadBytesStats {
+ v, ok := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats)
+ if !ok {
+ return nil
+ }
+ return v
+}
diff --git a/internal/grpcstats/stats_test.go b/internal/grpcstats/stats_test.go
new file mode 100644
index 000000000..84d9b349c
--- /dev/null
+++ b/internal/grpcstats/stats_test.go
@@ -0,0 +1,77 @@
+package grpcstats
+
+import (
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "google.golang.org/grpc/stats"
+)
+
+func TestPayloadBytes_TagRPC(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+ ctx = (&PayloadBytes{}).TagRPC(ctx, nil)
+ require.Equal(t,
+ logrus.Fields{"grpc.request.payload_bytes": int64(0), "grpc.response.payload_bytes": int64(0)},
+ FieldsProducer(ctx),
+ )
+}
+
+func TestPayloadBytes_HandleRPC(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+ handler := &PayloadBytes{}
+ ctx = handler.TagRPC(ctx, nil)
+ handler.HandleRPC(ctx, nil) // sanity check we don't fail anything
+ handler.HandleRPC(ctx, &stats.Begin{}) // sanity check we don't fail anything
+ handler.HandleRPC(ctx, &stats.InPayload{Length: 42})
+ require.Equal(t,
+ logrus.Fields{"grpc.request.payload_bytes": int64(42), "grpc.response.payload_bytes": int64(0)},
+ FieldsProducer(ctx),
+ )
+ handler.HandleRPC(ctx, &stats.OutPayload{Length: 24})
+ require.Equal(t,
+ logrus.Fields{"grpc.request.payload_bytes": int64(42), "grpc.response.payload_bytes": int64(24)},
+ FieldsProducer(ctx),
+ )
+ handler.HandleRPC(ctx, &stats.InPayload{Length: 38})
+ require.Equal(t,
+ logrus.Fields{"grpc.request.payload_bytes": int64(80), "grpc.response.payload_bytes": int64(24)},
+ FieldsProducer(ctx),
+ )
+ handler.HandleRPC(ctx, &stats.OutPayload{Length: 66})
+ require.Equal(t,
+ logrus.Fields{"grpc.request.payload_bytes": int64(80), "grpc.response.payload_bytes": int64(90)},
+ FieldsProducer(ctx),
+ )
+}
+
+func TestPayloadBytesStats_Fields(t *testing.T) {
+ bytesStats := PayloadBytesStats{InPayloadBytes: 80, OutPayloadBytes: 90}
+ require.Equal(t, logrus.Fields{
+ "grpc.request.payload_bytes": int64(80),
+ "grpc.response.payload_bytes": int64(90),
+ }, bytesStats.Fields())
+}
+
+func TestFieldsProducer(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ t.Run("ok", func(t *testing.T) {
+ handler := &PayloadBytes{}
+ ctx := handler.TagRPC(ctx, nil)
+ handler.HandleRPC(ctx, &stats.InPayload{Length: 42})
+ handler.HandleRPC(ctx, &stats.OutPayload{Length: 24})
+ require.Equal(t, logrus.Fields{
+ "grpc.request.payload_bytes": int64(42),
+ "grpc.response.payload_bytes": int64(24),
+ }, FieldsProducer(ctx))
+ })
+
+ t.Run("no data", func(t *testing.T) {
+ require.Nil(t, FieldsProducer(ctx))
+ })
+}
diff --git a/internal/grpcstats/testhelper_test.go b/internal/grpcstats/testhelper_test.go
new file mode 100644
index 000000000..b41bdc29e
--- /dev/null
+++ b/internal/grpcstats/testhelper_test.go
@@ -0,0 +1,11 @@
+package grpcstats_test
+
+import (
+ "testing"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+)
+
+func TestMain(m *testing.M) {
+ testhelper.Run(m)
+}
diff --git a/internal/log/log.go b/internal/log/log.go
index 99381418d..5c4e5f6db 100644
--- a/internal/log/log.go
+++ b/internal/log/log.go
@@ -1,9 +1,15 @@
package log
import (
+ "context"
"os"
+ grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
+ "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/stats"
)
const (
@@ -114,3 +120,149 @@ func Default() *logrus.Entry { return defaultLogger.WithField("pid", os.Getpid()
// GrpcGo is a dedicated logrus logger for the grpc-go library. We use it
// to control the library's chattiness.
func GrpcGo() *logrus.Entry { return grpcGo.WithField("pid", os.Getpid()) }
+
+// FieldsProducer returns fields that need to be added into the logging context.
+type FieldsProducer func(context.Context) logrus.Fields
+
+// MessageProducer returns a wrapper that extends passed mp to accept additional fields generated
+// by each of the fieldsProducers.
+func MessageProducer(mp grpcmwlogrus.MessageProducer, fieldsProducers ...FieldsProducer) grpcmwlogrus.MessageProducer {
+ return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
+ for _, fieldsProducer := range fieldsProducers {
+ for key, val := range fieldsProducer(ctx) {
+ fields[key] = val
+ }
+ }
+ mp(ctx, format, level, code, err, fields)
+ }
+}
+
+type messageProducerHolder struct {
+ logger *logrus.Entry
+ actual grpcmwlogrus.MessageProducer
+ format string
+ level logrus.Level
+ code codes.Code
+ err error
+ fields logrus.Fields
+}
+
+type messageProducerHolderKey struct{}
+
+// messageProducerPropagationFrom extracts *messageProducerHolder from context
+// and returns to the caller.
+// It returns nil in case it is not found.
+func messageProducerPropagationFrom(ctx context.Context) *messageProducerHolder {
+ mpp, ok := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder)
+ if !ok {
+ return nil
+ }
+ return mpp
+}
+
+// PropagationMessageProducer catches logging information from the context and populates it
+// to the special holder that should be present in the context.
+// Should be used only in combination with PerRPCLogHandler.
+func PropagationMessageProducer(actual grpcmwlogrus.MessageProducer) grpcmwlogrus.MessageProducer {
+ return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
+ mpp := messageProducerPropagationFrom(ctx)
+ if mpp == nil {
+ return
+ }
+ *mpp = messageProducerHolder{
+ logger: ctxlogrus.Extract(ctx),
+ actual: actual,
+ format: format,
+ level: level,
+ code: code,
+ err: err,
+ fields: fields,
+ }
+ }
+}
+
+// PerRPCLogHandler is designed to collect stats that are accessible
+// from the google.golang.org/grpc/stats.Handler, because some information
+// can't be extracted on the interceptors level.
+type PerRPCLogHandler struct {
+ Underlying stats.Handler
+ FieldProducers []FieldsProducer
+}
+
+// HandleConn only calls Underlying and exists to satisfy gRPC stats.Handler.
+func (lh PerRPCLogHandler) HandleConn(ctx context.Context, cs stats.ConnStats) {
+ lh.Underlying.HandleConn(ctx, cs)
+}
+
+// TagConn only calls Underlying and exists to satisfy gRPC stats.Handler.
+func (lh PerRPCLogHandler) TagConn(ctx context.Context, cti *stats.ConnTagInfo) context.Context {
+ return lh.Underlying.TagConn(ctx, cti)
+}
+
+// HandleRPC catches each RPC call and for the *stats.End stats invokes
+// custom message producers to populate logging data. Once all data is collected
+// the actual logging happens by using logger that is caught by PropagationMessageProducer.
+func (lh PerRPCLogHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) {
+ lh.Underlying.HandleRPC(ctx, rs)
+ switch rs.(type) {
+ case *stats.End:
+ // This code runs once all interceptors are finished their execution.
+ // That is why any logging info collected after interceptors completion
+ // is not at the logger's context. That is why we need to manually propagate
+ // it to the logger.
+ mpp := messageProducerPropagationFrom(ctx)
+ if mpp == nil || (mpp != nil && mpp.actual == nil) {
+ return
+ }
+
+ if mpp.fields == nil {
+ mpp.fields = logrus.Fields{}
+ }
+ for _, fp := range lh.FieldProducers {
+ for k, v := range fp(ctx) {
+ mpp.fields[k] = v
+ }
+ }
+ // Once again because all interceptors are finished and context doesn't contain
+ // a logger we need to set logger manually into the context.
+ // It's needed because github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus.DefaultMessageProducer
+ // extracts logger from the context and use it to write the logs.
+ ctx = ctxlogrus.ToContext(ctx, mpp.logger)
+ mpp.actual(ctx, mpp.format, mpp.level, mpp.code, mpp.err, mpp.fields)
+ return
+ }
+}
+
+// TagRPC propagates a special data holder into the context that is responsible to
+// hold logging information produced by the logging interceptor.
+// The logging data should be caught by the UnaryLogDataCatcherServerInterceptor. It needs to
+// be included into the interceptor chain below logging interceptor.
+func (lh PerRPCLogHandler) TagRPC(ctx context.Context, rti *stats.RPCTagInfo) context.Context {
+ ctx = context.WithValue(ctx, messageProducerHolderKey{}, new(messageProducerHolder))
+ return lh.Underlying.TagRPC(ctx, rti)
+}
+
+// UnaryLogDataCatcherServerInterceptor catches logging data produced by the upper interceptors and
+// propagates it into the holder to pop up it to the HandleRPC method of the PerRPCLogHandler.
+func UnaryLogDataCatcherServerInterceptor() grpc.UnaryServerInterceptor {
+ return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ mpp := messageProducerPropagationFrom(ctx)
+ if mpp != nil {
+ mpp.fields = ctxlogrus.Extract(ctx).Data
+ }
+ return handler(ctx, req)
+ }
+}
+
+// StreamLogDataCatcherServerInterceptor catches logging data produced by the upper interceptors and
+// propagates it into the holder to pop up it to the HandleRPC method of the PerRPCLogHandler.
+func StreamLogDataCatcherServerInterceptor() grpc.StreamServerInterceptor {
+ return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ ctx := ss.Context()
+ mpp := messageProducerPropagationFrom(ctx)
+ if mpp != nil {
+ mpp.fields = ctxlogrus.Extract(ctx).Data
+ }
+ return handler(srv, ss)
+ }
+}
diff --git a/internal/log/log_test.go b/internal/log/log_test.go
index 5be348027..fcc3d0573 100644
--- a/internal/log/log_test.go
+++ b/internal/log/log_test.go
@@ -2,14 +2,175 @@ package log
import (
"bytes"
+ "context"
+ "io"
+ "net"
+ "os"
+ "sync"
"testing"
"time"
+ grpcmw "github.com/grpc-ecosystem/go-grpc-middleware"
+ grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
+ "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/sirupsen/logrus"
+ "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/client"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/grpcstats"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/stats"
+ "google.golang.org/grpc/test/grpc_testing"
+ "google.golang.org/protobuf/proto"
)
+func TestPayloadBytes(t *testing.T) {
+ ctx := context.Background()
+
+ logger, hook := test.NewNullLogger()
+
+ opts := []grpc.ServerOption{
+ grpc.StatsHandler(PerRPCLogHandler{
+ Underlying: &grpcstats.PayloadBytes{},
+ FieldProducers: []FieldsProducer{grpcstats.FieldsProducer},
+ }),
+ grpc.UnaryInterceptor(
+ grpcmw.ChainUnaryServer(
+ grpcmwlogrus.UnaryServerInterceptor(
+ logrus.NewEntry(logger),
+ grpcmwlogrus.WithMessageProducer(
+ MessageProducer(
+ PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer),
+ grpcstats.FieldsProducer,
+ ),
+ ),
+ ),
+ UnaryLogDataCatcherServerInterceptor(),
+ ),
+ ),
+ grpc.StreamInterceptor(
+ grpcmw.ChainStreamServer(
+ grpcmwlogrus.StreamServerInterceptor(
+ logrus.NewEntry(logger),
+ grpcmwlogrus.WithMessageProducer(
+ MessageProducer(
+ PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer),
+ grpcstats.FieldsProducer,
+ ),
+ ),
+ ),
+ StreamLogDataCatcherServerInterceptor(),
+ ),
+ ),
+ }
+
+ srv := grpc.NewServer(opts...)
+ grpc_testing.RegisterTestServiceServer(srv, testService{})
+ sock, err := os.CreateTemp("", "")
+ require.NoError(t, err)
+ require.NoError(t, sock.Close())
+ require.NoError(t, os.RemoveAll(sock.Name()))
+ t.Cleanup(func() { require.NoError(t, os.RemoveAll(sock.Name())) })
+
+ lis, err := net.Listen("unix", sock.Name())
+ require.NoError(t, err)
+
+ t.Cleanup(srv.GracefulStop)
+ go func() { assert.NoError(t, srv.Serve(lis)) }()
+
+ cc, err := client.DialContext(ctx, "unix://"+sock.Name(), nil)
+ require.NoError(t, err)
+ t.Cleanup(func() { require.NoError(t, cc.Close()) })
+
+ testClient := grpc_testing.NewTestServiceClient(cc)
+ const invocations = 2
+ var wg sync.WaitGroup
+ for i := 0; i < invocations; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ resp, err := testClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{Payload: newStubPayload()})
+ if !assert.NoError(t, err) {
+ return
+ }
+ testassert.ProtoEqual(t, newStubPayload(), resp.Payload)
+
+ call, err := testClient.HalfDuplexCall(ctx)
+ if !assert.NoError(t, err) {
+ return
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ for {
+ _, err := call.Recv()
+ if err == io.EOF {
+ return
+ }
+ assert.NoError(t, err)
+ }
+ }()
+ assert.NoError(t, call.Send(&grpc_testing.StreamingOutputCallRequest{Payload: newStubPayload()}))
+ assert.NoError(t, call.Send(&grpc_testing.StreamingOutputCallRequest{Payload: newStubPayload()}))
+ assert.NoError(t, call.CloseSend())
+ <-done
+ }()
+ }
+ wg.Wait()
+
+ entries := hook.AllEntries()
+ require.Len(t, entries, 4)
+ var unary, stream int
+ for _, e := range entries {
+ if e.Message == "finished unary call with code OK" {
+ unary++
+ require.EqualValues(t, 8, e.Data["grpc.request.payload_bytes"])
+ require.EqualValues(t, 8, e.Data["grpc.response.payload_bytes"])
+ }
+ if e.Message == "finished streaming call with code OK" {
+ stream++
+ require.EqualValues(t, 16, e.Data["grpc.request.payload_bytes"])
+ require.EqualValues(t, 16, e.Data["grpc.response.payload_bytes"])
+ }
+ }
+ require.Equal(t, invocations, unary)
+ require.Equal(t, invocations, stream)
+}
+
+func newStubPayload() *grpc_testing.Payload {
+ return &grpc_testing.Payload{Body: []byte("stub")}
+}
+
+type testService struct {
+ grpc_testing.UnimplementedTestServiceServer
+}
+
+func (ts testService) UnaryCall(context.Context, *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) {
+ return &grpc_testing.SimpleResponse{Payload: newStubPayload()}, nil
+}
+
+func (ts testService) HalfDuplexCall(stream grpc_testing.TestService_HalfDuplexCallServer) error {
+ for {
+ if _, err := stream.Recv(); err != nil {
+ if err == io.EOF {
+ break
+ }
+ return err
+ }
+ }
+
+ resp := &grpc_testing.StreamingOutputCallResponse{Payload: newStubPayload()}
+ if err := stream.Send(proto.Clone(resp).(*grpc_testing.StreamingOutputCallResponse)); err != nil {
+ return err
+ }
+ return stream.Send(proto.Clone(resp).(*grpc_testing.StreamingOutputCallResponse))
+}
+
func TestConfigure(t *testing.T) {
for _, tc := range []struct {
desc string
@@ -95,3 +256,185 @@ func TestConfigure(t *testing.T) {
})
}
}
+
+func TestMessageProducer(t *testing.T) {
+ triggered := false
+ MessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
+ require.Equal(t, context.Background(), ctx)
+ require.Equal(t, "format-stub", format)
+ require.Equal(t, logrus.DebugLevel, level)
+ require.Equal(t, codes.OutOfRange, code)
+ require.Equal(t, assert.AnError, err)
+ require.Equal(t, logrus.Fields{"a": 1, "b": "test", "c": "stub"}, fields)
+ triggered = true
+ }, func(context.Context) logrus.Fields {
+ return logrus.Fields{"a": 1}
+ }, func(context.Context) logrus.Fields {
+ return logrus.Fields{"b": "test"}
+ })(context.Background(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"})
+ require.True(t, triggered)
+}
+
+func TestPropagationMessageProducer(t *testing.T) {
+ t.Run("empty context", func(t *testing.T) {
+ ctx := context.Background()
+ mp := PropagationMessageProducer(func(context.Context, string, logrus.Level, codes.Code, error, logrus.Fields) {})
+ mp(ctx, "", logrus.DebugLevel, codes.OK, nil, nil)
+ })
+
+ t.Run("context with holder", func(t *testing.T) {
+ holder := new(messageProducerHolder)
+ ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, holder)
+ triggered := false
+ mp := PropagationMessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
+ triggered = true
+ })
+ mp(ctx, "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"a": 1})
+ require.Equal(t, "format-stub", holder.format)
+ require.Equal(t, logrus.DebugLevel, holder.level)
+ require.Equal(t, codes.OutOfRange, holder.code)
+ require.Equal(t, assert.AnError, holder.err)
+ require.Equal(t, logrus.Fields{"a": 1}, holder.fields)
+ holder.actual(ctx, "", logrus.DebugLevel, codes.OK, nil, nil)
+ require.True(t, triggered)
+ })
+}
+
+func TestPerRPCLogHandler(t *testing.T) {
+ msh := &mockStatHandler{Calls: map[string][]interface{}{}}
+
+ lh := PerRPCLogHandler{
+ Underlying: msh,
+ FieldProducers: []FieldsProducer{
+ func(ctx context.Context) logrus.Fields { return logrus.Fields{"a": 1} },
+ func(ctx context.Context) logrus.Fields { return logrus.Fields{"b": "2"} },
+ },
+ }
+
+ t.Run("check propagation", func(t *testing.T) {
+ ctx := context.Background()
+ ctx = lh.TagConn(ctx, &stats.ConnTagInfo{})
+ lh.HandleConn(ctx, &stats.ConnBegin{})
+ ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{})
+ lh.HandleRPC(ctx, &stats.Begin{})
+ lh.HandleRPC(ctx, &stats.InHeader{})
+ lh.HandleRPC(ctx, &stats.InPayload{})
+ lh.HandleRPC(ctx, &stats.OutHeader{})
+ lh.HandleRPC(ctx, &stats.OutPayload{})
+ lh.HandleRPC(ctx, &stats.End{})
+ lh.HandleConn(ctx, &stats.ConnEnd{})
+
+ assert.Equal(t, map[string][]interface{}{
+ "TagConn": {&stats.ConnTagInfo{}},
+ "HandleConn": {&stats.ConnBegin{}, &stats.ConnEnd{}},
+ "TagRPC": {&stats.RPCTagInfo{}},
+ "HandleRPC": {&stats.Begin{}, &stats.InHeader{}, &stats.InPayload{}, &stats.OutHeader{}, &stats.OutPayload{}, &stats.End{}},
+ }, msh.Calls)
+ })
+
+ t.Run("log handling", func(t *testing.T) {
+ ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New()))
+ ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{})
+ mpp := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder)
+ mpp.format = "message"
+ mpp.level = logrus.InfoLevel
+ mpp.code = codes.InvalidArgument
+ mpp.err = assert.AnError
+ mpp.actual = func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
+ assert.Equal(t, "message", format)
+ assert.Equal(t, logrus.InfoLevel, level)
+ assert.Equal(t, codes.InvalidArgument, code)
+ assert.Equal(t, assert.AnError, err)
+ assert.Equal(t, logrus.Fields{"a": 1, "b": "2"}, mpp.fields)
+ }
+ lh.HandleRPC(ctx, &stats.End{})
+ })
+}
+
+type mockStatHandler struct {
+ Calls map[string][]interface{}
+}
+
+func (m *mockStatHandler) TagRPC(ctx context.Context, s *stats.RPCTagInfo) context.Context {
+ m.Calls["TagRPC"] = append(m.Calls["TagRPC"], s)
+ return ctx
+}
+
+func (m *mockStatHandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
+ m.Calls["HandleRPC"] = append(m.Calls["HandleRPC"], s)
+}
+
+func (m *mockStatHandler) TagConn(ctx context.Context, s *stats.ConnTagInfo) context.Context {
+ m.Calls["TagConn"] = append(m.Calls["TagConn"], s)
+ return ctx
+}
+
+func (m *mockStatHandler) HandleConn(ctx context.Context, s stats.ConnStats) {
+ m.Calls["HandleConn"] = append(m.Calls["HandleConn"], s)
+}
+
+func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) {
+ handlerStub := func(context.Context, interface{}) (interface{}, error) {
+ return nil, nil
+ }
+
+ t.Run("propagates call", func(t *testing.T) {
+ interceptor := UnaryLogDataCatcherServerInterceptor()
+ resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
+ return 42, assert.AnError
+ })
+
+ assert.Equal(t, 42, resp)
+ assert.Equal(t, assert.AnError, err)
+ })
+
+ t.Run("no logger", func(t *testing.T) {
+ mpp := &messageProducerHolder{}
+ ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+
+ interceptor := UnaryLogDataCatcherServerInterceptor()
+ _, _ = interceptor(ctx, nil, nil, handlerStub)
+ assert.Empty(t, mpp.fields)
+ })
+
+ t.Run("caught", func(t *testing.T) {
+ mpp := &messageProducerHolder{}
+ ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1))
+ interceptor := UnaryLogDataCatcherServerInterceptor()
+ _, _ = interceptor(ctx, nil, nil, handlerStub)
+ assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields)
+ })
+}
+
+func TestStreamLogDataCatcherServerInterceptor(t *testing.T) {
+ t.Run("propagates call", func(t *testing.T) {
+ interceptor := StreamLogDataCatcherServerInterceptor()
+ ss := &grpcmw.WrappedServerStream{WrappedContext: context.Background()}
+ err := interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error {
+ return assert.AnError
+ })
+
+ assert.Equal(t, assert.AnError, err)
+ })
+
+ t.Run("no logger", func(t *testing.T) {
+ mpp := &messageProducerHolder{}
+ ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+
+ interceptor := StreamLogDataCatcherServerInterceptor()
+ ss := &grpcmw.WrappedServerStream{WrappedContext: ctx}
+ _ = interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return nil })
+ })
+
+ t.Run("caught", func(t *testing.T) {
+ mpp := &messageProducerHolder{}
+ ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1))
+
+ interceptor := StreamLogDataCatcherServerInterceptor()
+ ss := &grpcmw.WrappedServerStream{WrappedContext: ctx}
+ _ = interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return nil })
+ assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields)
+ })
+}
diff --git a/internal/middleware/commandstatshandler/commandstatshandler.go b/internal/middleware/commandstatshandler/commandstatshandler.go
index 34f100a51..50f251b08 100644
--- a/internal/middleware/commandstatshandler/commandstatshandler.go
+++ b/internal/middleware/commandstatshandler/commandstatshandler.go
@@ -4,11 +4,9 @@ import (
"context"
grpcmw "github.com/grpc-ecosystem/go-grpc-middleware"
- "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/v14/internal/command"
"google.golang.org/grpc"
- "google.golang.org/grpc/codes"
)
// UnaryInterceptor returns a Unary Interceptor
@@ -33,23 +31,10 @@ func StreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.Str
return err
}
-// CommandStatsMessageProducer hooks into grpc_logrus to add more fields.
-//
-// It replaces github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus.DefaultMessageProducer.
-//
-// We cannot use ctxlogrus.AddFields() as it is not concurrency safe, and
-// we may be logging concurrently. Conversely, command.Stats.Fields() is
-// protected by a lock and can safely be called here.
-func CommandStatsMessageProducer(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
- if err != nil {
- fields[logrus.ErrorKey] = err
- }
- entry := ctxlogrus.Extract(ctx).WithContext(ctx).WithFields(fields)
-
- // safely inject commandstats
+// FieldsProducer extracts stats info from the context and returns it as a logging fields.
+func FieldsProducer(ctx context.Context) logrus.Fields {
if stats := command.StatsFromContext(ctx); stats != nil {
- entry = entry.WithFields(stats.Fields())
+ return stats.Fields()
}
-
- entry.Logf(level, format)
+ return nil
}
diff --git a/internal/middleware/commandstatshandler/commandstatshandler_test.go b/internal/middleware/commandstatshandler/commandstatshandler_test.go
index d474744bd..01c1e6479 100644
--- a/internal/middleware/commandstatshandler/commandstatshandler_test.go
+++ b/internal/middleware/commandstatshandler/commandstatshandler_test.go
@@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/command"
"gitlab.com/gitlab-org/gitaly/v14/internal/git"
"gitlab.com/gitlab-org/gitaly/v14/internal/git/catfile"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
@@ -38,13 +39,13 @@ func createNewServer(t *testing.T, cfg config.Cfg) *grpc.Server {
StreamInterceptor,
grpcmwlogrus.StreamServerInterceptor(logrusEntry,
grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(CommandStatsMessageProducer)),
+ grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))),
)),
grpc.UnaryInterceptor(grpcmw.ChainUnaryServer(
UnaryInterceptor,
grpcmwlogrus.UnaryServerInterceptor(logrusEntry,
grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(CommandStatsMessageProducer)),
+ grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))),
)),
}
@@ -141,3 +142,14 @@ func TestInterceptor(t *testing.T) {
})
}
}
+
+func TestFieldsProducer(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ ctx = command.InitContextStats(ctx)
+ stats := command.StatsFromContext(ctx)
+ stats.RecordMax("stub", 42)
+
+ require.Equal(t, logrus.Fields{"stub": 42}, FieldsProducer(ctx))
+}