From 197d53c3f8c20942f76f111cc7bf6aac04c0dad4 Mon Sep 17 00:00:00 2001 From: Vladimir Shushlin Date: Wed, 8 Dec 2021 18:50:25 +0300 Subject: refactor: abstract ratelimiter package We want to add domain-based ratelimiter. The logic will be identical, but we'll use host instead of IP address. --- app.go | 13 ++--- internal/ratelimiter/middleware.go | 31 ++++++------ internal/ratelimiter/middleware_test.go | 24 +++++---- internal/ratelimiter/ratelimiter.go | 86 +++++++++++++++++--------------- internal/ratelimiter/ratelimiter_test.go | 25 +++++++--- test/acceptance/ratelimiter_test.go | 2 +- 6 files changed, 100 insertions(+), 81 deletions(-) diff --git a/app.go b/app.go index 508b11fa..6030cad8 100644 --- a/app.go +++ b/app.go @@ -260,15 +260,16 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { if a.config.RateLimit.SourceIPLimitPerSecond > 0 { rl := ratelimiter.New( - ratelimiter.WithSourceIPCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), - ratelimiter.WithSourceIPCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries), - ratelimiter.WithSourceIPCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests), + "source_ip", + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests), ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount), - ratelimiter.WithSourceIPLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond), - ratelimiter.WithSourceIPBurstSize(a.config.RateLimit.SourceIPBurst), + ratelimiter.WithLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond), + ratelimiter.WithBurstSize(a.config.RateLimit.SourceIPBurst), ) - handler = rl.SourceIPLimiter(handler) + handler = rl.Middleware(handler) } // Health Check diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index f26cb2e0..86beaf39 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -18,25 +18,24 @@ const ( headerXForwardedProto = "X-Forwarded-Proto" ) -// SourceIPLimiter returns middleware for rate-limiting clients based on their IP -func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { +// Middleware returns middleware for rate-limiting clients +func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, sourceIP := request.GetHostWithoutPort(r), request.GetRemoteAddrWithoutPort(r) - if !rl.SourceIPAllowed(sourceIP) { - rl.logSourceIP(r, host, sourceIP) + if !rl.RequestAllowed(r) { + rl.logRateLimitedRequest(r) // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 if rateLimiterEnabled() { - if rl.sourceIPBlockedCount != nil { - rl.sourceIPBlockedCount.WithLabelValues("true").Inc() + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues("true").Inc() } httperrors.Serve429(w) return } - if rl.sourceIPBlockedCount != nil { - rl.sourceIPBlockedCount.WithLabelValues("false").Inc() + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues("false").Inc() } } @@ -44,24 +43,24 @@ func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { }) } -func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) { +func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) { log.WithFields(logrus.Fields{ - "handler": "source_ip_rate_limiter", + "rate_limiter_name": rl.name, "correlation_id": correlation.ExtractFromContext(r.Context()), "req_scheme": r.URL.Scheme, "req_host": r.Host, "req_path": r.URL.Path, - "pages_domain": host, + "pages_domain": request.GetHostWithoutPort(r), "remote_addr": r.RemoteAddr, - "source_ip": sourceIP, + "source_ip": request.GetRemoteAddrWithoutPort(r), "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), "rate_limiter_enabled": rateLimiterEnabled(), - "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond, - "rate_limiter_burst_size": rl.sourceIPBurstSize, + "rate_limiter_limit_per_second": rl.limitPerSecond, + "rate_limiter_burst_size": rl.burstSize, }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 - Info("source IP hit rate limit") + Info("request hit rate limit") } // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index 2cf3b3e5..a66f7523 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -28,9 +28,10 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { rl := New( + "rate_limiter", WithNow(mockNow), - WithSourceIPLimitPerSecond(tc.sourceIPLimit), - WithSourceIPBurstSize(tc.sourceIPBurstSize), + WithLimitPerSecond(tc.sourceIPLimit), + WithBurstSize(tc.sourceIPBurstSize), ) for i := 0; i < tc.reqNum; i++ { @@ -38,7 +39,7 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) rr.RemoteAddr = remoteAddr - handler := rl.SourceIPLimiter(next) + handler := rl.Middleware(next) handler.ServeHTTP(ww, rr) res := ww.Result() @@ -83,12 +84,13 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { for tn, tc := range tcs { t.Run(tn, func(t *testing.T) { rl := New( - WithSourceIPCachedEntriesMetric(cachedEntries), - WithSourceIPCachedRequestsMetric(cacheReqs), + "rate_limiter", + WithCachedEntriesMetric(cachedEntries), + WithCachedRequestsMetric(cacheReqs), WithBlockedCountMetric(blocked), WithNow(mockNow), - WithSourceIPLimitPerSecond(1), - WithSourceIPBurstSize(1), + WithLimitPerSecond(1), + WithBurstSize(1), ) for i := 0; i < 5; i++ { @@ -103,7 +105,7 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { rr.RemoteAddr = remoteAddr // middleware is evaluated in reverse order - handler := rl.SourceIPLimiter(next) + handler := rl.Middleware(next) handler.ServeHTTP(ww, rr) res := ww.Result() @@ -126,13 +128,13 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { } blocked.Reset() - cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("source_ip")) + cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("rate_limiter")) require.Equal(t, float64(1), cachedCount, "cached count") cachedEntries.Reset() - cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "miss")) + cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "miss")) require.Equal(t, float64(1), cacheReqMiss, "miss count") - cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "hit")) + cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "hit")) require.Equal(t, float64(4), cacheReqHit, "hit count") cacheReqs.Reset() }) diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index 37aca020..ab02c966 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -1,12 +1,14 @@ package ratelimiter import ( + "net/http" "time" "github.com/prometheus/client_golang/prometheus" "golang.org/x/time/rate" "gitlab.com/gitlab-org/gitlab-pages/internal/lru" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) const ( @@ -26,35 +28,40 @@ const ( // Option function to configure a RateLimiter type Option func(*RateLimiter) -// RateLimiter holds an LRU cache of elements to be rate limited. Currently, it supports -// a sourceIPCache and each item returns a rate.Limiter. +// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain) +type KeyFunc func(*http.Request) string + +// RateLimiter holds an LRU cache of elements to be rate limited. // It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry. // See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html // It also holds a now function that can be mocked in unit tests. type RateLimiter struct { - now func() time.Time - sourceIPLimitPerSecond float64 - sourceIPBurstSize int - sourceIPBlockedCount *prometheus.GaugeVec - sourceIPCache *lru.Cache - - sourceIPCacheOptions []lru.Option - // TODO: add domainCache https://gitlab.com/gitlab-org/gitlab-pages/-/issues/630 + name string + now func() time.Time + limitPerSecond float64 + burstSize int + blockedCount *prometheus.GaugeVec + cache *lru.Cache + key KeyFunc + + cacheOptions []lru.Option } // New creates a new RateLimiter with default values that can be configured via Option functions -func New(opts ...Option) *RateLimiter { +func New(name string, opts ...Option) *RateLimiter { rl := &RateLimiter{ - now: time.Now, - sourceIPLimitPerSecond: DefaultSourceIPLimitPerSecond, - sourceIPBurstSize: DefaultSourceIPBurstSize, + name: name, + now: time.Now, + limitPerSecond: DefaultSourceIPLimitPerSecond, + burstSize: DefaultSourceIPBurstSize, + key: request.GetRemoteAddrWithoutPort, } for _, opt := range opts { opt(rl) } - rl.sourceIPCache = lru.New("source_ip", rl.sourceIPCacheOptions...) + rl.cache = lru.New(name, rl.cacheOptions...) return rl } @@ -66,60 +73,61 @@ func WithNow(now func() time.Time) Option { } } -// WithSourceIPLimitPerSecond allows configuring per source IP limit per second for RateLimiter -func WithSourceIPLimitPerSecond(limit float64) Option { +// WithLimitPerSecond allows configuring limit per second for RateLimiter +func WithLimitPerSecond(limit float64) Option { return func(rl *RateLimiter) { - rl.sourceIPLimitPerSecond = limit + rl.limitPerSecond = limit } } -// WithSourceIPBurstSize configures burst per source IP for the RateLimiter -func WithSourceIPBurstSize(burst int) Option { +// WithBurstSize configures burst per key for the RateLimiter +func WithBurstSize(burst int) Option { return func(rl *RateLimiter) { - rl.sourceIPBurstSize = burst + rl.burstSize = burst } } -// WithBlockedCountMetric configures metric reporting how many requests were blocked based by IP +// WithBlockedCountMetric configures metric reporting how many requests were blocked func WithBlockedCountMetric(m *prometheus.GaugeVec) Option { return func(rl *RateLimiter) { - rl.sourceIPBlockedCount = m + rl.blockedCount = m } } -// WithSourceIPCacheMaxSize configures cache size for source IP ratelimiter -func WithSourceIPCacheMaxSize(size int64) Option { +// WithCacheMaxSize configures cache size for ratelimiter +func WithCacheMaxSize(size int64) Option { return func(rl *RateLimiter) { - rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithMaxSize(size)) + rl.cacheOptions = append(rl.cacheOptions, lru.WithMaxSize(size)) } } -// WithSourceIPCachedEntriesMetric configures metric reporting how many IPs are currently stored in -// source-IP rate-limiter cache -func WithSourceIPCachedEntriesMetric(m *prometheus.GaugeVec) Option { +// WithCachedEntriesMetric configures metric reporting how many keys are currently stored in +// the rate-limiter cache +func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option { return func(rl *RateLimiter) { - rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithCachedEntriesMetric(m)) + rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedEntriesMetric(m)) } } -// WithSourceIPCachedRequestsMetric configures metric for how many times we ask source IP cache -func WithSourceIPCachedRequestsMetric(m *prometheus.CounterVec) Option { +// WithCachedRequestsMetric configures metric for how many times we ask key cache +func WithCachedRequestsMetric(m *prometheus.CounterVec) Option { return func(rl *RateLimiter) { - rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithCachedRequestsMetric(m)) + rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m)) } } -func (rl *RateLimiter) getSourceIPLimiter(sourceIP string) *rate.Limiter { - limiterI, _ := rl.sourceIPCache.FindOrFetch(sourceIP, sourceIP, func() (interface{}, error) { - return rate.NewLimiter(rate.Limit(rl.sourceIPLimitPerSecond), rl.sourceIPBurstSize), nil +func (rl *RateLimiter) limiter(key string) *rate.Limiter { + limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) { + return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil }) return limiterI.(*rate.Limiter) } -// SourceIPAllowed checks that the real remote IP address is allowed to perform an operation -func (rl *RateLimiter) SourceIPAllowed(sourceIP string) bool { - limiter := rl.getSourceIPLimiter(sourceIP) +// RequestAllowed checks that the real remote IP address is allowed to perform an operation +func (rl *RateLimiter) RequestAllowed(r *http.Request) bool { + rateLimitedKey := rl.key(r) + limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. return limiter.AllowN(rl.now(), 1) diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index d6be214a..ce767b35 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -1,6 +1,8 @@ package ratelimiter import ( + "net/http" + "net/http/httptest" "sync" "testing" "time" @@ -51,17 +53,21 @@ func TestSourceIPAllowed(t *testing.T) { for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { rl := New( + "rate_limiter", WithNow(mockNow), - WithSourceIPLimitPerSecond(tc.sourceIPLimit), - WithSourceIPBurstSize(tc.sourceIPBurstSize), + WithLimitPerSecond(tc.sourceIPLimit), + WithBurstSize(tc.sourceIPBurstSize), ) for i := 0; i < tc.reqNum; i++ { - got := rl.SourceIPAllowed("172.16.123.1") + r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) + r.RemoteAddr = "172.16.123.1" + + got := rl.RequestAllowed(r) if i < tc.sourceIPBurstSize { require.Truef(t, got, "expected true for request no. %d", i) } else { - // requests should fail after reaching tc.sourceIPBurstSize because mockNow + // requests should fail after reaching tc.burstSize because mockNow // always returns the same time require.False(t, got, "expected false for request no. %d", i) } @@ -74,20 +80,23 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) { rate := 10 * time.Millisecond rl := New( - WithSourceIPLimitPerSecond(float64(1/rate)), - WithSourceIPBurstSize(1), + "rate_limiter", + WithLimitPerSecond(float64(1/rate)), + WithBurstSize(1), ) wg := sync.WaitGroup{} - testFn := func(domain string) func(t *testing.T) { + testFn := func(ip string) func(t *testing.T) { return func(t *testing.T) { wg.Add(1) go func() { defer wg.Done() for i := 0; i < 5; i++ { - got := rl.SourceIPAllowed(domain) + r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) + r.RemoteAddr = ip + got := rl.RequestAllowed(r) require.Truef(t, got, "expected true for request no. %d", i) time.Sleep(rate) } diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index 2986d46b..5b3be98a 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -104,7 +104,7 @@ func TestSourceIPRateLimitMiddleware(t *testing.T) { if tc.expectFail && i >= int(tc.rateLimit) { require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i) - assertLogFound(t, logBuf, []string{"source IP hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""}) + assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""}) continue } -- cgit v1.2.3