diff options
author | ngala <ngala@gitlab.com> | 2024-01-24 10:26:33 +0300 |
---|---|---|
committer | ngala <ngala@gitlab.com> | 2024-01-24 10:38:26 +0300 |
commit | bc8e55a3e6592001f8e29b58bc32b74c7e768d51 (patch) | |
tree | 4655b9c8f034fcde3490656fb001a0bce83d64b1 | |
parent | ce1ac9056ca1375bc7fc08a8563c91dbf74fbcb6 (diff) |
Add feature flag FF_RATE_LIMITER_CLOSE_CONNECTION
-rw-r--r-- | internal/feature/feature.go | 6 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 3 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 97 |
3 files changed, 104 insertions, 2 deletions
diff --git a/internal/feature/feature.go b/internal/feature/feature.go index b01e9c7a..1b97e935 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -43,3 +43,9 @@ var DomainRedirects = Feature{ EnvVariable: "FF_ENABLE_DOMAIN_REDIRECT", defaultEnabled: false, } + +// RateLimiterCloseConnection enables support for rate limiter close connection +var RateLimiterCloseConnection = Feature{ + EnvVariable: "FF_RATE_LIMITER_CLOSE_CONNECTION", + defaultEnabled: false, +} diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index 29519d5d..84dd6453 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -5,6 +5,7 @@ import ( "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/request" @@ -34,7 +35,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { rl.blockedCount.WithLabelValues(rl.name).Inc() } - if rl.closeConnection { + if feature.RateLimiterCloseConnection.Enabled() && rl.closeConnection { w.Header().Set("Connection", "close") } httperrors.Serve429(w) diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index 79f63a16..4d3e01be 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -10,6 +10,7 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" ) @@ -39,6 +40,99 @@ func TestMiddlewareWithDifferentLimits(t *testing.T) { for i := 0; i < tc.reqNum; i++ { r := requestFor(remoteAddr, "http://gitlab.com") + code, body, _ := testhelpers.PerformRequest(t, handler, r) + + if i < tc.burstSize { + require.Equal(t, http.StatusNoContent, code, "req: %d failed", i) + } else { + // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow + // always returns the same time + require.Equal(t, http.StatusTooManyRequests, code, "req: %d failed", i) + require.Contains(t, body, "Too many requests.") + assertSourceIPLog(t, remoteAddr, hook) + } + } + }) + } +} + +func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) { + hook := testlog.NewGlobal() + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + tcs := map[string]struct { + expectedStatus int + }{ + "enabled_rate_limit_http_blocks": { + expectedStatus: http.StatusTooManyRequests, + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + rl := New( + "rate_limiter", + WithCachedEntriesMetric(cachedEntries), + WithCachedRequestsMetric(cacheReqs), + WithBlockedCountMetric(blocked), + WithNow(mockNow), + WithLimitPerSecond(1), + WithBurstSize(1), + WithCloseConnection(true), + ) + + // middleware is evaluated in reverse order + handler := rl.Middleware(next) + + for i := 0; i < 5; i++ { + r := requestFor(remoteAddr, "http://gitlab.com") + code, _, _ := testhelpers.PerformRequest(t, handler, r) + + if i == 0 { + require.Equal(t, http.StatusNoContent, code) + continue + } + + // burst is 1 and limit is 1 per second, all subsequent requests should fail + require.Equal(t, tc.expectedStatus, code) + assertSourceIPLog(t, remoteAddr, hook) + } + + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter")) + require.Equal(t, float64(4), blockedCount, "blocked count") + blocked.Reset() + + cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("rate_limiter")) + require.Equal(t, float64(1), cachedCount, "cached count") + cachedEntries.Reset() + + cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "miss")) + require.Equal(t, float64(1), cacheReqMiss, "miss count") + cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "hit")) + require.Equal(t, float64(4), cacheReqHit, "hit count") + cacheReqs.Reset() + }) + } +} + +func TestMiddlewareWithDifferentLimitsWithFFCloseConnectionEnabled(t *testing.T) { + hook := testlog.NewGlobal() + t.Setenv(feature.RateLimiterCloseConnection.EnvVariable, "true") + + for tn, tc := range sharedTestCases { + t.Run(tn, func(t *testing.T) { + rl := New( + "rate_limiter", + WithNow(mockNow), + WithLimitPerSecond(tc.limit), + WithBurstSize(tc.burstSize), + WithCloseConnection(true), + ) + + handler := rl.Middleware(next) + + for i := 0; i < tc.reqNum; i++ { + r := requestFor(remoteAddr, "http://gitlab.com") code, body, header := testhelpers.PerformRequest(t, handler, r) if i < tc.burstSize { @@ -56,8 +150,9 @@ func TestMiddlewareWithDifferentLimits(t *testing.T) { } } -func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) { +func TestMiddlewareDenyRequestsAfterBurstWithFFCloseConnectionEnabled(t *testing.T) { hook := testlog.NewGlobal() + t.Setenv(feature.RateLimiterCloseConnection.EnvVariable, "true") blocked, cachedEntries, cacheReqs := newTestMetrics(t) tcs := map[string]struct { |