Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2021-09-29 07:44:23 +0300
committerJaime Martinez <jmartinez@gitlab.com>2021-09-30 06:38:48 +0300
commit3b9adb4067e88417652edc0a08f596d5196a8912 (patch)
tree45a38335a59a08dbb720568edb9b216b2d7a5210
parentdb4dd9522ba7e20713cab59664b85412d5a03cbb (diff)
feat: source IP middleware for rate limiting
-rw-r--r--go.mod1
-rw-r--r--internal/ratelimiter/middleware.go55
-rw-r--r--internal/ratelimiter/middleware_test.go88
3 files changed, 110 insertions, 34 deletions
diff --git a/go.mod b/go.mod
index 0ff516ed..be91cf20 100644
--- a/go.mod
+++ b/go.mod
@@ -16,6 +16,7 @@ require (
github.com/pires/go-proxyproto v0.2.0
github.com/prometheus/client_golang v1.6.0
github.com/rs/cors v1.7.0
+ github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35
github.com/sirupsen/logrus v1.7.0
github.com/stretchr/testify v1.6.1
github.com/tj/assert v0.0.3 // indirect
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index b3963c6f..39292dbb 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -5,47 +5,52 @@ import (
"net/http"
"os"
+ "github.com/sebest/xff"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/logging"
)
-// DomainRateLimiter middleware ensures that the requested domain can be served by the current
+// SourceIPLimiter middleware ensures that the originating
// rate limit. See -rate-limiter
-func DomainRateLimiter(rl *RateLimiter) func(http.Handler) http.Handler {
- return func(handler http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- host := getHost(r)
-
- if !rl.DomainAllowed(host) {
- logging.LogRequest(r).WithFields(logrus.Fields{
- "handler": "domain_rate_limiter",
- "pages_domain": host,
- "rate_limiter_enabled": rateLimiterEnabled(),
- "rate_limiter_frequency": rl.perDomainFrequency,
- "rate_limiter_burst_size": rl.perDomainBurstSize,
- }).Info("domain hit rate limit")
-
- // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled
- if rateLimiterEnabled() {
- httperrors.Serve429(w)
- return
- }
+func (rl *RateLimiter) SourceIPLimiter(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host, ip, https := getReqDetails(r)
+
+ // http requests do not contain real IP information yet
+ if !rl.SourceIPAllowed(ip) && https {
+ logging.LogRequest(r).WithFields(logrus.Fields{
+ "handler": "source_ip_rate_limiter",
+ "pages_domain": host,
+ "pages_https": https,
+ "source_ip": ip,
+ "rate_limiter_enabled": rateLimiterEnabled(),
+ "rate_limiter_limit_per_second": rl.sourceIPLimitPerSecond,
+ "rate_limiter_burst_size": rl.sourceIPBurstSize,
+ }).Info("source IP hit rate limit")
+
+ // Only drop requests once FF_ENABLE_RATE_LIMITER is enabled
+ if rateLimiterEnabled() {
+ httperrors.Serve429(w)
+ return
}
+ }
- handler.ServeHTTP(w, r)
- })
- }
+ handler.ServeHTTP(w, r)
+ })
}
-func getHost(r *http.Request) string {
+func getReqDetails(r *http.Request) (string, string, bool) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
- return host
+ https := r.URL.Scheme == "https"
+ ip := xff.GetRemoteAddr(r)
+
+ return host, ip, https
}
func rateLimiterEnabled() bool {
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index b15698fb..ef9b23b3 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -10,10 +10,11 @@ import (
"github.com/stretchr/testify/require"
)
-func TestDomainRateLimiter(t *testing.T) {
- next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNoContent)
- })
+var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNoContent)
+})
+
+func TestSourceIPLimiter(t *testing.T) {
enableRateLimiter(t)
@@ -21,19 +22,21 @@ func TestDomainRateLimiter(t *testing.T) {
t.Run(tn, func(t *testing.T) {
rl := New(
WithNow(mockNow),
- WithPerDomainFrequency(tc.domainRate),
- WithPerDomainBurstSize(tc.perDomainBurstPerSecond),
+ WithSourceIPLimitPerSecond(tc.sourceIPLimit),
+ WithSourceIPBurstSize(tc.sourceIPBurstSize),
)
for i := 0; i < tc.reqNum; i++ {
ww := httptest.NewRecorder()
- rr := httptest.NewRequest(http.MethodGet, "http://domain.gitlab.io", nil)
- handler := DomainRateLimiter(rl)(next)
+ rr := httptest.NewRequest(http.MethodGet, "https://domain.gitlab.io", nil)
+ rr.RemoteAddr = "172.16.123.1"
+
+ handler := rl.SourceIPLimiter(next)
handler.ServeHTTP(ww, rr)
res := ww.Result()
- if i < tc.perDomainBurstPerSecond {
+ if i < tc.sourceIPBurstSize {
require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i)
} else {
// requests should fail after reaching tc.perDomainBurstPerSecond because mockNow
@@ -50,6 +53,73 @@ func TestDomainRateLimiter(t *testing.T) {
}
}
+func TestSourceIPRateLimit(t *testing.T) {
+ rl := New(
+ WithNow(mockNow),
+ WithSourceIPLimitPerSecond(1),
+ WithSourceIPBurstSize(1),
+ )
+
+ tcs := map[string]struct {
+ enabled bool
+ ip string
+ host string
+ expectedStatus int
+ }{
+ "disabled_rate_limit_http": {
+ enabled: false,
+ ip: "172.16.123.1",
+ host: "http://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "disabled_rate_limit_https": {
+ enabled: false,
+ ip: "172.16.123.2",
+ host: "https://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "enabled_rate_limit_http_does_not_block": {
+ enabled: true,
+ ip: "172.16.123.3",
+ host: "http://gitlab.com",
+ expectedStatus: http.StatusNoContent,
+ },
+ "enabled_rate_limit_https_blocks": {
+ enabled: true,
+ ip: "172.16.123.4",
+ host: "https://gitlab.com",
+ expectedStatus: http.StatusTooManyRequests,
+ },
+ }
+
+ for tn, tc := range tcs {
+ t.Run(tn, func(t *testing.T) {
+
+ for i := 0; i < 5; i++ {
+ ww := httptest.NewRecorder()
+ rr := httptest.NewRequest(http.MethodGet, tc.host, nil)
+ rr.RemoteAddr = tc.ip
+
+ if tc.enabled {
+ enableRateLimiter(t)
+ }
+
+ handler := rl.SourceIPLimiter(next)
+
+ handler.ServeHTTP(ww, rr)
+ res := ww.Result()
+
+ if i == 0 {
+ require.Equal(t, http.StatusNoContent, res.StatusCode)
+ continue
+ }
+ // burst is 1 and limit is 1 per second, all subsequent requests should fail
+ require.Equal(t, tc.expectedStatus, res.StatusCode)
+ }
+ })
+ }
+}
+
func enableRateLimiter(t *testing.T) {
t.Helper()