Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app.go13
-rw-r--r--internal/ratelimiter/middleware.go31
-rw-r--r--internal/ratelimiter/middleware_test.go24
-rw-r--r--internal/ratelimiter/ratelimiter.go86
-rw-r--r--internal/ratelimiter/ratelimiter_test.go25
-rw-r--r--test/acceptance/ratelimiter_test.go2
6 files changed, 100 insertions, 81 deletions
diff --git a/app.go b/app.go
index 508b11fa..6030cad8 100644
--- a/app.go
+++ b/app.go
@@ -260,15 +260,16 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
if a.config.RateLimit.SourceIPLimitPerSecond > 0 {
rl := ratelimiter.New(
- ratelimiter.WithSourceIPCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
- ratelimiter.WithSourceIPCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
- ratelimiter.WithSourceIPCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
+ "source_ip",
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount),
- ratelimiter.WithSourceIPLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond),
- ratelimiter.WithSourceIPBurstSize(a.config.RateLimit.SourceIPBurst),
+ ratelimiter.WithLimitPerSecond(a.config.RateLimit.SourceIPLimitPerSecond),
+ ratelimiter.WithBurstSize(a.config.RateLimit.SourceIPBurst),
)
- handler = rl.SourceIPLimiter(handler)
+ handler = rl.Middleware(handler)
}
// Health Check
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index f26cb2e0..86beaf39 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -18,25 +18,24 @@ const (
headerXForwardedProto = "X-Forwarded-Proto"
)
-// SourceIPLimiter returns middleware for rate-limiting clients based on their IP
-func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler {
+// Middleware returns middleware for rate-limiting clients
+func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- host, sourceIP := request.GetHostWithoutPort(r), request.GetRemoteAddrWithoutPort(r)
- if !rl.SourceIPAllowed(sourceIP) {
- rl.logSourceIP(r, host, sourceIP)
+ if !rl.RequestAllowed(r) {
+ rl.logRateLimitedRequest(r)
// Only drop requests once FF_ENABLE_RATE_LIMITER is enabled
// https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
if rateLimiterEnabled() {
- if rl.sourceIPBlockedCount != nil {
- rl.sourceIPBlockedCount.WithLabelValues("true").Inc()
+ if rl.blockedCount != nil {
+ rl.blockedCount.WithLabelValues("true").Inc()
}
httperrors.Serve429(w)
return
}
- if rl.sourceIPBlockedCount != nil {
- rl.sourceIPBlockedCount.WithLabelValues("false").Inc()
+ if rl.blockedCount != nil {
+ rl.blockedCount.WithLabelValues("false").Inc()
}
}
@@ -44,24 +43,24 @@ func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler {
})
}
-func (rl *RateLimiter) logSourceIP(r *http.Request, host, sourceIP string) {
+func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) {
log.WithFields(logrus.Fields{
- "handler": "source_ip_rate_limiter",
+ "rate_limiter_name": rl.name,
"correlation_id": correlation.ExtractFromContext(r.Context()),
"req_scheme": r.URL.Scheme,
"req_host": r.Host,
"req_path": r.URL.Path,
- "pages_domain": host,
+ "pages_domain": request.GetHostWithoutPort(r),
"remote_addr": r.RemoteAddr,
- "source_ip": sourceIP,
+ "source_ip": request.GetRemoteAddrWithoutPort(r),
"x_forwarded_proto": r.Header.Get(headerXForwardedProto),
"x_forwarded_for": r.Header.Get(headerXForwardedFor),
"gitlab_real_ip": r.Header.Get(headerGitLabRealIP),
"rate_limiter_enabled": rateLimiterEnabled(),
- "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond,
- "rate_limiter_burst_size": rl.sourceIPBurstSize,
+ "rate_limiter_limit_per_second": rl.limitPerSecond,
+ "rate_limiter_burst_size": rl.burstSize,
}). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
- Info("source IP hit rate limit")
+ Info("request hit rate limit")
}
// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index 2cf3b3e5..a66f7523 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -28,9 +28,10 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) {
for tn, tc := range sharedTestCases {
t.Run(tn, func(t *testing.T) {
rl := New(
+ "rate_limiter",
WithNow(mockNow),
- WithSourceIPLimitPerSecond(tc.sourceIPLimit),
- WithSourceIPBurstSize(tc.sourceIPBurstSize),
+ WithLimitPerSecond(tc.sourceIPLimit),
+ WithBurstSize(tc.sourceIPBurstSize),
)
for i := 0; i < tc.reqNum; i++ {
@@ -38,7 +39,7 @@ func TestSourceIPLimiterWithDifferentLimits(t *testing.T) {
rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
rr.RemoteAddr = remoteAddr
- handler := rl.SourceIPLimiter(next)
+ handler := rl.Middleware(next)
handler.ServeHTTP(ww, rr)
res := ww.Result()
@@ -83,12 +84,13 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) {
for tn, tc := range tcs {
t.Run(tn, func(t *testing.T) {
rl := New(
- WithSourceIPCachedEntriesMetric(cachedEntries),
- WithSourceIPCachedRequestsMetric(cacheReqs),
+ "rate_limiter",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
WithBlockedCountMetric(blocked),
WithNow(mockNow),
- WithSourceIPLimitPerSecond(1),
- WithSourceIPBurstSize(1),
+ WithLimitPerSecond(1),
+ WithBurstSize(1),
)
for i := 0; i < 5; i++ {
@@ -103,7 +105,7 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) {
rr.RemoteAddr = remoteAddr
// middleware is evaluated in reverse order
- handler := rl.SourceIPLimiter(next)
+ handler := rl.Middleware(next)
handler.ServeHTTP(ww, rr)
res := ww.Result()
@@ -126,13 +128,13 @@ func TestSourceIPLimiterDenyRequestsAfterBurst(t *testing.T) {
}
blocked.Reset()
- cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("source_ip"))
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("rate_limiter"))
require.Equal(t, float64(1), cachedCount, "cached count")
cachedEntries.Reset()
- cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "miss"))
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "miss"))
require.Equal(t, float64(1), cacheReqMiss, "miss count")
- cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("source_ip", "hit"))
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("rate_limiter", "hit"))
require.Equal(t, float64(4), cacheReqHit, "hit count")
cacheReqs.Reset()
})
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index 37aca020..ab02c966 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -1,12 +1,14 @@
package ratelimiter
import (
+ "net/http"
"time"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
"gitlab.com/gitlab-org/gitlab-pages/internal/lru"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
const (
@@ -26,35 +28,40 @@ const (
// Option function to configure a RateLimiter
type Option func(*RateLimiter)
-// RateLimiter holds an LRU cache of elements to be rate limited. Currently, it supports
-// a sourceIPCache and each item returns a rate.Limiter.
+// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
+type KeyFunc func(*http.Request) string
+
+// RateLimiter holds an LRU cache of elements to be rate limited.
// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
// It also holds a now function that can be mocked in unit tests.
type RateLimiter struct {
- now func() time.Time
- sourceIPLimitPerSecond float64
- sourceIPBurstSize int
- sourceIPBlockedCount *prometheus.GaugeVec
- sourceIPCache *lru.Cache
-
- sourceIPCacheOptions []lru.Option
- // TODO: add domainCache https://gitlab.com/gitlab-org/gitlab-pages/-/issues/630
+ name string
+ now func() time.Time
+ limitPerSecond float64
+ burstSize int
+ blockedCount *prometheus.GaugeVec
+ cache *lru.Cache
+ key KeyFunc
+
+ cacheOptions []lru.Option
}
// New creates a new RateLimiter with default values that can be configured via Option functions
-func New(opts ...Option) *RateLimiter {
+func New(name string, opts ...Option) *RateLimiter {
rl := &RateLimiter{
- now: time.Now,
- sourceIPLimitPerSecond: DefaultSourceIPLimitPerSecond,
- sourceIPBurstSize: DefaultSourceIPBurstSize,
+ name: name,
+ now: time.Now,
+ limitPerSecond: DefaultSourceIPLimitPerSecond,
+ burstSize: DefaultSourceIPBurstSize,
+ key: request.GetRemoteAddrWithoutPort,
}
for _, opt := range opts {
opt(rl)
}
- rl.sourceIPCache = lru.New("source_ip", rl.sourceIPCacheOptions...)
+ rl.cache = lru.New(name, rl.cacheOptions...)
return rl
}
@@ -66,60 +73,61 @@ func WithNow(now func() time.Time) Option {
}
}
-// WithSourceIPLimitPerSecond allows configuring per source IP limit per second for RateLimiter
-func WithSourceIPLimitPerSecond(limit float64) Option {
+// WithLimitPerSecond allows configuring limit per second for RateLimiter
+func WithLimitPerSecond(limit float64) Option {
return func(rl *RateLimiter) {
- rl.sourceIPLimitPerSecond = limit
+ rl.limitPerSecond = limit
}
}
-// WithSourceIPBurstSize configures burst per source IP for the RateLimiter
-func WithSourceIPBurstSize(burst int) Option {
+// WithBurstSize configures burst per key for the RateLimiter
+func WithBurstSize(burst int) Option {
return func(rl *RateLimiter) {
- rl.sourceIPBurstSize = burst
+ rl.burstSize = burst
}
}
-// WithBlockedCountMetric configures metric reporting how many requests were blocked based by IP
+// WithBlockedCountMetric configures metric reporting how many requests were blocked
func WithBlockedCountMetric(m *prometheus.GaugeVec) Option {
return func(rl *RateLimiter) {
- rl.sourceIPBlockedCount = m
+ rl.blockedCount = m
}
}
-// WithSourceIPCacheMaxSize configures cache size for source IP ratelimiter
-func WithSourceIPCacheMaxSize(size int64) Option {
+// WithCacheMaxSize configures cache size for ratelimiter
+func WithCacheMaxSize(size int64) Option {
return func(rl *RateLimiter) {
- rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithMaxSize(size))
+ rl.cacheOptions = append(rl.cacheOptions, lru.WithMaxSize(size))
}
}
-// WithSourceIPCachedEntriesMetric configures metric reporting how many IPs are currently stored in
-// source-IP rate-limiter cache
-func WithSourceIPCachedEntriesMetric(m *prometheus.GaugeVec) Option {
+// WithCachedEntriesMetric configures metric reporting how many keys are currently stored in
+// the rate-limiter cache
+func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option {
return func(rl *RateLimiter) {
- rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithCachedEntriesMetric(m))
+ rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedEntriesMetric(m))
}
}
-// WithSourceIPCachedRequestsMetric configures metric for how many times we ask source IP cache
-func WithSourceIPCachedRequestsMetric(m *prometheus.CounterVec) Option {
+// WithCachedRequestsMetric configures metric for how many times we ask key cache
+func WithCachedRequestsMetric(m *prometheus.CounterVec) Option {
return func(rl *RateLimiter) {
- rl.sourceIPCacheOptions = append(rl.sourceIPCacheOptions, lru.WithCachedRequestsMetric(m))
+ rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m))
}
}
-func (rl *RateLimiter) getSourceIPLimiter(sourceIP string) *rate.Limiter {
- limiterI, _ := rl.sourceIPCache.FindOrFetch(sourceIP, sourceIP, func() (interface{}, error) {
- return rate.NewLimiter(rate.Limit(rl.sourceIPLimitPerSecond), rl.sourceIPBurstSize), nil
+func (rl *RateLimiter) limiter(key string) *rate.Limiter {
+ limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) {
+ return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil
})
return limiterI.(*rate.Limiter)
}
-// SourceIPAllowed checks that the real remote IP address is allowed to perform an operation
-func (rl *RateLimiter) SourceIPAllowed(sourceIP string) bool {
- limiter := rl.getSourceIPLimiter(sourceIP)
+// RequestAllowed checks that the real remote IP address is allowed to perform an operation
+func (rl *RateLimiter) RequestAllowed(r *http.Request) bool {
+ rateLimitedKey := rl.key(r)
+ limiter := rl.limiter(rateLimitedKey)
// AllowN allows us to use the rl.now function, so we can test this more easily.
return limiter.AllowN(rl.now(), 1)
diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go
index d6be214a..ce767b35 100644
--- a/internal/ratelimiter/ratelimiter_test.go
+++ b/internal/ratelimiter/ratelimiter_test.go
@@ -1,6 +1,8 @@
package ratelimiter
import (
+ "net/http"
+ "net/http/httptest"
"sync"
"testing"
"time"
@@ -51,17 +53,21 @@ func TestSourceIPAllowed(t *testing.T) {
for tn, tc := range sharedTestCases {
t.Run(tn, func(t *testing.T) {
rl := New(
+ "rate_limiter",
WithNow(mockNow),
- WithSourceIPLimitPerSecond(tc.sourceIPLimit),
- WithSourceIPBurstSize(tc.sourceIPBurstSize),
+ WithLimitPerSecond(tc.sourceIPLimit),
+ WithBurstSize(tc.sourceIPBurstSize),
)
for i := 0; i < tc.reqNum; i++ {
- got := rl.SourceIPAllowed("172.16.123.1")
+ r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ r.RemoteAddr = "172.16.123.1"
+
+ got := rl.RequestAllowed(r)
if i < tc.sourceIPBurstSize {
require.Truef(t, got, "expected true for request no. %d", i)
} else {
- // requests should fail after reaching tc.sourceIPBurstSize because mockNow
+ // requests should fail after reaching tc.burstSize because mockNow
// always returns the same time
require.False(t, got, "expected false for request no. %d", i)
}
@@ -74,20 +80,23 @@ func TestSingleRateLimiterWithMultipleSourceIPs(t *testing.T) {
rate := 10 * time.Millisecond
rl := New(
- WithSourceIPLimitPerSecond(float64(1/rate)),
- WithSourceIPBurstSize(1),
+ "rate_limiter",
+ WithLimitPerSecond(float64(1/rate)),
+ WithBurstSize(1),
)
wg := sync.WaitGroup{}
- testFn := func(domain string) func(t *testing.T) {
+ testFn := func(ip string) func(t *testing.T) {
return func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
- got := rl.SourceIPAllowed(domain)
+ r := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ r.RemoteAddr = ip
+ got := rl.RequestAllowed(r)
require.Truef(t, got, "expected true for request no. %d", i)
time.Sleep(rate)
}
diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go
index 2986d46b..5b3be98a 100644
--- a/test/acceptance/ratelimiter_test.go
+++ b/test/acceptance/ratelimiter_test.go
@@ -104,7 +104,7 @@ func TestSourceIPRateLimitMiddleware(t *testing.T) {
if tc.expectFail && i >= int(tc.rateLimit) {
require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i)
- assertLogFound(t, logBuf, []string{"source IP hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""})
+ assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.blockedIP + "\""})
continue
}