Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Shushlin <v.shushlin@gmail.com>2021-12-09 17:17:17 +0300
committerVladimir Shushlin <v.shushlin@gmail.com>2021-12-20 14:10:35 +0300
commit91dd7bc9011640dd02b497acf9fa78bee35a8402 (patch)
treea42e4d7ca3b2d169fc8454952ff90e2075589481
parentf8512edbec4ec83b426c8ca2dda467de424685e4 (diff)
refactor: handle defaults in ratelimiter package itself
also fix tests: * float64(1/time.Milesecond) == 0 * rate package doesn't actually refill the bucket on fractional seconds, so we need to use integers
-rw-r--r--app.go2
-rw-r--r--internal/handlers/ratelimiter.go10
-rw-r--r--internal/ratelimiter/middleware.go6
-rw-r--r--internal/ratelimiter/ratelimiter.go24
-rw-r--r--internal/ratelimiter/ratelimiter_test.go48
5 files changed, 34 insertions, 56 deletions
diff --git a/app.go b/app.go
index 84eb8956..a552bfd3 100644
--- a/app.go
+++ b/app.go
@@ -258,7 +258,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
handler = routing.NewMiddleware(handler, a.source)
- handler = handlers.Ratelimiter(handler, a.config)
+ handler = handlers.Ratelimiter(handler, &a.config.RateLimit)
// Health Check
handler, err = a.healthCheckMiddleware(handler)
diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go
index 9c66c15d..8263f497 100644
--- a/internal/handlers/ratelimiter.go
+++ b/internal/handlers/ratelimiter.go
@@ -10,19 +10,15 @@ import (
// Ratelimiter configures the ratelimiter middleware
// TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done
-func Ratelimiter(handler http.Handler, config *config.Config) http.Handler {
- if config.RateLimit.SourceIPLimitPerSecond == 0 {
- return handler
- }
-
+func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
rl := ratelimiter.New(
"source_ip",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount),
- ratelimiter.WithLimitPerSecond(config.RateLimit.SourceIPLimitPerSecond),
- ratelimiter.WithBurstSize(config.RateLimit.SourceIPBurst),
+ ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond),
+ ratelimiter.WithBurstSize(config.SourceIPBurst),
)
return rl.Middleware(handler)
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index a0fccff4..155a87a6 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -20,8 +20,12 @@ const (
// Middleware returns middleware for rate-limiting clients
func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler {
+ if rl.limitPerSecond <= 0.0 {
+ return handler
+ }
+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if !rl.RequestAllowed(r) {
+ if !rl.requestAllowed(r) {
rl.logRateLimitedRequest(r)
if feature.EnforceIPRateLimits.Enabled() {
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index ab02c966..510a44ae 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -12,14 +12,6 @@ import (
)
const (
- // DefaultSourceIPLimitPerSecond is the limit per second that rate.Limiter
- // needs to generate tokens every second.
- // The default value is 20 requests per second.
- DefaultSourceIPLimitPerSecond = 20.0
- // DefaultSourceIPBurstSize is the maximum burst allowed per rate limiter.
- // E.g. The first 100 requests within 1s will succeed, but the 101st will fail.
- DefaultSourceIPBurstSize = 100
-
// based on an avg ~4,000 unique IPs per minute
// https://log.gprd.gitlab.net/app/lens#/edit/f7110d00-2013-11ec-8c8e-ed83b5469915?_g=h@e78830b
DefaultSourceIPCacheSize = 5000
@@ -50,18 +42,18 @@ type RateLimiter struct {
// New creates a new RateLimiter with default values that can be configured via Option functions
func New(name string, opts ...Option) *RateLimiter {
rl := &RateLimiter{
- name: name,
- now: time.Now,
- limitPerSecond: DefaultSourceIPLimitPerSecond,
- burstSize: DefaultSourceIPBurstSize,
- key: request.GetRemoteAddrWithoutPort,
+ name: name,
+ now: time.Now,
+ key: request.GetRemoteAddrWithoutPort,
}
for _, opt := range opts {
opt(rl)
}
- rl.cache = lru.New(name, rl.cacheOptions...)
+ if rl.limitPerSecond > 0.0 {
+ rl.cache = lru.New(name, rl.cacheOptions...)
+ }
return rl
}
@@ -124,8 +116,8 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter {
return limiterI.(*rate.Limiter)
}
-// RequestAllowed checks that the real remote IP address is allowed to perform an operation
-func (rl *RateLimiter) RequestAllowed(r *http.Request) bool {
+// requestAllowed checks that the real remote IP address is allowed to perform an operation
+func (rl *RateLimiter) requestAllowed(r *http.Request) bool {
rateLimitedKey := rl.key(r)
limiter := rl.limiter(rateLimitedKey)
diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go
index ce767b35..926a90c1 100644
--- a/internal/ratelimiter/ratelimiter_test.go
+++ b/internal/ratelimiter/ratelimiter_test.go
@@ -3,7 +3,6 @@ package ratelimiter
import (
"net/http"
"net/http/httptest"
- "sync"
"testing"
"time"
@@ -63,7 +62,7 @@ func TestSourceIPAllowed(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
r.RemoteAddr = "172.16.123.1"
- got := rl.RequestAllowed(r)
+ got := rl.requestAllowed(r)
if i < tc.sourceIPBurstSize {
require.Truef(t, got, "expected true for request no. %d", i)
} else {
@@ -77,43 +76,30 @@ func TestSourceIPAllowed(t *testing.T) {
}
func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) {
- rate := 10 * time.Millisecond
+ now := time.Now()
rl := New(
"rate_limiter",
- WithLimitPerSecond(float64(1/rate)),
+ WithLimitPerSecond(1),
WithBurstSize(1),
+ WithNow(func() time.Time {
+ return now
+ }),
)
- wg := sync.WaitGroup{}
-
- testFn := func(ip string) func(t *testing.T) {
- return func(t *testing.T) {
- wg.Add(1)
- go func() {
- defer wg.Done()
-
- for i := 0; i < 5; i++ {
- r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
- r.RemoteAddr = ip
- got := rl.RequestAllowed(r)
- require.Truef(t, got, "expected true for request no. %d", i)
- time.Sleep(rate)
- }
- }()
- }
+ testRequest := func(ip string, i int) {
+ r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ r.RemoteAddr = ip
+ got := rl.requestAllowed(r)
+ require.Truef(t, got, "expected true for %v request no. %d", ip, i)
}
- first := "172.16.123.10"
- t.Run(first, testFn(first))
-
- second := "172.16.123.20"
- t.Run(second, testFn(second))
-
- third := "172.16.123.30"
- t.Run(third, testFn(third))
-
- wg.Wait()
+ for i := 0; i < 5; i++ {
+ testRequest("172.16.123.10", i)
+ testRequest("172.16.123.20", i)
+ testRequest("172.16.123.30", i)
+ now = now.Add(time.Second)
+ }
}
func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *prometheus.CounterVec) {