diff options
author | Vladimir Shushlin <vshushlin@gitlab.com> | 2022-02-22 13:17:56 +0300 |
---|---|---|
committer | Vladimir Shushlin <vshushlin@gitlab.com> | 2022-02-22 13:17:56 +0300 |
commit | e7ad17b8d0818dd4c94f3a06e81781f4068c1b97 (patch) | |
tree | 026e2619a53da1abf202322c0d4e656ed2e93ba3 | |
parent | dbd3785baf9f6af3c0c6a76ef44b12f3fd49b68a (diff) | |
parent | dec4b09ac1f6fdf98487d4db61055c1e64358c15 (diff) |
Merge branch 'reject-tls-2' into 'master'
feat: add rate limits on the TLS connection level
See merge request gitlab-org/gitlab-pages!700
-rw-r--r-- | app.go | 32 | ||||
-rw-r--r-- | internal/config/config.go | 94 | ||||
-rw-r--r-- | internal/config/flags.go | 59 | ||||
-rw-r--r-- | internal/config/tls/tls.go | 100 | ||||
-rw-r--r-- | internal/config/validate.go | 22 | ||||
-rw-r--r-- | internal/config/validate_test.go | 41 | ||||
-rw-r--r-- | internal/feature/feature.go | 16 | ||||
-rw-r--r-- | internal/handlers/ratelimiter.go | 16 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 5 | ||||
-rw-r--r-- | internal/ratelimiter/middleware_test.go | 4 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 31 | ||||
-rw-r--r-- | internal/ratelimiter/stub_conn_test.go | 14 | ||||
-rw-r--r-- | internal/ratelimiter/tls.go | 49 | ||||
-rw-r--r-- | internal/ratelimiter/tls_test.go | 194 | ||||
-rw-r--r-- | internal/tls/tls.go | 96 | ||||
-rw-r--r-- | internal/tls/tls_test.go (renamed from internal/config/tls/tls_test.go) | 73 | ||||
-rw-r--r-- | metrics/metrics.go | 57 | ||||
-rw-r--r-- | server.go | 5 | ||||
-rw-r--r-- | test/acceptance/helpers_test.go | 12 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 128 |
20 files changed, 777 insertions, 271 deletions
@@ -27,7 +27,6 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config" - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" @@ -40,6 +39,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab" + "gitlab.com/gitlab-org/gitlab-pages/internal/tls" "gitlab.com/gitlab-org/gitlab-pages/internal/urilimiter" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -51,6 +51,7 @@ var ( type theApp struct { config *cfg.Config source source.Source + tlsConfig *cryptotls.Config Artifact *artifact.Artifact Auth *auth.Auth Handlers *handlers.Handlers @@ -62,19 +63,33 @@ func (a *theApp) isReady() bool { return true } -func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { +func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { if ch.ServerName == "" { return nil, nil } if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil { - tls, _ := domain.EnsureCertificate() - return tls, nil + certificate, _ := domain.EnsureCertificate() + return certificate, nil } return nil, nil } +func (a *theApp) getTLSConfig() (*cryptotls.Config, error) { + // we call this function only when tls config is needed, and we ignore TLS related flags otherwise + // in theory you can configure both listen-https and listen-proxyv2, + // so this return is here to have a single TLS config + if a.tlsConfig != nil { + return a.tlsConfig, nil + } + + var err error + a.tlsConfig, err = tls.GetTLSConfig(a.config, a.GetCertificate) + + return a.tlsConfig, err +} + func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) { u := *r.URL u.Scheme = request.SchemeHTTPS @@ -306,7 +321,7 @@ func (a *theApp) Run() { // Listen for HTTPS for _, addr := range a.config.ListenHTTPSStrings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -334,7 +349,7 @@ func (a *theApp) Run() { // Listen for HTTPS PROXYv2 requests for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -478,11 +493,6 @@ func fatal(err error, message string) { log.WithError(err).Fatal(message) } -func (a *theApp) TLSConfig() (*cryptotls.Config, error) { - return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS, - a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) -} - // handlePanicMiddleware logs and captures the recover() information from any panic func handlePanicMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/config/config.go b/internal/config/config.go index 3bb7b126..24a811ec 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,8 +9,6 @@ import ( "github.com/namsral/flag" "gitlab.com/gitlab-org/labkit/log" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) // Config stores all the config options relevant to GitLab Pages. @@ -58,10 +56,17 @@ type General struct { // RateLimit config struct type RateLimit struct { + // HTTP limits SourceIPLimitPerSecond float64 SourceIPBurst int DomainLimitPerSecond float64 DomainBurst int + + // TLS connections limits + TLSSourceIPLimitPerSecond float64 + TLSSourceIPBurst int + TLSDomainLimitPerSecond float64 + TLSDomainBurst int } // ArtifactsServer groups settings related to configuring Artifacts @@ -183,6 +188,11 @@ func loadConfig() (*Config, error) { SourceIPBurst: *rateLimitSourceIPBurst, DomainLimitPerSecond: *rateLimitDomain, DomainBurst: *rateLimitDomainBurst, + + TLSSourceIPLimitPerSecond: *rateLimitTLSSourceIP, + TLSSourceIPBurst: *rateLimitTLSSourceIPBurst, + TLSDomainLimitPerSecond: *rateLimitTLSDomain, + TLSDomainBurst: *rateLimitTLSDomainBurst, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, @@ -217,8 +227,8 @@ func loadConfig() (*Config, error) { Environment: *sentryEnvironment, }, TLS: TLS{ - MinVersion: tls.AllTLSVersions[*tlsMinVersion], - MaxVersion: tls.AllTLSVersions[*tlsMaxVersion], + MinVersion: allTLSVersions[*tlsMinVersion], + MaxVersion: allTLSVersions[*tlsMaxVersion], }, Zip: ZipServing{ ExpirationInterval: *zipCacheExpiration, @@ -267,40 +277,48 @@ func loadConfig() (*Config, error) { func LogConfig(config *Config) { log.WithFields(log.Fields{ - "artifacts-server": *artifactsServer, - "artifacts-server-timeout": *artifactsServerTimeout, - "default-config-filename": flag.DefaultConfigFlagname, - "disable-cross-origin-requests": *disableCrossOriginRequests, - "domain": config.General.Domain, - "insecure-ciphers": config.General.InsecureCiphers, - "listen-http": listenHTTP, - "listen-https": listenHTTPS, - "listen-proxy": listenProxy, - "listen-https-proxyv2": listenHTTPSProxyv2, - "log-format": *logFormat, - "metrics-address": *metricsAddress, - "pages-domain": *pagesDomain, - "pages-root": *pagesRoot, - "pages-status": *pagesStatus, - "propagate-correlation-id": *propagateCorrelationID, - "redirect-http": config.General.RedirectHTTP, - "root-cert": *pagesRootKey, - "root-key": *pagesRootCert, - "status_path": config.General.StatusPath, - "tls-min-version": *tlsMinVersion, - "tls-max-version": *tlsMaxVersion, - "gitlab-server": config.GitLab.PublicServer, - "internal-gitlab-server": config.GitLab.InternalServer, - "api-secret-key": *gitLabAPISecretKey, - "enable-disk": config.GitLab.EnableDisk, - "auth-redirect-uri": config.Authentication.RedirectURI, - "auth-scope": config.Authentication.Scope, - "max-conns": config.General.MaxConns, - "max-uri-length": config.General.MaxURILength, - "zip-cache-expiration": config.Zip.ExpirationInterval, - "zip-cache-cleanup": config.Zip.CleanupInterval, - "zip-cache-refresh": config.Zip.RefreshInterval, - "zip-open-timeout": config.Zip.OpenTimeout, + "artifacts-server": *artifactsServer, + "artifacts-server-timeout": *artifactsServerTimeout, + "default-config-filename": flag.DefaultConfigFlagname, + "disable-cross-origin-requests": *disableCrossOriginRequests, + "domain": config.General.Domain, + "insecure-ciphers": config.General.InsecureCiphers, + "listen-http": listenHTTP, + "listen-https": listenHTTPS, + "listen-proxy": listenProxy, + "listen-https-proxyv2": listenHTTPSProxyv2, + "log-format": *logFormat, + "metrics-address": *metricsAddress, + "pages-domain": *pagesDomain, + "pages-root": *pagesRoot, + "pages-status": *pagesStatus, + "propagate-correlation-id": *propagateCorrelationID, + "redirect-http": config.General.RedirectHTTP, + "root-cert": *pagesRootKey, + "root-key": *pagesRootCert, + "status_path": config.General.StatusPath, + "tls-min-version": *tlsMinVersion, + "tls-max-version": *tlsMaxVersion, + "gitlab-server": config.GitLab.PublicServer, + "internal-gitlab-server": config.GitLab.InternalServer, + "api-secret-key": *gitLabAPISecretKey, + "enable-disk": config.GitLab.EnableDisk, + "auth-redirect-uri": config.Authentication.RedirectURI, + "auth-scope": config.Authentication.Scope, + "max-conns": config.General.MaxConns, + "max-uri-length": config.General.MaxURILength, + "zip-cache-expiration": config.Zip.ExpirationInterval, + "zip-cache-cleanup": config.Zip.CleanupInterval, + "zip-cache-refresh": config.Zip.RefreshInterval, + "zip-open-timeout": config.Zip.OpenTimeout, + "rate-limit-source-ip": config.RateLimit.SourceIPLimitPerSecond, + "rate-limit-source-ip-burst": config.RateLimit.SourceIPBurst, + "rate-limit-domain": config.RateLimit.DomainLimitPerSecond, + "rate-limit-domain-burst": config.RateLimit.DomainBurst, + "rate-limit-tls-source-ip": config.RateLimit.TLSSourceIPLimitPerSecond, + "rate-limit-tls-source-ip-burst": config.RateLimit.TLSSourceIPBurst, + "rate-limit-tls-domain": config.RateLimit.TLSDomainLimitPerSecond, + "rate-limit-tls-domain-burst": config.RateLimit.TLSDomainBurst, }).Debug("Start Pages with configuration") } diff --git a/internal/config/flags.go b/internal/config/flags.go index 93228827..409ecdc7 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -1,24 +1,41 @@ package config import ( + "crypto/tls" + "fmt" + "sort" + "strings" "time" "github.com/namsral/flag" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( - pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") - pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") - redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") - _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") - pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") - pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") - rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled") - rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") - rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled") - rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second") + // allTLSVersions has all supported flag values + allTLSVersions = map[string]uint16{ + "": 0, // Default value in tls.Config + "tls1.2": tls.VersionTLS12, + "tls1.3": tls.VersionTLS13, + } + + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") + pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") + redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") + _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") + pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") + pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") + + // HTTP rate limits + rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit HTTP requests per second from a single IP, 0 means is disabled") + rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit HTTP requests from a single IP, maximum burst allowed per second") + rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit HTTP requests per second to a single domain, 0 means is disabled") + rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit HTTP requests to a single domain, maximum burst allowed per second") + // TLS connections rate limits + rateLimitTLSSourceIP = flag.Float64("rate-limit-tls-source-ip", 0.0, "Rate limit new TLS connections per second from a single IP, 0 means is disabled") + rateLimitTLSSourceIPBurst = flag.Int("rate-limit-tls-source-ip-burst", 100, "Rate limit new TLS connections from a single IP, maximum burst allowed per second") + rateLimitTLSDomain = flag.Float64("rate-limit-tls-domain", 0.0, "Rate limit new TLS connections per second from to a single domain, 0 means is disabled") + rateLimitTLSDomainBurst = flag.Int("rate-limit-tls-domain-burst", 100, "Rate limit new TLS connections from a single domain, maximum burst allowed per second") + artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") @@ -55,8 +72,8 @@ var ( maxConns = flag.Int("max-conns", 0, "Limit on the number of concurrent connections to the HTTP, HTTPS or proxy listeners, 0 for no limit") maxURILength = flag.Int("max-uri-length", 1024, "Limit the length of URI, 0 for unlimited.") insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4") - tlsMinVersion = flag.String("tls-min-version", "tls1.2", tls.FlagUsage("min")) - tlsMaxVersion = flag.String("tls-max-version", "", tls.FlagUsage("max")) + tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min")) + tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max")) zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval") zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval") zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval") @@ -88,3 +105,17 @@ func initFlags() { flag.Parse() } + +// tlsVersionFlagUsage returns string with explanation how to use the tls version CLI flag +func tlsVersionFlagUsage(minOrMax string) string { + versions := []string{} + + for version := range allTLSVersions { + if version != "" { + versions = append(versions, fmt.Sprintf("%q", version)) + } + } + sort.Strings(versions) + + return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) +} diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go deleted file mode 100644 index c76bbff7..00000000 --- a/internal/config/tls/tls.go +++ /dev/null @@ -1,100 +0,0 @@ -package tls - -import ( - "crypto/tls" - "fmt" - "sort" - "strings" -) - -// GetCertificateFunc returns the certificate to be used for given domain -type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) - -var ( - preferredCipherSuites = []uint16{ - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - } - - // AllTLSVersions has all supported flag values - AllTLSVersions = map[string]uint16{ - "": 0, // Default value in tls.Config - "tls1.2": tls.VersionTLS12, - "tls1.3": tls.VersionTLS13, - } -) - -// FlagUsage returns string with explanation how to use the CLI flag -func FlagUsage(minOrMax string) string { - versions := []string{} - - for version := range AllTLSVersions { - if version != "" { - versions = append(versions, fmt.Sprintf("%q", version)) - } - } - sort.Strings(versions) - - return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) -} - -// Create returns tls.Config for given app configuration -func Create(cert, key []byte, getCertificate GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16) (*tls.Config, error) { - // set MinVersion to fix gosec: G402 - tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} - - err := configureCertificate(tlsConfig, cert, key) - if err != nil { - return nil, err - } - - if !insecureCiphers { - configureTLSCiphers(tlsConfig) - } - - tlsConfig.MinVersion = tlsMinVersion - tlsConfig.MaxVersion = tlsMaxVersion - - return tlsConfig, nil -} - -// ValidateTLSVersions returns error if the provided TLS versions config values are not valid -func ValidateTLSVersions(min, max string) error { - tlsMin, tlsMinOk := AllTLSVersions[min] - tlsMax, tlsMaxOk := AllTLSVersions[max] - - if !tlsMinOk { - return fmt.Errorf("invalid minimum TLS version: %s", min) - } - if !tlsMaxOk { - return fmt.Errorf("invalid maximum TLS version: %s", max) - } - if tlsMin > tlsMax && tlsMax > 0 { - return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) - } - - return nil -} - -func configureCertificate(tlsConfig *tls.Config, cert, key []byte) error { - certificate, err := tls.X509KeyPair(cert, key) - if err != nil { - return err - } - - tlsConfig.Certificates = []tls.Certificate{certificate} - - return nil -} - -func configureTLSCiphers(tlsConfig *tls.Config) { - tlsConfig.PreferServerCipherSuites = true - tlsConfig.CipherSuites = preferredCipherSuites -} diff --git a/internal/config/validate.go b/internal/config/validate.go index a3dbcc3b..bb247287 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -6,8 +6,6 @@ import ( "net/url" "github.com/hashicorp/go-multierror" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( @@ -30,7 +28,7 @@ func Validate(config *Config) error { validateListeners(config), validateAuthConfig(config), validateArtifactsServerConfig(config), - tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion), + validateTLSVersions(*tlsMinVersion, *tlsMaxVersion), ) return result.ErrorOrNil() @@ -115,3 +113,21 @@ func validateArtifactsServerConfig(config *Config) error { return result.ErrorOrNil() } + +// validateTLSVersions returns error if the provided TLS versions config values are not valid +func validateTLSVersions(min, max string) error { + tlsMin, tlsMinOk := allTLSVersions[min] + tlsMax, tlsMaxOk := allTLSVersions[max] + + if !tlsMinOk { + return fmt.Errorf("invalid minimum TLS version: %s", min) + } + if !tlsMaxOk { + return fmt.Errorf("invalid maximum TLS version: %s", max) + } + if tlsMin > tlsMax && tlsMax > 0 { + return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) + } + + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 60e37732..80e4ded3 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -159,3 +159,44 @@ func validConfig() Config { return cfg } + +func TestValidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + }{ + "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, + "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, + "tls 1.3 max": {tlsMax: "tls1.3"}, + "tls 1.2 max": {tlsMax: "tls1.2"}, + "tls 1.3+": {tlsMin: "tls1.3"}, + "tls 1.2+": {tlsMin: "tls1.2"}, + "default": {}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.NoError(t, err) + }) + } +} + +func TestInvalidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + err string + }{ + "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, + "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, + "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.EqualError(t, err, tc.err) + }) + } +} diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 81eef9a0..c98fe85d 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -8,17 +8,29 @@ type Feature struct { } // EnforceIPRateLimits enforces IP rate limiter to drop requests -// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 var EnforceIPRateLimits = Feature{ EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS", } // EnforceDomainRateLimits enforces domain rate limiter to drop requests -// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655 +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 var EnforceDomainRateLimits = Feature{ EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS", } +// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 +var EnforceDomainTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS", +} + +// EnforceIPTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706 +var EnforceIPTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_IP_TLS_RATE_LIMITS", +} + // RedirectsPlaceholders enables support for placeholders in redirects file // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620 var RedirectsPlaceholders = Feature{ diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go index 52281f6e..a8eee005 100644 --- a/internal/handlers/ratelimiter.go +++ b/internal/handlers/ratelimiter.go @@ -14,11 +14,11 @@ import ( // TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { sourceIPLimiter := ratelimiter.New( - "source_ip", + "http_requests_by_source_ip", ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond), ratelimiter.WithBurstSize(config.SourceIPBurst), ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()), @@ -27,12 +27,12 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler { handler = sourceIPLimiter.Middleware(handler) domainLimiter := ratelimiter.New( - "domain", + "http_requests_by_domain", ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), ratelimiter.WithKeyFunc(request.GetHostWithoutPort), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond), ratelimiter.WithBurstSize(config.DomainBurst), ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()), diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index af7b0881..1be6f642 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -8,7 +8,6 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) @@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { rl.logRateLimitedRequest(r) if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc() + rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc() } if rl.enforce { @@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) { "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), - "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(), + "enforced": rl.enforce, "rate_limiter_limit_per_second": rl.limitPerSecond, "rate_limiter_burst_size": rl.burstSize, }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go index 1f753fc4..25ac08b5 100644 --- a/internal/ratelimiter/middleware_test.go +++ b/internal/ratelimiter/middleware_test.go @@ -107,7 +107,7 @@ func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) { assertSourceIPLog(t, remoteAddr, hook) } - blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce))) + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter", strconv.FormatBool(tc.enforce))) require.Equal(t, float64(4), blockedCount, "blocked count") blocked.Reset() @@ -226,7 +226,7 @@ func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, * prometheus.GaugeOpts{ Name: t.Name(), }, - []string{"enforced"}, + []string{"limit_name", "enforced"}, ) cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{ diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index feeb8cb4..b64cb4ce 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -1,6 +1,8 @@ package ratelimiter import ( + "crypto/tls" + "net" "net/http" "time" @@ -27,6 +29,9 @@ type Option func(*RateLimiter) // KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain) type KeyFunc func(*http.Request) string +// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername) +type TLSKeyFunc func(*tls.ClientHelloInfo) string + // RateLimiter holds an LRU cache of elements to be rate limited. // It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry. // See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html @@ -35,6 +40,7 @@ type RateLimiter struct { name string now func() time.Time keyFunc KeyFunc + tlsKeyFunc TLSKeyFunc limitPerSecond float64 burstSize int blockedCount *prometheus.GaugeVec @@ -120,6 +126,26 @@ func WithKeyFunc(f KeyFunc) Option { } } +func TLSHostnameKey(info *tls.ClientHelloInfo) string { + return info.ServerName +} + +func TLSClientIPKey(info *tls.ClientHelloInfo) string { + remoteAddr := info.Conn.RemoteAddr().String() + remoteAddr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return remoteAddr +} + +func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option { + return func(rl *RateLimiter) { + rl.tlsKeyFunc = keyFunc + } +} + // WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics func WithEnforce(enforce bool) Option { return func(rl *RateLimiter) { @@ -138,6 +164,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter { // requestAllowed checks if request is within the rate-limit func (rl *RateLimiter) requestAllowed(r *http.Request) bool { rateLimitedKey := rl.keyFunc(r) + + return rl.allowed(rateLimitedKey) +} + +func (rl *RateLimiter) allowed(rateLimitedKey string) bool { limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. diff --git a/internal/ratelimiter/stub_conn_test.go b/internal/ratelimiter/stub_conn_test.go new file mode 100644 index 00000000..9d7a4a9a --- /dev/null +++ b/internal/ratelimiter/stub_conn_test.go @@ -0,0 +1,14 @@ +package ratelimiter + +import ( + "net" +) + +type stubConn struct { + net.Conn + remoteAddr net.Addr +} + +func (s stubConn) RemoteAddr() net.Addr { + return s.remoteAddr +} diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go new file mode 100644 index 00000000..3bebbc38 --- /dev/null +++ b/internal/ratelimiter/tls.go @@ -0,0 +1,49 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "strconv" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/log" +) + +var ErrTLSRateLimited = errors.New("too many connections, please retry later") + +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate GetCertificateFunc) GetCertificateFunc { + if rl.limitPerSecond <= 0.0 { + return getCertificate + } + + return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + if rl.allowed(rl.tlsKeyFunc(info)) { + return getCertificate(info) + } + + rl.logRateLimitedTLS(info) + + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc() + } + + if !rl.enforce { + return getCertificate(info) + } + + return nil, ErrTLSRateLimited + } +} + +func (rl *RateLimiter) logRateLimitedTLS(info *tls.ClientHelloInfo) { + log.WithFields(logrus.Fields{ + "rate_limiter_name": rl.name, + "source_ip": TLSClientIPKey(info), + "req_host": info.ServerName, + "rate_limiter_limit_per_second": rl.limitPerSecond, + "rate_limiter_burst_size": rl.burstSize, + "enforced": rl.enforce, + }).Info("TLS connection rate-limited") +} diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go new file mode 100644 index 00000000..6763514b --- /dev/null +++ b/internal/ratelimiter/tls_test.go @@ -0,0 +1,194 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "net" + "strconv" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/sirupsen/logrus" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestTLSHostnameKey(t *testing.T) { + info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"} + + require.Equal(t, "group.gitlab.io", TLSHostnameKey(info)) +} + +func TestTLSClientIPKey(t *testing.T) { + tests := []struct { + addr string + expected string + }{ + { + "10.1.2.3:1234", + "10.1.2.3", + }, + { + "[2001:db8:3333:4444:5555:6666:7777:8888]:1234", + "2001:db8:3333:4444:5555:6666:7777:8888", + }, + } + + for _, tt := range tests { + addr, err := net.ResolveTCPAddr("tcp", tt.addr) + require.NoError(t, err) + conn := stubConn{remoteAddr: addr} + info := &tls.ClientHelloInfo{Conn: conn} + + require.Equal(t, tt.expected, TLSClientIPKey(info)) + } +} + +func TestGetCertificateMiddleware(t *testing.T) { + tests := map[string]struct { + useHostnameAsKey bool + enforced bool + limitPerSecond float64 + burst int + successfulReqCnt int + }{ + "ip_limiter": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "hostname_limiter": { + useHostnameAsKey: true, + enforced: true, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "not_enforced": { + useHostnameAsKey: false, + enforced: false, + limitPerSecond: 0.1, + burst: 5, + successfulReqCnt: 5, + }, + "disabled": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0, + burst: 5, + successfulReqCnt: 10, + }, + "slowly_approach_limit": { + useHostnameAsKey: false, + enforced: true, + limitPerSecond: 0.2, + burst: 5, + successfulReqCnt: 6, // 5 * 0.2 gives another 1 request + }, + } + + expectedCert := &tls.Certificate{} + expectedErr := errors.New("expected error") + + getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return expectedCert, expectedErr + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + hook := testlog.NewGlobal() + blocked, cachedEntries, cacheReqs := newTestMetrics(t) + + keyFunc := TLSClientIPKey + if tt.useHostnameAsKey { + keyFunc = TLSHostnameKey + } + + rl := New("limit_name", + WithCachedEntriesMetric(cachedEntries), + WithCachedRequestsMetric(cacheReqs), + WithBlockedCountMetric(blocked), + WithNow(stubNow()), + WithLimitPerSecond(tt.limitPerSecond), + WithBurstSize(tt.burst), + WithEnforce(tt.enforced), + WithTLSKeyFunc(keyFunc)) + + middlewareGetCert := rl.GetCertificateMiddleware(getCertificate) + + addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345") + require.NoError(t, err) + conn := stubConn{remoteAddr: addr} + info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"} + + for i := 0; i < tt.successfulReqCnt; i++ { + cert, err := middlewareGetCert(info) + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + } + + // When rate-limiter disabled altogether + if tt.limitPerSecond <= 0 { + return + } + + cert, err := middlewareGetCert(info) + if tt.enforced { + require.Nil(t, cert) + require.Equal(t, err, ErrTLSRateLimited) + } else { + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + } + + require.NotNil(t, hook.LastEntry()) + require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message) + expectedFields := logrus.Fields{ + "rate_limiter_name": "limit_name", + "source_ip": "10.1.2.3", + "req_host": "group.gitlab.io", + "rate_limiter_limit_per_second": tt.limitPerSecond, + "rate_limiter_burst_size": tt.burst, + "enforced": tt.enforced, + } + require.Equal(t, expectedFields, hook.LastEntry().Data) + + // make another request with different key and expect success + if tt.useHostnameAsKey { + info.ServerName = "another-group.gitlab.io" + } else { + addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345") + require.NoError(t, err) + conn = stubConn{remoteAddr: addr} + info.Conn = conn + } + + cert, err = middlewareGetCert(info) + require.Equal(t, expectedCert, cert) + require.Equal(t, expectedErr, err) + + blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced))) + require.Equal(t, float64(1), blockedCount, "blocked count") + + cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name")) + require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one + + cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss")) + require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one + cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit")) + require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count") + }) + } +} + +func stubNow() func() time.Time { + now := time.Now() + return func() time.Time { + now = now.Add(time.Second) + + return now + } +} diff --git a/internal/tls/tls.go b/internal/tls/tls.go new file mode 100644 index 00000000..6d7397af --- /dev/null +++ b/internal/tls/tls.go @@ -0,0 +1,96 @@ +package tls + +import ( + "crypto/tls" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var preferredCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, +} + +// GetCertificateFunc returns the certificate to be used for given domain +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +// GetTLSConfig initializes tls.Config based on config flags +// getCertificateByServerName obtains certificate based on domain +func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) { + wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey) + if err != nil { + return nil, err + } + + getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Golang calls tls.Config.GetCertificate only if it's set and + // 1. ServerName != "" + // 2. Or tls.Config.Certificates is empty array + // tls.Config.Certificates contain wildcard certificate + // We want to implement rate limits via GetCertificate, so we need to call it every time + // So we don't set tls.Config.Certificates, but simulate the behavior of golang: + // 1. try to get certificate by name + // 2. if we can't, fallback to default(wildcard) certificate + customCertificate, err := getCertificateByServerName(info) + + if customCertificate != nil || err != nil { + return customCertificate, err + } + + return &wildcardCertificate, nil + } + + TLSDomainRateLimiter := ratelimiter.New( + "tls_connections_by_domain", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + TLSSourceIPRateLimiter := ratelimiter.New( + "tls_connections_by_source_ip", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst), + ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()), + ) + + getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate) + getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) + + // set MinVersion to fix gosec: G402 + tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} + + if !cfg.General.InsecureCiphers { + configureTLSCiphers(tlsConfig) + } + + tlsConfig.MinVersion = cfg.TLS.MinVersion + tlsConfig.MaxVersion = cfg.TLS.MaxVersion + + return tlsConfig, nil +} + +func configureTLSCiphers(tlsConfig *tls.Config) { + tlsConfig.PreferServerCipherSuites = true + tlsConfig.CipherSuites = preferredCipherSuites +} diff --git a/internal/config/tls/tls_test.go b/internal/tls/tls_test.go index 7146d904..e65b2671 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/tls/tls_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" ) var cert = []byte(`-----BEGIN CERTIFICATE----- @@ -29,65 +31,46 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { return nil, nil } -func TestInvalidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - err string - }{ - "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, - "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, - "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.EqualError(t, err, tc.err) - }) - } -} - -func TestValidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - }{ - "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, - "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, - "tls 1.3 max": {tlsMax: "tls1.3"}, - "tls 1.2 max": {tlsMax: "tls1.2"}, - "tls 1.3+": {tlsMin: "tls1.3"}, - "tls 1.2+": {tlsMin: "tls1.2"}, - "default": {}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.NoError(t, err) - }) - } -} - func TestInvalidKeyPair(t *testing.T) { - _, err := Create([]byte(``), []byte(``), getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{} + _, err := GetTLSConfig(cfg, getCertificate) require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") } func TestInsecureCiphers(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + InsecureCiphers: true, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.False(t, tlsConfig.PreferServerCipherSuites) require.Empty(t, tlsConfig.CipherSuites) } -func TestCreate(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) +func TestGetTLSConfig(t *testing.T) { + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + }, + TLS: config.TLS{ + MinVersion: tls.VersionTLS11, + MaxVersion: tls.VersionTLS12, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.IsType(t, getCertificate, tlsConfig.GetCertificate) require.True(t, tlsConfig.PreferServerCipherSuites) require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites) require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion) require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion) + + cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) } diff --git a/metrics/metrics.go b/metrics/metrics.go index e0e4ab29..26c7413f 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -193,60 +193,31 @@ var ( }, ) - // RateLimitSourceIPCacheRequests is the number of cache hits/misses - RateLimitSourceIPCacheRequests = prometheus.NewCounterVec( + // RateLimitCacheRequests is the number of cache hits/misses + RateLimitCacheRequests = prometheus.NewCounterVec( prometheus.CounterOpts{ - Name: "gitlab_pages_rate_limit_source_ip_cache_requests", + Name: "gitlab_pages_rate_limit_cache_requests", Help: "The number of source_ip cache hits/misses in the rate limiter", }, []string{"op", "cache"}, ) - // RateLimitSourceIPCachedEntries is the number of entries in the cache - RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec( + // RateLimitCachedEntries is the number of entries in the cache + RateLimitCachedEntries = prometheus.NewGaugeVec( prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_source_ip_cached_entries", + Name: "gitlab_pages_rate_limit_cached_entries", Help: "The number of entries in the cache", }, []string{"op"}, ) - // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the - // source IP rate limiter - RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec( + // RateLimitBlockedCount is the number of requests that have been blocked + RateLimitBlockedCount = prometheus.NewGaugeVec( prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_source_ip_blocked_count", - Help: "The number of requests that have been blocked by the IP rate limiter", + Name: "gitlab_pages_rate_limit_blocked_count", + Help: "The number of requests/connections that have been blocked by rate limiter", }, - []string{"enforced"}, - ) - - // RateLimitDomainCacheRequests is the number of cache hits/misses - RateLimitDomainCacheRequests = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_pages_rate_limit_domain_cache_requests", - Help: "The number of source_ip cache hits/misses in the rate limiter", - }, - []string{"op", "cache"}, - ) - - // RateLimitDomainCachedEntries is the number of entries in the cache - RateLimitDomainCachedEntries = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_domain_cached_entries", - Help: "The number of entries in the cache", - }, - []string{"op"}, - ) - - // RateLimitDomainBlockedCount is the number of requests that have been blocked by the - // domain rate limiter - RateLimitDomainBlockedCount = prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "gitlab_pages_rate_limit_domain_blocked_count", - Help: "The number of requests addresses that have been blocked by the domain rate limiter", - }, - []string{"enforced"}, + []string{"limit_name", "enforced"}, ) ) @@ -276,8 +247,8 @@ func MustRegister() { LimitListenerConcurrentConns, LimitListenerWaitingConns, PanicRecoveredCount, - RateLimitSourceIPCacheRequests, - RateLimitSourceIPCachedEntries, - RateLimitSourceIPBlockedCount, + RateLimitCacheRequests, + RateLimitCachedEntries, + RateLimitBlockedCount, ) } @@ -4,11 +4,13 @@ import ( "context" "crypto/tls" "fmt" + stdlog "log" "net" "net/http" "time" - proxyproto "github.com/pires/go-proxyproto" + "github.com/pires/go-proxyproto" + "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" @@ -31,6 +33,7 @@ func (a *theApp) listenAndServe(server *http.Server, addr string, h http.Handler server.Handler = h server.TLSConfig = config.tlsConfig + server.ErrorLog = stdlog.New(logrus.StandardLogger().Writer(), "", 0) // ensure http2 is enabled even if TLSConfig is not null // See https://github.com/golang/go/blob/97cee43c93cfccded197cd281f0a5885cdb605b4/src/net/http/server.go#L2947-L2954 if server.TLSConfig != nil { diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go index c44058ba..1b514a85 100644 --- a/test/acceptance/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -602,3 +602,15 @@ func copyFile(dest, src string) error { _, err = io.Copy(destFile, srcFile) return err } + +// RequireMetricEqual requests prometheus metrics and makes sure metric is there +func RequireMetricEqual(t *testing.T, metricsAddress, metricWithValue string) { + resp, err := http.Get(fmt.Sprintf("http://%s/metrics", metricsAddress)) + require.NoError(t, err) + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Contains(t, string(body), metricWithValue) +} diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index 02ba54f5..a97fdfb1 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -3,6 +3,7 @@ package acceptance_test import ( "fmt" "net/http" + "strconv" "testing" "time" @@ -77,7 +78,7 @@ func TestIPRateLimits(t *testing.T) { } } -func TestDomainateLimits(t *testing.T) { +func TestDomainRateLimits(t *testing.T) { testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true) for name, tc := range ratelimitedListeners { @@ -112,6 +113,131 @@ func TestDomainateLimits(t *testing.T) { } } +func TestTLSRateLimits(t *testing.T) { + tests := map[string]struct { + spec ListenSpec + domainLimit bool + sourceIP string + enforceEnabled bool + }{ + "https_with_domain_limit": { + spec: httpsListener, + domainLimit: true, + sourceIP: "127.0.0.1", + enforceEnabled: true, + }, + "https_with_domain_limit_not_enforced": { + spec: httpsListener, + domainLimit: true, + sourceIP: "127.0.0.1", + enforceEnabled: false, + }, + "https_with_ip_limit": { + spec: httpsListener, + sourceIP: "127.0.0.1", + enforceEnabled: true, + }, + "https_with_ip_limit_not_enforced": { + spec: httpsListener, + sourceIP: "127.0.0.1", + enforceEnabled: false, + }, + "proxyv2_with_domain_limit": { + spec: httpsProxyv2Listener, + domainLimit: true, + sourceIP: "10.1.1.1", + enforceEnabled: true, + }, + "proxyv2_with_domain_limit_not_enforced": { + spec: httpsProxyv2Listener, + domainLimit: true, + sourceIP: "10.1.1.1", + enforceEnabled: false, + }, + "proxyv2_with_ip_limit": { + spec: httpsProxyv2Listener, + sourceIP: "10.1.1.1", + enforceEnabled: true, + }, + "proxyv2_with_ip_limit_not_enforced": { + spec: httpsProxyv2Listener, + sourceIP: "10.1.1.1", + enforceEnabled: false, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + rateLimit := 5 + + options := []processOption{ + withListeners([]ListenSpec{tt.spec}), + withExtraArgument("metrics-address", ":42345"), + } + + featureName := feature.EnforceIPTLSRateLimits.EnvVariable + limitName := "tls_connections_by_source_ip" + + if tt.domainLimit { + options = append(options, + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit))) + + featureName = feature.EnforceDomainTLSRateLimits.EnvVariable + limitName = "tls_connections_by_domain" + } else { + options = append(options, + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit))) + } + + testhelpers.StubFeatureFlagValue(t, featureName, tt.enforceEnabled) + logBuf := RunPagesProcess(t, options...) + + // when we start the process we make 1 requests to verify that process is up + // it gets counted in the rate limit for IP, but host is different + if !tt.domainLimit { + rateLimit-- + } + + for i := 0; i < 10; i++ { + rsp, err := makeTLSRequest(t, tt.spec) + + if i >= rateLimit { + assertLogFound(t, logBuf, []string{ + "TLS connection rate-limited", + "\"req_host\":\"group.gitlab-example.com\"", + fmt.Sprintf("\"source_ip\":\"%s\"", tt.sourceIP), + "\"enforced\":" + strconv.FormatBool(tt.enforceEnabled)}) + + if tt.enforceEnabled { + require.Error(t, err) + require.Contains(t, err.Error(), "remote error: tls: internal error") + } + + continue + } + + require.NoError(t, err, "request: %d failed", i) + require.NoError(t, rsp.Body.Close()) + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) + } + expectedMetric := fmt.Sprintf( + "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} %v", + tt.enforceEnabled, limitName, 10-rateLimit) + + RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric) + }) + } +} + +func makeTLSRequest(t *testing.T, spec ListenSpec) (*http.Response, error) { + req, err := http.NewRequest("GET", "https://group.gitlab-example.com/project", nil) + require.NoError(t, err) + + return spec.Client().Do(req) +} + func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) { t.Helper() |