diff options
Diffstat (limited to 'internal/ratelimiter/middleware_test.go')
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 88 |
1 files changed, 79 insertions, 9 deletions
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index b15698fb..ef9b23b3 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -10,10 +10,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestDomainRateLimiter(t *testing.T) { - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) +var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) +}) + +func TestSourceIPLimiter(t *testing.T) { enableRateLimiter(t) @@ -21,19 +22,21 @@ func TestDomainRateLimiter(t *testing.T) { t.Run(tn, func(t *testing.T) { rl := New( WithNow(mockNow), - WithPerDomainFrequency(tc.domainRate), - WithPerDomainBurstSize(tc.perDomainBurstPerSecond), + WithSourceIPLimitPerSecond(tc.sourceIPLimit), + WithSourceIPBurstSize(tc.sourceIPBurstSize), ) for i := 0; i < tc.reqNum; i++ { ww := httptest.NewRecorder() - rr := httptest.NewRequest(http.MethodGet, "http://domain.gitlab.io", nil) - handler := DomainRateLimiter(rl)(next) + rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil) + rr.RemoteAddr = "172.16.123.1" + + handler := rl.SourceIPLimiter(next) handler.ServeHTTP(ww, rr) res := ww.Result() - if i < tc.perDomainBurstPerSecond { + if i < tc.sourceIPBurstSize { require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i) } else { // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow @@ -50,6 +53,73 @@ func TestDomainRateLimiter(t *testing.T) { } } +func TestSourceIPRateLimit(t *testing.T) { + rl := New( + WithNow(mockNow), + WithSourceIPLimitPerSecond(1), + WithSourceIPBurstSize(1), + ) + + tcs := map[string]struct { + enabled bool + ip string + host string + expectedStatus int + }{ + "disabled_rate_limit_http": { + enabled: false, + ip: "172.16.123.1", + host: "http://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "disabled_rate_limit_https": { + enabled: false, + ip: "172.16.123.2", + host: "https://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "enabled_rate_limit_http_does_not_block": { + enabled: true, + ip: "172.16.123.3", + host: "http://gitlab.com", + expectedStatus: http.StatusNoContent, + }, + "enabled_rate_limit_https_blocks": { + enabled: true, + ip: "172.16.123.4", + host: "https://gitlab.com", + expectedStatus: http.StatusTooManyRequests, + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + + for i := 0; i < 5; i++ { + ww := httptest.NewRecorder() + rr := httptest.NewRequest(http.MethodGet, tc.host, nil) + rr.RemoteAddr = tc.ip + + if tc.enabled { + enableRateLimiter(t) + } + + handler := rl.SourceIPLimiter(next) + + handler.ServeHTTP(ww, rr) + res := ww.Result() + + if i == 0 { + require.Equal(t, http.StatusNoContent, res.StatusCode) + continue + } + // burst is 1 and limit is 1 per second, all subsequent requests should fail + require.Equal(t, tc.expectedStatus, res.StatusCode) + } + }) + } +} + func enableRateLimiter(t *testing.T) { t.Helper() |