diff options
-rw-r--r-- | app.go | 13 | ||||
-rw-r--r-- | internal/config/config.go | 11 | ||||
-rw-r--r-- | internal/config/flags.go | 2 | ||||
-rw-r--r-- | internal/httperrors/httperrors.go | 13 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 66 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 149 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 9 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 85 | ||||
-rw-r--r-- | internal/request/request.go | 10 | ||||
-rw-r--r-- | internal/testhelpers/testhelpers.go | 18 | ||||
-rw-r--r-- | metrics/metrics.go | 31 | ||||
-rw-r--r-- | test/acceptance/helpers_test.go | 15 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 128 |
13 files changed, 516 insertions, 34 deletions
@@ -32,6 +32,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/routing" @@ -262,6 +263,18 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = routing.NewMiddleware(handler, a.source) + if a.config.RateLimit.SourceIPLimitPerSecond > 0 { + rl := ratelimiter.New( + metrics.RateLimitSourceIPBlockedCount, + metrics.RateLimitSourceIPCachedEntries, + metrics.RateLimitSourceIPCacheRequests, + ratelimiter.WithSourceIPLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond), + ratelimiter.WithSourceIPBurstSize(a.config.RateLimit.SourceIPBurst), + ) + + handler = rl.SourceIPLimiter(handler) + } + // Health Check handler, err = a.healthCheckMiddleware(handler) if err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 94c22328..3e03f7d6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,7 @@ import ( // Config stores all the config options relevant to GitLab Pages. type Config struct { General General + RateLimit RateLimit ArtifactsServer ArtifactsServer Authentication Auth GitLab GitLab @@ -62,6 +63,12 @@ type General struct { CustomHeaders []string } +// RateLimit config struct +type RateLimit struct { + SourceIPLimitPerSecond float64 + SourceIPBurst int +} + // ArtifactsServer groups settings related to configuring Artifacts // server type ArtifactsServer struct { @@ -184,6 +191,10 @@ func loadConfig() (*Config, error) { CustomHeaders: header.Split(), ShowVersion: *showVersion, }, + RateLimit: RateLimit{ + SourceIPLimitPerSecond: *rateLimitSourceIP, + SourceIPBurst: *rateLimitSourceIPBurst, + }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, JWTTokenExpiration: *gitlabClientJWTExpiry, diff --git a/internal/config/flags.go b/internal/config/flags.go index 52b7be18..c61447c7 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -15,6 +15,8 @@ var ( _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") + rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled") + rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go index ed56ee10..8e61d590 100644 --- a/internal/httperrors/httperrors.go +++ b/internal/httperrors/httperrors.go @@ -34,6 +34,14 @@ var ( <p>Make sure the address is correct and that the page hasn't moved.</p> <p>Please contact your GitLab administrator if you think this is a mistake.</p>`, } + + content429 = content{ + http.StatusTooManyRequests, + "Too many requests (429)", + "429", + "Too many requests.", + `<p>The resource that you are attempting to access is being rate limited.</p>`, + } content500 = content{ http.StatusInternalServerError, "Something went wrong (500)", @@ -176,6 +184,11 @@ func Serve404(w http.ResponseWriter) { serveErrorPage(w, content404) } +// Serve429 returns a 429 error response / HTML page to the http.ResponseWriter +func Serve429(w http.ResponseWriter) { + serveErrorPage(w, content429) +} + // Serve500 returns a 500 error response / HTML page to the http.ResponseWriter func Serve500(w http.ResponseWriter) { serveErrorPage(w, content500) diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go new file mode 100644 index 00000000..0cd5b81e --- /dev/null +++ b/internal/ratelimiter/middleware.go @@ -0,0 +1,66 @@ +package ratelimiter + +import ( + "net/http" + "os" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/log" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" +) + +const ( + headerGitLabRealIP = "GitLab-Real-IP" + headerXForwardedFor = "X-Forwarded-For" + headerXForwardedProto = "X-Forwarded-Proto" +) + +// SourceIPLimiter returns middleware for rate-limiting clients based on their IP +func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, sourceIP := request.GetHostWithoutPort(r), request.GetRemoteAddrWithoutPort(r) + if !rl.SourceIPAllowed(sourceIP) { + rl.logSourceIP(r, host, sourceIP) + + // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled + // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 + if rateLimiterEnabled() { + rl.sourceIPBlockedCount.WithLabelValues("true").Inc() + httperrors.Serve429(w) + return + } + + rl.sourceIPBlockedCount.WithLabelValues("false").Inc() + } + + handler.ServeHTTP(w, r) + }) +} + +func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) { + log.WithFields(logrus.Fields{ + "handler": "source_ip_rate_limiter", + "correlation_id": correlation.ExtractFromContext(r.Context()), + "req_scheme": r.URL.Scheme, + "req_host": r.Host, + "req_path": r.URL.Path, + "pages_domain": host, + "remote_addr": r.RemoteAddr, + "source_ip": sourceIP, + "x_forwarded_proto": r.Header.Get(headerXForwardedProto), + "x_forwarded_for": r.Header.Get(headerXForwardedFor), + "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), + "rate_limiter_enabled": rateLimiterEnabled(), + "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond, + "rate_limiter_burst_size": rl.sourceIPBurstSize, + }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 + Info("source IP hit rate limit") +} + +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 +func rateLimiterEnabled() bool { + return os.Getenv("FF_ENABLE_RATE_LIMITER") == "true" +} diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go new file mode 100644 index 00000000..2e51fcad --- /dev/null +++ b/internal/ratelimiter/middleware_test.go @@ -0,0 +1,149 @@ +package ratelimiter + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" +) + +const ( + remoteAddr = "192.168.1.1" +) + +var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +}) + +func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { + hook := testlog.NewGlobal() + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true") + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + for tn, tc := range sharedTestCases { + t.Run(tn, func(t *testing.T) { + rl := New(blocked, cachedEntries, cacheReqs, + WithNow(mockNow), + WithSourceIPLimitPerSecond(tc.sourceIPLimit), + WithSourceIPBurstSize(tc.sourceIPBurstSize), + ) + + for i := 0; i < tc.reqNum; i++ { + ww := httptest.NewRecorder() + rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) + rr.RemoteAddr = remoteAddr + + handler := rl.SourceIPLimiter(next) + + handler.ServeHTTP(ww, rr) + res := ww.Result() + + if i < tc.sourceIPBurstSize { + require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i) + } else { + // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow + // always returns the same time + require.Equal(t, http.StatusTooManyRequests, res.StatusCode, "req: %d failed", i) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, string(b), "Too many requests.") + res.Body.Close() + + assertSourceIPLog(t, remoteAddr, hook) + } + } + }) + } +} + +func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { + hook := testlog.NewGlobal() + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + tcs := map[string]struct { + enabled bool + expectedStatus int + }{ + "disabled_rate_limit_http": { + enabled: false, + expectedStatus: http.StatusNoContent, + }, + "enabled_rate_limit_http_blocks": { + enabled: true, + expectedStatus: http.StatusTooManyRequests, + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + rl := New(blocked, cachedEntries, cacheReqs, + WithNow(mockNow), + WithSourceIPLimitPerSecond(1), + WithSourceIPBurstSize(1), + ) + + for i := 0; i < 5; i++ { + ww := httptest.NewRecorder() + rr := httptest.NewRequest(http.MethodGet, "http://gitlab.com", nil) + if tc.enabled { + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true") + } else { + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "false") + } + + rr.RemoteAddr = remoteAddr + + // middleware is evaluated in reverse order + handler := rl.SourceIPLimiter(next) + + handler.ServeHTTP(ww, rr) + res := ww.Result() + + if i == 0 { + require.Equal(t, http.StatusNoContent, res.StatusCode) + continue + } + + // burst is 1 and limit is 1 per second, all subsequent requests should fail + require.Equal(t, tc.expectedStatus, res.StatusCode) + assertSourceIPLog(t, remoteAddr, hook) + } + + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("true")) + if tc.enabled { + require.Equal(t, float64(4), blockedCount, "blocked count") + } else { + require.Equal(t, float64(0), blockedCount, "blocked count") + } + blocked.Reset() + + cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("source_ip")) + require.Equal(t, float64(1), cachedCount, "cached count") + cachedEntries.Reset() + + cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "miss")) + require.Equal(t, float64(1), cacheReqMiss, "miss count") + cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "hit")) + require.Equal(t, float64(4), cacheReqHit, "hit count") + cacheReqs.Reset() + }) + } +} + +func assertSourceIPLog(t *testing.T, remoteAddr string, hook *testlog.Hook) { + t.Helper() + + require.NotNil(t, hook.LastEntry()) + + // source_ip that was rate limited + require.Equal(t, remoteAddr, hook.LastEntry().Data["source_ip"]) + + hook.Reset() +} diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index e1cf076d..1359b19c 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -36,23 +36,24 @@ type RateLimiter struct { now func() time.Time sourceIPLimitPerSecond float64 sourceIPBurstSize int + sourceIPBlockedCount *prometheus.GaugeVec sourceIPCache *lru.Cache // TODO: add domainCache https://gitlab.com/gitlab-org/gitlab-pages/-/issues/630 } // New creates a new RateLimiter with default values that can be configured via Option functions -func New(opts ...Option) *RateLimiter { +func New(blockCountMetric, cachedEntriesMetric *prometheus.GaugeVec, cacheRequestsMetric *prometheus.CounterVec, opts ...Option) *RateLimiter { rl := &RateLimiter{ now: time.Now, sourceIPLimitPerSecond: DefaultSourceIPLimitPerSecond, sourceIPBurstSize: DefaultSourceIPBurstSize, + sourceIPBlockedCount: blockCountMetric, sourceIPCache: lru.New( "source_ip", defaultSourceIPItems, defaultSourceIPExpirationInterval, - // TODO: @jaime to add proper metrics in subsequent MR - prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"op"}), - prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"op", "cache"}), + cachedEntriesMetric, + cacheRequestsMetric, ), } diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index cdf12fe6..03393fb0 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" ) @@ -17,40 +18,41 @@ func mockNow() time.Time { return validTime } +var sharedTestCases = map[string]struct { + sourceIPLimit float64 + sourceIPBurstSize int + reqNum int +}{ + "one_request_per_second": { + sourceIPLimit: 1, + sourceIPBurstSize: 1, + reqNum: 2, + }, + "one_request_per_second_but_big_bucket": { + sourceIPLimit: 1, + sourceIPBurstSize: 10, + reqNum: 11, + }, + "three_req_per_second_bucket_size_one": { + sourceIPLimit: 3, + sourceIPBurstSize: 1, // max burst 1 means 1 at a time + reqNum: 3, + }, + "10_requests_per_second": { + sourceIPLimit: 10, + sourceIPBurstSize: 10, + reqNum: 11, + }, +} + func TestSourceIPAllowed(t *testing.T) { t.Parallel() - tcs := map[string]struct { - now string - sourceIPLimit float64 - sourceIPBurstSize int - reqNum int - }{ - "one_request_per_second": { - sourceIPLimit: 1, - sourceIPBurstSize: 1, - reqNum: 2, - }, - "one_request_per_second_but_big_bucket": { - sourceIPLimit: 1, - sourceIPBurstSize: 10, - reqNum: 11, - }, - "three_req_per_second_bucket_size_one": { - sourceIPLimit: 3, - sourceIPBurstSize: 1, // max burst 1 means 1 at a time - reqNum: 3, - }, - "10_requests_per_second": { - sourceIPLimit: 10, - sourceIPBurstSize: 10, - reqNum: 11, - }, - } + blocked, cachedEntries, cacheReqs := newTestMetrics(t) - for tn, tc := range tcs { + for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { - rl := New( + rl := New(blocked, cachedEntries, cacheReqs, WithNow(mockNow), WithSourceIPLimitPerSecond(tc.sourceIPLimit), WithSourceIPBurstSize(tc.sourceIPBurstSize), @@ -72,7 +74,9 @@ func TestSourceIPAllowed(t *testing.T) { func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) { rate := 10 * time.Millisecond - rl := New( + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + rl := New(blocked, cachedEntries, cacheReqs, WithSourceIPLimitPerSecond(float64(1/rate)), WithSourceIPBurstSize(1), ) @@ -105,3 +109,24 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) { wg.Wait() } + +func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *prometheus.CounterVec) { + t.Helper() + + blockedGauge := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: t.Name(), + }, + []string{"enforced"}, + ) + + cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: t.Name(), + }, []string{"op"}) + + cacheReqs := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: t.Name(), + }, []string{"op", "cache"}) + + return blockedGauge, cachedEntries, cacheReqs +} diff --git a/internal/request/request.go b/internal/request/request.go index cbda16e5..77cc4a76 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -55,3 +55,13 @@ func GetHostWithoutPort(r *http.Request) string { return host } + +// GetRemoteAddrWithoutPort strips the port from the r.RemoteAddr +func GetRemoteAddrWithoutPort(r *http.Request) string { + remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + + return remoteAddr +} diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go index 3ec97a79..de48cd7a 100644 --- a/internal/testhelpers/testhelpers.go +++ b/internal/testhelpers/testhelpers.go @@ -13,6 +13,10 @@ import ( "github.com/stretchr/testify/require" ) +// FFEnableRateLimiter enforces ratelimiter package to drop requests +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 +const FFEnableRateLimiter = "FF_ENABLE_RATE_LIMITER" + // AssertHTTP404 asserts handler returns 404 with provided str body func AssertHTTP404(t *testing.T, handler http.HandlerFunc, mode, url string, values url.Values, str interface{}) { w := httptest.NewRecorder() @@ -77,3 +81,17 @@ func Getwd(t *testing.T) string { return wd } + +// SetEnvironmentVariable for testing, restoring the original value on t.Cleanup +func SetEnvironmentVariable(t *testing.T, key, value string) { + t.Helper() + + orig := os.Getenv(key) + + err := os.Setenv(key, value) + require.NoError(t, err) + + t.Cleanup(func() { + os.Setenv(FFEnableRateLimiter, orig) + }) +} diff --git a/metrics/metrics.go b/metrics/metrics.go index b4ac9415..23962dc4 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -184,6 +184,34 @@ var ( Help: "The number of backlogged connections waiting on concurrency limit.", }, ) + + // RateLimitSourceIPCacheRequests is the number of cache hits/misses + RateLimitSourceIPCacheRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_pages_rate_limit_source_ip_cache_requests", + Help: "The number of source_ip cache hits/misses in the rate limiter", + }, + []string{"op", "cache"}, + ) + + // RateLimitSourceIPCachedEntries is the number of entries in the cache + RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_source_ip_cached_entries", + Help: "The number of entries in the cache", + }, + []string{"op"}, + ) + + // RateLimitSourceIPBlockedCount is the number of source IPs that have been blocked by the + // source IP rate limiter + RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_source_ip_blocked_count", + Help: "The number of source IP addresses that have been blocked by the rate limiter", + }, + []string{"enforced"}, + ) ) // MustRegister collectors with the Prometheus client @@ -211,5 +239,8 @@ func MustRegister() { LimitListenerMaxConns, LimitListenerConcurrentConns, LimitListenerWaitingConns, + RateLimitSourceIPCacheRequests, + RateLimitSourceIPCachedEntries, + RateLimitSourceIPBlockedCount, ) } diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go index e2e1c1d0..4abbe33d 100644 --- a/test/acceptance/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -399,6 +399,21 @@ func GetCompressedPageFromListener(t *testing.T, spec ListenSpec, host, urlsuffi return DoPagesRequest(t, spec, req) } +func GetPageFromListenerWithHeaders(t *testing.T, spec ListenSpec, host, urlSuffix string, header http.Header) (*http.Response, error) { + t.Helper() + + url := spec.URL(urlSuffix) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + req.Host = host + req.Header = header + + return DoPagesRequest(t, spec, req) +} + func GetProxiedPageFromListener(t *testing.T, spec ListenSpec, host, xForwardedHost, urlsuffix string) (*http.Response, error) { url := spec.URL(urlsuffix) req, err := http.NewRequest("GET", url, nil) diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go new file mode 100644 index 00000000..2986d46b --- /dev/null +++ b/test/acceptance/ratelimiter_test.go @@ -0,0 +1,128 @@ +package acceptance_test + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" +) + +func TestSourceIPRateLimitMiddleware(t *testing.T) { + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true") + + tcs := map[string]struct { + listener ListenSpec + rateLimit float64 + rateBurst string + blockedIP string + header http.Header + expectFail bool + sleep time.Duration + }{ + "http_slow_requests_should_not_be_blocked": { + listener: httpListener, + rateLimit: 1000, + // RunPagesProcess makes one request, so we need to allow a burst of 2 + // because r.RemoteAddr == 127.0.0.1 and X-Forwarded-For is ignored for non-proxy requests + rateBurst: "2", + sleep: 10 * time.Millisecond, + }, + "https_slow_requests_should_not_be_blocked": { + listener: httpsListener, + rateLimit: 1000, + rateBurst: "2", + sleep: 10 * time.Millisecond, + }, + "proxy_slow_requests_should_not_be_blocked": { + listener: proxyListener, + rateLimit: 1000, + // listen-proxy uses X-Forwarded-For + rateBurst: "1", + header: http.Header{ + "X-Forwarded-For": []string{"172.16.123.1"}, + "X-Forwarded-Host": []string{"group.gitlab-example.com"}, + }, + sleep: 10 * time.Millisecond, + }, + "proxyv2_slow_requests_should_not_be_blocked": { + listener: httpsProxyv2Listener, + rateLimit: 1000, + rateBurst: "2", + sleep: 10 * time.Millisecond, + }, + "http_fast_requests_blocked_after_burst": { + listener: httpListener, + rateLimit: 1, + rateBurst: "2", + expectFail: true, + blockedIP: "127.0.0.1", + }, + "https_fast_requests_blocked_after_burst": { + listener: httpsListener, + rateLimit: 1, + rateBurst: "2", + expectFail: true, + blockedIP: "127.0.0.1", + }, + "proxy_fast_requests_blocked_after_burst": { + listener: proxyListener, + rateLimit: 1, + rateBurst: "1", + header: http.Header{ + "X-Forwarded-For": []string{"172.16.123.1"}, + "X-Forwarded-Host": []string{"group.gitlab-example.com"}, + }, + expectFail: true, + blockedIP: "172.16.123.1", + }, + "proxyv2_fast_requests_blocked_after_burst": { + listener: httpsProxyv2Listener, + rateLimit: 1, + rateBurst: "2", + expectFail: true, + // use TestProxyv2Client SourceIP + blockedIP: "10.1.1.1", + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + logBuf := RunPagesProcess(t, + withListeners([]ListenSpec{tc.listener}), + withExtraArgument("rate-limit-source-ip", fmt.Sprint(tc.rateLimit)), + withExtraArgument("rate-limit-source-ip-burst", tc.rateBurst), + ) + + for i := 0; i < 5; i++ { + rsp, err := GetPageFromListenerWithHeaders(t, tc.listener, "group.gitlab-example.com", "project/", tc.header) + require.NoError(t, err) + rsp.Body.Close() + + if tc.expectFail && i >= int(tc.rateLimit) { + require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i) + assertLogFound(t, logBuf, []string{"source IP hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""}) + continue + } + + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) + time.Sleep(tc.sleep) + } + }) + } +} + +func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) { + t.Helper() + + // give the process enough time to write the log message + require.Eventually(t, func() bool { + for _, e := range expectedLogs { + require.Contains(t, logBuf.String(), e, "log mismatch") + } + return true + }, 100*time.Millisecond, 10*time.Millisecond) +} |