diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-10-13 03:35:45 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-10-14 09:01:02 +0300 |
commit | 88e1154f726aba6e3a36ad3f31cd78fa3c9313c5 (patch) | |
tree | d469637c5036c06293a478068566b7eb8c9d0fdd /internal | |
parent | d21ad4d3f334774bbbcd9a586c7bdfd32a0ae804 (diff) |
refactor: remove WithProxied setting
Diffstat (limited to 'internal')
-rw-r--r-- | internal/logging/logging.go | 1 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 28 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 64 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 7 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 25 | ||||
-rw-r--r-- | internal/request/request.go | 10 | ||||
-rw-r--r-- | internal/testhelpers/testhelpers.go | 27 |
7 files changed, 28 insertions, 134 deletions
diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 27edf865..4ffbeb4b 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -85,7 +85,6 @@ func BasicAccessLogger(handler http.Handler, format string, extraFields log.Extr return log.AccessLogger(handler, log.WithExtraFields(enrichExtraFields(extraFields)), log.WithAccessLogger(accessLogger), - // TODO: log IP for HTTP requests https://gitlab.com/gitlab-org/gitlab-pages/-/issues/640 log.WithXFFAllowed(func(sip string) bool { return false }), ), nil } diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index 5db004b7..0cd5b81e 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -1,16 +1,15 @@ package ratelimiter import ( - "net" "net/http" "os" - "github.com/sebest/xff" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) const ( @@ -22,7 +21,7 @@ const ( // SourceIPLimiter returns middleware for rate-limiting clients based on their IP func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host, sourceIP := rl.getReqDetails(r) + host, sourceIP := request.GetHostWithoutPort(r), request.GetRemoteAddrWithoutPort(r) if !rl.SourceIPAllowed(sourceIP) { rl.logSourceIP(r, host, sourceIP) @@ -41,28 +40,6 @@ func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler { }) } -func (rl *RateLimiter) getReqDetails(r *http.Request) (string, string) { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host - } - - // TODO: consider using X-Real-IP https://gitlab.com/gitlab-org/gitlab-pages/-/issues/644 - // choose between r.RemoteAddr and X-Forwarded-For. Only uses XFF when proxied - remoteAddr := xff.GetRemoteAddrIfAllowed(r, func(sip string) bool { - // We enable github.com/gorilla/handlers.ProxyHeaders which sets r.RemoteAddr - // with the value of X-Forwarded-For when --listen-proxy is set - return rl.proxied - }) - - ip, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - ip = remoteAddr - } - - return host, ip -} - func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) { log.WithFields(logrus.Fields{ "handler": "source_ip_rate_limiter", @@ -73,7 +50,6 @@ func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) { "pages_domain": host, "remote_addr": r.RemoteAddr, "source_ip": sourceIP, - "proxied": rl.proxied, "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index a6de0d10..b0134b29 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -6,7 +6,6 @@ import ( "net/http/httptest" "testing" - ghandlers "github.com/gorilla/handlers" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" @@ -14,8 +13,7 @@ import ( ) const ( - xForwardedFor = "172.16.123.1" - remoteAddr = "192.168.1.1" + remoteAddr = "192.168.1.1" ) var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -24,7 +22,7 @@ var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { hook := testlog.NewGlobal() - testhelpers.EnableRateLimiter(t) + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true") for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { @@ -32,19 +30,14 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { WithNow(mockNow), WithSourceIPLimitPerSecond(tc.sourceIPLimit), WithSourceIPBurstSize(tc.sourceIPBurstSize), - WithProxied(tc.proxied), ) for i := 0; i < tc.reqNum; i++ { ww := httptest.NewRecorder() rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) - rr.Header.Set(headerXForwardedFor, xForwardedFor) rr.RemoteAddr = remoteAddr handler := rl.SourceIPLimiter(next) - if tc.proxied { - handler = ghandlers.ProxyHeaders(handler) - } handler.ServeHTTP(ww, rr) res := ww.Result() @@ -61,7 +54,7 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) { require.Contains(t, string(b), "Too many requests.") res.Body.Close() - assertSourceIPLog(t, tc.proxied, xForwardedFor, remoteAddr, hook) + assertSourceIPLog(t, remoteAddr, hook) } } }) @@ -73,52 +66,22 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { tcs := map[string]struct { enabled bool - proxied bool - host string expectedStatus int }{ "disabled_rate_limit_http": { enabled: false, - host: "http://gitlab.com", expectedStatus: http.StatusNoContent, }, "disabled_rate_limit_https": { enabled: false, - host: "https://gitlab.com", expectedStatus: http.StatusNoContent, }, "enabled_rate_limit_http_blocks": { enabled: true, - host: "http://gitlab.com", expectedStatus: http.StatusTooManyRequests, }, "enabled_rate_limit_https_blocks": { enabled: true, - host: "https://gitlab.com", - expectedStatus: http.StatusTooManyRequests, - }, - "disabled_rate_limit_http_proxied": { - enabled: false, - proxied: true, - host: "http://gitlab.com", - expectedStatus: http.StatusNoContent, - }, - "disabled_rate_limit_https_proxied": { - enabled: false, - proxied: true, - host: "https://gitlab.com", - expectedStatus: http.StatusNoContent, - }, - "enabled_rate_limit_http_blocks_proxied": { - enabled: true, - proxied: true, - host: "http://gitlab.com", - expectedStatus: http.StatusTooManyRequests, - }, - "enabled_rate_limit_https_blocks_proxied": { - enabled: true, - proxied: true, - host: "https://gitlab.com", expectedStatus: http.StatusTooManyRequests, }, } @@ -129,26 +92,21 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { WithNow(mockNow), WithSourceIPLimitPerSecond(1), WithSourceIPBurstSize(1), - WithProxied(tc.proxied), ) for i := 0; i < 5; i++ { ww := httptest.NewRecorder() - rr := httptest.NewRequest(http.MethodGet, tc.host, nil) + rr := httptest.NewRequest(http.MethodGet, "http://gitlab.com", nil) if tc.enabled { - testhelpers.EnableRateLimiter(t) + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "true") } else { - testhelpers.DisableRateLimiter(t) + testhelpers.SetEnvironmentVariable(t, testhelpers.FFEnableRateLimiter, "false") } - rr.Header.Set(headerXForwardedFor, xForwardedFor) rr.RemoteAddr = remoteAddr // middleware is evaluated in reverse order handler := rl.SourceIPLimiter(next) - if tc.proxied { - handler = ghandlers.ProxyHeaders(handler) - } handler.ServeHTTP(ww, rr) res := ww.Result() @@ -160,23 +118,19 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) { // burst is 1 and limit is 1 per second, all subsequent requests should fail require.Equal(t, tc.expectedStatus, res.StatusCode) - assertSourceIPLog(t, tc.proxied, xForwardedFor, remoteAddr, hook) + assertSourceIPLog(t, remoteAddr, hook) } }) } } -func assertSourceIPLog(t *testing.T, proxied bool, xForwardedFor, remoteAddr string, hook *testlog.Hook) { +func assertSourceIPLog(t *testing.T, remoteAddr string, hook *testlog.Hook) { t.Helper() require.NotNil(t, hook.LastEntry()) // source_ip that was rate limited - if proxied { - require.Equal(t, xForwardedFor, hook.LastEntry().Data["source_ip"]) - } else { - require.Equal(t, remoteAddr, hook.LastEntry().Data["source_ip"]) - } + require.Equal(t, remoteAddr, hook.LastEntry().Data["source_ip"]) hook.Reset() } diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index 36c72cde..1a622c6f 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -35,7 +35,6 @@ type Option func(*RateLimiter) // It also holds a now function that can be mocked in unit tests. type RateLimiter struct { now func() time.Time - proxied bool sourceIPLimitPerSecond float64 sourceIPBurstSize int sourceIPBlockedCount *prometheus.GaugeVec @@ -87,12 +86,6 @@ func WithSourceIPBurstSize(burst int) Option { } } -// WithProxied sets the proxy flag to true. Used by the SourceIPLimiter middleware. -func WithProxied(proxied bool) Option { - return func(rl *RateLimiter) { - rl.proxied = proxied - } -} func (rl *RateLimiter) getSourceIPLimiter(sourceIP string) *rate.Limiter { limiterI, _ := rl.sourceIPCache.FindOrFetch(sourceIP, sourceIP, func() (interface{}, error) { return rate.NewLimiter(rate.Limit(rl.sourceIPLimitPerSecond), rl.sourceIPBurstSize), nil diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index 03e764f2..77da8e81 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -21,7 +21,6 @@ var sharedTestCases = map[string]struct { sourceIPLimit float64 sourceIPBurstSize int reqNum int - proxied bool }{ "one_request_per_second": { sourceIPLimit: 1, @@ -43,30 +42,6 @@ var sharedTestCases = map[string]struct { sourceIPBurstSize: 10, reqNum: 11, }, - "one_request_per_second_proxied": { - proxied: true, - sourceIPLimit: 1, - sourceIPBurstSize: 1, - reqNum: 2, - }, - "one_request_per_second_but_big_bucket_proxied": { - proxied: true, - sourceIPLimit: 1, - sourceIPBurstSize: 10, - reqNum: 11, - }, - "three_req_per_second_bucket_size_one_proxied": { - proxied: true, - sourceIPLimit: 3, - sourceIPBurstSize: 1, // max burst 1 means 1 at a time - reqNum: 3, - }, - "10_requests_per_second_proxied": { - proxied: true, - sourceIPLimit: 10, - sourceIPBurstSize: 10, - reqNum: 11, - }, } func TestSourceIPAllowed(t *testing.T) { diff --git a/internal/request/request.go b/internal/request/request.go index cbda16e5..77cc4a76 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -55,3 +55,13 @@ func GetHostWithoutPort(r *http.Request) string { return host } + +// GetRemoteAddrWithoutPort strips the port from the r.RemoteAddr +func GetRemoteAddrWithoutPort(r *http.Request) string { + remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + + return remoteAddr +} diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go index 2f1f7c27..de48cd7a 100644 --- a/internal/testhelpers/testhelpers.go +++ b/internal/testhelpers/testhelpers.go @@ -13,8 +13,9 @@ import ( "github.com/stretchr/testify/require" ) +// FFEnableRateLimiter enforces ratelimiter package to drop requests // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 -const ffEnableRateLimiter = "FF_ENABLE_RATE_LIMITER" +const FFEnableRateLimiter = "FF_ENABLE_RATE_LIMITER" // AssertHTTP404 asserts handler returns 404 with provided str body func AssertHTTP404(t *testing.T, handler http.HandlerFunc, mode, url string, values url.Values, str interface{}) { @@ -81,30 +82,16 @@ func Getwd(t *testing.T) string { return wd } -// EnableRateLimiter environment variable -func EnableRateLimiter(t *testing.T) { +// SetEnvironmentVariable for testing, restoring the original value on t.Cleanup +func SetEnvironmentVariable(t *testing.T, key, value string) { t.Helper() - orig := os.Getenv(ffEnableRateLimiter) + orig := os.Getenv(key) - err := os.Setenv(ffEnableRateLimiter, "true") + err := os.Setenv(key, value) require.NoError(t, err) t.Cleanup(func() { - os.Setenv(ffEnableRateLimiter, orig) - }) -} - -// DisableRateLimiter environment variable -func DisableRateLimiter(t *testing.T) { - t.Helper() - - orig := os.Getenv(ffEnableRateLimiter) - - err := os.Setenv(ffEnableRateLimiter, "false") - require.NoError(t, err) - - t.Cleanup(func() { - os.Setenv(ffEnableRateLimiter, orig) + os.Setenv(FFEnableRateLimiter, orig) }) } |