diff options
author | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-17 13:27:27 +0300 |
---|---|---|
committer | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-21 16:44:20 +0300 |
commit | 62a6491652aa6975d9ecf3b9e258766c886d49d4 (patch) | |
tree | 18a2ddf45d3e997dbb8ea6a6c27a7da26d6f88be | |
parent | 92fb5e54ad42ed489c4dd93eec69fb5876d11efe (diff) |
feat: Add TLS rate limits
Changelog: added
-rw-r--r-- | app.go | 61 | ||||
-rw-r--r-- | internal/config/config.go | 12 | ||||
-rw-r--r-- | internal/config/flags.go | 28 | ||||
-rw-r--r-- | internal/feature/feature.go | 16 | ||||
-rw-r--r-- | internal/handlers/ratelimiter.go | 16 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 5 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 4 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 31 | ||||
-rw-r--r-- | internal/ratelimiter/stubconn.go | 45 | ||||
-rw-r--r-- | internal/ratelimiter/tls.go | 49 | ||||
-rw-r--r-- | internal/ratelimiter/tls_test.go | 194 | ||||
-rw-r--r-- | metrics/metrics.go | 57 | ||||
-rw-r--r-- | test/acceptance/helpers_test.go | 12 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 157 |
14 files changed, 608 insertions, 79 deletions
@@ -30,10 +30,12 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/routing" @@ -51,6 +53,7 @@ var ( type theApp struct { config *cfg.Config source source.Source + tlsConfig *cryptotls.Config Artifact *artifact.Artifact Auth *auth.Auth Handlers *handlers.Handlers @@ -62,19 +65,62 @@ func (a *theApp) isReady() bool { return true } -func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { +func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { + log.Info("GetCertificate called") if ch.ServerName == "" { return nil, nil } if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil { - tls, _ := domain.EnsureCertificate() - return tls, nil + certificate, _ := domain.EnsureCertificate() + return certificate, nil } return nil, nil } +// TODO: find a better place than app.go for all the TLS logic https://gitlab.com/gitlab-org/gitlab-pages/-/issues/707 +// right now we have config/tls, but I think does more than config +// related logic should +func (a *theApp) getTLSConfig() (*cryptotls.Config, error) { + if a.tlsConfig != nil { + return a.tlsConfig, nil + } + TLSDomainRateLimiter := ratelimiter.New( + "tls_connections_by_domain", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSDomainLimitPerSecond), + ratelimiter.WithBurstSize(a.config.RateLimit.TLSDomainBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + TLSSourceIPRateLimiter := ratelimiter.New( + "tls_connections_by_source_ip", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSSourceIPLimitPerSecond), + ratelimiter.WithBurstSize(a.config.RateLimit.TLSSourceIPBurst), + ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()), + ) + + getCertificate := TLSDomainRateLimiter.GetCertificateMiddleware(a.GetCertificate) + getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) + + tlsConfig, err := tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, getCertificate, + a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) + + a.tlsConfig = tlsConfig + + return a.tlsConfig, err +} + func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) { u := *r.URL u.Scheme = request.SchemeHTTPS @@ -306,7 +352,7 @@ func (a *theApp) Run() { // Listen for HTTPS for _, addr := range a.config.ListenHTTPSStrings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -334,7 +380,7 @@ func (a *theApp) Run() { // Listen for HTTPS PROXYv2 requests for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -478,11 +524,6 @@ func fatal(err error, message string) { log.WithError(err).Fatal(message) } -func (a *theApp) TLSConfig() (*cryptotls.Config, error) { - return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS, - a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) -} - // handlePanicMiddleware logs and captures the recover() information from any panic func handlePanicMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/config/config.go b/internal/config/config.go index 3bb7b126..3dd4ecb3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,10 +58,17 @@ type General struct { // RateLimit config struct type RateLimit struct { + // HTTP limits SourceIPLimitPerSecond float64 SourceIPBurst int DomainLimitPerSecond float64 DomainBurst int + + // TLS connections limits + TLSSourceIPLimitPerSecond float64 + TLSSourceIPBurst int + TLSDomainLimitPerSecond float64 + TLSDomainBurst int } // ArtifactsServer groups settings related to configuring Artifacts @@ -183,6 +190,11 @@ func loadConfig() (*Config, error) { SourceIPBurst: *rateLimitSourceIPBurst, DomainLimitPerSecond: *rateLimitDomain, DomainBurst: *rateLimitDomainBurst, + + TLSSourceIPLimitPerSecond: *rateLimitTLSSourceIP, + TLSSourceIPBurst: *rateLimitTLSSourceIPBurst, + TLSDomainLimitPerSecond: *rateLimitTLSDomain, + TLSDomainBurst: *rateLimitTLSDomainBurst, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index 93228827..3778a677 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -9,16 +9,24 @@ import ( ) var ( - pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") - pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") - redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") - _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") - pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") - pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") - rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled") - rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") - rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled") - rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second") + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") + pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") + redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") + _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") + pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") + pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") + + // HTTP rate limits + rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit HTTP requests per second from a single IP, 0 means is disabled") + rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit HTTP requests from a single IP, maximum burst allowed per second") + rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit HTTP requests per second to a single domain, 0 means is disabled") + rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit HTTP requests to a single domain, maximum burst allowed per second") + // TLS connections rate limits + rateLimitTLSSourceIP = flag.Float64("rate-limit-tls-source-ip", 0.0, "Rate limit new TLS connections per second from a single IP, 0 means is disabled") + rateLimitTLSSourceIPBurst = flag.Int("rate-limit-tls-source-ip-burst", 100, "Rate limit new TLS connections from a single IP, maximum burst allowed per second") + rateLimitTLSDomain = flag.Float64("rate-limit-tls-domain", 0.0, "Rate limit new TLS connections per second from to a single domain, 0 means is disabled") + rateLimitTLSDomainBurst = flag.Int("rate-limit-tls-domain-burst", 100, "Rate limit new TLS connections from a single domain, maximum burst allowed per second") + artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 81eef9a0..c98fe85d 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -8,17 +8,29 @@ type Feature struct { } // EnforceIPRateLimits enforces IP rate limiter to drop requests -// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 var EnforceIPRateLimits = Feature{ EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS", } // EnforceDomainRateLimits enforces domain rate limiter to drop requests -// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655 +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 var EnforceDomainRateLimits = Feature{ EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS", } +// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 +var EnforceDomainTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS", +} + +// EnforceIPTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 +var EnforceIPTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_IP_TLS_RATE_LIMITS", +} + // RedirectsPlaceholders enables support for placeholders in redirects file // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620 var RedirectsPlaceholders = Feature{ diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go index 52281f6e..a8eee005 100644 --- a/internal/handlers/ratelimiter.go +++ b/internal/handlers/ratelimiter.go @@ -14,11 +14,11 @@ import ( // TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { sourceIPLimiter := ratelimiter.New( - "source_ip", + "http_requests_by_source_ip", ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond), ratelimiter.WithBurstSize(config.SourceIPBurst), ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()), @@ -27,12 +27,12 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { handler = sourceIPLimiter.Middleware(handler) domainLimiter := ratelimiter.New( - "domain", + "http_requests_by_domain", ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), ratelimiter.WithKeyFunc(request.GetHostWithoutPort), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond), ratelimiter.WithBurstSize(config.DomainBurst), ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()), diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index af7b0881..1be6f642 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -8,7 +8,6 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) @@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { rl.logRateLimitedRequest(r) if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc() + rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc() } if rl.enforce { @@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) { "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), - "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(), + "enforced": rl.enforce, "rate_limiter_limit_per_second": rl.limitPerSecond, "rate_limiter_burst_size": rl.burstSize, }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index 1f753fc4..25ac08b5 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -107,7 +107,7 @@ func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) { assertSourceIPLog(t, remoteAddr, hook) } - blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce))) + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter", strconv.FormatBool(tc.enforce))) require.Equal(t, float64(4), blockedCount, "blocked count") blocked.Reset() @@ -226,7 +226,7 @@ func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, * prometheus.GaugeOpts{ Name: t.Name(), }, - []string{"enforced"}, + []string{"limit_name", "enforced"}, ) cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{ diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index feeb8cb4..b64cb4ce 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -1,6 +1,8 @@ package ratelimiter import ( + "crypto/tls" + "net" "net/http" "time" @@ -27,6 +29,9 @@ type Option func(*RateLimiter) // KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain) type KeyFunc func(*http.Request) string +// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername) +type TLSKeyFunc func(*tls.ClientHelloInfo) string + // RateLimiter holds an LRU cache of elements to be rate limited. // It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry. // See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html @@ -35,6 +40,7 @@ type RateLimiter struct { name string now func() time.Time keyFunc KeyFunc + tlsKeyFunc TLSKeyFunc limitPerSecond float64 burstSize int blockedCount *prometheus.GaugeVec @@ -120,6 +126,26 @@ func WithKeyFunc(f KeyFunc) Option { } } +func TLSHostnameKey(info *tls.ClientHelloInfo) string { + return info.ServerName +} + +func TLSClientIPKey(info *tls.ClientHelloInfo) string { + remoteAddr := info.Conn.RemoteAddr().String() + remoteAddr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return remoteAddr +} + +func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option { + return func(rl *RateLimiter) { + rl.tlsKeyFunc = keyFunc + } +} + // WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics func WithEnforce(enforce bool) Option { return func(rl *RateLimiter) { @@ -138,6 +164,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter { // requestAllowed checks if request is within the rate-limit func (rl *RateLimiter) requestAllowed(r *http.Request) bool { rateLimitedKey := rl.keyFunc(r) + + return rl.allowed(rateLimitedKey) +} + +func (rl *RateLimiter) allowed(rateLimitedKey string) bool { limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. diff --git a/internal/ratelimiter/stubconn.go b/internal/ratelimiter/stubconn.go new file mode 100644 index 00000000..b351e4bd --- /dev/null +++ b/internal/ratelimiter/stubconn.go @@ -0,0 +1,45 @@ +package ratelimiter + +import ( + "net" + "time" +) + +type stubConn struct { + remoteAddr net.Addr +} + +func (s stubConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +func (s stubConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func (s stubConn) Close() error { + return nil +} + +func (s stubConn) LocalAddr() net.Addr { + return &net.IPAddr{ + IP: net.IPv4(10, 10, 10, 10), + Zone: "", + } +} + +func (s stubConn) RemoteAddr() net.Addr { + return s.remoteAddr +} + +func (s stubConn) SetDeadline(t time.Time) error { + return nil +} + +func (s stubConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (s stubConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go new file mode 100644 index 00000000..15b10cb4 --- /dev/null +++ b/internal/ratelimiter/tls.go @@ -0,0 +1,49 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "strconv" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/log" + + tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" +) + +var ErrTLSRateLimited = errors.New("too many connections, please retry later") + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { + if rl.limitPerSecond <= 0.0 { + return getCertificate + } + + return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + if rl.allowed(rl.tlsKeyFunc(info)) { + return getCertificate(info) + } + + rl.logRateLimitedTLS(info) + + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc() + } + + if !rl.enforce { + return getCertificate(info) + } + + return nil, ErrTLSRateLimited + } +} + +func (rl *RateLimiter) logRateLimitedTLS(info *tls.ClientHelloInfo) { + log.WithFields(logrus.Fields{ + "rate_limiter_name": rl.name, + "source_ip": TLSClientIPKey(info), + "req_host": info.ServerName, + "rate_limiter_limit_per_second": rl.limitPerSecond, + "rate_limiter_burst_size": rl.burstSize, + "enforced": rl.enforce, + }).Info("TLS connection rate-limited") +} diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go new file mode 100644 index 00000000..9b13192b --- /dev/null +++ b/internal/ratelimiter/tls_test.go @@ -0,0 +1,194 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "net" + "strconv" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/sirupsen/logrus" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestTLSHostnameKey(t *testing.T) { + info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"} + + require.Equal(t, "group.gitlab.io", TLSHostnameKey(info)) +} + +func TestTLSClientIPKey(t *testing.T) { + tests := []struct { + addr string + expected string + }{ + { + "10.1.2.3:1234", + "10.1.2.3", + }, + { + "[2001:db8:3333:4444:5555:6666:7777:8888]:1234", + "2001:db8:3333:4444:5555:6666:7777:8888", + }, + } + + for _, tt := range tests { + addr, err := net.ResolveTCPAddr("tcp", tt.addr) + require.NoError(t, err) + conn := stubConn{addr} + info := &tls.ClientHelloInfo{Conn: conn} + + require.Equal(t, tt.expected, TLSClientIPKey(info)) + } +} + +func TestGetCertificateMiddleware(t *testing.T) { + tests := map[string]struct { + useHostnameAsKey bool + enforced bool + limitPerSecond float64 + burst int + successfulReqCnt int + }{ + "ip_limiter": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "hostname_limiter": { + useHostnameAsKey: true, + enforced: true, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "not_enforced": { + useHostnameAsKey: false, + enforced: false, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "disabled": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0, + burst: 5, + successfulReqCnt: 10, + }, + "slowly_approach_limit": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0.2, + burst: 5, + successfulReqCnt: 6, // 5 * 0.2 gives another 1 request + }, + } + + expectedCert := &tls.Certificate{} + expectedErr := errors.New("expected error") + + getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return expectedCert, expectedErr + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + hook := testlog.NewGlobal() + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + keyFunc := TLSClientIPKey + if tt.useHostnameAsKey { + keyFunc = TLSHostnameKey + } + + rl := New("limit_name", + WithCachedEntriesMetric(cachedEntries), + WithCachedRequestsMetric(cacheReqs), + WithBlockedCountMetric(blocked), + WithNow(stubNow()), + WithLimitPerSecond(tt.limitPerSecond), + WithBurstSize(tt.burst), + WithEnforce(tt.enforced), + WithTLSKeyFunc(keyFunc)) + + middlewareGetCert := rl.GetCertificateMiddleware(getCertificate) + + addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345") + require.NoError(t, err) + conn := stubConn{addr} + info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"} + + for i := 0; i < tt.successfulReqCnt; i++ { + cert, err := middlewareGetCert(info) + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + } + + // When rate-limiter disabled altogether + if tt.limitPerSecond <= 0 { + return + } + + cert, err := middlewareGetCert(info) + if tt.enforced { + require.Nil(t, cert) + require.Equal(t, err, ErrTLSRateLimited) + } else { + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + } + + require.NotNil(t, hook.LastEntry()) + require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message) + expectedFields := logrus.Fields{ + "rate_limiter_name": "limit_name", + "source_ip": "10.1.2.3", + "req_host": "group.gitlab.io", + "rate_limiter_limit_per_second": tt.limitPerSecond, + "rate_limiter_burst_size": tt.burst, + "enforced": tt.enforced, + } + require.Equal(t, expectedFields, hook.LastEntry().Data) + + // make another request with different key and expect success + if tt.useHostnameAsKey { + info.ServerName = "another-group.gitlab.io" + } else { + addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345") + require.NoError(t, err) + conn = stubConn{addr} + info.Conn = conn + } + + cert, err = middlewareGetCert(info) + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced))) + require.Equal(t, float64(1), blockedCount, "blocked count") + + cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name")) + require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one + + cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss")) + require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one + cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit")) + require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count") + }) + } +} + +func stubNow() func() time.Time { + now := time.Now() + return func() time.Time { + now = now.Add(time.Second) + + return now + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index e0e4ab29..26c7413f 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -193,60 +193,31 @@ var ( }, ) - // RateLimitSourceIPCacheRequests is the number of cache hits/misses - RateLimitSourceIPCacheRequests = prometheus.NewCounterVec( + // RateLimitCacheRequests is the number of cache hits/misses + RateLimitCacheRequests = prometheus.NewCounterVec( prometheus.CounterOpts{ - Name: "gitlab_pages_rate_limit_source_ip_cache_requests", + Name: "gitlab_pages_rate_limit_cache_requests", Help: "The number of source_ip cache hits/misses in the rate limiter", }, []string{"op", "cache"}, ) - // RateLimitSourceIPCachedEntries is the number of entries in the cache - RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec( + // RateLimitCachedEntries is the number of entries in the cache + RateLimitCachedEntries = prometheus.NewGaugeVec( prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_source_ip_cached_entries", + Name: "gitlab_pages_rate_limit_cached_entries", Help: "The number of entries in the cache", }, []string{"op"}, ) - // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the - // source IP rate limiter - RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec( + // RateLimitBlockedCount is the number of requests that have been blocked + RateLimitBlockedCount = prometheus.NewGaugeVec( prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_source_ip_blocked_count", - Help: "The number of requests that have been blocked by the IP rate limiter", + Name: "gitlab_pages_rate_limit_blocked_count", + Help: "The number of requests/connections that have been blocked by rate limiter", }, - []string{"enforced"}, - ) - - // RateLimitDomainCacheRequests is the number of cache hits/misses - RateLimitDomainCacheRequests = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_pages_rate_limit_domain_cache_requests", - Help: "The number of source_ip cache hits/misses in the rate limiter", - }, - []string{"op", "cache"}, - ) - - // RateLimitDomainCachedEntries is the number of entries in the cache - RateLimitDomainCachedEntries = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_domain_cached_entries", - Help: "The number of entries in the cache", - }, - []string{"op"}, - ) - - // RateLimitDomainBlockedCount is the number of requests that have been blocked by the - // domain rate limiter - RateLimitDomainBlockedCount = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_domain_blocked_count", - Help: "The number of requests addresses that have been blocked by the domain rate limiter", - }, - []string{"enforced"}, + []string{"limit_name", "enforced"}, ) ) @@ -276,8 +247,8 @@ func MustRegister() { LimitListenerConcurrentConns, LimitListenerWaitingConns, PanicRecoveredCount, - RateLimitSourceIPCacheRequests, - RateLimitSourceIPCachedEntries, - RateLimitSourceIPBlockedCount, + RateLimitCacheRequests, + RateLimitCachedEntries, + RateLimitBlockedCount, ) } diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go index c44058ba..1b514a85 100644 --- a/test/acceptance/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -602,3 +602,15 @@ func copyFile(dest, src string) error { _, err = io.Copy(destFile, srcFile) return err } + +// RequireMetricEqual requests prometheus metrics and makes sure metric is there +func RequireMetricEqual(t *testing.T, metricsAddress, metricWithValue string) { + resp, err := http.Get(fmt.Sprintf("http://%s/metrics", metricsAddress)) + require.NoError(t, err) + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Contains(t, string(body), metricWithValue) +} diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index 02ba54f5..365ba4cd 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -3,6 +3,7 @@ package acceptance_test import ( "fmt" "net/http" + "strconv" "testing" "time" @@ -77,7 +78,7 @@ func TestIPRateLimits(t *testing.T) { } } -func TestDomainateLimits(t *testing.T) { +func TestDomainRateLimits(t *testing.T) { testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true) for name, tc := range ratelimitedListeners { @@ -112,6 +113,160 @@ func TestDomainateLimits(t *testing.T) { } } +func TestTLSRateLimits(t *testing.T) { + rateLimit := 5 + + tests := map[string]struct { + spec ListenSpec + options []processOption + sourceIP string + featureName string + enforceEnabled bool + limitName string + }{ + "https_with_domain_limit": { + spec: httpsListener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "127.0.0.1", + featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, + enforceEnabled: true, + limitName: "tls_connections_by_domain", + }, + "https_with_domain_limit_not_enforced": { + spec: httpsListener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "127.0.0.1", + featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, + enforceEnabled: false, + limitName: "tls_connections_by_domain", + }, + "https_with_ip_limit": { + spec: httpsListener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "127.0.0.1", + featureName: feature.EnforceIPTLSRateLimits.EnvVariable, + enforceEnabled: true, + limitName: "tls_connections_by_source_ip", + }, + "https_with_ip_limit_not_enforced": { + spec: httpsListener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "127.0.0.1", + featureName: feature.EnforceIPTLSRateLimits.EnvVariable, + enforceEnabled: false, + limitName: "tls_connections_by_source_ip", + }, + "proxyv2_with_domain_limit": { + spec: httpsProxyv2Listener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "10.1.1.1", + featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, + enforceEnabled: true, + limitName: "tls_connections_by_domain", + }, + "proxyv2_with_domain_limit_not_enforced": { + spec: httpsProxyv2Listener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "10.1.1.1", + featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, + enforceEnabled: false, + limitName: "tls_connections_by_domain", + }, + "proxyv2_with_ip_limit": { + spec: httpsProxyv2Listener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "10.1.1.1", + featureName: feature.EnforceIPTLSRateLimits.EnvVariable, + enforceEnabled: true, + limitName: "tls_connections_by_source_ip", + }, + "proxyv2_with_ip_limit_not_enforced": { + spec: httpsProxyv2Listener, + options: []processOption{ + withExtraArgument("metrics-address", ":42345"), + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), + }, + sourceIP: "10.1.1.1", + featureName: feature.EnforceIPTLSRateLimits.EnvVariable, + enforceEnabled: false, + limitName: "tls_connections_by_source_ip", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + testhelpers.StubFeatureFlagValue(t, tt.featureName, tt.enforceEnabled) + + options := append(tt.options, withListeners([]ListenSpec{tt.spec})) + logBuf := RunPagesProcess(t, options...) + + for i := 0; i < 10; i++ { + rsp, err := makeTLSRequest(t, tt.spec) + + if i >= rateLimit { + assertLogFound(t, logBuf, []string{ + "TLS connection rate-limited", + "\"req_host\":\"group.gitlab-example.com\"", + fmt.Sprintf("\"source_ip\":\"%s\"", tt.sourceIP), + "\"enforced\":" + strconv.FormatBool(tt.enforceEnabled)}) + + if tt.enforceEnabled { + require.Error(t, err) + require.Contains(t, err.Error(), "remote error: tls: internal error") + } + + continue + } + + require.NoError(t, err) + require.NoError(t, rsp.Body.Close()) + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) + } + expectedMetric := fmt.Sprintf( + "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} 5", + tt.enforceEnabled, tt.limitName) + + RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric) + }) + } +} + +func makeTLSRequest(t *testing.T, spec ListenSpec) (*http.Response, error) { + req, err := http.NewRequest("GET", "https://group.gitlab-example.com/project", nil) + require.NoError(t, err) + + return spec.Client().Do(req) +} + func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) { t.Helper() |