diff options
-rw-r--r-- | app.go | 25 | ||||
-rw-r--r-- | app_config.go | 6 | ||||
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | internal/rate_limiting/http_handler.go | 32 | ||||
-rw-r--r-- | internal/rate_limiting/rate_limiting.go | 62 | ||||
-rw-r--r-- | main.go | 9 |
7 files changed, 134 insertions, 3 deletions
@@ -26,8 +26,10 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/rate_limiting" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" + "gitlab.com/gitlab-org/gitlab-pages/internal/tlsconfig" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -352,6 +354,23 @@ func (a *theApp) Run() { httpHandler := a.httpInitialMiddleware(commonHandlerPipeline) + tlsGetCertificate := a.ServeTLS + + if a.appConfig.HostRateLimit > 0 { + hostRateLimiter := rate_limiting.NewRateLimiting( + a.appConfig.HostRateLimitWindow, a.appConfig.HostRateLimit) + + httpHandler = hostRateLimiter.LimitHostHandler(httpHandler) + proxyHandler = hostRateLimiter.LimitHostHandler(proxyHandler) + } + + if a.appConfig.TLSSNIRateLimit > 0 { + tlsRateLimiter := rate_limiting.NewRateLimiting( + a.appConfig.HostRateLimitWindow, a.appConfig.HostRateLimit) + + tlsGetCertificate = tlsRateLimiter.LimitServeTLS(a.ServeTLS) + } + // Listen for HTTP for _, fd := range a.ListenHTTP { a.listenHTTPFD(&wg, fd, httpHandler, limiter) @@ -359,7 +378,7 @@ func (a *theApp) Run() { // Listen for HTTPS for _, fd := range a.ListenHTTPS { - a.listenHTTPSFD(&wg, fd, httpHandler, limiter) + a.listenHTTPSFD(&wg, fd, httpHandler, tlsGetCertificate, limiter) } // Listen for HTTP proxy requests @@ -388,11 +407,11 @@ func (a *theApp) listenHTTPFD(wg *sync.WaitGroup, fd uintptr, httpHandler http.H }() } -func (a *theApp) listenHTTPSFD(wg *sync.WaitGroup, fd uintptr, httpHandler http.Handler, limiter *netutil.Limiter) { +func (a *theApp) listenHTTPSFD(wg *sync.WaitGroup, fd uintptr, httpHandler http.Handler, tlsGetCertificate tlsconfig.GetCertificateFunc, limiter *netutil.Limiter) { wg.Add(1) go func() { defer wg.Done() - err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter) + err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, tlsGetCertificate, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter) if err != nil { capturingFatal(err, errortracking.WithField("listener", request.SchemeHTTPS)) } diff --git a/app_config.go b/app_config.go index 3bc2197b..96dd8f92 100644 --- a/app_config.go +++ b/app_config.go @@ -10,6 +10,12 @@ type appConfig struct { RootKey []byte MaxConns int + HostRateLimit uint + HostRateLimitWindow time.Duration + + TLSSNIRateLimit uint + TLSSNIRateLimitWindow time.Duration + ListenHTTP []uintptr ListenHTTPS []uintptr ListenProxy []uintptr @@ -34,6 +34,7 @@ require ( golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f golang.org/x/net v0.0.0-20200226121028-0de0cce0169b golang.org/x/sys v0.0.0-20200420163511-1957bb5e6d1f + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20200502202811-ed308ab3e770 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect honnef.co/go/tools v0.0.1-2020.1.3 // indirect @@ -53,6 +53,7 @@ github.com/client9/reopen v1.0.0 h1:8tpLVR74DLpLObrn2KvsyxJY++2iORGR17WLUdSzUws= github.com/client9/reopen v1.0.0/go.mod h1:caXVCEr+lUtoN1FlsRiOWdfQtdRHIYfcb0ai8qKWtkQ= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= +github.com/coreos/etcd v3.3.10+incompatible h1:jFneRYjIvLMLhDLCzuTuU4rSJUjRplcJQ7pD7MnhC04= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -448,6 +449,7 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3 golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/rate_limiting/http_handler.go b/internal/rate_limiting/http_handler.go new file mode 100644 index 00000000..d6341781 --- /dev/null +++ b/internal/rate_limiting/http_handler.go @@ -0,0 +1,32 @@ +package rate_limiting + +import ( + "crypto/tls" + "errors" + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/tlsconfig" +) + +func (r *RateLimiting) LimitHostHandler(handler http.Handler) http.Handler { + fn := func(rw http.ResponseWriter, req *http.Request) { + if r.Allow(req.Host) { + handler.ServeHTTP(rw, req) + return + } + + rw.WriteHeader(http.StatusTooManyRequests) + } + + return http.HandlerFunc(fn) +} + +func (r *RateLimiting) LimitServeTLS(handler tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { + return func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { + if r.Allow(ch.ServerName) { + return handler(ch) + } + + return nil, errors.New("rate limited") + } +} diff --git a/internal/rate_limiting/rate_limiting.go b/internal/rate_limiting/rate_limiting.go new file mode 100644 index 00000000..d69dd873 --- /dev/null +++ b/internal/rate_limiting/rate_limiting.go @@ -0,0 +1,62 @@ +package rate_limiting + +import ( + "time" + + "github.com/patrickmn/go-cache" + "golang.org/x/time/rate" +) + +type rateLimit struct { + *rate.Limiter +} + +type RateLimiting struct { + cache *cache.Cache + + window time.Duration + limit uint +} + +func NewRateLimiting(window time.Duration, limit uint) *RateLimiting { + return &RateLimiting{ + cache: cache.New(window*2, window), + window: window, + limit: limit, + } +} + +func (r *RateLimiting) newRateLimiter() rateLimit { + // we divide a window by amount of requests + // the bucket is refilled every interval + // allowing to consume up to the defined `limit` + everyNs := r.window.Nanoseconds() / int64(r.limit) + every := time.Duration(everyNs) + + return rateLimit{ + rate.NewLimiter(rate.Every(every), int(r.limit)), + } +} + +func (r *RateLimiting) findOrCreate(key string) rateLimit { + for { + // try to get existing item + if item, expiry, found := r.cache.GetWithExpiration(key); found { + // extend item window + if time.Until(expiry) > r.window { + r.cache.SetDefault(key, item) + } + + return item.(rateLimit) + } + + // add a new item + if rateLimiter := r.newRateLimiter(); r.cache.Add(key, rateLimiter, cache.DefaultExpiration) == nil { + return rateLimiter + } + } +} + +func (r *RateLimiting) Allow(key string) bool { + return r.findOrCreate(key).Allow() +} @@ -75,6 +75,11 @@ var ( tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsconfig.FlagUsage("min")) tlsMaxVersion = flag.String("tls-max-version", "", tlsconfig.FlagUsage("max")) + hostRateLimit = flag.Uint("host-rate-limit", 0, "Set to non-zero value to enable host-based rate limiting. Requests over rate-limit will respond with 429.") + hostRateLimitWindow = flag.Duration("host-rate-limit-window", 10*time.Minute, "Define a host-bassed rate limiting window") + tlsSniRateLimit = flag.Uint("tls-sni-rate-limit", 0, "Set to non-zero value to enable tls-sni-based rate limiting. New connections over that limit will be rejected.") + tlsSniRateLimitWindow = flag.Duration("tls-sni-limit-window", 10*time.Minute, "Define a tls-sni-bassed rate limiting window") + disableCrossOriginRequests = flag.Bool("disable-cross-origin-requests", false, "Disable cross-origin requests") // See init() @@ -175,6 +180,10 @@ func configFromFlags() appConfig { config.TLSMinVersion = tlsconfig.AllTLSVersions[*tlsMinVersion] config.TLSMaxVersion = tlsconfig.AllTLSVersions[*tlsMaxVersion] config.CustomHeaders = header + config.HostRateLimit = *hostRateLimit + config.HostRateLimitWindow = *hostRateLimitWindow + config.TLSSNIRateLimit = *tlsSniRateLimit + config.TLSSNIRateLimitWindow = *tlsSniRateLimitWindow for _, file := range []struct { contents *[]byte |