Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Shushlin <v.shushlin@gmail.com>2022-02-17 13:27:27 +0300
committerVladimir Shushlin <v.shushlin@gmail.com>2022-02-21 16:44:20 +0300
commit62a6491652aa6975d9ecf3b9e258766c886d49d4 (patch)
tree18a2ddf45d3e997dbb8ea6a6c27a7da26d6f88be
parent92fb5e54ad42ed489c4dd93eec69fb5876d11efe (diff)
feat: Add TLS rate limits
Changelog: added
-rw-r--r--app.go61
-rw-r--r--internal/config/config.go12
-rw-r--r--internal/config/flags.go28
-rw-r--r--internal/feature/feature.go16
-rw-r--r--internal/handlers/ratelimiter.go16
-rw-r--r--internal/ratelimiter/middleware.go5
-rw-r--r--internal/ratelimiter/middleware_test.go4
-rw-r--r--internal/ratelimiter/ratelimiter.go31
-rw-r--r--internal/ratelimiter/stubconn.go45
-rw-r--r--internal/ratelimiter/tls.go49
-rw-r--r--internal/ratelimiter/tls_test.go194
-rw-r--r--metrics/metrics.go57
-rw-r--r--test/acceptance/helpers_test.go12
-rw-r--r--test/acceptance/ratelimiter_test.go157
14 files changed, 608 insertions, 79 deletions
diff --git a/app.go b/app.go
index ddafb0bf..3e24e538 100644
--- a/app.go
+++ b/app.go
@@ -30,10 +30,12 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
"gitlab.com/gitlab-org/gitlab-pages/internal/customheaders"
"gitlab.com/gitlab-org/gitlab-pages/internal/domain"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/handlers"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/logging"
"gitlab.com/gitlab-org/gitlab-pages/internal/netutil"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter"
"gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
"gitlab.com/gitlab-org/gitlab-pages/internal/routing"
@@ -51,6 +53,7 @@ var (
type theApp struct {
config *cfg.Config
source source.Source
+ tlsConfig *cryptotls.Config
Artifact *artifact.Artifact
Auth *auth.Auth
Handlers *handlers.Handlers
@@ -62,19 +65,62 @@ func (a *theApp) isReady() bool {
return true
}
-func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
+func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
+ log.Info("GetCertificate called")
if ch.ServerName == "" {
return nil, nil
}
if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil {
- tls, _ := domain.EnsureCertificate()
- return tls, nil
+ certificate, _ := domain.EnsureCertificate()
+ return certificate, nil
}
return nil, nil
}
+// TODO: find a better place than app.go for all the TLS logic https://gitlab.com/gitlab-org/gitlab-pages/-/issues/707
+// right now we have config/tls, but I think does more than config
+// related logic should
+func (a *theApp) getTLSConfig() (*cryptotls.Config, error) {
+ if a.tlsConfig != nil {
+ return a.tlsConfig, nil
+ }
+ TLSDomainRateLimiter := ratelimiter.New(
+ "tls_connections_by_domain",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSDomainLimitPerSecond),
+ ratelimiter.WithBurstSize(a.config.RateLimit.TLSDomainBurst),
+ ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()),
+ )
+
+ TLSSourceIPRateLimiter := ratelimiter.New(
+ "tls_connections_by_source_ip",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSSourceIPLimitPerSecond),
+ ratelimiter.WithBurstSize(a.config.RateLimit.TLSSourceIPBurst),
+ ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()),
+ )
+
+ getCertificate := TLSDomainRateLimiter.GetCertificateMiddleware(a.GetCertificate)
+ getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)
+
+ tlsConfig, err := tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, getCertificate,
+ a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion)
+
+ a.tlsConfig = tlsConfig
+
+ return a.tlsConfig, err
+}
+
func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) {
u := *r.URL
u.Scheme = request.SchemeHTTPS
@@ -306,7 +352,7 @@ func (a *theApp) Run() {
// Listen for HTTPS
for _, addr := range a.config.ListenHTTPSStrings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -334,7 +380,7 @@ func (a *theApp) Run() {
// Listen for HTTPS PROXYv2 requests
for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -478,11 +524,6 @@ func fatal(err error, message string) {
log.WithError(err).Fatal(message)
}
-func (a *theApp) TLSConfig() (*cryptotls.Config, error) {
- return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS,
- a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion)
-}
-
// handlePanicMiddleware logs and captures the recover() information from any panic
func handlePanicMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
diff --git a/internal/config/config.go b/internal/config/config.go
index 3bb7b126..3dd4ecb3 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -58,10 +58,17 @@ type General struct {
// RateLimit config struct
type RateLimit struct {
+ // HTTP limits
SourceIPLimitPerSecond float64
SourceIPBurst int
DomainLimitPerSecond float64
DomainBurst int
+
+ // TLS connections limits
+ TLSSourceIPLimitPerSecond float64
+ TLSSourceIPBurst int
+ TLSDomainLimitPerSecond float64
+ TLSDomainBurst int
}
// ArtifactsServer groups settings related to configuring Artifacts
@@ -183,6 +190,11 @@ func loadConfig() (*Config, error) {
SourceIPBurst: *rateLimitSourceIPBurst,
DomainLimitPerSecond: *rateLimitDomain,
DomainBurst: *rateLimitDomainBurst,
+
+ TLSSourceIPLimitPerSecond: *rateLimitTLSSourceIP,
+ TLSSourceIPBurst: *rateLimitTLSSourceIPBurst,
+ TLSDomainLimitPerSecond: *rateLimitTLSDomain,
+ TLSDomainBurst: *rateLimitTLSDomainBurst,
},
GitLab: GitLab{
ClientHTTPTimeout: *gitlabClientHTTPTimeout,
diff --git a/internal/config/flags.go b/internal/config/flags.go
index 93228827..3778a677 100644
--- a/internal/config/flags.go
+++ b/internal/config/flags.go
@@ -9,16 +9,24 @@ import (
)
var (
- pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
- pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
- redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
- _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
- pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
- pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
- rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled")
- rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second")
- rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled")
- rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second")
+ pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
+ pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
+ redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
+ _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
+ pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
+ pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
+
+ // HTTP rate limits
+ rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit HTTP requests per second from a single IP, 0 means is disabled")
+ rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit HTTP requests from a single IP, maximum burst allowed per second")
+ rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit HTTP requests per second to a single domain, 0 means is disabled")
+ rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit HTTP requests to a single domain, maximum burst allowed per second")
+ // TLS connections rate limits
+ rateLimitTLSSourceIP = flag.Float64("rate-limit-tls-source-ip", 0.0, "Rate limit new TLS connections per second from a single IP, 0 means is disabled")
+ rateLimitTLSSourceIPBurst = flag.Int("rate-limit-tls-source-ip-burst", 100, "Rate limit new TLS connections from a single IP, maximum burst allowed per second")
+ rateLimitTLSDomain = flag.Float64("rate-limit-tls-domain", 0.0, "Rate limit new TLS connections per second from to a single domain, 0 means is disabled")
+ rateLimitTLSDomainBurst = flag.Int("rate-limit-tls-domain-burst", 100, "Rate limit new TLS connections from a single domain, maximum burst allowed per second")
+
artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'")
artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server")
pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status")
diff --git a/internal/feature/feature.go b/internal/feature/feature.go
index 81eef9a0..c98fe85d 100644
--- a/internal/feature/feature.go
+++ b/internal/feature/feature.go
@@ -8,17 +8,29 @@ type Feature struct {
}
// EnforceIPRateLimits enforces IP rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceIPRateLimits = Feature{
EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS",
}
// EnforceDomainRateLimits enforces domain rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceDomainRateLimits = Feature{
EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS",
}
+// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceDomainTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS",
+}
+
+// EnforceIPTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceIPTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_IP_TLS_RATE_LIMITS",
+}
+
// RedirectsPlaceholders enables support for placeholders in redirects file
// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620
var RedirectsPlaceholders = Feature{
diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go
index 52281f6e..a8eee005 100644
--- a/internal/handlers/ratelimiter.go
+++ b/internal/handlers/ratelimiter.go
@@ -14,11 +14,11 @@ import (
// TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done
func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
sourceIPLimiter := ratelimiter.New(
- "source_ip",
+ "http_requests_by_source_ip",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond),
ratelimiter.WithBurstSize(config.SourceIPBurst),
ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()),
@@ -27,12 +27,12 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
handler = sourceIPLimiter.Middleware(handler)
domainLimiter := ratelimiter.New(
- "domain",
+ "http_requests_by_domain",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
ratelimiter.WithKeyFunc(request.GetHostWithoutPort),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond),
ratelimiter.WithBurstSize(config.DomainBurst),
ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()),
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index af7b0881..1be6f642 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -8,7 +8,6 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
- "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
@@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler {
rl.logRateLimitedRequest(r)
if rl.blockedCount != nil {
- rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc()
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
}
if rl.enforce {
@@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) {
"x_forwarded_proto": r.Header.Get(headerXForwardedProto),
"x_forwarded_for": r.Header.Get(headerXForwardedFor),
"gitlab_real_ip": r.Header.Get(headerGitLabRealIP),
- "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(),
+ "enforced": rl.enforce,
"rate_limiter_limit_per_second": rl.limitPerSecond,
"rate_limiter_burst_size": rl.burstSize,
}). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index 1f753fc4..25ac08b5 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -107,7 +107,7 @@ func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) {
assertSourceIPLog(t, remoteAddr, hook)
}
- blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce)))
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter", strconv.FormatBool(tc.enforce)))
require.Equal(t, float64(4), blockedCount, "blocked count")
blocked.Reset()
@@ -226,7 +226,7 @@ func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *
prometheus.GaugeOpts{
Name: t.Name(),
},
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index feeb8cb4..b64cb4ce 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -1,6 +1,8 @@
package ratelimiter
import (
+ "crypto/tls"
+ "net"
"net/http"
"time"
@@ -27,6 +29,9 @@ type Option func(*RateLimiter)
// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
type KeyFunc func(*http.Request) string
+// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername)
+type TLSKeyFunc func(*tls.ClientHelloInfo) string
+
// RateLimiter holds an LRU cache of elements to be rate limited.
// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
@@ -35,6 +40,7 @@ type RateLimiter struct {
name string
now func() time.Time
keyFunc KeyFunc
+ tlsKeyFunc TLSKeyFunc
limitPerSecond float64
burstSize int
blockedCount *prometheus.GaugeVec
@@ -120,6 +126,26 @@ func WithKeyFunc(f KeyFunc) Option {
}
}
+func TLSHostnameKey(info *tls.ClientHelloInfo) string {
+ return info.ServerName
+}
+
+func TLSClientIPKey(info *tls.ClientHelloInfo) string {
+ remoteAddr := info.Conn.RemoteAddr().String()
+ remoteAddr, _, err := net.SplitHostPort(remoteAddr)
+ if err != nil {
+ return remoteAddr
+ }
+
+ return remoteAddr
+}
+
+func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option {
+ return func(rl *RateLimiter) {
+ rl.tlsKeyFunc = keyFunc
+ }
+}
+
// WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics
func WithEnforce(enforce bool) Option {
return func(rl *RateLimiter) {
@@ -138,6 +164,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter {
// requestAllowed checks if request is within the rate-limit
func (rl *RateLimiter) requestAllowed(r *http.Request) bool {
rateLimitedKey := rl.keyFunc(r)
+
+ return rl.allowed(rateLimitedKey)
+}
+
+func (rl *RateLimiter) allowed(rateLimitedKey string) bool {
limiter := rl.limiter(rateLimitedKey)
// AllowN allows us to use the rl.now function, so we can test this more easily.
diff --git a/internal/ratelimiter/stubconn.go b/internal/ratelimiter/stubconn.go
new file mode 100644
index 00000000..b351e4bd
--- /dev/null
+++ b/internal/ratelimiter/stubconn.go
@@ -0,0 +1,45 @@
+package ratelimiter
+
+import (
+ "net"
+ "time"
+)
+
+type stubConn struct {
+ remoteAddr net.Addr
+}
+
+func (s stubConn) Read(b []byte) (n int, err error) {
+ return 0, nil
+}
+
+func (s stubConn) Write(b []byte) (n int, err error) {
+ return 0, nil
+}
+
+func (s stubConn) Close() error {
+ return nil
+}
+
+func (s stubConn) LocalAddr() net.Addr {
+ return &net.IPAddr{
+ IP: net.IPv4(10, 10, 10, 10),
+ Zone: "",
+ }
+}
+
+func (s stubConn) RemoteAddr() net.Addr {
+ return s.remoteAddr
+}
+
+func (s stubConn) SetDeadline(t time.Time) error {
+ return nil
+}
+
+func (s stubConn) SetReadDeadline(t time.Time) error {
+ return nil
+}
+
+func (s stubConn) SetWriteDeadline(t time.Time) error {
+ return nil
+}
diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go
new file mode 100644
index 00000000..15b10cb4
--- /dev/null
+++ b/internal/ratelimiter/tls.go
@@ -0,0 +1,49 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "strconv"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
+)
+
+var ErrTLSRateLimited = errors.New("too many connections, please retry later")
+
+func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc {
+ if rl.limitPerSecond <= 0.0 {
+ return getCertificate
+ }
+
+ return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if rl.allowed(rl.tlsKeyFunc(info)) {
+ return getCertificate(info)
+ }
+
+ rl.logRateLimitedTLS(info)
+
+ if rl.blockedCount != nil {
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
+ }
+
+ if !rl.enforce {
+ return getCertificate(info)
+ }
+
+ return nil, ErrTLSRateLimited
+ }
+}
+
+func (rl *RateLimiter) logRateLimitedTLS(info *tls.ClientHelloInfo) {
+ log.WithFields(logrus.Fields{
+ "rate_limiter_name": rl.name,
+ "source_ip": TLSClientIPKey(info),
+ "req_host": info.ServerName,
+ "rate_limiter_limit_per_second": rl.limitPerSecond,
+ "rate_limiter_burst_size": rl.burstSize,
+ "enforced": rl.enforce,
+ }).Info("TLS connection rate-limited")
+}
diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go
new file mode 100644
index 00000000..9b13192b
--- /dev/null
+++ b/internal/ratelimiter/tls_test.go
@@ -0,0 +1,194 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ "github.com/sirupsen/logrus"
+ testlog "github.com/sirupsen/logrus/hooks/test"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTLSHostnameKey(t *testing.T) {
+ info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"}
+
+ require.Equal(t, "group.gitlab.io", TLSHostnameKey(info))
+}
+
+func TestTLSClientIPKey(t *testing.T) {
+ tests := []struct {
+ addr string
+ expected string
+ }{
+ {
+ "10.1.2.3:1234",
+ "10.1.2.3",
+ },
+ {
+ "[2001:db8:3333:4444:5555:6666:7777:8888]:1234",
+ "2001:db8:3333:4444:5555:6666:7777:8888",
+ },
+ }
+
+ for _, tt := range tests {
+ addr, err := net.ResolveTCPAddr("tcp", tt.addr)
+ require.NoError(t, err)
+ conn := stubConn{addr}
+ info := &tls.ClientHelloInfo{Conn: conn}
+
+ require.Equal(t, tt.expected, TLSClientIPKey(info))
+ }
+}
+
+func TestGetCertificateMiddleware(t *testing.T) {
+ tests := map[string]struct {
+ useHostnameAsKey bool
+ enforced bool
+ limitPerSecond float64
+ burst int
+ successfulReqCnt int
+ }{
+ "ip_limiter": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "hostname_limiter": {
+ useHostnameAsKey: true,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "not_enforced": {
+ useHostnameAsKey: false,
+ enforced: false,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "disabled": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0,
+ burst: 5,
+ successfulReqCnt: 10,
+ },
+ "slowly_approach_limit": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.2,
+ burst: 5,
+ successfulReqCnt: 6, // 5 * 0.2 gives another 1 request
+ },
+ }
+
+ expectedCert := &tls.Certificate{}
+ expectedErr := errors.New("expected error")
+
+ getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return expectedCert, expectedErr
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ keyFunc := TLSClientIPKey
+ if tt.useHostnameAsKey {
+ keyFunc = TLSHostnameKey
+ }
+
+ rl := New("limit_name",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
+ WithBlockedCountMetric(blocked),
+ WithNow(stubNow()),
+ WithLimitPerSecond(tt.limitPerSecond),
+ WithBurstSize(tt.burst),
+ WithEnforce(tt.enforced),
+ WithTLSKeyFunc(keyFunc))
+
+ middlewareGetCert := rl.GetCertificateMiddleware(getCertificate)
+
+ addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345")
+ require.NoError(t, err)
+ conn := stubConn{addr}
+ info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"}
+
+ for i := 0; i < tt.successfulReqCnt; i++ {
+ cert, err := middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ // When rate-limiter disabled altogether
+ if tt.limitPerSecond <= 0 {
+ return
+ }
+
+ cert, err := middlewareGetCert(info)
+ if tt.enforced {
+ require.Nil(t, cert)
+ require.Equal(t, err, ErrTLSRateLimited)
+ } else {
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ require.NotNil(t, hook.LastEntry())
+ require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message)
+ expectedFields := logrus.Fields{
+ "rate_limiter_name": "limit_name",
+ "source_ip": "10.1.2.3",
+ "req_host": "group.gitlab.io",
+ "rate_limiter_limit_per_second": tt.limitPerSecond,
+ "rate_limiter_burst_size": tt.burst,
+ "enforced": tt.enforced,
+ }
+ require.Equal(t, expectedFields, hook.LastEntry().Data)
+
+ // make another request with different key and expect success
+ if tt.useHostnameAsKey {
+ info.ServerName = "another-group.gitlab.io"
+ } else {
+ addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345")
+ require.NoError(t, err)
+ conn = stubConn{addr}
+ info.Conn = conn
+ }
+
+ cert, err = middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced)))
+ require.Equal(t, float64(1), blockedCount, "blocked count")
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name"))
+ require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss"))
+ require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit"))
+ require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count")
+ })
+ }
+}
+
+func stubNow() func() time.Time {
+ now := time.Now()
+ return func() time.Time {
+ now = now.Add(time.Second)
+
+ return now
+ }
+}
diff --git a/metrics/metrics.go b/metrics/metrics.go
index e0e4ab29..26c7413f 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -193,60 +193,31 @@ var (
},
)
- // RateLimitSourceIPCacheRequests is the number of cache hits/misses
- RateLimitSourceIPCacheRequests = prometheus.NewCounterVec(
+ // RateLimitCacheRequests is the number of cache hits/misses
+ RateLimitCacheRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cache_requests",
+ Name: "gitlab_pages_rate_limit_cache_requests",
Help: "The number of source_ip cache hits/misses in the rate limiter",
},
[]string{"op", "cache"},
)
- // RateLimitSourceIPCachedEntries is the number of entries in the cache
- RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec(
+ // RateLimitCachedEntries is the number of entries in the cache
+ RateLimitCachedEntries = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cached_entries",
+ Name: "gitlab_pages_rate_limit_cached_entries",
Help: "The number of entries in the cache",
},
[]string{"op"},
)
- // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the
- // source IP rate limiter
- RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec(
+ // RateLimitBlockedCount is the number of requests that have been blocked
+ RateLimitBlockedCount = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_blocked_count",
- Help: "The number of requests that have been blocked by the IP rate limiter",
+ Name: "gitlab_pages_rate_limit_blocked_count",
+ Help: "The number of requests/connections that have been blocked by rate limiter",
},
- []string{"enforced"},
- )
-
- // RateLimitDomainCacheRequests is the number of cache hits/misses
- RateLimitDomainCacheRequests = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_domain_cache_requests",
- Help: "The number of source_ip cache hits/misses in the rate limiter",
- },
- []string{"op", "cache"},
- )
-
- // RateLimitDomainCachedEntries is the number of entries in the cache
- RateLimitDomainCachedEntries = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_cached_entries",
- Help: "The number of entries in the cache",
- },
- []string{"op"},
- )
-
- // RateLimitDomainBlockedCount is the number of requests that have been blocked by the
- // domain rate limiter
- RateLimitDomainBlockedCount = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_blocked_count",
- Help: "The number of requests addresses that have been blocked by the domain rate limiter",
- },
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
)
@@ -276,8 +247,8 @@ func MustRegister() {
LimitListenerConcurrentConns,
LimitListenerWaitingConns,
PanicRecoveredCount,
- RateLimitSourceIPCacheRequests,
- RateLimitSourceIPCachedEntries,
- RateLimitSourceIPBlockedCount,
+ RateLimitCacheRequests,
+ RateLimitCachedEntries,
+ RateLimitBlockedCount,
)
}
diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go
index c44058ba..1b514a85 100644
--- a/test/acceptance/helpers_test.go
+++ b/test/acceptance/helpers_test.go
@@ -602,3 +602,15 @@ func copyFile(dest, src string) error {
_, err = io.Copy(destFile, srcFile)
return err
}
+
+// RequireMetricEqual requests prometheus metrics and makes sure metric is there
+func RequireMetricEqual(t *testing.T, metricsAddress, metricWithValue string) {
+ resp, err := http.Get(fmt.Sprintf("http://%s/metrics", metricsAddress))
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ require.Contains(t, string(body), metricWithValue)
+}
diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go
index 02ba54f5..365ba4cd 100644
--- a/test/acceptance/ratelimiter_test.go
+++ b/test/acceptance/ratelimiter_test.go
@@ -3,6 +3,7 @@ package acceptance_test
import (
"fmt"
"net/http"
+ "strconv"
"testing"
"time"
@@ -77,7 +78,7 @@ func TestIPRateLimits(t *testing.T) {
}
}
-func TestDomainateLimits(t *testing.T) {
+func TestDomainRateLimits(t *testing.T) {
testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true)
for name, tc := range ratelimitedListeners {
@@ -112,6 +113,160 @@ func TestDomainateLimits(t *testing.T) {
}
}
+func TestTLSRateLimits(t *testing.T) {
+ rateLimit := 5
+
+ tests := map[string]struct {
+ spec ListenSpec
+ options []processOption
+ sourceIP string
+ featureName string
+ enforceEnabled bool
+ limitName string
+ }{
+ "https_with_domain_limit": {
+ spec: httpsListener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "127.0.0.1",
+ featureName: feature.EnforceDomainTLSRateLimits.EnvVariable,
+ enforceEnabled: true,
+ limitName: "tls_connections_by_domain",
+ },
+ "https_with_domain_limit_not_enforced": {
+ spec: httpsListener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "127.0.0.1",
+ featureName: feature.EnforceDomainTLSRateLimits.EnvVariable,
+ enforceEnabled: false,
+ limitName: "tls_connections_by_domain",
+ },
+ "https_with_ip_limit": {
+ spec: httpsListener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "127.0.0.1",
+ featureName: feature.EnforceIPTLSRateLimits.EnvVariable,
+ enforceEnabled: true,
+ limitName: "tls_connections_by_source_ip",
+ },
+ "https_with_ip_limit_not_enforced": {
+ spec: httpsListener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "127.0.0.1",
+ featureName: feature.EnforceIPTLSRateLimits.EnvVariable,
+ enforceEnabled: false,
+ limitName: "tls_connections_by_source_ip",
+ },
+ "proxyv2_with_domain_limit": {
+ spec: httpsProxyv2Listener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "10.1.1.1",
+ featureName: feature.EnforceDomainTLSRateLimits.EnvVariable,
+ enforceEnabled: true,
+ limitName: "tls_connections_by_domain",
+ },
+ "proxyv2_with_domain_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "10.1.1.1",
+ featureName: feature.EnforceDomainTLSRateLimits.EnvVariable,
+ enforceEnabled: false,
+ limitName: "tls_connections_by_domain",
+ },
+ "proxyv2_with_ip_limit": {
+ spec: httpsProxyv2Listener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "10.1.1.1",
+ featureName: feature.EnforceIPTLSRateLimits.EnvVariable,
+ enforceEnabled: true,
+ limitName: "tls_connections_by_source_ip",
+ },
+ "proxyv2_with_ip_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ options: []processOption{
+ withExtraArgument("metrics-address", ":42345"),
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)),
+ },
+ sourceIP: "10.1.1.1",
+ featureName: feature.EnforceIPTLSRateLimits.EnvVariable,
+ enforceEnabled: false,
+ limitName: "tls_connections_by_source_ip",
+ },
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ testhelpers.StubFeatureFlagValue(t, tt.featureName, tt.enforceEnabled)
+
+ options := append(tt.options, withListeners([]ListenSpec{tt.spec}))
+ logBuf := RunPagesProcess(t, options...)
+
+ for i := 0; i < 10; i++ {
+ rsp, err := makeTLSRequest(t, tt.spec)
+
+ if i >= rateLimit {
+ assertLogFound(t, logBuf, []string{
+ "TLS connection rate-limited",
+ "\"req_host\":\"group.gitlab-example.com\"",
+ fmt.Sprintf("\"source_ip\":\"%s\"", tt.sourceIP),
+ "\"enforced\":" + strconv.FormatBool(tt.enforceEnabled)})
+
+ if tt.enforceEnabled {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "remote error: tls: internal error")
+ }
+
+ continue
+ }
+
+ require.NoError(t, err)
+ require.NoError(t, rsp.Body.Close())
+ require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i)
+ }
+ expectedMetric := fmt.Sprintf(
+ "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} 5",
+ tt.enforceEnabled, tt.limitName)
+
+ RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric)
+ })
+ }
+}
+
+func makeTLSRequest(t *testing.T, spec ListenSpec) (*http.Response, error) {
+ req, err := http.NewRequest("GET", "https://group.gitlab-example.com/project", nil)
+ require.NoError(t, err)
+
+ return spec.Client().Do(req)
+}
+
func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) {
t.Helper()