Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ratelimiter/middleware_test.go')
-rw-r--r--internal/ratelimiter/middleware_test.go97
1 files changed, 96 insertions, 1 deletions
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index 79f63a16..4d3e01be 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -10,6 +10,7 @@ import (
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
"gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
@@ -39,6 +40,99 @@ func TestMiddlewareWithDifferentLimits(t *testing.T) {
for i := 0; i < tc.reqNum; i++ {
r := requestFor(remoteAddr, "http://gitlab.com")
+ code, body, _ := testhelpers.PerformRequest(t, handler, r)
+
+ if i < tc.burstSize {
+ require.Equal(t, http.StatusNoContent, code, "req: %d failed", i)
+ } else {
+ // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow
+ // always returns the same time
+ require.Equal(t, http.StatusTooManyRequests, code, "req: %d failed", i)
+ require.Contains(t, body, "Too many requests.")
+ assertSourceIPLog(t, remoteAddr, hook)
+ }
+ }
+ })
+ }
+}
+
+func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ tcs := map[string]struct {
+ expectedStatus int
+ }{
+ "enabled_rate_limit_http_blocks": {
+ expectedStatus: http.StatusTooManyRequests,
+ },
+ }
+
+ for tn, tc := range tcs {
+ t.Run(tn, func(t *testing.T) {
+ rl := New(
+ "rate_limiter",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
+ WithBlockedCountMetric(blocked),
+ WithNow(mockNow),
+ WithLimitPerSecond(1),
+ WithBurstSize(1),
+ WithCloseConnection(true),
+ )
+
+ // middleware is evaluated in reverse order
+ handler := rl.Middleware(next)
+
+ for i := 0; i < 5; i++ {
+ r := requestFor(remoteAddr, "http://gitlab.com")
+ code, _, _ := testhelpers.PerformRequest(t, handler, r)
+
+ if i == 0 {
+ require.Equal(t, http.StatusNoContent, code)
+ continue
+ }
+
+ // burst is 1 and limit is 1 per second, all subsequent requests should fail
+ require.Equal(t, tc.expectedStatus, code)
+ assertSourceIPLog(t, remoteAddr, hook)
+ }
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter"))
+ require.Equal(t, float64(4), blockedCount, "blocked count")
+ blocked.Reset()
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("rate_limiter"))
+ require.Equal(t, float64(1), cachedCount, "cached count")
+ cachedEntries.Reset()
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "miss"))
+ require.Equal(t, float64(1), cacheReqMiss, "miss count")
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "hit"))
+ require.Equal(t, float64(4), cacheReqHit, "hit count")
+ cacheReqs.Reset()
+ })
+ }
+}
+
+func TestMiddlewareWithDifferentLimitsWithFFCloseConnectionEnabled(t *testing.T) {
+ hook := testlog.NewGlobal()
+ t.Setenv(feature.RateLimiterCloseConnection.EnvVariable, "true")
+
+ for tn, tc := range sharedTestCases {
+ t.Run(tn, func(t *testing.T) {
+ rl := New(
+ "rate_limiter",
+ WithNow(mockNow),
+ WithLimitPerSecond(tc.limit),
+ WithBurstSize(tc.burstSize),
+ WithCloseConnection(true),
+ )
+
+ handler := rl.Middleware(next)
+
+ for i := 0; i < tc.reqNum; i++ {
+ r := requestFor(remoteAddr, "http://gitlab.com")
code, body, header := testhelpers.PerformRequest(t, handler, r)
if i < tc.burstSize {
@@ -56,8 +150,9 @@ func TestMiddlewareWithDifferentLimits(t *testing.T) {
}
}
-func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) {
+func TestMiddlewareDenyRequestsAfterBurstWithFFCloseConnectionEnabled(t *testing.T) {
hook := testlog.NewGlobal()
+ t.Setenv(feature.RateLimiterCloseConnection.EnvVariable, "true")
blocked, cachedEntries, cacheReqs := newTestMetrics(t)
tcs := map[string]struct {