diff options
author | John Cai <jcai@gitlab.com> | 2022-04-07 00:26:24 +0300 |
---|---|---|
committer | John Cai <jcai@gitlab.com> | 2022-04-07 00:26:24 +0300 |
commit | 778bdf56489619646d20600a217d464257b4fe1e (patch) | |
tree | f712f1b38feb20366e9dfdd8207ac39f9a700289 | |
parent | 3c00908c2743ef03a2d5b7057bf8f87d5b267a75 (diff) | |
parent | 57db9d3f3c2945dfbe3af16392b2568a0081240a (diff) |
Merge branch 'jc-rate-limiter' into 'master'
Add RateLimiting
Closes #4026
See merge request gitlab-org/gitaly!4427
20 files changed, 587 insertions, 176 deletions
@@ -16405,6 +16405,36 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +LICENSE - golang.org/x/time/rate +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LICENSE - golang.org/x/xerrors Copyright (c) 2019 The Go Authors. All rights reserved. diff --git a/cmd/gitaly-ssh/auth_test.go b/cmd/gitaly-ssh/auth_test.go index cc3b9fd29..070006ce5 100644 --- a/cmd/gitaly-ssh/auth_test.go +++ b/cmd/gitaly-ssh/auth_test.go @@ -150,9 +150,9 @@ func runServer(t *testing.T, secure bool, cfg config.Cfg, connectionType string, hookManager := hook.NewManager(cfg, locator, gitCmdFactory, txManager, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, )) - limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo) + limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters) diskCache := cache.New(cfg, locator) - srv, err := server.New(secure, cfg, testhelper.NewDiscardingLogEntry(t), registry, diskCache, limitHandler) + srv, err := server.New(secure, cfg, testhelper.NewDiscardingLogEntry(t), registry, diskCache, []*limithandler.LimiterMiddleware{limitHandler}) require.NoError(t, err) setup.RegisterAll(srv, &service.Dependencies{ Cfg: cfg, diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go index 4bd8e234b..c7386180d 100644 --- a/cmd/gitaly/main.go +++ b/cmd/gitaly/main.go @@ -216,10 +216,26 @@ func run(cfg config.Cfg) error { return fmt.Errorf("disk cache walkers: %w", err) } - limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo) - prometheus.MustRegister(limitHandler) + concurrencyLimitHandler := limithandler.New( + cfg, + limithandler.LimitConcurrencyByRepo, + limithandler.WithConcurrencyLimiters, + ) + + rateLimitHandler := limithandler.New( + cfg, + limithandler.LimitConcurrencyByRepo, + limithandler.WithRateLimiters(ctx), + ) + prometheus.MustRegister(concurrencyLimitHandler, rateLimitHandler) - gitalyServerFactory := server.NewGitalyServerFactory(cfg, glog.Default(), registry, diskCache, limitHandler) + gitalyServerFactory := server.NewGitalyServerFactory( + cfg, + glog.Default(), + registry, + diskCache, + []*limithandler.LimiterMiddleware{concurrencyLimitHandler, rateLimitHandler}, + ) defer gitalyServerFactory.Stop() ling, err := linguist.New(cfg, gitCmdFactory) @@ -45,6 +45,7 @@ require ( golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20211102192858-4dd72447c267 + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 google.golang.org/grpc v1.38.0 google.golang.org/protobuf v1.26.0 ) @@ -1209,8 +1209,9 @@ golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/gitaly/config/config.go b/internal/gitaly/config/config.go index 3f6757db3..49b6cc509 100644 --- a/internal/gitaly/config/config.go +++ b/internal/gitaly/config/config.go @@ -60,6 +60,7 @@ type Cfg struct { GitlabShell GitlabShell `toml:"gitlab-shell"` Hooks Hooks `toml:"hooks"` Concurrency []Concurrency `toml:"concurrency"` + RateLimiting []RateLimiting `toml:"rate_limiting"` GracefulRestartTimeout Duration `toml:"graceful_restart_timeout"` InternalSocketDir string `toml:"internal_socket_dir"` DailyMaintenance DailyJob `toml:"daily_maintenance"` @@ -149,6 +150,22 @@ type Concurrency struct { MaxQueueWait Duration `toml:"max_queue_wait"` } +// RateLimiting allows endpoints to be limited to a maximum request rate per +// second. The rate limiter uses a concept of a "token bucket". In order to serve a +// request, a token is retrieved from the token bucket. The size of the token +// bucket is configured through the Burst value, while the rate at which the +// token bucket is refilled per second is configured through the RequestsPerSecond +// value. +type RateLimiting struct { + // RPC is the full name of the RPC including the service name + RPC string `toml:"rpc"` + // Interval sets the interval with which the token bucket will + // be refilled to what is configured in Burst. + Interval time.Duration `toml:"interval"` + // Burst sets the capacity of the token bucket (see above). + Burst int `toml:"burst"` +} + // StreamCacheConfig contains settings for a streamcache instance. type StreamCacheConfig struct { Enabled bool `toml:"enabled"` // Default: false diff --git a/internal/gitaly/server/auth_test.go b/internal/gitaly/server/auth_test.go index 82be922fd..78107cbe7 100644 --- a/internal/gitaly/server/auth_test.go +++ b/internal/gitaly/server/auth_test.go @@ -200,10 +200,10 @@ func runServer(t *testing.T, cfg config.Cfg) string { catfileCache := catfile.NewCache(cfg) t.Cleanup(catfileCache.Stop) diskCache := cache.New(cfg, locator) - limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo) + limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters) updaterWithHooks := updateref.NewUpdaterWithHooks(cfg, locator, hookManager, gitCmdFactory, catfileCache) - srv, err := New(false, cfg, testhelper.NewDiscardingLogEntry(t), registry, diskCache, limitHandler) + srv, err := New(false, cfg, testhelper.NewDiscardingLogEntry(t), registry, diskCache, []*limithandler.LimiterMiddleware{limitHandler}) require.NoError(t, err) setup.RegisterAll(srv, &service.Dependencies{ @@ -244,7 +244,7 @@ func runSecureServer(t *testing.T, cfg config.Cfg) string { testhelper.NewDiscardingLogEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) require.NoError(t, err) diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index 2304c0b8f..ab00acd6e 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -59,7 +59,7 @@ func New( logrusEntry *log.Entry, registry *backchannel.Registry, cacheInvalidator diskcache.Invalidator, - limitHandler *limithandler.LimiterMiddleware, + limitHandlers []*limithandler.LimiterMiddleware, ) (*grpc.Server, error) { ctxTagOpts := []grpcmwtags.Option{ grpcmwtags.WithFieldExtractorForInitialReq(fieldextractors.FieldExtractor), @@ -95,56 +95,67 @@ func New( ), ) + streamServerInterceptors := []grpc.StreamServerInterceptor{ + grpcmwtags.StreamServerInterceptor(ctxTagOpts...), + grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler + metadatahandler.StreamInterceptor, + grpcprometheus.StreamServerInterceptor, + commandstatshandler.StreamInterceptor, + grpcmwlogrus.StreamServerInterceptor(logrusEntry, + grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), + logMsgProducer, + gitalylog.DeciderOption(), + ), + gitalylog.StreamLogDataCatcherServerInterceptor(), + sentryhandler.StreamLogHandler, + cancelhandler.Stream, // Should be below LogHandler + auth.StreamServerInterceptor(cfg.Auth), + } + unaryServerInterceptors := []grpc.UnaryServerInterceptor{ + grpcmwtags.UnaryServerInterceptor(ctxTagOpts...), + grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler + metadatahandler.UnaryInterceptor, + grpcprometheus.UnaryServerInterceptor, + commandstatshandler.UnaryInterceptor, + grpcmwlogrus.UnaryServerInterceptor(logrusEntry, + grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), + logMsgProducer, + gitalylog.DeciderOption(), + ), + gitalylog.UnaryLogDataCatcherServerInterceptor(), + sentryhandler.UnaryLogHandler, + cancelhandler.Unary, // Should be below LogHandler + auth.UnaryServerInterceptor(cfg.Auth), + } + // Should be below auth handler to prevent v2 hmac tokens from timing out while queued + for _, limitHandler := range limitHandlers { + streamServerInterceptors = append(streamServerInterceptors, limitHandler.StreamInterceptor()) + unaryServerInterceptors = append(unaryServerInterceptors, limitHandler.UnaryInterceptor()) + } + + streamServerInterceptors = append(streamServerInterceptors, + grpctracing.StreamServerTracingInterceptor(), + cache.StreamInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered), + // Panic handler should remain last so that application panics will be + // converted to errors and logged + panichandler.StreamPanicHandler, + ) + + unaryServerInterceptors = append(unaryServerInterceptors, + cache.UnaryInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered), + // Panic handler should remain last so that application panics will be + // converted to errors and logged + panichandler.UnaryPanicHandler, + ) + opts := []grpc.ServerOption{ grpc.StatsHandler(gitalylog.PerRPCLogHandler{ Underlying: &grpcstats.PayloadBytes{}, FieldProducers: []gitalylog.FieldsProducer{grpcstats.FieldsProducer}, }), grpc.Creds(lm), - grpc.StreamInterceptor(grpcmw.ChainStreamServer( - grpcmwtags.StreamServerInterceptor(ctxTagOpts...), - grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler - metadatahandler.StreamInterceptor, - grpcprometheus.StreamServerInterceptor, - commandstatshandler.StreamInterceptor, - grpcmwlogrus.StreamServerInterceptor(logrusEntry, - grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - logMsgProducer, - gitalylog.DeciderOption(), - ), - gitalylog.StreamLogDataCatcherServerInterceptor(), - sentryhandler.StreamLogHandler, - cancelhandler.Stream, // Should be below LogHandler - auth.StreamServerInterceptor(cfg.Auth), - limitHandler.StreamInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued - grpctracing.StreamServerTracingInterceptor(), - cache.StreamInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered), - // Panic handler should remain last so that application panics will be - // converted to errors and logged - panichandler.StreamPanicHandler, - )), - grpc.UnaryInterceptor(grpcmw.ChainUnaryServer( - grpcmwtags.UnaryServerInterceptor(ctxTagOpts...), - grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler - metadatahandler.UnaryInterceptor, - grpcprometheus.UnaryServerInterceptor, - commandstatshandler.UnaryInterceptor, - grpcmwlogrus.UnaryServerInterceptor(logrusEntry, - grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - logMsgProducer, - gitalylog.DeciderOption(), - ), - gitalylog.UnaryLogDataCatcherServerInterceptor(), - sentryhandler.UnaryLogHandler, - cancelhandler.Unary, // Should be below LogHandler - auth.UnaryServerInterceptor(cfg.Auth), - limitHandler.UnaryInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued - grpctracing.UnaryServerTracingInterceptor(), - cache.UnaryInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered), - // Panic handler should remain last so that application panics will be - // converted to errors and logged - panichandler.UnaryPanicHandler, - )), + grpc.StreamInterceptor(grpcmw.ChainStreamServer(streamServerInterceptors...)), + grpc.UnaryInterceptor(grpcmw.ChainUnaryServer(unaryServerInterceptors...)), grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ MinTime: 20 * time.Second, PermitWithoutStream: true, diff --git a/internal/gitaly/server/server_factory.go b/internal/gitaly/server/server_factory.go index defea0c3b..6b33557be 100644 --- a/internal/gitaly/server/server_factory.go +++ b/internal/gitaly/server/server_factory.go @@ -15,7 +15,7 @@ import ( type GitalyServerFactory struct { registry *backchannel.Registry cacheInvalidator cache.Invalidator - limitHandler *limithandler.LimiterMiddleware + limitHandlers []*limithandler.LimiterMiddleware cfg config.Cfg logger *logrus.Entry externalServers []*grpc.Server @@ -29,14 +29,14 @@ func NewGitalyServerFactory( logger *logrus.Entry, registry *backchannel.Registry, cacheInvalidator cache.Invalidator, - limitHandler *limithandler.LimiterMiddleware, + limitHandlers []*limithandler.LimiterMiddleware, ) *GitalyServerFactory { return &GitalyServerFactory{ cfg: cfg, logger: logger, registry: registry, cacheInvalidator: cacheInvalidator, - limitHandler: limitHandler, + limitHandlers: limitHandlers, } } @@ -78,7 +78,14 @@ func (s *GitalyServerFactory) GracefulStop() { // CreateExternal creates a new external gRPC server. The external servers are closed // before the internal servers when gracefully shutting down. func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, error) { - server, err := New(secure, s.cfg, s.logger, s.registry, s.cacheInvalidator, s.limitHandler) + server, err := New( + secure, + s.cfg, + s.logger, + s.registry, + s.cacheInvalidator, + s.limitHandlers, + ) if err != nil { return nil, err } @@ -90,7 +97,13 @@ func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, error) // CreateInternal creates a new internal gRPC server. Internal servers are closed // after the external ones when gracefully shutting down. func (s *GitalyServerFactory) CreateInternal() (*grpc.Server, error) { - server, err := New(false, s.cfg, s.logger, s.registry, s.cacheInvalidator, s.limitHandler) + server, err := New( + false, + s.cfg, + s.logger, + s.registry, + s.cacheInvalidator, + s.limitHandlers) if err != nil { return nil, err } diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go index c2afcd7c6..a28827248 100644 --- a/internal/gitaly/server/server_factory_test.go +++ b/internal/gitaly/server/server_factory_test.go @@ -93,7 +93,7 @@ func TestGitalyServerFactory(t *testing.T) { testhelper.NewDiscardingLogEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) checkHealth(t, sf, starter.TCP, "localhost:0") @@ -112,7 +112,7 @@ func TestGitalyServerFactory(t *testing.T) { testhelper.NewDiscardingLogEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) t.Cleanup(sf.Stop) @@ -126,7 +126,7 @@ func TestGitalyServerFactory(t *testing.T) { testhelper.NewDiscardingLogEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) t.Cleanup(sf.Stop) @@ -156,7 +156,7 @@ func TestGitalyServerFactory(t *testing.T) { logger.WithContext(ctx), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) checkHealth(t, sf, starter.TCP, "localhost:0") @@ -190,7 +190,7 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { testhelper.NewDiscardingLogEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), - limithandler.New(cfg, limithandler.LimitConcurrencyByRepo), + []*limithandler.LimiterMiddleware{limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters)}, ) defer sf.Stop() diff --git a/internal/gitaly/service/repository/create_fork_test.go b/internal/gitaly/service/repository/create_fork_test.go index b8ea85654..22aa1c0da 100644 --- a/internal/gitaly/service/repository/create_fork_test.go +++ b/internal/gitaly/service/repository/create_fork_test.go @@ -259,8 +259,8 @@ func runSecureServer(t *testing.T, cfg config.Cfg, rubySrv *rubyserver.Server) s registry := backchannel.NewRegistry() locator := config.NewLocator(cfg) cache := cache.New(cfg, locator) - limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo) - server, err := gserver.New(true, cfg, testhelper.NewDiscardingLogEntry(t), registry, cache, limitHandler) + limitHandler := limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters) + server, err := gserver.New(true, cfg, testhelper.NewDiscardingLogEntry(t), registry, cache, []*limithandler.LimiterMiddleware{limitHandler}) require.NoError(t, err) listener, addr := testhelper.GetLocalhostListener(t) diff --git a/internal/metadata/featureflag/ff_rate_limiter.go b/internal/metadata/featureflag/ff_rate_limiter.go new file mode 100644 index 000000000..7207e4b37 --- /dev/null +++ b/internal/metadata/featureflag/ff_rate_limiter.go @@ -0,0 +1,5 @@ +package featureflag + +// RateLimit will enable the rate limiter to reject requests beyond a configured +// rate. +var RateLimit = NewFeatureFlag("rate_limit", false) diff --git a/internal/middleware/limithandler/concurrency_limiter.go b/internal/middleware/limithandler/concurrency_limiter.go index 9dbf87b75..f6587554a 100644 --- a/internal/middleware/limithandler/concurrency_limiter.go +++ b/internal/middleware/limithandler/concurrency_limiter.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v14/internal/helper" "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag" ) @@ -18,9 +20,6 @@ var ErrMaxQueueTime = errors.New("maximum time in concurrency queue reached") // ErrMaxQueueSize indicates the concurrency queue has reached its maximum size var ErrMaxQueueSize = errors.New("maximum queue size reached") -// LimitedFunc represents a function that will be limited -type LimitedFunc func() (resp interface{}, err error) - // QueueTickerCreator is a function that provides a ticker type QueueTickerCreator func() helper.Ticker @@ -176,8 +175,8 @@ func (c *ConcurrencyLimiter) Limit(ctx context.Context, lockKey string, f Limite return f() } -// NewLimiter creates a new rate limiter -func NewLimiter(perKeyLimit, globalLimit int, maxWaitTickerGetter QueueTickerCreator, monitor ConcurrencyMonitor) *ConcurrencyLimiter { +// NewConcurrencyLimiter creates a new concurrency rate limiter +func NewConcurrencyLimiter(perKeyLimit, globalLimit int, maxWaitTickerGetter QueueTickerCreator, monitor ConcurrencyMonitor) *ConcurrencyLimiter { if monitor == nil { monitor = &nullConcurrencyMonitor{} } @@ -190,3 +189,80 @@ func NewLimiter(perKeyLimit, globalLimit int, maxWaitTickerGetter QueueTickerCre maxWaitTickerGetter: maxWaitTickerGetter, } } + +// WithConcurrencyLimiters sets up middleware to limit the concurrency of +// requests based on RPC and repository +func WithConcurrencyLimiters(cfg config.Cfg, middleware *LimiterMiddleware) { + acquiringSecondsMetric := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "gitaly", + Subsystem: "concurrency_limiting", + Name: "acquiring_seconds", + Help: "Histogram of time calls are rate limited (in seconds)", + Buckets: cfg.Prometheus.GRPCLatencyBuckets, + }, + []string{"system", "grpc_service", "grpc_method"}, + ) + inProgressMetric := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "gitaly", + Subsystem: "concurrency_limiting", + Name: "in_progress", + Help: "Gauge of number of concurrent in-progress calls", + }, + []string{"system", "grpc_service", "grpc_method"}, + ) + queuedMetric := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "gitaly", + Subsystem: "concurrency_limiting", + Name: "queued", + Help: "Gauge of number of queued calls", + }, + []string{"system", "grpc_service", "grpc_method"}, + ) + + middleware.collect = func(metrics chan<- prometheus.Metric) { + acquiringSecondsMetric.Collect(metrics) + inProgressMetric.Collect(metrics) + queuedMetric.Collect(metrics) + } + + result := make(map[string]Limiter) + + newTickerFunc := func() helper.Ticker { + return helper.NewManualTicker() + } + + for _, limit := range cfg.Concurrency { + if limit.MaxQueueWait > 0 { + limit := limit + newTickerFunc = func() helper.Ticker { + return helper.NewTimerTicker(limit.MaxQueueWait.Duration()) + } + } + + result[limit.RPC] = NewConcurrencyLimiter( + limit.MaxPerRepo, + limit.MaxQueueSize, + newTickerFunc, + newPromMonitor("gitaly", limit.RPC, queuedMetric, inProgressMetric, + acquiringSecondsMetric, middleware.requestsDroppedMetric), + ) + } + + // Set default for ReplicateRepository. + replicateRepositoryFullMethod := "/gitaly.RepositoryService/ReplicateRepository" + if _, ok := result[replicateRepositoryFullMethod]; !ok { + result[replicateRepositoryFullMethod] = NewConcurrencyLimiter( + 1, + 0, + func() helper.Ticker { + return helper.NewManualTicker() + }, + newPromMonitor("gitaly", replicateRepositoryFullMethod, queuedMetric, + inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric)) + } + + middleware.methodLimiters = result +} diff --git a/internal/middleware/limithandler/concurrency_limiter_test.go b/internal/middleware/limithandler/concurrency_limiter_test.go index bbeda4d76..4a13a3cd4 100644 --- a/internal/middleware/limithandler/concurrency_limiter_test.go +++ b/internal/middleware/limithandler/concurrency_limiter_test.go @@ -150,7 +150,7 @@ func TestLimiter(t *testing.T) { gauge := &counter{} - limiter := NewLimiter( + limiter := NewConcurrencyLimiter( tt.maxConcurrency, 0, nil, @@ -266,7 +266,7 @@ func TestConcurrencyLimiter_queueLimit(t *testing.T) { monitorCh := make(chan struct{}) monitor := &blockingQueueCounter{queuedCh: monitorCh} ch := make(chan struct{}) - limiter := NewLimiter(1, queueLimit, nil, monitor) + limiter := NewConcurrencyLimiter(1, queueLimit, nil, monitor) // occupied with one live request that takes a long time to complete go func() { @@ -355,7 +355,7 @@ func TestLimitConcurrency_queueWaitTime(t *testing.T) { dequeuedCh := make(chan struct{}) monitor := &blockingDequeueCounter{dequeuedCh: dequeuedCh} - limiter := NewLimiter( + limiter := NewConcurrencyLimiter( 1, 0, func() helper.Ticker { @@ -409,7 +409,7 @@ func TestLimitConcurrency_queueWaitTime(t *testing.T) { dequeuedCh := make(chan struct{}) monitor := &blockingDequeueCounter{dequeuedCh: dequeuedCh} - limiter := NewLimiter( + limiter := NewConcurrencyLimiter( 1, 0, func() helper.Ticker { diff --git a/internal/middleware/limithandler/middleware.go b/internal/middleware/limithandler/middleware.go index ac33ff4b1..0d4ff6bbc 100644 --- a/internal/middleware/limithandler/middleware.go +++ b/internal/middleware/limithandler/middleware.go @@ -7,7 +7,6 @@ import ( grpcmwtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/prometheus/client_golang/prometheus" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" - "gitlab.com/gitlab-org/gitaly/v14/internal/helper" "google.golang.org/grpc" ) @@ -30,50 +29,27 @@ func LimitConcurrencyByRepo(ctx context.Context) string { return "" } +// Limiter limits incoming requests +type Limiter interface { + Limit(ctx context.Context, lockKey string, f LimitedFunc) (interface{}, error) +} + +// LimitedFunc represents a function that will be limited +type LimitedFunc func() (resp interface{}, err error) + // LimiterMiddleware contains rate limiter state type LimiterMiddleware struct { - methodLimiters map[string]*ConcurrencyLimiter - getLockKey GetLockKey - - acquiringSecondsMetric *prometheus.HistogramVec - inProgressMetric *prometheus.GaugeVec - queuedMetric *prometheus.GaugeVec - requestsDroppedMetric *prometheus.CounterVec + methodLimiters map[string]Limiter + getLockKey GetLockKey + requestsDroppedMetric *prometheus.CounterVec + collect func(metrics chan<- prometheus.Metric) } -// New creates a new rate limiter -func New(cfg config.Cfg, getLockKey GetLockKey) *LimiterMiddleware { +// New creates a new middleware that limits requests. SetupFunc sets up the +// middlware with a specific kind of limiter. +func New(cfg config.Cfg, getLockKey GetLockKey, setupMiddleware SetupFunc) *LimiterMiddleware { middleware := &LimiterMiddleware{ getLockKey: getLockKey, - - acquiringSecondsMetric: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Namespace: "gitaly", - Subsystem: "rate_limiting", - Name: "acquiring_seconds", - Help: "Histogram of time calls are rate limited (in seconds)", - Buckets: cfg.Prometheus.GRPCLatencyBuckets, - }, - []string{"system", "grpc_service", "grpc_method"}, - ), - inProgressMetric: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: "gitaly", - Subsystem: "rate_limiting", - Name: "in_progress", - Help: "Gauge of number of concurrent in-progress calls", - }, - []string{"system", "grpc_service", "grpc_method"}, - ), - queuedMetric: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: "gitaly", - Subsystem: "rate_limiting", - Name: "queued", - Help: "Gauge of number of queued calls", - }, - []string{"system", "grpc_service", "grpc_method"}, - ), requestsDroppedMetric: prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "gitaly_requests_dropped_total", @@ -87,7 +63,9 @@ func New(cfg config.Cfg, getLockKey GetLockKey) *LimiterMiddleware { }, ), } - middleware.methodLimiters = createLimiterConfig(middleware, cfg) + + setupMiddleware(cfg, middleware) + return middleware } @@ -98,10 +76,10 @@ func (c *LimiterMiddleware) Describe(descs chan<- *prometheus.Desc) { // Collect is used to collect Prometheus metrics. func (c *LimiterMiddleware) Collect(metrics chan<- prometheus.Metric) { - c.acquiringSecondsMetric.Collect(metrics) - c.inProgressMetric.Collect(metrics) - c.queuedMetric.Collect(metrics) c.requestsDroppedMetric.Collect(metrics) + if c.collect != nil { + c.collect(metrics) + } } // UnaryInterceptor returns a Unary Interceptor @@ -132,43 +110,8 @@ func (c *LimiterMiddleware) StreamInterceptor() grpc.StreamServerInterceptor { } } -func createLimiterConfig(middleware *LimiterMiddleware, cfg config.Cfg) map[string]*ConcurrencyLimiter { - result := make(map[string]*ConcurrencyLimiter) - - newTickerFunc := func() helper.Ticker { - return helper.NewManualTicker() - } - - for _, limit := range cfg.Concurrency { - if limit.MaxQueueWait > 0 { - limit := limit - newTickerFunc = func() helper.Ticker { - return helper.NewTimerTicker(limit.MaxQueueWait.Duration()) - } - } - - result[limit.RPC] = NewLimiter( - limit.MaxPerRepo, - limit.MaxQueueSize, - newTickerFunc, - newPromMonitor(middleware, "gitaly", limit.RPC), - ) - } - - // Set default for ReplicateRepository. - replicateRepositoryFullMethod := "/gitaly.RepositoryService/ReplicateRepository" - if _, ok := result[replicateRepositoryFullMethod]; !ok { - result[replicateRepositoryFullMethod] = NewLimiter( - 1, - 0, - func() helper.Ticker { - return helper.NewManualTicker() - }, - newPromMonitor(middleware, "gitaly", replicateRepositoryFullMethod)) - } - - return result -} +// SetupFunc set up a middleware to limiting requests +type SetupFunc func(cfg config.Cfg, middleware *LimiterMiddleware) type wrappedStream struct { grpc.ServerStream diff --git a/internal/middleware/limithandler/middleware_test.go b/internal/middleware/limithandler/middleware_test.go index bd94228bb..53f0f53f2 100644 --- a/internal/middleware/limithandler/middleware_test.go +++ b/internal/middleware/limithandler/middleware_test.go @@ -17,6 +17,8 @@ import ( pb "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/limithandler/testdata" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestMain(m *testing.M) { @@ -38,7 +40,7 @@ func TestUnaryLimitHandler(t *testing.T) { }, } - lh := limithandler.New(cfg, fixedLockKey) + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithConcurrencyLimiters) interceptor := lh.UnaryInterceptor() srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) defer srv.Stop() @@ -191,7 +193,7 @@ func TestStreamLimitHandler(t *testing.T) { }, } - lh := limithandler.New(cfg, fixedLockKey) + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithConcurrencyLimiters) interceptor := lh.StreamInterceptor() srv, serverSocketPath := runServer(t, s, grpc.StreamInterceptor(interceptor)) defer srv.Stop() @@ -233,7 +235,7 @@ func (q *queueTestServer) Unary(ctx context.Context, in *pb.UnaryRequest) (*pb.U return &pb.UnaryResponse{Ok: true}, nil } -func TestLimitHandlerMetrics(t *testing.T) { +func TestConcurrencyLimitHandlerMetrics(t *testing.T) { s := &queueTestServer{reqArrivedCh: make(chan struct{})} s.blockCh = make(chan struct{}) @@ -244,7 +246,7 @@ func TestLimitHandlerMetrics(t *testing.T) { }, } - lh := limithandler.New(cfg, fixedLockKey) + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithConcurrencyLimiters) interceptor := lh.UnaryInterceptor() srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) defer srv.Stop() @@ -290,22 +292,22 @@ func TestLimitHandlerMetrics(t *testing.T) { } } - expectedMetrics := `# HELP gitaly_rate_limiting_in_progress Gauge of number of concurrent in-progress calls -# TYPE gitaly_rate_limiting_in_progress gauge -gitaly_rate_limiting_in_progress{grpc_method="ReplicateRepository",grpc_service="gitaly.RepositoryService",system="gitaly"} 0 -gitaly_rate_limiting_in_progress{grpc_method="Unary",grpc_service="test.limithandler.Test",system="gitaly"} 1 -# HELP gitaly_rate_limiting_queued Gauge of number of queued calls -# TYPE gitaly_rate_limiting_queued gauge -gitaly_rate_limiting_queued{grpc_method="ReplicateRepository",grpc_service="gitaly.RepositoryService",system="gitaly"} 0 -gitaly_rate_limiting_queued{grpc_method="Unary",grpc_service="test.limithandler.Test",system="gitaly"} 1 + expectedMetrics := `# HELP gitaly_concurrency_limiting_in_progress Gauge of number of concurrent in-progress calls +# TYPE gitaly_concurrency_limiting_in_progress gauge +gitaly_concurrency_limiting_in_progress{grpc_method="ReplicateRepository",grpc_service="gitaly.RepositoryService",system="gitaly"} 0 +gitaly_concurrency_limiting_in_progress{grpc_method="Unary",grpc_service="test.limithandler.Test",system="gitaly"} 1 +# HELP gitaly_concurrency_limiting_queued Gauge of number of queued calls +# TYPE gitaly_concurrency_limiting_queued gauge +gitaly_concurrency_limiting_queued{grpc_method="ReplicateRepository",grpc_service="gitaly.RepositoryService",system="gitaly"} 0 +gitaly_concurrency_limiting_queued{grpc_method="Unary",grpc_service="test.limithandler.Test",system="gitaly"} 1 # HELP gitaly_requests_dropped_total Number of requests dropped from the queue # TYPE gitaly_requests_dropped_total counter gitaly_requests_dropped_total{grpc_method="Unary",grpc_service="test.limithandler.Test",reason="max_size",system="gitaly"} 9 ` assert.NoError(t, promtest.CollectAndCompare(lh, bytes.NewBufferString(expectedMetrics), - "gitaly_rate_limiting_queued", + "gitaly_concurrency_limiting_queued", "gitaly_requests_dropped_total", - "gitaly_rate_limiting_in_progress")) + "gitaly_concurrency_limiting_in_progress")) close(s.blockCh) <-s.reqArrivedCh @@ -315,6 +317,86 @@ gitaly_requests_dropped_total{grpc_method="Unary",grpc_service="test.limithandle <-respCh } +func TestRateLimitHandler(t *testing.T) { + t.Parallel() + testhelper.NewFeatureSets(featureflag.RateLimit).Run(t, testRateLimitHandler) +} + +func testRateLimitHandler(t *testing.T, ctx context.Context) { + methodName := "/test.limithandler.Test/Unary" + cfg := config.Cfg{ + RateLimiting: []config.RateLimiting{ + {RPC: methodName, Interval: 1 * time.Hour, Burst: 1}, + }, + } + + t.Run("rate has hit max", func(t *testing.T) { + s := &server{blockCh: make(chan struct{})} + + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters(ctx)) + interceptor := lh.UnaryInterceptor() + srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer testhelper.MustClose(t, conn) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, err := client.Unary(ctx, &pb.UnaryRequest{}) + require.NoError(t, err) + }() + // wait until the first request is being processed so we know the rate + // limiter already knows about it. + s.blockCh <- struct{}{} + close(s.blockCh) + + for i := 0; i < 10; i++ { + _, err := client.Unary(ctx, &pb.UnaryRequest{}) + + if featureflag.RateLimit.IsEnabled(ctx) { + testhelper.RequireGrpcError(t, status.Error(codes.Unavailable, "too many requests"), err) + } else { + require.NoError(t, err) + } + } + + expectedMetrics := `# HELP gitaly_requests_dropped_total Number of requests dropped from the queue +# TYPE gitaly_requests_dropped_total counter +gitaly_requests_dropped_total{grpc_method="Unary",grpc_service="test.limithandler.Test",reason="rate",system="gitaly"} 10 +` + assert.NoError(t, promtest.CollectAndCompare(lh, bytes.NewBufferString(expectedMetrics), + "gitaly_requests_dropped_total")) + + wg.Wait() + }) + + t.Run("rate has not hit max", func(t *testing.T) { + s := &server{blockCh: make(chan struct{})} + + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters(ctx)) + interceptor := lh.UnaryInterceptor() + srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer testhelper.MustClose(t, conn) + + close(s.blockCh) + _, err := client.Unary(ctx, &pb.UnaryRequest{}) + require.NoError(t, err) + + expectedMetrics := `# HELP gitaly_requests_dropped_total Number of requests dropped from the queue +# TYPE gitaly_requests_dropped_total counter +gitaly_requests_dropped_total{grpc_method="Unary",grpc_service="test.limithandler.Test",reason="rate",system="gitaly"} 0 +` + assert.NoError(t, promtest.CollectAndCompare(lh, bytes.NewBufferString(expectedMetrics), + "gitaly_requests_dropped_total")) + }) +} + func runServer(t *testing.T, s pb.TestServer, opt ...grpc.ServerOption) (*grpc.Server, string) { serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) grpcServer := grpc.NewServer(opt...) diff --git a/internal/middleware/limithandler/monitor.go b/internal/middleware/limithandler/monitor.go index f77014b9d..98dabf2a0 100644 --- a/internal/middleware/limithandler/monitor.go +++ b/internal/middleware/limithandler/monitor.go @@ -37,14 +37,19 @@ type promMonitor struct { // newPromMonitor creates a new ConcurrencyMonitor that tracks limiter // activity in Prometheus. -func newPromMonitor(lh *LimiterMiddleware, system string, fullMethod string) ConcurrencyMonitor { +func newPromMonitor( + system, fullMethod string, + queuedMetric, inProgressMetric *prometheus.GaugeVec, + acquiringSecondsMetric prometheus.ObserverVec, + requestsDroppedMetric *prometheus.CounterVec, +) ConcurrencyMonitor { serviceName, methodName := splitMethodName(fullMethod) return &promMonitor{ - lh.queuedMetric.WithLabelValues(system, serviceName, methodName), - lh.inProgressMetric.WithLabelValues(system, serviceName, methodName), - lh.acquiringSecondsMetric.WithLabelValues(system, serviceName, methodName), - lh.requestsDroppedMetric.MustCurryWith(prometheus.Labels{ + queuedMetric.WithLabelValues(system, serviceName, methodName), + inProgressMetric.WithLabelValues(system, serviceName, methodName), + acquiringSecondsMetric.WithLabelValues(system, serviceName, methodName), + requestsDroppedMetric.MustCurryWith(prometheus.Labels{ "system": system, "grpc_service": serviceName, "grpc_method": methodName, diff --git a/internal/middleware/limithandler/rate_limiter.go b/internal/middleware/limithandler/rate_limiter.go new file mode 100644 index 000000000..f47bb5409 --- /dev/null +++ b/internal/middleware/limithandler/rate_limiter.go @@ -0,0 +1,117 @@ +package limithandler + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" + "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag" + "golang.org/x/time/rate" +) + +// RateLimiter is an implementation of Limiter that puts a hard limit on the +// number of requests per second +type RateLimiter struct { + limitersByKey, lastAccessedByKey sync.Map + refillInterval time.Duration + burst int + requestsDroppedMetric prometheus.Counter + ticker helper.Ticker +} + +// Limit rejects an incoming reequest if the maximum number of requests per +// second has been reached +func (r *RateLimiter) Limit(ctx context.Context, lockKey string, f LimitedFunc) (interface{}, error) { + limiter, _ := r.limitersByKey.LoadOrStore( + lockKey, + rate.NewLimiter(rate.Every(r.refillInterval), r.burst), + ) + r.lastAccessedByKey.Store(lockKey, time.Now()) + + if !limiter.(*rate.Limiter).Allow() { + // For now, we are only emitting this metric to get an idea of the shape + // of traffic. + r.requestsDroppedMetric.Inc() + if featureflag.RateLimit.IsEnabled(ctx) { + return nil, helper.ErrUnavailable(errors.New("too many requests")) + } + } + + return f() +} + +// PruneUnusedLimiters enters an infinite loop to periodically check if any +// limiters can be cleaned up. This is meant to be called in a separate +// goroutine. +func (r *RateLimiter) PruneUnusedLimiters(ctx context.Context) { + defer r.ticker.Stop() + for { + r.ticker.Reset() + select { + case <-r.ticker.C(): + r.pruneUnusedLimiters() + case <-ctx.Done(): + return + } + } +} + +func (r *RateLimiter) pruneUnusedLimiters() { + r.lastAccessedByKey.Range(func(key, value interface{}) bool { + if value.(time.Time).Before(time.Now().Add(-10 * r.refillInterval)) { + r.limitersByKey.Delete(key) + } + + return true + }) +} + +// NewRateLimiter creates a new instance of RateLimiter +func NewRateLimiter( + refillInterval time.Duration, + burst int, + ticker helper.Ticker, + requestsDroppedMetric prometheus.Counter, +) *RateLimiter { + r := &RateLimiter{ + refillInterval: refillInterval, + burst: burst, + requestsDroppedMetric: requestsDroppedMetric, + ticker: ticker, + } + + return r +} + +// WithRateLimiters sets up a middleware with limiters that limit requests +// based on its rate per second per RPC +func WithRateLimiters(ctx context.Context) SetupFunc { + return func(cfg config.Cfg, middleware *LimiterMiddleware) { + result := make(map[string]Limiter) + + for _, limitCfg := range cfg.RateLimiting { + if limitCfg.Burst > 0 && limitCfg.Interval > 0 { + serviceName, methodName := splitMethodName(limitCfg.RPC) + rateLimiter := NewRateLimiter( + limitCfg.Interval, + limitCfg.Burst, + helper.NewTimerTicker(5*time.Minute), + middleware.requestsDroppedMetric.With(prometheus.Labels{ + "system": "gitaly", + "grpc_service": serviceName, + "grpc_method": methodName, + "reason": "rate", + }), + ) + result[limitCfg.RPC] = rateLimiter + go rateLimiter.PruneUnusedLimiters(ctx) + } + } + + middleware.methodLimiters = result + } +} diff --git a/internal/middleware/limithandler/rate_limiter_test.go b/internal/middleware/limithandler/rate_limiter_test.go new file mode 100644 index 000000000..8ee9b64e6 --- /dev/null +++ b/internal/middleware/limithandler/rate_limiter_test.go @@ -0,0 +1,94 @@ +package limithandler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" +) + +func TestRateLimiter_pruneUnusedLimiters(t *testing.T) { + t.Parallel() + + testCases := []struct { + desc string + setup func(r *RateLimiter) + expectedLimiters, expectedRemovedLimiters []string + }{ + { + desc: "none are prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now()) + r.lastAccessedByKey.Store("b", time.Now()) + r.lastAccessedByKey.Store("c", time.Now()) + }, + expectedLimiters: []string{"a", "b", "c"}, + expectedRemovedLimiters: []string{}, + }, + { + desc: "all are prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now().Add(-1*time.Minute)) + r.lastAccessedByKey.Store("b", time.Now().Add(-1*time.Minute)) + r.lastAccessedByKey.Store("c", time.Now().Add(-1*time.Minute)) + }, + expectedLimiters: []string{}, + expectedRemovedLimiters: []string{"a", "b", "c"}, + }, + { + desc: "one is prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now()) + r.lastAccessedByKey.Store("b", time.Now()) + r.lastAccessedByKey.Store("c", time.Now().Add(-1*time.Minute)) + }, + expectedLimiters: []string{"a", "b"}, + expectedRemovedLimiters: []string{"c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := testhelper.Context(t) + ticker := helper.NewManualTicker() + ch := make(chan struct{}) + ticker.ResetFunc = func() { + ch <- struct{}{} + } + + rateLimiter := &RateLimiter{ + refillInterval: time.Second, + ticker: ticker, + } + + tc.setup(rateLimiter) + + go rateLimiter.PruneUnusedLimiters(ctx) + <-ch + + ticker.Tick() + <-ch + + for _, expectedLimiter := range tc.expectedLimiters { + _, ok := rateLimiter.limitersByKey.Load(expectedLimiter) + assert.True(t, ok) + } + + for _, expectedRemovedLimiter := range tc.expectedRemovedLimiters { + _, ok := rateLimiter.limitersByKey.Load(expectedRemovedLimiter) + assert.False(t, ok) + } + }) + } +} diff --git a/internal/testhelper/testserver/gitaly.go b/internal/testhelper/testserver/gitaly.go index 9f60a4afa..753a579e3 100644 --- a/internal/testhelper/testserver/gitaly.go +++ b/internal/testhelper/testserver/gitaly.go @@ -158,7 +158,7 @@ func runGitaly(t testing.TB, cfg config.Cfg, rubyServer *rubyserver.Server, regi gsd.logger.WithField("test", t.Name()), deps.GetBackchannelRegistry(), deps.GetDiskCache(), - deps.GetLimitHandler(), + []*limithandler.LimiterMiddleware{deps.GetLimitHandler()}, ) if cfg.InternalSocketDir != "" { @@ -302,7 +302,7 @@ func (gsd *gitalyServerDeps) createDependencies(t testing.TB, cfg config.Cfg, ru } if gsd.limitHandler == nil { - gsd.limitHandler = limithandler.New(cfg, limithandler.LimitConcurrencyByRepo) + gsd.limitHandler = limithandler.New(cfg, limithandler.LimitConcurrencyByRepo, limithandler.WithConcurrencyLimiters) } if gsd.git2goExecutor == nil { |