diff options
author | Toon Claes <toon@gitlab.com> | 2021-11-16 15:53:30 +0300 |
---|---|---|
committer | Toon Claes <toon@gitlab.com> | 2021-11-16 15:53:30 +0300 |
commit | bb7fa728e9206c7b2f9ede313f3db7b616af9350 (patch) | |
tree | fea7680412f6fa1247f38fc1d11fc9dd709441df | |
parent | 724e844c814b91bb4a98f7eb3607756b7504657d (diff) | |
parent | d56f6c64667e53107f6c40f3af4698e1f620290f (diff) |
Merge branch 'ps-track-payload-size' into 'master'
Track payload bytes for RPC
Closes #3867
See merge request gitlab-org/gitaly!4030
-rw-r--r-- | internal/gitaly/server/server.go | 21 | ||||
-rw-r--r-- | internal/gitaly/server/server_factory_test.go | 25 | ||||
-rw-r--r-- | internal/grpcstats/stats.go | 75 | ||||
-rw-r--r-- | internal/grpcstats/stats_test.go | 77 | ||||
-rw-r--r-- | internal/grpcstats/testhelper_test.go | 11 | ||||
-rw-r--r-- | internal/log/log.go | 152 | ||||
-rw-r--r-- | internal/log/log_test.go | 343 | ||||
-rw-r--r-- | internal/middleware/commandstatshandler/commandstatshandler.go | 23 | ||||
-rw-r--r-- | internal/middleware/commandstatshandler/commandstatshandler_test.go | 16 |
9 files changed, 720 insertions, 23 deletions
diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index ffdf55c10..9dcf07aa4 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -16,6 +16,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/server/auth" + "gitlab.com/gitlab-org/gitaly/v14/internal/grpcstats" "gitlab.com/gitlab-org/gitaly/v14/internal/helper/fieldextractors" "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" gitalylog "gitlab.com/gitlab-org/gitaly/v14/internal/log" @@ -103,7 +104,19 @@ func New( []grpc.DialOption{client.UnaryInterceptor()}, )) + logMsgProducer := grpcmwlogrus.WithMessageProducer( + gitalylog.MessageProducer( + gitalylog.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), + commandstatshandler.FieldsProducer, + grpcstats.FieldsProducer, + ), + ) + opts := []grpc.ServerOption{ + grpc.StatsHandler(gitalylog.PerRPCLogHandler{ + Underlying: &grpcstats.PayloadBytes{}, + FieldProducers: []gitalylog.FieldsProducer{grpcstats.FieldsProducer}, + }), grpc.Creds(lm), grpc.StreamInterceptor(grpcmw.ChainStreamServer( grpcmwtags.StreamServerInterceptor(ctxTagOpts...), @@ -113,7 +126,9 @@ func New( commandstatshandler.StreamInterceptor, grpcmwlogrus.StreamServerInterceptor(logrusEntry, grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)), + logMsgProducer, + ), + gitalylog.StreamLogDataCatcherServerInterceptor(), sentryhandler.StreamLogHandler, cancelhandler.Stream, // Should be below LogHandler auth.StreamServerInterceptor(cfg.Auth), @@ -132,7 +147,9 @@ func New( commandstatshandler.UnaryInterceptor, grpcmwlogrus.UnaryServerInterceptor(logrusEntry, grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)), + logMsgProducer, + ), + gitalylog.UnaryLogDataCatcherServerInterceptor(), sentryhandler.UnaryLogHandler, cancelhandler.Unary, // Should be below LogHandler auth.UnaryServerInterceptor(cfg.Auth), diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go index 369cf3775..a9ad14696 100644 --- a/internal/gitaly/server/server_factory_test.go +++ b/internal/gitaly/server/server_factory_test.go @@ -9,6 +9,8 @@ import ( "os" "testing" + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/client" @@ -125,6 +127,29 @@ func TestGitalyServerFactory(t *testing.T) { _, socketErr := socketHealthClient.Check(ctx, &healthpb.HealthCheckRequest{}) require.Equal(t, codes.Unavailable, status.Code(socketErr)) }) + + t.Run("logging check", func(t *testing.T) { + cfg := testcfg.Build(t) + logger, hook := test.NewNullLogger() + sf := NewGitalyServerFactory(cfg, logger.WithContext(ctx), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg))) + + checkHealth(t, sf, starter.TCP, "localhost:0") + + var entry *logrus.Entry + for _, e := range hook.AllEntries() { + if e.Message == "finished unary call with code OK" { + entry = e + break + } + } + require.NotNil(t, entry) + reqSize, found := entry.Data["grpc.request.payload_bytes"] + assert.EqualValues(t, 0, reqSize) + require.True(t, found) + respSize, found := entry.Data["grpc.response.payload_bytes"] + assert.EqualValues(t, 2, respSize) + require.True(t, found) + }) } func TestGitalyServerFactory_closeOrder(t *testing.T) { diff --git a/internal/grpcstats/stats.go b/internal/grpcstats/stats.go new file mode 100644 index 000000000..f656aead5 --- /dev/null +++ b/internal/grpcstats/stats.go @@ -0,0 +1,75 @@ +package grpcstats + +import ( + "context" + + "github.com/sirupsen/logrus" + "google.golang.org/grpc/stats" +) + +// PayloadBytes implements stats.Handler and tracks amount of bytes received and send by gRPC service +// for each method call. The information about statistics is added into the context and can be +// extracted with payloadBytesStatsFromContext. +type PayloadBytes struct{} + +type payloadBytesStatsKey struct{} + +// HandleConn exists to satisfy gRPC stats.Handler. +func (s *PayloadBytes) HandleConn(context.Context, stats.ConnStats) {} + +// TagConn exists to satisfy gRPC stats.Handler. We don't gather connection level stats +// and are thus not currently using it. +func (s *PayloadBytes) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} + +// HandleRPC implements per-RPC tracing and stats instrumentation. +func (s *PayloadBytes) HandleRPC(ctx context.Context, rs stats.RPCStats) { + switch st := rs.(type) { + case *stats.InPayload: + bytesStats := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats) + bytesStats.InPayloadBytes += int64(st.Length) + case *stats.OutPayload: + bytesStats := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats) + bytesStats.OutPayloadBytes += int64(st.Length) + } +} + +// TagRPC initializes context with an RPC specific stats collector. +// The returned context will be used in method invocation as is passed into HandleRPC. +func (s *PayloadBytes) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return context.WithValue(ctx, payloadBytesStatsKey{}, new(PayloadBytesStats)) +} + +// PayloadBytesStats contains info about bytes received and sent by the gRPC method. +type PayloadBytesStats struct { + // InPayloadBytes amount of bytes received. + InPayloadBytes int64 + // OutPayloadBytes amount of bytes sent. + OutPayloadBytes int64 +} + +// Fields returns logging info. +func (s *PayloadBytesStats) Fields() logrus.Fields { + return logrus.Fields{ + "grpc.request.payload_bytes": s.InPayloadBytes, + "grpc.response.payload_bytes": s.OutPayloadBytes, + } +} + +// FieldsProducer extracts stats info from the context and returns it as a logging fields. +func FieldsProducer(ctx context.Context) logrus.Fields { + payloadBytesStats := payloadBytesStatsFromContext(ctx) + if payloadBytesStats != nil { + return payloadBytesStats.Fields() + } + return nil +} + +func payloadBytesStatsFromContext(ctx context.Context) *PayloadBytesStats { + v, ok := ctx.Value(payloadBytesStatsKey{}).(*PayloadBytesStats) + if !ok { + return nil + } + return v +} diff --git a/internal/grpcstats/stats_test.go b/internal/grpcstats/stats_test.go new file mode 100644 index 000000000..84d9b349c --- /dev/null +++ b/internal/grpcstats/stats_test.go @@ -0,0 +1,77 @@ +package grpcstats + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" + "google.golang.org/grpc/stats" +) + +func TestPayloadBytes_TagRPC(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + ctx = (&PayloadBytes{}).TagRPC(ctx, nil) + require.Equal(t, + logrus.Fields{"grpc.request.payload_bytes": int64(0), "grpc.response.payload_bytes": int64(0)}, + FieldsProducer(ctx), + ) +} + +func TestPayloadBytes_HandleRPC(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + handler := &PayloadBytes{} + ctx = handler.TagRPC(ctx, nil) + handler.HandleRPC(ctx, nil) // sanity check we don't fail anything + handler.HandleRPC(ctx, &stats.Begin{}) // sanity check we don't fail anything + handler.HandleRPC(ctx, &stats.InPayload{Length: 42}) + require.Equal(t, + logrus.Fields{"grpc.request.payload_bytes": int64(42), "grpc.response.payload_bytes": int64(0)}, + FieldsProducer(ctx), + ) + handler.HandleRPC(ctx, &stats.OutPayload{Length: 24}) + require.Equal(t, + logrus.Fields{"grpc.request.payload_bytes": int64(42), "grpc.response.payload_bytes": int64(24)}, + FieldsProducer(ctx), + ) + handler.HandleRPC(ctx, &stats.InPayload{Length: 38}) + require.Equal(t, + logrus.Fields{"grpc.request.payload_bytes": int64(80), "grpc.response.payload_bytes": int64(24)}, + FieldsProducer(ctx), + ) + handler.HandleRPC(ctx, &stats.OutPayload{Length: 66}) + require.Equal(t, + logrus.Fields{"grpc.request.payload_bytes": int64(80), "grpc.response.payload_bytes": int64(90)}, + FieldsProducer(ctx), + ) +} + +func TestPayloadBytesStats_Fields(t *testing.T) { + bytesStats := PayloadBytesStats{InPayloadBytes: 80, OutPayloadBytes: 90} + require.Equal(t, logrus.Fields{ + "grpc.request.payload_bytes": int64(80), + "grpc.response.payload_bytes": int64(90), + }, bytesStats.Fields()) +} + +func TestFieldsProducer(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + t.Run("ok", func(t *testing.T) { + handler := &PayloadBytes{} + ctx := handler.TagRPC(ctx, nil) + handler.HandleRPC(ctx, &stats.InPayload{Length: 42}) + handler.HandleRPC(ctx, &stats.OutPayload{Length: 24}) + require.Equal(t, logrus.Fields{ + "grpc.request.payload_bytes": int64(42), + "grpc.response.payload_bytes": int64(24), + }, FieldsProducer(ctx)) + }) + + t.Run("no data", func(t *testing.T) { + require.Nil(t, FieldsProducer(ctx)) + }) +} diff --git a/internal/grpcstats/testhelper_test.go b/internal/grpcstats/testhelper_test.go new file mode 100644 index 000000000..b41bdc29e --- /dev/null +++ b/internal/grpcstats/testhelper_test.go @@ -0,0 +1,11 @@ +package grpcstats_test + +import ( + "testing" + + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" +) + +func TestMain(m *testing.M) { + testhelper.Run(m) +} diff --git a/internal/log/log.go b/internal/log/log.go index 99381418d..5c4e5f6db 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -1,9 +1,15 @@ package log import ( + "context" "os" + grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/stats" ) const ( @@ -114,3 +120,149 @@ func Default() *logrus.Entry { return defaultLogger.WithField("pid", os.Getpid() // GrpcGo is a dedicated logrus logger for the grpc-go library. We use it // to control the library's chattiness. func GrpcGo() *logrus.Entry { return grpcGo.WithField("pid", os.Getpid()) } + +// FieldsProducer returns fields that need to be added into the logging context. +type FieldsProducer func(context.Context) logrus.Fields + +// MessageProducer returns a wrapper that extends passed mp to accept additional fields generated +// by each of the fieldsProducers. +func MessageProducer(mp grpcmwlogrus.MessageProducer, fieldsProducers ...FieldsProducer) grpcmwlogrus.MessageProducer { + return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { + for _, fieldsProducer := range fieldsProducers { + for key, val := range fieldsProducer(ctx) { + fields[key] = val + } + } + mp(ctx, format, level, code, err, fields) + } +} + +type messageProducerHolder struct { + logger *logrus.Entry + actual grpcmwlogrus.MessageProducer + format string + level logrus.Level + code codes.Code + err error + fields logrus.Fields +} + +type messageProducerHolderKey struct{} + +// messageProducerPropagationFrom extracts *messageProducerHolder from context +// and returns to the caller. +// It returns nil in case it is not found. +func messageProducerPropagationFrom(ctx context.Context) *messageProducerHolder { + mpp, ok := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder) + if !ok { + return nil + } + return mpp +} + +// PropagationMessageProducer catches logging information from the context and populates it +// to the special holder that should be present in the context. +// Should be used only in combination with PerRPCLogHandler. +func PropagationMessageProducer(actual grpcmwlogrus.MessageProducer) grpcmwlogrus.MessageProducer { + return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { + mpp := messageProducerPropagationFrom(ctx) + if mpp == nil { + return + } + *mpp = messageProducerHolder{ + logger: ctxlogrus.Extract(ctx), + actual: actual, + format: format, + level: level, + code: code, + err: err, + fields: fields, + } + } +} + +// PerRPCLogHandler is designed to collect stats that are accessible +// from the google.golang.org/grpc/stats.Handler, because some information +// can't be extracted on the interceptors level. +type PerRPCLogHandler struct { + Underlying stats.Handler + FieldProducers []FieldsProducer +} + +// HandleConn only calls Underlying and exists to satisfy gRPC stats.Handler. +func (lh PerRPCLogHandler) HandleConn(ctx context.Context, cs stats.ConnStats) { + lh.Underlying.HandleConn(ctx, cs) +} + +// TagConn only calls Underlying and exists to satisfy gRPC stats.Handler. +func (lh PerRPCLogHandler) TagConn(ctx context.Context, cti *stats.ConnTagInfo) context.Context { + return lh.Underlying.TagConn(ctx, cti) +} + +// HandleRPC catches each RPC call and for the *stats.End stats invokes +// custom message producers to populate logging data. Once all data is collected +// the actual logging happens by using logger that is caught by PropagationMessageProducer. +func (lh PerRPCLogHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { + lh.Underlying.HandleRPC(ctx, rs) + switch rs.(type) { + case *stats.End: + // This code runs once all interceptors are finished their execution. + // That is why any logging info collected after interceptors completion + // is not at the logger's context. That is why we need to manually propagate + // it to the logger. + mpp := messageProducerPropagationFrom(ctx) + if mpp == nil || (mpp != nil && mpp.actual == nil) { + return + } + + if mpp.fields == nil { + mpp.fields = logrus.Fields{} + } + for _, fp := range lh.FieldProducers { + for k, v := range fp(ctx) { + mpp.fields[k] = v + } + } + // Once again because all interceptors are finished and context doesn't contain + // a logger we need to set logger manually into the context. + // It's needed because github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus.DefaultMessageProducer + // extracts logger from the context and use it to write the logs. + ctx = ctxlogrus.ToContext(ctx, mpp.logger) + mpp.actual(ctx, mpp.format, mpp.level, mpp.code, mpp.err, mpp.fields) + return + } +} + +// TagRPC propagates a special data holder into the context that is responsible to +// hold logging information produced by the logging interceptor. +// The logging data should be caught by the UnaryLogDataCatcherServerInterceptor. It needs to +// be included into the interceptor chain below logging interceptor. +func (lh PerRPCLogHandler) TagRPC(ctx context.Context, rti *stats.RPCTagInfo) context.Context { + ctx = context.WithValue(ctx, messageProducerHolderKey{}, new(messageProducerHolder)) + return lh.Underlying.TagRPC(ctx, rti) +} + +// UnaryLogDataCatcherServerInterceptor catches logging data produced by the upper interceptors and +// propagates it into the holder to pop up it to the HandleRPC method of the PerRPCLogHandler. +func UnaryLogDataCatcherServerInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + mpp := messageProducerPropagationFrom(ctx) + if mpp != nil { + mpp.fields = ctxlogrus.Extract(ctx).Data + } + return handler(ctx, req) + } +} + +// StreamLogDataCatcherServerInterceptor catches logging data produced by the upper interceptors and +// propagates it into the holder to pop up it to the HandleRPC method of the PerRPCLogHandler. +func StreamLogDataCatcherServerInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx := ss.Context() + mpp := messageProducerPropagationFrom(ctx) + if mpp != nil { + mpp.fields = ctxlogrus.Extract(ctx).Data + } + return handler(srv, ss) + } +} diff --git a/internal/log/log_test.go b/internal/log/log_test.go index 5be348027..fcc3d0573 100644 --- a/internal/log/log_test.go +++ b/internal/log/log_test.go @@ -2,14 +2,175 @@ package log import ( "bytes" + "context" + "io" + "net" + "os" + "sync" "testing" "time" + grpcmw "github.com/grpc-ecosystem/go-grpc-middleware" + grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/client" + "gitlab.com/gitlab-org/gitaly/v14/internal/grpcstats" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/protobuf/proto" ) +func TestPayloadBytes(t *testing.T) { + ctx := context.Background() + + logger, hook := test.NewNullLogger() + + opts := []grpc.ServerOption{ + grpc.StatsHandler(PerRPCLogHandler{ + Underlying: &grpcstats.PayloadBytes{}, + FieldProducers: []FieldsProducer{grpcstats.FieldsProducer}, + }), + grpc.UnaryInterceptor( + grpcmw.ChainUnaryServer( + grpcmwlogrus.UnaryServerInterceptor( + logrus.NewEntry(logger), + grpcmwlogrus.WithMessageProducer( + MessageProducer( + PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), + grpcstats.FieldsProducer, + ), + ), + ), + UnaryLogDataCatcherServerInterceptor(), + ), + ), + grpc.StreamInterceptor( + grpcmw.ChainStreamServer( + grpcmwlogrus.StreamServerInterceptor( + logrus.NewEntry(logger), + grpcmwlogrus.WithMessageProducer( + MessageProducer( + PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), + grpcstats.FieldsProducer, + ), + ), + ), + StreamLogDataCatcherServerInterceptor(), + ), + ), + } + + srv := grpc.NewServer(opts...) + grpc_testing.RegisterTestServiceServer(srv, testService{}) + sock, err := os.CreateTemp("", "") + require.NoError(t, err) + require.NoError(t, sock.Close()) + require.NoError(t, os.RemoveAll(sock.Name())) + t.Cleanup(func() { require.NoError(t, os.RemoveAll(sock.Name())) }) + + lis, err := net.Listen("unix", sock.Name()) + require.NoError(t, err) + + t.Cleanup(srv.GracefulStop) + go func() { assert.NoError(t, srv.Serve(lis)) }() + + cc, err := client.DialContext(ctx, "unix://"+sock.Name(), nil) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, cc.Close()) }) + + testClient := grpc_testing.NewTestServiceClient(cc) + const invocations = 2 + var wg sync.WaitGroup + for i := 0; i < invocations; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + resp, err := testClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{Payload: newStubPayload()}) + if !assert.NoError(t, err) { + return + } + testassert.ProtoEqual(t, newStubPayload(), resp.Payload) + + call, err := testClient.HalfDuplexCall(ctx) + if !assert.NoError(t, err) { + return + } + + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, err := call.Recv() + if err == io.EOF { + return + } + assert.NoError(t, err) + } + }() + assert.NoError(t, call.Send(&grpc_testing.StreamingOutputCallRequest{Payload: newStubPayload()})) + assert.NoError(t, call.Send(&grpc_testing.StreamingOutputCallRequest{Payload: newStubPayload()})) + assert.NoError(t, call.CloseSend()) + <-done + }() + } + wg.Wait() + + entries := hook.AllEntries() + require.Len(t, entries, 4) + var unary, stream int + for _, e := range entries { + if e.Message == "finished unary call with code OK" { + unary++ + require.EqualValues(t, 8, e.Data["grpc.request.payload_bytes"]) + require.EqualValues(t, 8, e.Data["grpc.response.payload_bytes"]) + } + if e.Message == "finished streaming call with code OK" { + stream++ + require.EqualValues(t, 16, e.Data["grpc.request.payload_bytes"]) + require.EqualValues(t, 16, e.Data["grpc.response.payload_bytes"]) + } + } + require.Equal(t, invocations, unary) + require.Equal(t, invocations, stream) +} + +func newStubPayload() *grpc_testing.Payload { + return &grpc_testing.Payload{Body: []byte("stub")} +} + +type testService struct { + grpc_testing.UnimplementedTestServiceServer +} + +func (ts testService) UnaryCall(context.Context, *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{Payload: newStubPayload()}, nil +} + +func (ts testService) HalfDuplexCall(stream grpc_testing.TestService_HalfDuplexCallServer) error { + for { + if _, err := stream.Recv(); err != nil { + if err == io.EOF { + break + } + return err + } + } + + resp := &grpc_testing.StreamingOutputCallResponse{Payload: newStubPayload()} + if err := stream.Send(proto.Clone(resp).(*grpc_testing.StreamingOutputCallResponse)); err != nil { + return err + } + return stream.Send(proto.Clone(resp).(*grpc_testing.StreamingOutputCallResponse)) +} + func TestConfigure(t *testing.T) { for _, tc := range []struct { desc string @@ -95,3 +256,185 @@ func TestConfigure(t *testing.T) { }) } } + +func TestMessageProducer(t *testing.T) { + triggered := false + MessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { + require.Equal(t, context.Background(), ctx) + require.Equal(t, "format-stub", format) + require.Equal(t, logrus.DebugLevel, level) + require.Equal(t, codes.OutOfRange, code) + require.Equal(t, assert.AnError, err) + require.Equal(t, logrus.Fields{"a": 1, "b": "test", "c": "stub"}, fields) + triggered = true + }, func(context.Context) logrus.Fields { + return logrus.Fields{"a": 1} + }, func(context.Context) logrus.Fields { + return logrus.Fields{"b": "test"} + })(context.Background(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"}) + require.True(t, triggered) +} + +func TestPropagationMessageProducer(t *testing.T) { + t.Run("empty context", func(t *testing.T) { + ctx := context.Background() + mp := PropagationMessageProducer(func(context.Context, string, logrus.Level, codes.Code, error, logrus.Fields) {}) + mp(ctx, "", logrus.DebugLevel, codes.OK, nil, nil) + }) + + t.Run("context with holder", func(t *testing.T) { + holder := new(messageProducerHolder) + ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, holder) + triggered := false + mp := PropagationMessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { + triggered = true + }) + mp(ctx, "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"a": 1}) + require.Equal(t, "format-stub", holder.format) + require.Equal(t, logrus.DebugLevel, holder.level) + require.Equal(t, codes.OutOfRange, holder.code) + require.Equal(t, assert.AnError, holder.err) + require.Equal(t, logrus.Fields{"a": 1}, holder.fields) + holder.actual(ctx, "", logrus.DebugLevel, codes.OK, nil, nil) + require.True(t, triggered) + }) +} + +func TestPerRPCLogHandler(t *testing.T) { + msh := &mockStatHandler{Calls: map[string][]interface{}{}} + + lh := PerRPCLogHandler{ + Underlying: msh, + FieldProducers: []FieldsProducer{ + func(ctx context.Context) logrus.Fields { return logrus.Fields{"a": 1} }, + func(ctx context.Context) logrus.Fields { return logrus.Fields{"b": "2"} }, + }, + } + + t.Run("check propagation", func(t *testing.T) { + ctx := context.Background() + ctx = lh.TagConn(ctx, &stats.ConnTagInfo{}) + lh.HandleConn(ctx, &stats.ConnBegin{}) + ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{}) + lh.HandleRPC(ctx, &stats.Begin{}) + lh.HandleRPC(ctx, &stats.InHeader{}) + lh.HandleRPC(ctx, &stats.InPayload{}) + lh.HandleRPC(ctx, &stats.OutHeader{}) + lh.HandleRPC(ctx, &stats.OutPayload{}) + lh.HandleRPC(ctx, &stats.End{}) + lh.HandleConn(ctx, &stats.ConnEnd{}) + + assert.Equal(t, map[string][]interface{}{ + "TagConn": {&stats.ConnTagInfo{}}, + "HandleConn": {&stats.ConnBegin{}, &stats.ConnEnd{}}, + "TagRPC": {&stats.RPCTagInfo{}}, + "HandleRPC": {&stats.Begin{}, &stats.InHeader{}, &stats.InPayload{}, &stats.OutHeader{}, &stats.OutPayload{}, &stats.End{}}, + }, msh.Calls) + }) + + t.Run("log handling", func(t *testing.T) { + ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New())) + ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{}) + mpp := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder) + mpp.format = "message" + mpp.level = logrus.InfoLevel + mpp.code = codes.InvalidArgument + mpp.err = assert.AnError + mpp.actual = func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { + assert.Equal(t, "message", format) + assert.Equal(t, logrus.InfoLevel, level) + assert.Equal(t, codes.InvalidArgument, code) + assert.Equal(t, assert.AnError, err) + assert.Equal(t, logrus.Fields{"a": 1, "b": "2"}, mpp.fields) + } + lh.HandleRPC(ctx, &stats.End{}) + }) +} + +type mockStatHandler struct { + Calls map[string][]interface{} +} + +func (m *mockStatHandler) TagRPC(ctx context.Context, s *stats.RPCTagInfo) context.Context { + m.Calls["TagRPC"] = append(m.Calls["TagRPC"], s) + return ctx +} + +func (m *mockStatHandler) HandleRPC(ctx context.Context, s stats.RPCStats) { + m.Calls["HandleRPC"] = append(m.Calls["HandleRPC"], s) +} + +func (m *mockStatHandler) TagConn(ctx context.Context, s *stats.ConnTagInfo) context.Context { + m.Calls["TagConn"] = append(m.Calls["TagConn"], s) + return ctx +} + +func (m *mockStatHandler) HandleConn(ctx context.Context, s stats.ConnStats) { + m.Calls["HandleConn"] = append(m.Calls["HandleConn"], s) +} + +func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { + handlerStub := func(context.Context, interface{}) (interface{}, error) { + return nil, nil + } + + t.Run("propagates call", func(t *testing.T) { + interceptor := UnaryLogDataCatcherServerInterceptor() + resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + return 42, assert.AnError + }) + + assert.Equal(t, 42, resp) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("no logger", func(t *testing.T) { + mpp := &messageProducerHolder{} + ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + + interceptor := UnaryLogDataCatcherServerInterceptor() + _, _ = interceptor(ctx, nil, nil, handlerStub) + assert.Empty(t, mpp.fields) + }) + + t.Run("caught", func(t *testing.T) { + mpp := &messageProducerHolder{} + ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1)) + interceptor := UnaryLogDataCatcherServerInterceptor() + _, _ = interceptor(ctx, nil, nil, handlerStub) + assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields) + }) +} + +func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { + t.Run("propagates call", func(t *testing.T) { + interceptor := StreamLogDataCatcherServerInterceptor() + ss := &grpcmw.WrappedServerStream{WrappedContext: context.Background()} + err := interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { + return assert.AnError + }) + + assert.Equal(t, assert.AnError, err) + }) + + t.Run("no logger", func(t *testing.T) { + mpp := &messageProducerHolder{} + ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + + interceptor := StreamLogDataCatcherServerInterceptor() + ss := &grpcmw.WrappedServerStream{WrappedContext: ctx} + _ = interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return nil }) + }) + + t.Run("caught", func(t *testing.T) { + mpp := &messageProducerHolder{} + ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1)) + + interceptor := StreamLogDataCatcherServerInterceptor() + ss := &grpcmw.WrappedServerStream{WrappedContext: ctx} + _ = interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return nil }) + assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields) + }) +} diff --git a/internal/middleware/commandstatshandler/commandstatshandler.go b/internal/middleware/commandstatshandler/commandstatshandler.go index 34f100a51..50f251b08 100644 --- a/internal/middleware/commandstatshandler/commandstatshandler.go +++ b/internal/middleware/commandstatshandler/commandstatshandler.go @@ -4,11 +4,9 @@ import ( "context" grpcmw "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/v14/internal/command" "google.golang.org/grpc" - "google.golang.org/grpc/codes" ) // UnaryInterceptor returns a Unary Interceptor @@ -33,23 +31,10 @@ func StreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.Str return err } -// CommandStatsMessageProducer hooks into grpc_logrus to add more fields. -// -// It replaces github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus.DefaultMessageProducer. -// -// We cannot use ctxlogrus.AddFields() as it is not concurrency safe, and -// we may be logging concurrently. Conversely, command.Stats.Fields() is -// protected by a lock and can safely be called here. -func CommandStatsMessageProducer(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { - if err != nil { - fields[logrus.ErrorKey] = err - } - entry := ctxlogrus.Extract(ctx).WithContext(ctx).WithFields(fields) - - // safely inject commandstats +// FieldsProducer extracts stats info from the context and returns it as a logging fields. +func FieldsProducer(ctx context.Context) logrus.Fields { if stats := command.StatsFromContext(ctx); stats != nil { - entry = entry.WithFields(stats.Fields()) + return stats.Fields() } - - entry.Logf(level, format) + return nil } diff --git a/internal/middleware/commandstatshandler/commandstatshandler_test.go b/internal/middleware/commandstatshandler/commandstatshandler_test.go index d474744bd..01c1e6479 100644 --- a/internal/middleware/commandstatshandler/commandstatshandler_test.go +++ b/internal/middleware/commandstatshandler/commandstatshandler_test.go @@ -12,6 +12,7 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/command" "gitlab.com/gitlab-org/gitaly/v14/internal/git" "gitlab.com/gitlab-org/gitaly/v14/internal/git/catfile" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" @@ -38,13 +39,13 @@ func createNewServer(t *testing.T, cfg config.Cfg) *grpc.Server { StreamInterceptor, grpcmwlogrus.StreamServerInterceptor(logrusEntry, grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(CommandStatsMessageProducer)), + grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))), )), grpc.UnaryInterceptor(grpcmw.ChainUnaryServer( UnaryInterceptor, grpcmwlogrus.UnaryServerInterceptor(logrusEntry, grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(CommandStatsMessageProducer)), + grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))), )), } @@ -141,3 +142,14 @@ func TestInterceptor(t *testing.T) { }) } } + +func TestFieldsProducer(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + ctx = command.InitContextStats(ctx) + stats := command.StatsFromContext(ctx) + stats.RecordMax("stub", 42) + + require.Equal(t, logrus.Fields{"stub": 42}, FieldsProducer(ctx)) +} |