Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app.go13
-rw-r--r--internal/config/config.go11
-rw-r--r--internal/config/flags.go2
-rw-r--r--internal/httperrors/httperrors.go13
-rw-r--r--internal/ratelimiter/middleware.go66
-rw-r--r--internal/ratelimiter/middleware_test.go149
-rw-r--r--internal/ratelimiter/ratelimiter.go9
-rw-r--r--internal/ratelimiter/ratelimiter_test.go85
-rw-r--r--internal/request/request.go10
-rw-r--r--internal/testhelpers/testhelpers.go18
-rw-r--r--metrics/metrics.go31
-rw-r--r--test/acceptance/helpers_test.go15
-rw-r--r--test/acceptance/ratelimiter_test.go128
13 files changed, 516 insertions, 34 deletions
diff --git a/app.go b/app.go
index 23e8a3cd..b97ec0bc 100644
--- a/app.go
+++ b/app.go
@@ -32,6 +32,7 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/logging"
"gitlab.com/gitlab-org/gitlab-pages/internal/netutil"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter"
"gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
"gitlab.com/gitlab-org/gitlab-pages/internal/routing"
@@ -262,6 +263,18 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
handler = routing.NewMiddleware(handler, a.source)
+ if a.config.RateLimit.SourceIPLimitPerSecond > 0 {
+ rl := ratelimiter.New(
+ metrics.RateLimitSourceIPBlockedCount,
+ metrics.RateLimitSourceIPCachedEntries,
+ metrics.RateLimitSourceIPCacheRequests,
+ ratelimiter.WithSourceIPLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond),
+ ratelimiter.WithSourceIPBurstSize(a.config.RateLimit.SourceIPBurst),
+ )
+
+ handler = rl.SourceIPLimiter(handler)
+ }
+
// Health Check
handler, err = a.healthCheckMiddleware(handler)
if err != nil {
diff --git a/internal/config/config.go b/internal/config/config.go
index 94c22328..3e03f7d6 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -16,6 +16,7 @@ import (
// Config stores all the config options relevant to GitLab Pages.
type Config struct {
General General
+ RateLimit RateLimit
ArtifactsServer ArtifactsServer
Authentication Auth
GitLab GitLab
@@ -62,6 +63,12 @@ type General struct {
CustomHeaders []string
}
+// RateLimit config struct
+type RateLimit struct {
+ SourceIPLimitPerSecond float64
+ SourceIPBurst int
+}
+
// ArtifactsServer groups settings related to configuring Artifacts
// server
type ArtifactsServer struct {
@@ -184,6 +191,10 @@ func loadConfig() (*Config, error) {
CustomHeaders: header.Split(),
ShowVersion: *showVersion,
},
+ RateLimit: RateLimit{
+ SourceIPLimitPerSecond: *rateLimitSourceIP,
+ SourceIPBurst: *rateLimitSourceIPBurst,
+ },
GitLab: GitLab{
ClientHTTPTimeout: *gitlabClientHTTPTimeout,
JWTTokenExpiration: *gitlabClientJWTExpiry,
diff --git a/internal/config/flags.go b/internal/config/flags.go
index 52b7be18..c61447c7 100644
--- a/internal/config/flags.go
+++ b/internal/config/flags.go
@@ -15,6 +15,8 @@ var (
_ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
+ rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled")
+ rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second")
artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'")
artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server")
pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status")
diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go
index ed56ee10..8e61d590 100644
--- a/internal/httperrors/httperrors.go
+++ b/internal/httperrors/httperrors.go
@@ -34,6 +34,14 @@ var (
<p>Make sure the address is correct and that the page hasn't moved.</p>
<p>Please contact your GitLab administrator if you think this is a mistake.</p>`,
}
+
+ content429 = content{
+ http.StatusTooManyRequests,
+ "Too many requests (429)",
+ "429",
+ "Too many requests.",
+ `<p>The resource that you are attempting to access is being rate limited.</p>`,
+ }
content500 = content{
http.StatusInternalServerError,
"Something went wrong (500)",
@@ -176,6 +184,11 @@ func Serve404(w http.ResponseWriter) {
serveErrorPage(w, content404)
}
+// Serve429 returns a 429 error response / HTML page to the http.ResponseWriter
+func Serve429(w http.ResponseWriter) {
+ serveErrorPage(w, content429)
+}
+
// Serve500 returns a 500 error response / HTML page to the http.ResponseWriter
func Serve500(w http.ResponseWriter) {
serveErrorPage(w, content500)
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
new file mode 100644
index 00000000..0cd5b81e
--- /dev/null
+++ b/internal/ratelimiter/middleware.go
@@ -0,0 +1,66 @@
+package ratelimiter
+
+import (
+ "net/http"
+ "os"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/labkit/correlation"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/request"
+)
+
+const (
+ headerGitLabRealIP = "GitLab-Real-IP"
+ headerXForwardedFor = "X-Forwarded-For"
+ headerXForwardedProto = "X-Forwarded-Proto"
+)
+
+// SourceIPLimiter returns middleware for rate-limiting clients based on their IP
+func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host, sourceIP := request.GetHostWithoutPort(r), request.GetRemoteAddrWithoutPort(r)
+ if !rl.SourceIPAllowed(sourceIP) {
+ rl.logSourceIP(r, host, sourceIP)
+
+ // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled
+ // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+ if rateLimiterEnabled() {
+ rl.sourceIPBlockedCount.WithLabelValues("true").Inc()
+ httperrors.Serve429(w)
+ return
+ }
+
+ rl.sourceIPBlockedCount.WithLabelValues("false").Inc()
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) {
+ log.WithFields(logrus.Fields{
+ "handler": "source_ip_rate_limiter",
+ "correlation_id": correlation.ExtractFromContext(r.Context()),
+ "req_scheme": r.URL.Scheme,
+ "req_host": r.Host,
+ "req_path": r.URL.Path,
+ "pages_domain": host,
+ "remote_addr": r.RemoteAddr,
+ "source_ip": sourceIP,
+ "x_forwarded_proto": r.Header.Get(headerXForwardedProto),
+ "x_forwarded_for": r.Header.Get(headerXForwardedFor),
+ "gitlab_real_ip": r.Header.Get(headerGitLabRealIP),
+ "rate_limiter_enabled": rateLimiterEnabled(),
+ "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond,
+ "rate_limiter_burst_size": rl.sourceIPBurstSize,
+ }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+ Info("source IP hit rate limit")
+}
+
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+func rateLimiterEnabled() bool {
+ return os.Getenv("FF_ENABLE_RATE_LIMITER") == "true"
+}
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
new file mode 100644
index 00000000..2e51fcad
--- /dev/null
+++ b/internal/ratelimiter/middleware_test.go
@@ -0,0 +1,149 @@
+package ratelimiter
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ testlog "github.com/sirupsen/logrus/hooks/test"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
+)
+
+const (
+ remoteAddr = "192.168.1.1"
+)
+
+var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNoContent)
+})
+
+func TestSourceIPLimiterWithDifferentLimits(t *testing.T) {
+ hook := testlog.NewGlobal()
+ testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true")
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ for tn, tc := range sharedTestCases {
+ t.Run(tn, func(t *testing.T) {
+ rl := New(blocked, cachedEntries, cacheReqs,
+ WithNow(mockNow),
+ WithSourceIPLimitPerSecond(tc.sourceIPLimit),
+ WithSourceIPBurstSize(tc.sourceIPBurstSize),
+ )
+
+ for i := 0; i < tc.reqNum; i++ {
+ ww := httptest.NewRecorder()
+ rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ rr.RemoteAddr = remoteAddr
+
+ handler := rl.SourceIPLimiter(next)
+
+ handler.ServeHTTP(ww, rr)
+ res := ww.Result()
+
+ if i < tc.sourceIPBurstSize {
+ require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i)
+ } else {
+ // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow
+ // always returns the same time
+ require.Equal(t, http.StatusTooManyRequests, res.StatusCode, "req: %d failed", i)
+ b, err := io.ReadAll(res.Body)
+ require.NoError(t, err)
+
+ require.Contains(t, string(b), "Too many requests.")
+ res.Body.Close()
+
+ assertSourceIPLog(t, remoteAddr, hook)
+ }
+ }
+ })
+ }
+}
+
+func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ tcs := map[string]struct {
+ enabled bool
+ expectedStatus int
+ }{
+ "disabled_rate_limit_http": {
+ enabled: false,
+ expectedStatus: http.StatusNoContent,
+ },
+ "enabled_rate_limit_http_blocks": {
+ enabled: true,
+ expectedStatus: http.StatusTooManyRequests,
+ },
+ }
+
+ for tn, tc := range tcs {
+ t.Run(tn, func(t *testing.T) {
+ rl := New(blocked, cachedEntries, cacheReqs,
+ WithNow(mockNow),
+ WithSourceIPLimitPerSecond(1),
+ WithSourceIPBurstSize(1),
+ )
+
+ for i := 0; i < 5; i++ {
+ ww := httptest.NewRecorder()
+ rr := httptest.NewRequest(http.MethodGet, "http://gitlab.com", nil)
+ if tc.enabled {
+ testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true")
+ } else {
+ testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "false")
+ }
+
+ rr.RemoteAddr = remoteAddr
+
+ // middleware is evaluated in reverse order
+ handler := rl.SourceIPLimiter(next)
+
+ handler.ServeHTTP(ww, rr)
+ res := ww.Result()
+
+ if i == 0 {
+ require.Equal(t, http.StatusNoContent, res.StatusCode)
+ continue
+ }
+
+ // burst is 1 and limit is 1 per second, all subsequent requests should fail
+ require.Equal(t, tc.expectedStatus, res.StatusCode)
+ assertSourceIPLog(t, remoteAddr, hook)
+ }
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("true"))
+ if tc.enabled {
+ require.Equal(t, float64(4), blockedCount, "blocked count")
+ } else {
+ require.Equal(t, float64(0), blockedCount, "blocked count")
+ }
+ blocked.Reset()
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("source_ip"))
+ require.Equal(t, float64(1), cachedCount, "cached count")
+ cachedEntries.Reset()
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "miss"))
+ require.Equal(t, float64(1), cacheReqMiss, "miss count")
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "hit"))
+ require.Equal(t, float64(4), cacheReqHit, "hit count")
+ cacheReqs.Reset()
+ })
+ }
+}
+
+func assertSourceIPLog(t *testing.T, remoteAddr string, hook *testlog.Hook) {
+ t.Helper()
+
+ require.NotNil(t, hook.LastEntry())
+
+ // source_ip that was rate limited
+ require.Equal(t, remoteAddr, hook.LastEntry().Data["source_ip"])
+
+ hook.Reset()
+}
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index e1cf076d..1359b19c 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -36,23 +36,24 @@ type RateLimiter struct {
now func() time.Time
sourceIPLimitPerSecond float64
sourceIPBurstSize int
+ sourceIPBlockedCount *prometheus.GaugeVec
sourceIPCache *lru.Cache
// TODO: add domainCache https://gitlab.com/gitlab-org/gitlab-pages/-/issues/630
}
// New creates a new RateLimiter with default values that can be configured via Option functions
-func New(opts ...Option) *RateLimiter {
+func New(blockCountMetric, cachedEntriesMetric *prometheus.GaugeVec, cacheRequestsMetric *prometheus.CounterVec, opts ...Option) *RateLimiter {
rl := &RateLimiter{
now: time.Now,
sourceIPLimitPerSecond: DefaultSourceIPLimitPerSecond,
sourceIPBurstSize: DefaultSourceIPBurstSize,
+ sourceIPBlockedCount: blockCountMetric,
sourceIPCache: lru.New(
"source_ip",
defaultSourceIPItems,
defaultSourceIPExpirationInterval,
- // TODO: @jaime to add proper metrics in subsequent MR
- prometheus.NewGaugeVec(prometheus.GaugeOpts{}, []string{"op"}),
- prometheus.NewCounterVec(prometheus.CounterOpts{}, []string{"op", "cache"}),
+ cachedEntriesMetric,
+ cacheRequestsMetric,
),
}
diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go
index cdf12fe6..03393fb0 100644
--- a/internal/ratelimiter/ratelimiter_test.go
+++ b/internal/ratelimiter/ratelimiter_test.go
@@ -5,6 +5,7 @@ import (
"testing"
"time"
+ "github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
)
@@ -17,40 +18,41 @@ func mockNow() time.Time {
return validTime
}
+var sharedTestCases = map[string]struct {
+ sourceIPLimit float64
+ sourceIPBurstSize int
+ reqNum int
+}{
+ "one_request_per_second": {
+ sourceIPLimit: 1,
+ sourceIPBurstSize: 1,
+ reqNum: 2,
+ },
+ "one_request_per_second_but_big_bucket": {
+ sourceIPLimit: 1,
+ sourceIPBurstSize: 10,
+ reqNum: 11,
+ },
+ "three_req_per_second_bucket_size_one": {
+ sourceIPLimit: 3,
+ sourceIPBurstSize: 1, // max burst 1 means 1 at a time
+ reqNum: 3,
+ },
+ "10_requests_per_second": {
+ sourceIPLimit: 10,
+ sourceIPBurstSize: 10,
+ reqNum: 11,
+ },
+}
+
func TestSourceIPAllowed(t *testing.T) {
t.Parallel()
- tcs := map[string]struct {
- now string
- sourceIPLimit float64
- sourceIPBurstSize int
- reqNum int
- }{
- "one_request_per_second": {
- sourceIPLimit: 1,
- sourceIPBurstSize: 1,
- reqNum: 2,
- },
- "one_request_per_second_but_big_bucket": {
- sourceIPLimit: 1,
- sourceIPBurstSize: 10,
- reqNum: 11,
- },
- "three_req_per_second_bucket_size_one": {
- sourceIPLimit: 3,
- sourceIPBurstSize: 1, // max burst 1 means 1 at a time
- reqNum: 3,
- },
- "10_requests_per_second": {
- sourceIPLimit: 10,
- sourceIPBurstSize: 10,
- reqNum: 11,
- },
- }
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
- for tn, tc := range tcs {
+ for tn, tc := range sharedTestCases {
t.Run(tn, func(t *testing.T) {
- rl := New(
+ rl := New(blocked, cachedEntries, cacheReqs,
WithNow(mockNow),
WithSourceIPLimitPerSecond(tc.sourceIPLimit),
WithSourceIPBurstSize(tc.sourceIPBurstSize),
@@ -72,7 +74,9 @@ func TestSourceIPAllowed(t *testing.T) {
func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) {
rate := 10 * time.Millisecond
- rl := New(
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ rl := New(blocked, cachedEntries, cacheReqs,
WithSourceIPLimitPerSecond(float64(1/rate)),
WithSourceIPBurstSize(1),
)
@@ -105,3 +109,24 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) {
wg.Wait()
}
+
+func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *prometheus.CounterVec) {
+ t.Helper()
+
+ blockedGauge := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: t.Name(),
+ },
+ []string{"enforced"},
+ )
+
+ cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{
+ Name: t.Name(),
+ }, []string{"op"})
+
+ cacheReqs := prometheus.NewCounterVec(prometheus.CounterOpts{
+ Name: t.Name(),
+ }, []string{"op", "cache"})
+
+ return blockedGauge, cachedEntries, cacheReqs
+}
diff --git a/internal/request/request.go b/internal/request/request.go
index cbda16e5..77cc4a76 100644
--- a/internal/request/request.go
+++ b/internal/request/request.go
@@ -55,3 +55,13 @@ func GetHostWithoutPort(r *http.Request) string {
return host
}
+
+// GetRemoteAddrWithoutPort strips the port from the r.RemoteAddr
+func GetRemoteAddrWithoutPort(r *http.Request) string {
+ remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr)
+ if err != nil {
+ return r.RemoteAddr
+ }
+
+ return remoteAddr
+}
diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go
index 3ec97a79..de48cd7a 100644
--- a/internal/testhelpers/testhelpers.go
+++ b/internal/testhelpers/testhelpers.go
@@ -13,6 +13,10 @@ import (
"github.com/stretchr/testify/require"
)
+// FFEnableRateLimiter enforces ratelimiter package to drop requests
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+const FFEnableRateLimiter = "FF_ENABLE_RATE_LIMITER"
+
// AssertHTTP404 asserts handler returns 404 with provided str body
func AssertHTTP404(t *testing.T, handler http.HandlerFunc, mode, url string, values url.Values, str interface{}) {
w := httptest.NewRecorder()
@@ -77,3 +81,17 @@ func Getwd(t *testing.T) string {
return wd
}
+
+// SetEnvironmentVariable for testing, restoring the original value on t.Cleanup
+func SetEnvironmentVariable(t *testing.T, key, value string) {
+ t.Helper()
+
+ orig := os.Getenv(key)
+
+ err := os.Setenv(key, value)
+ require.NoError(t, err)
+
+ t.Cleanup(func() {
+ os.Setenv(FFEnableRateLimiter, orig)
+ })
+}
diff --git a/metrics/metrics.go b/metrics/metrics.go
index b4ac9415..23962dc4 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -184,6 +184,34 @@ var (
Help: "The number of backlogged connections waiting on concurrency limit.",
},
)
+
+ // RateLimitSourceIPCacheRequests is the number of cache hits/misses
+ RateLimitSourceIPCacheRequests = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_pages_rate_limit_source_ip_cache_requests",
+ Help: "The number of source_ip cache hits/misses in the rate limiter",
+ },
+ []string{"op", "cache"},
+ )
+
+ // RateLimitSourceIPCachedEntries is the number of entries in the cache
+ RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "gitlab_pages_rate_limit_source_ip_cached_entries",
+ Help: "The number of entries in the cache",
+ },
+ []string{"op"},
+ )
+
+ // RateLimitSourceIPBlockedCount is the number of source IPs that have been blocked by the
+ // source IP rate limiter
+ RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "gitlab_pages_rate_limit_source_ip_blocked_count",
+ Help: "The number of source IP addresses that have been blocked by the rate limiter",
+ },
+ []string{"enforced"},
+ )
)
// MustRegister collectors with the Prometheus client
@@ -211,5 +239,8 @@ func MustRegister() {
LimitListenerMaxConns,
LimitListenerConcurrentConns,
LimitListenerWaitingConns,
+ RateLimitSourceIPCacheRequests,
+ RateLimitSourceIPCachedEntries,
+ RateLimitSourceIPBlockedCount,
)
}
diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go
index e2e1c1d0..4abbe33d 100644
--- a/test/acceptance/helpers_test.go
+++ b/test/acceptance/helpers_test.go
@@ -399,6 +399,21 @@ func GetCompressedPageFromListener(t *testing.T, spec ListenSpec, host, urlsuffi
return DoPagesRequest(t, spec, req)
}
+func GetPageFromListenerWithHeaders(t *testing.T, spec ListenSpec, host, urlSuffix string, header http.Header) (*http.Response, error) {
+ t.Helper()
+
+ url := spec.URL(urlSuffix)
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Host = host
+ req.Header = header
+
+ return DoPagesRequest(t, spec, req)
+}
+
func GetProxiedPageFromListener(t *testing.T, spec ListenSpec, host, xForwardedHost, urlsuffix string) (*http.Response, error) {
url := spec.URL(urlsuffix)
req, err := http.NewRequest("GET", url, nil)
diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go
new file mode 100644
index 00000000..2986d46b
--- /dev/null
+++ b/test/acceptance/ratelimiter_test.go
@@ -0,0 +1,128 @@
+package acceptance_test
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
+)
+
+func TestSourceIPRateLimitMiddleware(t *testing.T) {
+ testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true")
+
+ tcs := map[string]struct {
+ listener ListenSpec
+ rateLimit float64
+ rateBurst string
+ blockedIP string
+ header http.Header
+ expectFail bool
+ sleep time.Duration
+ }{
+ "http_slow_requests_should_not_be_blocked": {
+ listener: httpListener,
+ rateLimit: 1000,
+ // RunPagesProcess makes one request, so we need to allow a burst of 2
+ // because r.RemoteAddr == 127.0.0.1 and X-Forwarded-For is ignored for non-proxy requests
+ rateBurst: "2",
+ sleep: 10 * time.Millisecond,
+ },
+ "https_slow_requests_should_not_be_blocked": {
+ listener: httpsListener,
+ rateLimit: 1000,
+ rateBurst: "2",
+ sleep: 10 * time.Millisecond,
+ },
+ "proxy_slow_requests_should_not_be_blocked": {
+ listener: proxyListener,
+ rateLimit: 1000,
+ // listen-proxy uses X-Forwarded-For
+ rateBurst: "1",
+ header: http.Header{
+ "X-Forwarded-For": []string{"172.16.123.1"},
+ "X-Forwarded-Host": []string{"group.gitlab-example.com"},
+ },
+ sleep: 10 * time.Millisecond,
+ },
+ "proxyv2_slow_requests_should_not_be_blocked": {
+ listener: httpsProxyv2Listener,
+ rateLimit: 1000,
+ rateBurst: "2",
+ sleep: 10 * time.Millisecond,
+ },
+ "http_fast_requests_blocked_after_burst": {
+ listener: httpListener,
+ rateLimit: 1,
+ rateBurst: "2",
+ expectFail: true,
+ blockedIP: "127.0.0.1",
+ },
+ "https_fast_requests_blocked_after_burst": {
+ listener: httpsListener,
+ rateLimit: 1,
+ rateBurst: "2",
+ expectFail: true,
+ blockedIP: "127.0.0.1",
+ },
+ "proxy_fast_requests_blocked_after_burst": {
+ listener: proxyListener,
+ rateLimit: 1,
+ rateBurst: "1",
+ header: http.Header{
+ "X-Forwarded-For": []string{"172.16.123.1"},
+ "X-Forwarded-Host": []string{"group.gitlab-example.com"},
+ },
+ expectFail: true,
+ blockedIP: "172.16.123.1",
+ },
+ "proxyv2_fast_requests_blocked_after_burst": {
+ listener: httpsProxyv2Listener,
+ rateLimit: 1,
+ rateBurst: "2",
+ expectFail: true,
+ // use TestProxyv2Client SourceIP
+ blockedIP: "10.1.1.1",
+ },
+ }
+
+ for tn, tc := range tcs {
+ t.Run(tn, func(t *testing.T) {
+ logBuf := RunPagesProcess(t,
+ withListeners([]ListenSpec{tc.listener}),
+ withExtraArgument("rate-limit-source-ip", fmt.Sprint(tc.rateLimit)),
+ withExtraArgument("rate-limit-source-ip-burst", tc.rateBurst),
+ )
+
+ for i := 0; i < 5; i++ {
+ rsp, err := GetPageFromListenerWithHeaders(t, tc.listener, "group.gitlab-example.com", "project/", tc.header)
+ require.NoError(t, err)
+ rsp.Body.Close()
+
+ if tc.expectFail && i >= int(tc.rateLimit) {
+ require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i)
+ assertLogFound(t, logBuf, []string{"source IP hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""})
+ continue
+ }
+
+ require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i)
+ time.Sleep(tc.sleep)
+ }
+ })
+ }
+}
+
+func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) {
+ t.Helper()
+
+ // give the process enough time to write the log message
+ require.Eventually(t, func() bool {
+ for _, e := range expectedLogs {
+ require.Contains(t, logBuf.String(), e, "log mismatch")
+ }
+ return true
+ }, 100*time.Millisecond, 10*time.Millisecond)
+}