diff options
author | John Cai <jcai@gitlab.com> | 2020-05-31 09:15:10 +0300 |
---|---|---|
committer | John Cai <jcai@gitlab.com> | 2020-06-22 19:53:14 +0300 |
commit | bf62340c329cf0c8e1595f9fa620af38ee4b1f3b (patch) | |
tree | a9237309832e84707cafdf10854f21464645c888 | |
parent | 543d00cffaaf69b66c18e1f992b60d8f3671ed7f (diff) |
Adding error handler middleware to keep track of read/write errors
-rw-r--r-- | internal/praefect/config/config.go | 9 | ||||
-rw-r--r-- | internal/praefect/config/config_test.go | 9 | ||||
-rw-r--r-- | internal/praefect/config/testdata/config.toml | 5 | ||||
-rw-r--r-- | internal/praefect/middleware/errorhandler.go | 72 | ||||
-rw-r--r-- | internal/praefect/middleware/errorhandler_test.go | 147 | ||||
-rw-r--r-- | internal/praefect/nodes/errors.go | 186 | ||||
-rw-r--r-- | internal/praefect/nodes/errors_test.go | 134 |
7 files changed, 556 insertions, 6 deletions
diff --git a/internal/praefect/config/config.go b/internal/praefect/config/config.go index e25f909ff..dc9024d05 100644 --- a/internal/praefect/config/config.go +++ b/internal/praefect/config/config.go @@ -15,9 +15,12 @@ import ( ) type Failover struct { - Enabled bool `toml:"enabled"` - ElectionStrategy string `toml:"election_strategy"` - ReadOnlyAfterFailover bool `toml:"read_only_after_failover"` + Enabled bool `toml:"enabled"` + ElectionStrategy string `toml:"election_strategy"` + ReadOnlyAfterFailover bool `toml:"read_only_after_failover"` + ErrorThresholdWindow config.Duration `toml:"error_threshold_window"` + WriteErrorThresholdCount uint32 `toml:"write_error_threshold_count"` + ReadErrorThresholdCount uint32 `toml:"read_error_threshold_count"` } const sqlFailoverValue = "sql" diff --git a/internal/praefect/config/config_test.go b/internal/praefect/config/config_test.go index 982565229..1944727f0 100644 --- a/internal/praefect/config/config_test.go +++ b/internal/praefect/config/config_test.go @@ -266,9 +266,12 @@ func TestConfigParsing(t *testing.T) { MemoryQueueEnabled: true, GracefulStopTimeout: config.Duration(30 * time.Second), Failover: Failover{ - Enabled: true, - ElectionStrategy: sqlFailoverValue, - ReadOnlyAfterFailover: true, + Enabled: true, + ElectionStrategy: sqlFailoverValue, + ReadOnlyAfterFailover: true, + ErrorThresholdWindow: config.Duration(20 * time.Second), + WriteErrorThresholdCount: 1500, + ReadErrorThresholdCount: 100, }, }, }, diff --git a/internal/praefect/config/testdata/config.toml b/internal/praefect/config/testdata/config.toml index 0628d9034..32b4c94db 100644 --- a/internal/praefect/config/testdata/config.toml +++ b/internal/praefect/config/testdata/config.toml @@ -41,3 +41,8 @@ sslmode = "require" sslcert = "/path/to/cert" sslkey = "/path/to/key" sslrootcert = "/path/to/root-cert" + +[failover] +error_threshold_window = "20s" +write_error_threshold_count = 1500 +read_error_threshold_count = 100 diff --git a/internal/praefect/middleware/errorhandler.go b/internal/praefect/middleware/errorhandler.go new file mode 100644 index 000000000..c6d0e36ca --- /dev/null +++ b/internal/praefect/middleware/errorhandler.go @@ -0,0 +1,72 @@ +package middleware + +import ( + "context" + "fmt" + "io" + + "gitlab.com/gitlab-org/gitaly/internal/praefect/nodes" + "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry" + "google.golang.org/grpc" +) + +// StreamErrorHandler returns a client interceptor that will track accessor/mutator errors from internal gitaly nodes +func StreamErrorHandler(registry *protoregistry.Registry, errorTracker nodes.ErrorTracker, nodeStorage string) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + stream, err := streamer(ctx, desc, cc, method, opts...) + + mi, lookupErr := registry.LookupMethod(method) + if err != nil { + return nil, fmt.Errorf("error when looking up method: %w %v", err, lookupErr) + } + + return newCatchErrorStreamer(stream, errorTracker, mi.Operation, nodeStorage), err + } +} + +// catchErrorSteamer is a custom ClientStream that adheres to grpc.ClientStream but keeps track of accessor/mutator errors +type catchErrorStreamer struct { + grpc.ClientStream + errors nodes.ErrorTracker + operation protoregistry.OpType + nodeStorage string +} + +func newCatchErrorStreamer(streamer grpc.ClientStream, errors nodes.ErrorTracker, operation protoregistry.OpType, nodeStorage string) *catchErrorStreamer { + return &catchErrorStreamer{ + ClientStream: streamer, + errors: errors, + operation: operation, + nodeStorage: nodeStorage, + } +} + +// SendMsg proxies the send but records any errors +func (c *catchErrorStreamer) SendMsg(m interface{}) error { + err := c.ClientStream.SendMsg(m) + if err != nil { + switch c.operation { + case protoregistry.OpAccessor: + c.errors.IncrReadErr(c.nodeStorage) + case protoregistry.OpMutator: + c.errors.IncrWriteErr(c.nodeStorage) + } + } + + return err +} + +// RecvMsg proxies the send but records any errors +func (c *catchErrorStreamer) RecvMsg(m interface{}) error { + err := c.ClientStream.RecvMsg(m) + if err != nil && err != io.EOF { + switch c.operation { + case protoregistry.OpAccessor: + c.errors.IncrReadErr(c.nodeStorage) + case protoregistry.OpMutator: + c.errors.IncrWriteErr(c.nodeStorage) + } + } + + return err +} diff --git a/internal/praefect/middleware/errorhandler_test.go b/internal/praefect/middleware/errorhandler_test.go new file mode 100644 index 000000000..0ee85d548 --- /dev/null +++ b/internal/praefect/middleware/errorhandler_test.go @@ -0,0 +1,147 @@ +package middleware + +import ( + "context" + "net" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/empty" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/internal/helper" + "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/proxy" + "gitlab.com/gitlab-org/gitaly/internal/praefect/mock" + "gitlab.com/gitlab-org/gitaly/internal/praefect/nodes" + "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" + "google.golang.org/grpc" +) + +type simpleService struct{} + +func (s *simpleService) RepoAccessorUnary(ctx context.Context, in *mock.RepoRequest) (*empty.Empty, error) { + if in.GetRepo() == nil { + return nil, helper.ErrInternalf("error") + } + + return &empty.Empty{}, nil +} + +func (s *simpleService) RepoMutatorUnary(ctx context.Context, in *mock.RepoRequest) (*empty.Empty, error) { + if in.GetRepo() == nil { + return nil, helper.ErrInternalf("error") + } + + return &empty.Empty{}, nil +} + +func (s *simpleService) ServerAccessor(ctx context.Context, in *mock.SimpleRequest) (*mock.SimpleResponse, error) { + return &mock.SimpleResponse{}, nil +} + +func TestStreamInterceptor(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + window := 1 * time.Second + threshold := 5 + errTracker, err := nodes.NewErrors(ctx, window, uint32(threshold), uint32(threshold)) + require.NoError(t, err) + nodeName := "node-1" + + internalSrv := grpc.NewServer() + + internalServerSocketPath := testhelper.GetTemporaryGitalySocketFileName() + lis, err := net.Listen("unix", internalServerSocketPath) + + gz := proto.FileDescriptor("mock.proto") + fd, err := protoregistry.ExtractFileDescriptor(gz) + require.NoError(t, err) + + registry, err := protoregistry.New(fd) + require.NoError(t, err) + + require.NoError(t, err) + mock.RegisterSimpleServiceServer(internalSrv, &simpleService{}) + + go internalSrv.Serve(lis) + defer internalSrv.Stop() + + srvOptions := []grpc.ServerOption{ + grpc.CustomCodec(proxy.NewCodec()), + grpc.UnknownServiceHandler(proxy.TransparentHandler(func(ctx context.Context, + fullMethodName string, + peeker proxy.StreamModifier, + ) (*proxy.StreamParameters, error) { + cc, err := grpc.Dial("unix://"+internalServerSocketPath, + grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())), + grpc.WithInsecure(), + grpc.WithStreamInterceptor(StreamErrorHandler(registry, errTracker, nodeName)), + ) + require.NoError(t, err) + return proxy.NewStreamParameters(ctx, cc, func() {}, nil), nil + })), + } + + praefectSocket := testhelper.GetTemporaryGitalySocketFileName() + praefectLis, err := net.Listen("unix", praefectSocket) + require.NoError(t, err) + + praefectSrv := grpc.NewServer(srvOptions...) + defer praefectSrv.Stop() + go praefectSrv.Serve(praefectLis) + + praefectCC, err := grpc.Dial("unix://"+praefectSocket, grpc.WithInsecure()) + require.NoError(t, err) + + simpleClient := mock.NewSimpleServiceClient(praefectCC) + + testRepo, _, cleanup := testhelper.NewTestRepo(t) + defer cleanup() + + for i := 0; i < threshold; i++ { + _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ + Repo: testRepo, + }) + require.NoError(t, err) + _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ + Repo: testRepo, + }) + require.NoError(t, err) + } + + assert.False(t, errTracker.WriteThresholdReached(nodeName)) + assert.False(t, errTracker.ReadThresholdReached(nodeName)) + + for i := 0; i < threshold; i++ { + _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ + Repo: nil, + }) + require.Error(t, err) + _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ + Repo: nil, + }) + require.Error(t, err) + } + + assert.True(t, errTracker.WriteThresholdReached(nodeName)) + assert.True(t, errTracker.ReadThresholdReached(nodeName)) + + time.Sleep(window) + + for i := 0; i < threshold; i++ { + _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ + Repo: testRepo, + }) + require.NoError(t, err) + _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ + Repo: testRepo, + }) + require.NoError(t, err) + } + + assert.False(t, errTracker.WriteThresholdReached(nodeName)) + assert.False(t, errTracker.ReadThresholdReached(nodeName)) +} diff --git a/internal/praefect/nodes/errors.go b/internal/praefect/nodes/errors.go new file mode 100644 index 000000000..2faaeb3e0 --- /dev/null +++ b/internal/praefect/nodes/errors.go @@ -0,0 +1,186 @@ +package nodes + +import ( + "context" + "errors" + "sync" + "time" +) + +// ErrorTracker allows tracking how many read/write errors have occurred, and whether or not it has +// exceeded a configured threshold in a configured time window +type ErrorTracker interface { + // IncrReadErr increases read errors by 1 + IncrReadErr(nodeStorage string) + // IncrWriteErr increases write errors by 1 + IncrWriteErr(nodeStorage string) + // ReadThresholdReached returns whether or not the read threshold was reached + ReadThresholdReached(nodeStorage string) bool + // WriteThresholdReached returns whether or not the read threshold was reached + WriteThresholdReached(nodeStorage string) bool +} + +type errorTracker struct { + olderThan func() time.Time + m sync.RWMutex + writeThreshold, readThreshold int + readErrors, writeErrors map[string][]time.Time + ctx context.Context +} + +func newErrors(ctx context.Context, errorWindow time.Duration, readThreshold, writeThreshold uint32) (*errorTracker, error) { + if errorWindow == 0 { + return nil, errors.New("errorWindow must be non zero") + } + + if readThreshold == 0 { + return nil, errors.New("readThreshold must be non zero") + } + + if writeThreshold == 0 { + return nil, errors.New("writeThreshold must be non zero") + } + + e := &errorTracker{ + olderThan: func() time.Time { + return time.Now().Add(-errorWindow) + }, + readErrors: make(map[string][]time.Time), + writeErrors: make(map[string][]time.Time), + readThreshold: int(readThreshold), + writeThreshold: int(writeThreshold), + ctx: ctx, + } + go e.periodicallyClear() + + return e, nil +} + +// NewErrors creates a new Error instance given a time window in seconds, and read and write thresholds +func NewErrors(ctx context.Context, errorWindow time.Duration, readThreshold, writeThreshold uint32) (ErrorTracker, error) { + return newErrors(ctx, errorWindow, readThreshold, writeThreshold) +} + +// IncrReadErr increases the read errors for a node by 1 +func (e *errorTracker) IncrReadErr(node string) { + select { + case <-e.ctx.Done(): + return + default: + e.m.Lock() + defer e.m.Unlock() + + e.readErrors[node] = append(e.readErrors[node], time.Now()) + + if len(e.readErrors[node]) > e.readThreshold { + e.readErrors[node] = e.readErrors[node][1:] + } + } +} + +// IncrWriteErr increases the read errors for a node by 1 +func (e *errorTracker) IncrWriteErr(node string) { + select { + case <-e.ctx.Done(): + return + default: + e.m.Lock() + defer e.m.Unlock() + + e.writeErrors[node] = append(e.writeErrors[node], time.Now()) + + if len(e.writeErrors[node]) > e.writeThreshold { + e.writeErrors[node] = e.writeErrors[node][1:] + } + } +} + +// ReadThresholdReached indicates whether or not the read threshold has been reached within the time window +func (e *errorTracker) ReadThresholdReached(node string) bool { + select { + case <-e.ctx.Done(): + break + default: + e.m.RLock() + defer e.m.RUnlock() + + olderThanTime := e.olderThan() + + for i, errTime := range e.readErrors[node] { + if errTime.After(olderThanTime) { + if i == 0 { + return len(e.readErrors[node]) >= e.readThreshold + } + return len(e.readErrors[node][i-1:]) >= e.readThreshold + } + } + } + + return false +} + +// WriteThresholdReached indicates whether or not the write threshold has been reached within the time window +func (e *errorTracker) WriteThresholdReached(node string) bool { + select { + case <-e.ctx.Done(): + break + default: + e.m.RLock() + defer e.m.RUnlock() + + olderThanTime := e.olderThan() + + for i, errTime := range e.writeErrors[node] { + if errTime.After(olderThanTime) { + if i == 0 { + return len(e.writeErrors[node]) >= e.writeThreshold + } + return len(e.writeErrors[node][i-1:]) >= e.writeThreshold + } + } + } + + return false +} + +// periodicallyClear runs in an infinite loop clearing out old error entries +func (e *errorTracker) periodicallyClear() { + ticker := time.NewTicker(1 * time.Second) + for { + select { + case <-ticker.C: + e.clear() + case <-e.ctx.Done(): + e.m.Lock() + defer e.m.Unlock() + e.readErrors = nil + e.writeErrors = nil + return + } + } +} + +func (e *errorTracker) clear() { + e.m.Lock() + defer e.m.Unlock() + + olderThanTime := e.olderThan() + + clearErrors(e.writeErrors, olderThanTime) + clearErrors(e.readErrors, olderThanTime) +} + +func clearErrors(errs map[string][]time.Time, olderThan time.Time) { + for node, errors := range errs { + for i, errTime := range errors { + if errTime.After(olderThan) { + errs[node] = errors[i:] + break + } + + if i == len(errors)-1 { + errs[node] = errs[node][:0] + } + } + } +} diff --git a/internal/praefect/nodes/errors_test.go b/internal/praefect/nodes/errors_test.go new file mode 100644 index 000000000..f5b42e06a --- /dev/null +++ b/internal/praefect/nodes/errors_test.go @@ -0,0 +1,134 @@ +package nodes + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" +) + +func TestErrorTracker_IncrErrors(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + writeThreshold, readThreshold := 10, 10 + + errors, err := newErrors(ctx, time.Second, uint32(readThreshold), uint32(writeThreshold)) + require.NoError(t, err) + + node := "backend-node-1" + + assert.False(t, errors.WriteThresholdReached(node)) + assert.False(t, errors.ReadThresholdReached(node)) + + for i := 0; i < writeThreshold; i++ { + errors.IncrWriteErr(node) + } + + assert.True(t, errors.WriteThresholdReached(node)) + + for i := 0; i < readThreshold; i++ { + errors.IncrReadErr(node) + } + + assert.True(t, errors.ReadThresholdReached(node)) + + // use negative value for window so we are ensured to clear all of the errors in the queue + errors, err = newErrors(ctx, -time.Second, uint32(readThreshold), uint32(writeThreshold)) + require.NoError(t, err) + + errors.clear() + + assert.False(t, errors.WriteThresholdReached(node)) + assert.False(t, errors.ReadThresholdReached(node)) +} + +func TestErrorTracker_Concurrency(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + readAndWriteThreshold := 10 + errors, err := newErrors(ctx, 1*time.Second, uint32(readAndWriteThreshold), uint32(readAndWriteThreshold)) + require.NoError(t, err) + + node := "backend-node-1" + + assert.False(t, errors.WriteThresholdReached(node)) + assert.False(t, errors.ReadThresholdReached(node)) + + var g sync.WaitGroup + for i := 0; i < readAndWriteThreshold; i++ { + g.Add(1) + go func() { + errors.IncrWriteErr(node) + errors.IncrReadErr(node) + errors.ReadThresholdReached(node) + errors.WriteThresholdReached(node) + + g.Done() + }() + } + + g.Wait() +} + +func TestErrorTracker_ClearErrors(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + writeThreshold, readThreshold := 10, 10 + errors, err := newErrors(ctx, time.Second, uint32(readThreshold), uint32(writeThreshold)) + require.NoError(t, err) + + node := "backend-node-1" + + errors.IncrWriteErr(node) + errors.IncrReadErr(node) + + clearBeforeNow := time.Now() + + errors.olderThan = func() time.Time { + return clearBeforeNow + } + + errors.IncrWriteErr(node) + errors.IncrReadErr(node) + + errors.clear() + assert.Len(t, errors.readErrors[node], 1, "clear should only have cleared the read error older than the time specifiied") + assert.Len(t, errors.writeErrors[node], 1, "clear should only have cleared the write error older than the time specifiied") +} + +func TestErrorTracker_Expired(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + threshold := 10 + errors, err := newErrors(ctx, 10*time.Second, uint32(threshold), uint32(threshold)) + require.NoError(t, err) + + node := "node" + for i := 0; i < threshold; i++ { + errors.IncrWriteErr(node) + errors.IncrReadErr(node) + } + + assert.True(t, errors.ReadThresholdReached(node)) + assert.True(t, errors.WriteThresholdReached(node)) + + cancel() + + assert.False(t, errors.ReadThresholdReached(node)) + assert.False(t, errors.WriteThresholdReached(node)) + + for i := 0; i < threshold; i++ { + errors.IncrWriteErr(node) + errors.IncrReadErr(node) + } + + assert.False(t, errors.ReadThresholdReached(node)) + assert.False(t, errors.WriteThresholdReached(node)) +} |