Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ratelimiter/middleware_test.go')
-rw-r--r--internal/ratelimiter/middleware_test.go88
1 files changed, 79 insertions, 9 deletions
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index b15698fb..ef9b23b3 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -10,10 +10,11 @@ import (
"github.com/stretchr/testify/require"
)
-func TestDomainRateLimiter(t *testing.T) {
- next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNoContent)
- })
+var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNoContent)
+})
+
+func TestSourceIPLimiter(t *testing.T) {
enableRateLimiter(t)
@@ -21,19 +22,21 @@ func TestDomainRateLimiter(t *testing.T) {
t.Run(tn, func(t *testing.T) {
rl := New(
WithNow(mockNow),
- WithPerDomainFrequency(tc.domainRate),
- WithPerDomainBurstSize(tc.perDomainBurstPerSecond),
+ WithSourceIPLimitPerSecond(tc.sourceIPLimit),
+ WithSourceIPBurstSize(tc.sourceIPBurstSize),
)
for i := 0; i < tc.reqNum; i++ {
ww := httptest.NewRecorder()
- rr := httptest.NewRequest(http.MethodGet, "http://domain.gitlab.io", nil)
- handler := DomainRateLimiter(rl)(next)
+ rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ rr.RemoteAddr = "172.16.123.1"
+
+ handler := rl.SourceIPLimiter(next)
handler.ServeHTTP(ww, rr)
res := ww.Result()
- if i < tc.perDomainBurstPerSecond {
+ if i < tc.sourceIPBurstSize {
require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i)
} else {
// requests should fail after reaching tc.perDomainBurstPerSecond because mockNow
@@ -50,6 +53,73 @@ func TestDomainRateLimiter(t *testing.T) {
}
}
+func TestSourceIPRateLimit(t *testing.T) {
+ rl := New(
+ WithNow(mockNow),
+ WithSourceIPLimitPerSecond(1),
+ WithSourceIPBurstSize(1),
+ )
+
+ tcs := map[string]struct {
+ enabled bool
+ ip string
+ host string
+ expectedStatus int
+ }{
+ "disabled_rate_limit_http": {
+ enabled: false,
+ ip: "172.16.123.1",
+ host: "http://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "disabled_rate_limit_https": {
+ enabled: false,
+ ip: "172.16.123.2",
+ host: "https://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "enabled_rate_limit_http_does_not_block": {
+ enabled: true,
+ ip: "172.16.123.3",
+ host: "http://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "enabled_rate_limit_https_blocks": {
+ enabled: true,
+ ip: "172.16.123.4",
+ host: "https://gitlab.com",
+ expectedStatus: http.StatusTooManyRequests,
+ },
+ }
+
+ for tn, tc := range tcs {
+ t.Run(tn, func(t *testing.T) {
+
+ for i := 0; i < 5; i++ {
+ ww := httptest.NewRecorder()
+ rr := httptest.NewRequest(http.MethodGet, tc.host, nil)
+ rr.RemoteAddr = tc.ip
+
+ if tc.enabled {
+ enableRateLimiter(t)
+ }
+
+ handler := rl.SourceIPLimiter(next)
+
+ handler.ServeHTTP(ww, rr)
+ res := ww.Result()
+
+ if i == 0 {
+ require.Equal(t, http.StatusNoContent, res.StatusCode)
+ continue
+ }
+ // burst is 1 and limit is 1 per second, all subsequent requests should fail
+ require.Equal(t, tc.expectedStatus, res.StatusCode)
+ }
+ })
+ }
+}
+
func enableRateLimiter(t *testing.T) {
t.Helper()