diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-29 07:44:23 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-30 06:38:48 +0300 |
commit | 3b9adb4067e88417652edc0a08f596d5196a8912 (patch) | |
tree | 45a38335a59a08dbb720568edb9b216b2d7a5210 | |
parent | db4dd9522ba7e20713cab59664b85412d5a03cbb (diff) |
feat: source IP middleware for rate limiting
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 55 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 88 |
3 files changed, 110 insertions, 34 deletions
@@ -16,6 +16,7 @@ require ( github.com/pires/go-proxyproto v0.2.0 github.com/prometheus/client_golang v1.6.0 github.com/rs/cors v1.7.0 + github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.6.1 github.com/tj/assert v0.0.3 // indirect diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index b3963c6f..39292dbb 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -5,47 +5,52 @@ import ( "net/http" "os" + "github.com/sebest/xff" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" ) -// DomainRateLimiter middleware ensures that the requested domain can be served by the current +// SourceIPLimiter middleware ensures that the originating // rate limit. See -rate-limiter -func DomainRateLimiter(rl *RateLimiter) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := getHost(r) - - if !rl.DomainAllowed(host) { - logging.LogRequest(r).WithFields(logrus.Fields{ - "handler": "domain_rate_limiter", - "pages_domain": host, - "rate_limiter_enabled": rateLimiterEnabled(), - "rate_limiter_frequency": rl.perDomainFrequency, - "rate_limiter_burst_size": rl.perDomainBurstSize, - }).Info("domain hit rate limit") - - // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled - if rateLimiterEnabled() { - httperrors.Serve429(w) - return - } +func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, ip, https := getReqDetails(r) + + // http requests do not contain real IP information yet + if !rl.SourceIPAllowed(ip) && https { + logging.LogRequest(r).WithFields(logrus.Fields{ + "handler": "source_ip_rate_limiter", + "pages_domain": host, + "pages_https": https, + "source_ip": ip, + "rate_limiter_enabled": rateLimiterEnabled(), + "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond, + "rate_limiter_burst_size": rl.sourceIPBurstSize, + }).Info("source IP hit rate limit") + + // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled + if rateLimiterEnabled() { + httperrors.Serve429(w) + return } + } - handler.ServeHTTP(w, r) - }) - } + handler.ServeHTTP(w, r) + }) } -func getHost(r *http.Request) string { +func getReqDetails(r *http.Request) (string, string, bool) { host, _, err := net.SplitHostPort(r.Host) if err != nil { host = r.Host } - return host + https := r.URL.Scheme == "https" + ip := xff.GetRemoteAddr(r) + + return host, ip, https } func rateLimiterEnabled() bool { diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index b15698fb..ef9b23b3 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -10,10 +10,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestDomainRateLimiter(t *testing.T) { - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) +var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +}) + +func TestSourceIPLimiter(t *testing.T) { enableRateLimiter(t) @@ -21,19 +22,21 @@ func TestDomainRateLimiter(t *testing.T) { t.Run(tn, func(t *testing.T) { rl := New( WithNow(mockNow), - WithPerDomainFrequency(tc.domainRate), - WithPerDomainBurstSize(tc.perDomainBurstPerSecond), + WithSourceIPLimitPerSecond(tc.sourceIPLimit), + WithSourceIPBurstSize(tc.sourceIPBurstSize), ) for i := 0; i < tc.reqNum; i++ { ww := httptest.NewRecorder() - rr := httptest.NewRequest(http.MethodGet, "http://domain.gitlab.io", nil) - handler := DomainRateLimiter(rl)(next) + rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) + rr.RemoteAddr = "172.16.123.1" + + handler := rl.SourceIPLimiter(next) handler.ServeHTTP(ww, rr) res := ww.Result() - if i < tc.perDomainBurstPerSecond { + if i < tc.sourceIPBurstSize { require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i) } else { // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow @@ -50,6 +53,73 @@ func TestDomainRateLimiter(t *testing.T) { } } +func TestSourceIPRateLimit(t *testing.T) { + rl := New( + WithNow(mockNow), + WithSourceIPLimitPerSecond(1), + WithSourceIPBurstSize(1), + ) + + tcs := map[string]struct { + enabled bool + ip string + host string + expectedStatus int + }{ + "disabled_rate_limit_http": { + enabled: false, + ip: "172.16.123.1", + host: "http://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "disabled_rate_limit_https": { + enabled: false, + ip: "172.16.123.2", + host: "https://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "enabled_rate_limit_http_does_not_block": { + enabled: true, + ip: "172.16.123.3", + host: "http://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "enabled_rate_limit_https_blocks": { + enabled: true, + ip: "172.16.123.4", + host: "https://gitlab.com", + expectedStatus: http.StatusTooManyRequests, + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + + for i := 0; i < 5; i++ { + ww := httptest.NewRecorder() + rr := httptest.NewRequest(http.MethodGet, tc.host, nil) + rr.RemoteAddr = tc.ip + + if tc.enabled { + enableRateLimiter(t) + } + + handler := rl.SourceIPLimiter(next) + + handler.ServeHTTP(ww, rr) + res := ww.Result() + + if i == 0 { + require.Equal(t, http.StatusNoContent, res.StatusCode) + continue + } + // burst is 1 and limit is 1 per second, all subsequent requests should fail + require.Equal(t, tc.expectedStatus, res.StatusCode) + } + }) + } +} + func enableRateLimiter(t *testing.T) { t.Helper() |