diff options
-rw-r--r-- | internal/config/config.go | 4 | ||||
-rw-r--r-- | internal/config/flags.go | 2 | ||||
-rw-r--r-- | internal/feature/feature.go | 8 | ||||
-rw-r--r-- | internal/handlers/ratelimiter.go | 21 | ||||
-rw-r--r-- | internal/handlers/ratelimiter_test.go | 102 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 25 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 181 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 35 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 68 | ||||
-rw-r--r-- | internal/testhelpers/testhelpers.go | 18 | ||||
-rw-r--r-- | metrics/metrics.go | 32 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 172 |
12 files changed, 466 insertions, 202 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index c6f91db0..c29fa65d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,6 +68,8 @@ type General struct { type RateLimit struct { SourceIPLimitPerSecond float64 SourceIPBurst int + DomainLimitPerSecond float64 + DomainBurst int } // ArtifactsServer groups settings related to configuring Artifacts @@ -196,6 +198,8 @@ func loadConfig() (*Config, error) { RateLimit: RateLimit{ SourceIPLimitPerSecond: *rateLimitSourceIP, SourceIPBurst: *rateLimitSourceIPBurst, + DomainLimitPerSecond: *rateLimitDomain, + DomainBurst: *rateLimitDomainBurst, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index 6c9bd4a6..93228827 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -17,6 +17,8 @@ var ( pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled") rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") + rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled") + rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second") artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 0a29f4a4..81eef9a0 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -7,12 +7,18 @@ type Feature struct { defaultEnabled bool } -// EnforceIPRateLimits enforces ratelimiter package to drop requests +// EnforceIPRateLimits enforces IP rate limiter to drop requests // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 var EnforceIPRateLimits = Feature{ EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS", } +// EnforceDomainRateLimits enforces domain rate limiter to drop requests +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655 +var EnforceDomainRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS", +} + // RedirectsPlaceholders enables support for placeholders in redirects file // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620 var RedirectsPlaceholders = Feature{ diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go index 8263f497..52281f6e 100644 --- a/internal/handlers/ratelimiter.go +++ b/internal/handlers/ratelimiter.go @@ -4,14 +4,16 @@ import ( "net/http" "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) // Ratelimiter configures the ratelimiter middleware // TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { - rl := ratelimiter.New( + sourceIPLimiter := ratelimiter.New( "source_ip", ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries), @@ -19,7 +21,22 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount), ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond), ratelimiter.WithBurstSize(config.SourceIPBurst), + ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()), ) - return rl.Middleware(handler) + handler = sourceIPLimiter.Middleware(handler) + + domainLimiter := ratelimiter.New( + "domain", + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), + ratelimiter.WithKeyFunc(request.GetHostWithoutPort), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount), + ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond), + ratelimiter.WithBurstSize(config.DomainBurst), + ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()), + ) + + return domainLimiter.Middleware(handler) } diff --git a/internal/handlers/ratelimiter_test.go b/internal/handlers/ratelimiter_test.go new file mode 100644 index 00000000..43acfc9a --- /dev/null +++ b/internal/handlers/ratelimiter_test.go @@ -0,0 +1,102 @@ +package handlers + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" +) + +var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +}) + +func TestRatelimiter(t *testing.T) { + tt := map[string]struct { + firstRemoteAddr string + firstTarget string + secondRemoteAddr string + secondTarget string + sourceIPEnforced bool + domainEnforced bool + expectedSecondCode int + }{ + "rejected_by_ip": { + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.1", + secondTarget: "https://different.gitlab.io", + sourceIPEnforced: true, + domainEnforced: true, + expectedSecondCode: http.StatusTooManyRequests, + }, + "rejected_by_domain": { + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "https://domain.gitlab.io", + sourceIPEnforced: true, + domainEnforced: true, + expectedSecondCode: http.StatusTooManyRequests, + }, + "ip_rate_limiter_disabled": { + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.1", + secondTarget: "https://different.gitlab.io", + sourceIPEnforced: false, + domainEnforced: true, + expectedSecondCode: http.StatusNoContent, + }, + "domain_rate_limiter_disabled": { + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "https://domain.gitlab.io", + sourceIPEnforced: true, + domainEnforced: false, + expectedSecondCode: http.StatusNoContent, + }, + "different_ip_and_domain_passes": { + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "https://different.gitlab.io", + sourceIPEnforced: true, + domainEnforced: true, + expectedSecondCode: http.StatusNoContent, + }, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + testhelpers.StubFeatureFlagValue(t, feature.EnforceIPRateLimits.EnvVariable, tc.sourceIPEnforced) + testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, tc.domainEnforced) + + conf := config.RateLimit{ + SourceIPLimitPerSecond: 0.1, + SourceIPBurst: 1, + DomainLimitPerSecond: 0.1, + DomainBurst: 1, + } + + handler := Ratelimiter(next, &conf) + + r1 := httptest.NewRequest(http.MethodGet, tc.firstTarget, nil) + r1.RemoteAddr = tc.firstRemoteAddr + + firstCode, _ := testhelpers.PerformRequest(t, handler, r1) + require.Equal(t, http.StatusNoContent, firstCode) + + r2 := httptest.NewRequest(http.MethodGet, tc.secondTarget, nil) + r2.RemoteAddr = tc.secondRemoteAddr + secondCode, _ := testhelpers.PerformRequest(t, handler, r2) + require.Equal(t, tc.expectedSecondCode, secondCode) + }) + } +} diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index 155a87a6..af7b0881 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -2,6 +2,7 @@ package ratelimiter import ( "net/http" + "strconv" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" @@ -25,20 +26,20 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !rl.requestAllowed(r) { - rl.logRateLimitedRequest(r) + if rl.requestAllowed(r) { + handler.ServeHTTP(w, r) + return + } + + rl.logRateLimitedRequest(r) - if feature.EnforceIPRateLimits.Enabled() { - if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues("true").Inc() - } - httperrors.Serve429(w) - return - } + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc() + } - if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues("false").Inc() - } + if rl.enforce { + httperrors.Serve429(w) + return } handler.ServeHTTP(w, r) diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index 95d95e95..1f753fc4 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -1,16 +1,18 @@ package ratelimiter import ( - "io" "net/http" "net/http/httptest" + "strconv" "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" ) @@ -22,41 +24,32 @@ var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) }) -func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { +func TestMiddlewareWithDifferentLimits(t *testing.T) { hook := testlog.NewGlobal() - testhelpers.StubFeatureFlagValue(t, feature.EnforceIPRateLimits.EnvVariable, true) for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { rl := New( "rate_limiter", WithNow(mockNow), - WithLimitPerSecond(tc.sourceIPLimit), - WithBurstSize(tc.sourceIPBurstSize), + WithLimitPerSecond(tc.limit), + WithBurstSize(tc.burstSize), + WithEnforce(true), ) - for i := 0; i < tc.reqNum; i++ { - ww := httptest.NewRecorder() - rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) - rr.RemoteAddr = remoteAddr - - handler := rl.Middleware(next) + handler := rl.Middleware(next) - handler.ServeHTTP(ww, rr) - res := ww.Result() + for i := 0; i < tc.reqNum; i++ { + r := requestFor(remoteAddr, "http://gitlab.com") + code, body := testhelpers.PerformRequest(t, handler, r) - if i < tc.sourceIPBurstSize { - require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i) + if i < tc.burstSize { + require.Equal(t, http.StatusNoContent, code, "req: %d failed", i) } else { // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow // always returns the same time - require.Equal(t, http.StatusTooManyRequests, res.StatusCode, "req: %d failed", i) - b, err := io.ReadAll(res.Body) - require.NoError(t, err) - - require.Contains(t, string(b), "Too many requests.") - res.Body.Close() - + require.Equal(t, http.StatusTooManyRequests, code, "req: %d failed", i) + require.Contains(t, body, "Too many requests.") assertSourceIPLog(t, remoteAddr, hook) } } @@ -64,26 +57,28 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { } } -func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { +func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) { hook := testlog.NewGlobal() blocked, cachedEntries, cacheReqs := newTestMetrics(t) tcs := map[string]struct { - enabled bool + enforce bool expectedStatus int }{ "disabled_rate_limit_http": { - enabled: false, + enforce: false, expectedStatus: http.StatusNoContent, }, "enabled_rate_limit_http_blocks": { - enabled: true, + enforce: true, expectedStatus: http.StatusTooManyRequests, }, } for tn, tc := range tcs { t.Run(tn, func(t *testing.T) { + testhelpers.StubFeatureFlagValue(t, feature.EnforceIPRateLimits.EnvVariable, tc.enforce) + rl := New( "rate_limiter", WithCachedEntriesMetric(cachedEntries), @@ -92,37 +87,28 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { WithNow(mockNow), WithLimitPerSecond(1), WithBurstSize(1), + WithEnforce(tc.enforce), ) - for i := 0; i < 5; i++ { - ww := httptest.NewRecorder() - rr := httptest.NewRequest(http.MethodGet, "http://gitlab.com", nil) - testhelpers.StubFeatureFlagValue(t, feature.EnforceIPRateLimits.EnvVariable, tc.enabled) - - rr.RemoteAddr = remoteAddr + // middleware is evaluated in reverse order + handler := rl.Middleware(next) - // middleware is evaluated in reverse order - handler := rl.Middleware(next) - - handler.ServeHTTP(ww, rr) - res := ww.Result() + for i := 0; i < 5; i++ { + r := requestFor(remoteAddr, "http://gitlab.com") + code, _ := testhelpers.PerformRequest(t, handler, r) if i == 0 { - require.Equal(t, http.StatusNoContent, res.StatusCode) + require.Equal(t, http.StatusNoContent, code) continue } // burst is 1 and limit is 1 per second, all subsequent requests should fail - require.Equal(t, tc.expectedStatus, res.StatusCode) + require.Equal(t, tc.expectedStatus, code) assertSourceIPLog(t, remoteAddr, hook) } - blockedCount := testutil.ToFloat64(blocked.WithLabelValues("true")) - if tc.enabled { - require.Equal(t, float64(4), blockedCount, "blocked count") - } else { - require.Equal(t, float64(0), blockedCount, "blocked count") - } + blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce))) + require.Equal(t, float64(4), blockedCount, "blocked count") blocked.Reset() cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("rate_limiter")) @@ -138,6 +124,90 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { } } +func TestKeyFunc(t *testing.T) { + tt := map[string]struct { + keyFunc KeyFunc + firstRemoteAddr string + firstTarget string + secondRemoteAddr string + secondTarget string + expectedSecondCode int + }{ + "rejected_by_ip": { + keyFunc: request.GetRemoteAddrWithoutPort, + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.1", + secondTarget: "https://different.gitlab.io", + expectedSecondCode: http.StatusTooManyRequests, + }, + "rejected_by_ip_with_different_port": { + keyFunc: request.GetRemoteAddrWithoutPort, + firstRemoteAddr: "10.0.0.1:41000", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.1:41001", + secondTarget: "https://different.gitlab.io", + expectedSecondCode: http.StatusTooManyRequests, + }, + "rejected_by_domain": { + keyFunc: request.GetHostWithoutPort, + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "https://domain.gitlab.io", + expectedSecondCode: http.StatusTooManyRequests, + }, + "rejected_by_domain_with_different_protocol": { + keyFunc: request.GetHostWithoutPort, + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "http://domain.gitlab.io", + expectedSecondCode: http.StatusTooManyRequests, + }, + "domain_limiter_allows_same_ip": { + keyFunc: request.GetHostWithoutPort, + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.1", + secondTarget: "https://different.gitlab.io", + expectedSecondCode: http.StatusNoContent, + }, + "ip_limiter_allows_same_domain": { + keyFunc: request.GetRemoteAddrWithoutPort, + firstRemoteAddr: "10.0.0.1", + firstTarget: "https://domain.gitlab.io", + secondRemoteAddr: "10.0.0.2", + secondTarget: "https://domain.gitlab.io", + expectedSecondCode: http.StatusNoContent, + }, + } + + for name, tc := range tt { + t.Run(name, func(t *testing.T) { + handler := New( + "rate_limiter", + WithNow(mockNow), + WithLimitPerSecond(1), + WithBurstSize(1), + WithKeyFunc(tc.keyFunc), + WithEnforce(true), + ).Middleware(next) + + r1 := httptest.NewRequest(http.MethodGet, tc.firstTarget, nil) + r1.RemoteAddr = tc.firstRemoteAddr + + firstCode, _ := testhelpers.PerformRequest(t, handler, r1) + require.Equal(t, http.StatusNoContent, firstCode) + + r2 := httptest.NewRequest(http.MethodGet, tc.secondTarget, nil) + r2.RemoteAddr = tc.secondRemoteAddr + secondCode, _ := testhelpers.PerformRequest(t, handler, r2) + require.Equal(t, tc.expectedSecondCode, secondCode) + }) + } +} + func assertSourceIPLog(t *testing.T, remoteAddr string, hook *testlog.Hook) { t.Helper() @@ -148,3 +218,24 @@ func assertSourceIPLog(t *testing.T, remoteAddr string, hook *testlog.Hook) { hook.Reset() } + +func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *prometheus.CounterVec) { + t.Helper() + + blockedGauge := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: t.Name(), + }, + []string{"enforced"}, + ) + + cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: t.Name(), + }, []string{"op"}) + + cacheReqs := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: t.Name(), + }, []string{"op", "cache"}) + + return blockedGauge, cachedEntries, cacheReqs +} diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index 510a44ae..feeb8cb4 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -15,6 +15,10 @@ const ( // based on an avg ~4,000 unique IPs per minute // https://log.gprd.gitlab.net/app/lens#/edit/f7110d00-2013-11ec-8c8e-ed83b5469915?_g=h@e78830b DefaultSourceIPCacheSize = 5000 + + // we have less than 4000 different hosts per minute + // https://log.gprd.gitlab.net/app/dashboards#/view/d52ab740-61a4-11ec-b20d-65f14d890d9b?_a=(viewMode:edit)&_g=h@42b0d52 + DefaultDomainCacheSize = 4000 ) // Option function to configure a RateLimiter @@ -30,11 +34,12 @@ type KeyFunc func(*http.Request) string type RateLimiter struct { name string now func() time.Time + keyFunc KeyFunc limitPerSecond float64 burstSize int blockedCount *prometheus.GaugeVec cache *lru.Cache - key KeyFunc + enforce bool cacheOptions []lru.Option } @@ -42,9 +47,9 @@ type RateLimiter struct { // New creates a new RateLimiter with default values that can be configured via Option functions func New(name string, opts ...Option) *RateLimiter { rl := &RateLimiter{ - name: name, - now: time.Now, - key: request.GetRemoteAddrWithoutPort, + name: name, + now: time.Now, + keyFunc: request.GetRemoteAddrWithoutPort, } for _, opt := range opts { @@ -72,7 +77,7 @@ func WithLimitPerSecond(limit float64) Option { } } -// WithBurstSize configures burst per key for the RateLimiter +// WithBurstSize configures burst per keyFunc value for the RateLimiter func WithBurstSize(burst int) Option { return func(rl *RateLimiter) { rl.burstSize = burst @@ -101,13 +106,27 @@ func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option { } } -// WithCachedRequestsMetric configures metric for how many times we ask key cache +// WithCachedRequestsMetric configures metric for how many times we access cache func WithCachedRequestsMetric(m *prometheus.CounterVec) Option { return func(rl *RateLimiter) { rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m)) } } +// WithKeyFunc configures keyFunc +func WithKeyFunc(f KeyFunc) Option { + return func(rl *RateLimiter) { + rl.keyFunc = f + } +} + +// WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics +func WithEnforce(enforce bool) Option { + return func(rl *RateLimiter) { + rl.enforce = enforce + } +} + func (rl *RateLimiter) limiter(key string) *rate.Limiter { limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) { return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil @@ -116,9 +135,9 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter { return limiterI.(*rate.Limiter) } -// requestAllowed checks that the real remote IP address is allowed to perform an operation +// requestAllowed checks if request is within the rate-limit func (rl *RateLimiter) requestAllowed(r *http.Request) bool { - rateLimitedKey := rl.key(r) + rateLimitedKey := rl.keyFunc(r) limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index 926a90c1..930cbd06 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" ) @@ -20,33 +19,33 @@ func mockNow() time.Time { } var sharedTestCases = map[string]struct { - sourceIPLimit float64 - sourceIPBurstSize int - reqNum int + limit float64 + burstSize int + reqNum int }{ "one_request_per_second": { - sourceIPLimit: 1, - sourceIPBurstSize: 1, - reqNum: 2, + limit: 1, + burstSize: 1, + reqNum: 2, }, "one_request_per_second_but_big_bucket": { - sourceIPLimit: 1, - sourceIPBurstSize: 10, - reqNum: 11, + limit: 1, + burstSize: 10, + reqNum: 11, }, "three_req_per_second_bucket_size_one": { - sourceIPLimit: 3, - sourceIPBurstSize: 1, // max burst 1 means 1 at a time - reqNum: 3, + limit: 3, + burstSize: 1, // max burst 1 means 1 at a time + reqNum: 3, }, "10_requests_per_second": { - sourceIPLimit: 10, - sourceIPBurstSize: 10, - reqNum: 11, + limit: 10, + burstSize: 10, + reqNum: 11, }, } -func TestSourceIPAllowed(t *testing.T) { +func TestRequestAllowed(t *testing.T) { t.Parallel() for tn, tc := range sharedTestCases { @@ -54,16 +53,15 @@ func TestSourceIPAllowed(t *testing.T) { rl := New( "rate_limiter", WithNow(mockNow), - WithLimitPerSecond(tc.sourceIPLimit), - WithBurstSize(tc.sourceIPBurstSize), + WithLimitPerSecond(tc.limit), + WithBurstSize(tc.burstSize), ) for i := 0; i < tc.reqNum; i++ { - r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) - r.RemoteAddr = "172.16.123.1" + r := requestFor("172.16.123.1", "https://domain.gitlab.io") got := rl.requestAllowed(r) - if i < tc.sourceIPBurstSize { + if i < tc.burstSize { require.Truef(t, got, "expected true for request no. %d", i) } else { // requests should fail after reaching tc.burstSize because mockNow @@ -88,8 +86,7 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) { ) testRequest := func(ip string, i int) { - r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) - r.RemoteAddr = ip + r := requestFor(ip, "https://domain.gitlab.io") got := rl.requestAllowed(r) require.Truef(t, got, "expected true for %v request no. %d", ip, i) } @@ -102,23 +99,8 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) { } } -func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *prometheus.CounterVec) { - t.Helper() - - blockedGauge := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: t.Name(), - }, - []string{"enforced"}, - ) - - cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: t.Name(), - }, []string{"op"}) - - cacheReqs := prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: t.Name(), - }, []string{"op", "cache"}) - - return blockedGauge, cachedEntries, cacheReqs +func requestFor(remoteAddr, domain string) *http.Request { + r := httptest.NewRequest(http.MethodGet, domain, nil) + r.RemoteAddr = remoteAddr + return r } diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go index df2252d7..bb02b698 100644 --- a/internal/testhelpers/testhelpers.go +++ b/internal/testhelpers/testhelpers.go @@ -2,6 +2,7 @@ package testhelpers import ( "fmt" + "io" "mime" "net/http" "net/http/httptest" @@ -89,10 +90,25 @@ func SetEnvironmentVariable(t testing.TB, key, value string) { require.NoError(t, err) t.Cleanup(func() { - os.Setenv(key, orig) + require.NoError(t, os.Setenv(key, orig)) }) } func StubFeatureFlagValue(t testing.TB, envVar string, value bool) { SetEnvironmentVariable(t, envVar, strconv.FormatBool(value)) } + +func PerformRequest(t *testing.T, handler http.Handler, r *http.Request) (int, string) { + t.Helper() + + ww := httptest.NewRecorder() + + handler.ServeHTTP(ww, r) + res := ww.Result() + + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, res.Body.Close()) + + return res.StatusCode, string(b) +} diff --git a/metrics/metrics.go b/metrics/metrics.go index 680f7a05..e0e4ab29 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -211,12 +211,40 @@ var ( []string{"op"}, ) - // RateLimitSourceIPBlockedCount is the number of source IPs that have been blocked by the + // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the // source IP rate limiter RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Name: "gitlab_pages_rate_limit_source_ip_blocked_count", - Help: "The number of source IP addresses that have been blocked by the rate limiter", + Help: "The number of requests that have been blocked by the IP rate limiter", + }, + []string{"enforced"}, + ) + + // RateLimitDomainCacheRequests is the number of cache hits/misses + RateLimitDomainCacheRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_pages_rate_limit_domain_cache_requests", + Help: "The number of source_ip cache hits/misses in the rate limiter", + }, + []string{"op", "cache"}, + ) + + // RateLimitDomainCachedEntries is the number of entries in the cache + RateLimitDomainCachedEntries = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_cached_entries", + Help: "The number of entries in the cache", + }, + []string{"op"}, + ) + + // RateLimitDomainBlockedCount is the number of requests that have been blocked by the + // domain rate limiter + RateLimitDomainBlockedCount = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_blocked_count", + Help: "The number of requests addresses that have been blocked by the domain rate limiter", }, []string{"enforced"}, ) diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index ae6184ca..037eef96 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -6,112 +6,108 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" + + "github.com/stretchr/testify/require" ) -func TestSourceIPRateLimitMiddleware(t *testing.T) { +var ratelimitedListeners = map[string]struct { + listener ListenSpec + header http.Header + clientIP string + // We perform requests to server while we're waiting for it to boot up, + // successful request gets counted in IP rate limit + includeWaitRequest bool +}{ + "http_listener": { + listener: httpListener, + clientIP: "127.0.0.1", + includeWaitRequest: true, + }, + "https_listener": { + listener: httpsListener, + clientIP: "127.0.0.1", + includeWaitRequest: true, + }, + "proxy_listener": { + listener: proxyListener, + header: http.Header{ + "X-Forwarded-For": []string{"172.16.123.1"}, + "X-Forwarded-Host": []string{"group.gitlab-example.com"}, + }, + clientIP: "172.16.123.1", + }, + "proxyv2_listener": { + listener: httpsProxyv2Listener, + clientIP: "10.1.1.1", + includeWaitRequest: true, + }, +} + +func TestIPRateLimits(t *testing.T) { testhelpers.StubFeatureFlagValue(t, feature.EnforceIPRateLimits.EnvVariable, true) - tcs := map[string]struct { - listener ListenSpec - rateLimit float64 - rateBurst string - blockedIP string - header http.Header - expectFail bool - sleep time.Duration - }{ - "http_slow_requests_should_not_be_blocked": { - listener: httpListener, - rateLimit: 1000, - // RunPagesProcess makes one request, so we need to allow a burst of 2 - // because r.RemoteAddr == 127.0.0.1 and X-Forwarded-For is ignored for non-proxy requests - rateBurst: "2", - sleep: 10 * time.Millisecond, - }, - "https_slow_requests_should_not_be_blocked": { - listener: httpsListener, - rateLimit: 1000, - rateBurst: "2", - sleep: 10 * time.Millisecond, - }, - "proxy_slow_requests_should_not_be_blocked": { - listener: proxyListener, - rateLimit: 1000, - // listen-proxy uses X-Forwarded-For - rateBurst: "1", - header: http.Header{ - "X-Forwarded-For": []string{"172.16.123.1"}, - "X-Forwarded-Host": []string{"group.gitlab-example.com"}, - }, - sleep: 10 * time.Millisecond, - }, - "proxyv2_slow_requests_should_not_be_blocked": { - listener: httpsProxyv2Listener, - rateLimit: 1000, - rateBurst: "2", - sleep: 10 * time.Millisecond, - }, - "http_fast_requests_blocked_after_burst": { - listener: httpListener, - rateLimit: 1, - rateBurst: "2", - expectFail: true, - blockedIP: "127.0.0.1", - }, - "https_fast_requests_blocked_after_burst": { - listener: httpsListener, - rateLimit: 1, - rateBurst: "2", - expectFail: true, - blockedIP: "127.0.0.1", - }, - "proxy_fast_requests_blocked_after_burst": { - listener: proxyListener, - rateLimit: 1, - rateBurst: "1", - header: http.Header{ - "X-Forwarded-For": []string{"172.16.123.1"}, - "X-Forwarded-Host": []string{"group.gitlab-example.com"}, - }, - expectFail: true, - blockedIP: "172.16.123.1", - }, - "proxyv2_fast_requests_blocked_after_burst": { - listener: httpsProxyv2Listener, - rateLimit: 1, - rateBurst: "2", - expectFail: true, - // use TestProxyv2Client SourceIP - blockedIP: "10.1.1.1", - }, + for name, tc := range ratelimitedListeners { + t.Run(name, func(t *testing.T) { + rateLimit := 5 + logBuf := RunPagesProcess(t, + withListeners([]ListenSpec{tc.listener}), + withExtraArgument("rate-limit-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-source-ip-burst", fmt.Sprint(rateLimit)), + ) + + if tc.includeWaitRequest { + rateLimit-- // we've already used one of requests while checking if server is up + } + + for i := 0; i < 10; i++ { + rsp, err := GetPageFromListenerWithHeaders(t, tc.listener, "group.gitlab-example.com", "project/", tc.header) + require.NoError(t, err) + require.NoError(t, rsp.Body.Close()) + + if i >= rateLimit { + require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i) + assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.clientIP + "\""}) + } else { + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) + } + } + }) } +} - for tn, tc := range tcs { - t.Run(tn, func(t *testing.T) { +func TestDomainateLimits(t *testing.T) { + testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true) + + for name, tc := range ratelimitedListeners { + t.Run(name, func(t *testing.T) { + rateLimit := 5 logBuf := RunPagesProcess(t, withListeners([]ListenSpec{tc.listener}), - withExtraArgument("rate-limit-source-ip", fmt.Sprint(tc.rateLimit)), - withExtraArgument("rate-limit-source-ip-burst", tc.rateBurst), + withExtraArgument("rate-limit-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-domain-burst", fmt.Sprint(rateLimit)), ) - for i := 0; i < 5; i++ { + for i := 0; i < 10; i++ { rsp, err := GetPageFromListenerWithHeaders(t, tc.listener, "group.gitlab-example.com", "project/", tc.header) require.NoError(t, err) - rsp.Body.Close() + require.NoError(t, rsp.Body.Close()) - if tc.expectFail && i >= int(tc.rateLimit) { + if i >= rateLimit { require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i) - assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""}) - continue + assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.clientIP + "\""}) + } else { + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) } - - require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) - time.Sleep(tc.sleep) } + + // make sure that requests to other domains are passing + rsp, err := GetPageFromListener(t, tc.listener, "CapitalGroup.gitlab-example.com", "project/") + require.NoError(t, err) + require.NoError(t, rsp.Body.Close()) + + require.Equal(t, http.StatusOK, rsp.StatusCode, "request to unrelated domain failed") }) } } |