diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-22 09:38:25 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-30 06:38:15 +0300 |
commit | 545099df5a11149df62a46b6af4ebb8e424b1155 (patch) | |
tree | 4bf16e071acc3595181623b06fb9a8ed63066a03 | |
parent | 25eeea495282065e82d7e72c6d8ffd01a6f79602 (diff) |
feat: add ratelimiter middleware
Changelog: added
(cherry picked from commit f2275574d0097692131c6cbcedb9a6ecde251340)
-rw-r--r-- | internal/httperrors/httperrors.go | 13 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 34 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 48 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 58 |
4 files changed, 124 insertions, 29 deletions
diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go index ed56ee10..8e61d590 100644 --- a/internal/httperrors/httperrors.go +++ b/internal/httperrors/httperrors.go @@ -34,6 +34,14 @@ var ( <p>Make sure the address is correct and that the page hasn't moved.</p> <p>Please contact your GitLab administrator if you think this is a mistake.</p>`, } + + content429 = content{ + http.StatusTooManyRequests, + "Too many requests (429)", + "429", + "Too many requests.", + `<p>The resource that you are attempting to access is being rate limited.</p>`, + } content500 = content{ http.StatusInternalServerError, "Something went wrong (500)", @@ -176,6 +184,11 @@ func Serve404(w http.ResponseWriter) { serveErrorPage(w, content404) } +// Serve429 returns a 429 error response / HTML page to the http.ResponseWriter +func Serve429(w http.ResponseWriter) { + serveErrorPage(w, content429) +} + // Serve500 returns a 500 error response / HTML page to the http.ResponseWriter func Serve500(w http.ResponseWriter) { serveErrorPage(w, content500) diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go new file mode 100644 index 00000000..088b9414 --- /dev/null +++ b/internal/ratelimiter/middleware.go @@ -0,0 +1,34 @@ +package ratelimiter + +import ( + "net" + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" +) + +// DomainRateLimiter middleware ensures that the requested domain can be served by the current +// rate limit. See -rate-limiter +func DomainRateLimiter(rl *RateLimiter) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := getHost(r) + + if !rl.DomainAllowed(host) { + httperrors.Serve429(w) + return + } + + handler.ServeHTTP(w, r) + }) + } +} + +func getHost(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + + return host +} diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go new file mode 100644 index 00000000..b1b74dfd --- /dev/null +++ b/internal/ratelimiter/middleware_test.go @@ -0,0 +1,48 @@ +package ratelimiter + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDomainRateLimiter(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + + for tn, tc := range sharedTestCases { + t.Run(tn, func(t *testing.T) { + rl := New( + WithNow(mockNow), + WithPerDomainFrequency(tc.domainRate), + WithPerDomainBurstSize(tc.perDomainBurstPerSecond), + ) + + for i := 0; i < tc.reqNum; i++ { + ww := httptest.NewRecorder() + rr := httptest.NewRequest(http.MethodGet, "http://domain.gitlab.io", nil) + handler := DomainRateLimiter(rl)(next) + + handler.ServeHTTP(ww, rr) + res := ww.Result() + + if i < tc.perDomainBurstPerSecond { + require.Equal(t, http.StatusNoContent, res.StatusCode, "req: %d failed", i) + } else { + // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow + // always returns the same time + require.Equal(t, http.StatusTooManyRequests, res.StatusCode, "req: %d failed", i) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Contains(t, string(b), "Too many requests.") + res.Body.Close() + } + } + }) + } +} diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index cdf12fe6..74b465d1 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -17,38 +17,38 @@ func mockNow() time.Time { return validTime } +var sharedTestCases = map[string]struct { + now string + sourceIPLimit float64 + sourceIPBurstSize int + reqNum int +}{ + "one_request_per_second": { + sourceIPLimit: 1, + sourceIPBurstSize: 1, + reqNum: 2, + }, + "one_request_per_second_but_big_bucket": { + sourceIPLimit: 1, + sourceIPBurstSize: 10, + reqNum: 11, + }, + "three_req_per_second_bucket_size_one": { + sourceIPLimit: 3, + sourceIPBurstSize: 1, // max burst 1 means 1 at a time + reqNum: 3, + }, + "10_requests_per_second": { + sourceIPLimit: 10, + sourceIPBurstSize: 10, + reqNum: 11, + }, +} + func TestSourceIPAllowed(t *testing.T) { t.Parallel() - tcs := map[string]struct { - now string - sourceIPLimit float64 - sourceIPBurstSize int - reqNum int - }{ - "one_request_per_second": { - sourceIPLimit: 1, - sourceIPBurstSize: 1, - reqNum: 2, - }, - "one_request_per_second_but_big_bucket": { - sourceIPLimit: 1, - sourceIPBurstSize: 10, - reqNum: 11, - }, - "three_req_per_second_bucket_size_one": { - sourceIPLimit: 3, - sourceIPBurstSize: 1, // max burst 1 means 1 at a time - reqNum: 3, - }, - "10_requests_per_second": { - sourceIPLimit: 10, - sourceIPBurstSize: 10, - reqNum: 11, - }, - } - - for tn, tc := range tcs { + for tn, tc := range sharedTestCases { t.Run(tn, func(t *testing.T) { rl := New( WithNow(mockNow), |