diff options
author | John Cai <jcai@gitlab.com> | 2022-03-30 17:29:33 +0300 |
---|---|---|
committer | John Cai <jcai@gitlab.com> | 2022-04-06 22:27:35 +0300 |
commit | 7ce0a1bb170eb89da15a12acf9446f2cb43262a1 (patch) | |
tree | 84e374118aeb9c558f98c8e4cc192e5ca184b5fb | |
parent | 00266a9213efcda1138e1977c476d6ceff4ba9b3 (diff) |
limithandler: Prune unused limiters in RateLimiter
RateLimiter contains a limiter per rpc/repo pair. We don't want this to
grow monotinically since it will incur a heavy memory burden on the
machine. Instead, introduce a background process that looks through the
limiters and removes the ones that have not been used in the past 10
refill intervals.
-rw-r--r-- | internal/middleware/limithandler/middleware_test.go | 4 | ||||
-rw-r--r-- | internal/middleware/limithandler/rate_limiter.go | 78 | ||||
-rw-r--r-- | internal/middleware/limithandler/rate_limiter_test.go | 94 |
3 files changed, 153 insertions, 23 deletions
diff --git a/internal/middleware/limithandler/middleware_test.go b/internal/middleware/limithandler/middleware_test.go index 3a6b3e920..53f0f53f2 100644 --- a/internal/middleware/limithandler/middleware_test.go +++ b/internal/middleware/limithandler/middleware_test.go @@ -333,7 +333,7 @@ func testRateLimitHandler(t *testing.T, ctx context.Context) { t.Run("rate has hit max", func(t *testing.T) { s := &server{blockCh: make(chan struct{})} - lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters) + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters(ctx)) interceptor := lh.UnaryInterceptor() srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) defer srv.Stop() @@ -376,7 +376,7 @@ gitaly_requests_dropped_total{grpc_method="Unary",grpc_service="test.limithandle t.Run("rate has not hit max", func(t *testing.T) { s := &server{blockCh: make(chan struct{})} - lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters) + lh := limithandler.New(cfg, fixedLockKey, limithandler.WithRateLimiters(ctx)) interceptor := lh.UnaryInterceptor() srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) defer srv.Stop() diff --git a/internal/middleware/limithandler/rate_limiter.go b/internal/middleware/limithandler/rate_limiter.go index b8bc605db..f47bb5409 100644 --- a/internal/middleware/limithandler/rate_limiter.go +++ b/internal/middleware/limithandler/rate_limiter.go @@ -16,10 +16,11 @@ import ( // RateLimiter is an implementation of Limiter that puts a hard limit on the // number of requests per second type RateLimiter struct { - limitersByKey sync.Map - refillInterval time.Duration - burst int - requestsDroppedMetric prometheus.Counter + limitersByKey, lastAccessedByKey sync.Map + refillInterval time.Duration + burst int + requestsDroppedMetric prometheus.Counter + ticker helper.Ticker } // Limit rejects an incoming reequest if the maximum number of requests per @@ -29,6 +30,8 @@ func (r *RateLimiter) Limit(ctx context.Context, lockKey string, f LimitedFunc) lockKey, rate.NewLimiter(rate.Every(r.refillInterval), r.burst), ) + r.lastAccessedByKey.Store(lockKey, time.Now()) + if !limiter.(*rate.Limiter).Allow() { // For now, we are only emitting this metric to get an idea of the shape // of traffic. @@ -41,16 +44,44 @@ func (r *RateLimiter) Limit(ctx context.Context, lockKey string, f LimitedFunc) return f() } +// PruneUnusedLimiters enters an infinite loop to periodically check if any +// limiters can be cleaned up. This is meant to be called in a separate +// goroutine. +func (r *RateLimiter) PruneUnusedLimiters(ctx context.Context) { + defer r.ticker.Stop() + for { + r.ticker.Reset() + select { + case <-r.ticker.C(): + r.pruneUnusedLimiters() + case <-ctx.Done(): + return + } + } +} + +func (r *RateLimiter) pruneUnusedLimiters() { + r.lastAccessedByKey.Range(func(key, value interface{}) bool { + if value.(time.Time).Before(time.Now().Add(-10 * r.refillInterval)) { + r.limitersByKey.Delete(key) + } + + return true + }) +} + // NewRateLimiter creates a new instance of RateLimiter func NewRateLimiter( refillInterval time.Duration, burst int, + ticker helper.Ticker, requestsDroppedMetric prometheus.Counter, ) *RateLimiter { r := &RateLimiter{ refillInterval: refillInterval, burst: burst, requestsDroppedMetric: requestsDroppedMetric, + ticker: ticker, } return r @@ -58,24 +89,29 @@ func NewRateLimiter( // WithRateLimiters sets up a middleware with limiters that limit requests // based on its rate per second per RPC -func WithRateLimiters(cfg config.Cfg, middleware *LimiterMiddleware) { - result := make(map[string]Limiter) +func WithRateLimiters(ctx context.Context) SetupFunc { + return func(cfg config.Cfg, middleware *LimiterMiddleware) { + result := make(map[string]Limiter) - for _, limitCfg := range cfg.RateLimiting { - if limitCfg.Burst > 0 && limitCfg.Interval > 0 { - serviceName, methodName := splitMethodName(limitCfg.RPC) - result[limitCfg.RPC] = NewRateLimiter( - limitCfg.Interval, - limitCfg.Burst, - middleware.requestsDroppedMetric.With(prometheus.Labels{ - "system": "gitaly", - "grpc_service": serviceName, - "grpc_method": methodName, - "reason": "rate", - }), - ) + for _, limitCfg := range cfg.RateLimiting { + if limitCfg.Burst > 0 && limitCfg.Interval > 0 { + serviceName, methodName := splitMethodName(limitCfg.RPC) + rateLimiter := NewRateLimiter( + limitCfg.Interval, + limitCfg.Burst, + helper.NewTimerTicker(5*time.Minute), + middleware.requestsDroppedMetric.With(prometheus.Labels{ + "system": "gitaly", + "grpc_service": serviceName, + "grpc_method": methodName, + "reason": "rate", + }), + ) + result[limitCfg.RPC] = rateLimiter + go rateLimiter.PruneUnusedLimiters(ctx) + } } - } - middleware.methodLimiters = result + middleware.methodLimiters = result + } } diff --git a/internal/middleware/limithandler/rate_limiter_test.go b/internal/middleware/limithandler/rate_limiter_test.go new file mode 100644 index 000000000..8ee9b64e6 --- /dev/null +++ b/internal/middleware/limithandler/rate_limiter_test.go @@ -0,0 +1,94 @@ +package limithandler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" +) + +func TestRateLimiter_pruneUnusedLimiters(t *testing.T) { + t.Parallel() + + testCases := []struct { + desc string + setup func(r *RateLimiter) + expectedLimiters, expectedRemovedLimiters []string + }{ + { + desc: "none are prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now()) + r.lastAccessedByKey.Store("b", time.Now()) + r.lastAccessedByKey.Store("c", time.Now()) + }, + expectedLimiters: []string{"a", "b", "c"}, + expectedRemovedLimiters: []string{}, + }, + { + desc: "all are prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now().Add(-1*time.Minute)) + r.lastAccessedByKey.Store("b", time.Now().Add(-1*time.Minute)) + r.lastAccessedByKey.Store("c", time.Now().Add(-1*time.Minute)) + }, + expectedLimiters: []string{}, + expectedRemovedLimiters: []string{"a", "b", "c"}, + }, + { + desc: "one is prunable", + setup: func(r *RateLimiter) { + r.limitersByKey.Store("a", struct{}{}) + r.limitersByKey.Store("b", struct{}{}) + r.limitersByKey.Store("c", struct{}{}) + r.lastAccessedByKey.Store("a", time.Now()) + r.lastAccessedByKey.Store("b", time.Now()) + r.lastAccessedByKey.Store("c", time.Now().Add(-1*time.Minute)) + }, + expectedLimiters: []string{"a", "b"}, + expectedRemovedLimiters: []string{"c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx := testhelper.Context(t) + ticker := helper.NewManualTicker() + ch := make(chan struct{}) + ticker.ResetFunc = func() { + ch <- struct{}{} + } + + rateLimiter := &RateLimiter{ + refillInterval: time.Second, + ticker: ticker, + } + + tc.setup(rateLimiter) + + go rateLimiter.PruneUnusedLimiters(ctx) + <-ch + + ticker.Tick() + <-ch + + for _, expectedLimiter := range tc.expectedLimiters { + _, ok := rateLimiter.limitersByKey.Load(expectedLimiter) + assert.True(t, ok) + } + + for _, expectedRemovedLimiter := range tc.expectedRemovedLimiters { + _, ok := rateLimiter.limitersByKey.Load(expectedRemovedLimiter) + assert.False(t, ok) + } + }) + } +} |