Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Shushlin <vshushlin@gitlab.com>2022-02-22 13:17:56 +0300
committerVladimir Shushlin <vshushlin@gitlab.com>2022-02-22 13:17:56 +0300
commite7ad17b8d0818dd4c94f3a06e81781f4068c1b97 (patch)
tree026e2619a53da1abf202322c0d4e656ed2e93ba3
parentdbd3785baf9f6af3c0c6a76ef44b12f3fd49b68a (diff)
parentdec4b09ac1f6fdf98487d4db61055c1e64358c15 (diff)
Merge branch 'reject-tls-2' into 'master'
feat: add rate limits on the TLS connection level See merge request gitlab-org/gitlab-pages!700
-rw-r--r--app.go32
-rw-r--r--internal/config/config.go94
-rw-r--r--internal/config/flags.go59
-rw-r--r--internal/config/tls/tls.go100
-rw-r--r--internal/config/validate.go22
-rw-r--r--internal/config/validate_test.go41
-rw-r--r--internal/feature/feature.go16
-rw-r--r--internal/handlers/ratelimiter.go16
-rw-r--r--internal/ratelimiter/middleware.go5
-rw-r--r--internal/ratelimiter/middleware_test.go4
-rw-r--r--internal/ratelimiter/ratelimiter.go31
-rw-r--r--internal/ratelimiter/stub_conn_test.go14
-rw-r--r--internal/ratelimiter/tls.go49
-rw-r--r--internal/ratelimiter/tls_test.go194
-rw-r--r--internal/tls/tls.go96
-rw-r--r--internal/tls/tls_test.go (renamed from internal/config/tls/tls_test.go)73
-rw-r--r--metrics/metrics.go57
-rw-r--r--server.go5
-rw-r--r--test/acceptance/helpers_test.go12
-rw-r--r--test/acceptance/ratelimiter_test.go128
20 files changed, 777 insertions, 271 deletions
diff --git a/app.go b/app.go
index ddafb0bf..ec928300 100644
--- a/app.go
+++ b/app.go
@@ -27,7 +27,6 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/artifact"
"gitlab.com/gitlab-org/gitlab-pages/internal/auth"
cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config"
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
"gitlab.com/gitlab-org/gitlab-pages/internal/customheaders"
"gitlab.com/gitlab-org/gitlab-pages/internal/domain"
"gitlab.com/gitlab-org/gitlab-pages/internal/handlers"
@@ -40,6 +39,7 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip"
"gitlab.com/gitlab-org/gitlab-pages/internal/source"
"gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/tls"
"gitlab.com/gitlab-org/gitlab-pages/internal/urilimiter"
"gitlab.com/gitlab-org/gitlab-pages/metrics"
)
@@ -51,6 +51,7 @@ var (
type theApp struct {
config *cfg.Config
source source.Source
+ tlsConfig *cryptotls.Config
Artifact *artifact.Artifact
Auth *auth.Auth
Handlers *handlers.Handlers
@@ -62,19 +63,33 @@ func (a *theApp) isReady() bool {
return true
}
-func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
+func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
if ch.ServerName == "" {
return nil, nil
}
if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil {
- tls, _ := domain.EnsureCertificate()
- return tls, nil
+ certificate, _ := domain.EnsureCertificate()
+ return certificate, nil
}
return nil, nil
}
+func (a *theApp) getTLSConfig() (*cryptotls.Config, error) {
+ // we call this function only when tls config is needed, and we ignore TLS related flags otherwise
+ // in theory you can configure both listen-https and listen-proxyv2,
+ // so this return is here to have a single TLS config
+ if a.tlsConfig != nil {
+ return a.tlsConfig, nil
+ }
+
+ var err error
+ a.tlsConfig, err = tls.GetTLSConfig(a.config, a.GetCertificate)
+
+ return a.tlsConfig, err
+}
+
func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) {
u := *r.URL
u.Scheme = request.SchemeHTTPS
@@ -306,7 +321,7 @@ func (a *theApp) Run() {
// Listen for HTTPS
for _, addr := range a.config.ListenHTTPSStrings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -334,7 +349,7 @@ func (a *theApp) Run() {
// Listen for HTTPS PROXYv2 requests
for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -478,11 +493,6 @@ func fatal(err error, message string) {
log.WithError(err).Fatal(message)
}
-func (a *theApp) TLSConfig() (*cryptotls.Config, error) {
- return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS,
- a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion)
-}
-
// handlePanicMiddleware logs and captures the recover() information from any panic
func handlePanicMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
diff --git a/internal/config/config.go b/internal/config/config.go
index 3bb7b126..24a811ec 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -9,8 +9,6 @@ import (
"github.com/namsral/flag"
"gitlab.com/gitlab-org/labkit/log"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
// Config stores all the config options relevant to GitLab Pages.
@@ -58,10 +56,17 @@ type General struct {
// RateLimit config struct
type RateLimit struct {
+ // HTTP limits
SourceIPLimitPerSecond float64
SourceIPBurst int
DomainLimitPerSecond float64
DomainBurst int
+
+ // TLS connections limits
+ TLSSourceIPLimitPerSecond float64
+ TLSSourceIPBurst int
+ TLSDomainLimitPerSecond float64
+ TLSDomainBurst int
}
// ArtifactsServer groups settings related to configuring Artifacts
@@ -183,6 +188,11 @@ func loadConfig() (*Config, error) {
SourceIPBurst: *rateLimitSourceIPBurst,
DomainLimitPerSecond: *rateLimitDomain,
DomainBurst: *rateLimitDomainBurst,
+
+ TLSSourceIPLimitPerSecond: *rateLimitTLSSourceIP,
+ TLSSourceIPBurst: *rateLimitTLSSourceIPBurst,
+ TLSDomainLimitPerSecond: *rateLimitTLSDomain,
+ TLSDomainBurst: *rateLimitTLSDomainBurst,
},
GitLab: GitLab{
ClientHTTPTimeout: *gitlabClientHTTPTimeout,
@@ -217,8 +227,8 @@ func loadConfig() (*Config, error) {
Environment: *sentryEnvironment,
},
TLS: TLS{
- MinVersion: tls.AllTLSVersions[*tlsMinVersion],
- MaxVersion: tls.AllTLSVersions[*tlsMaxVersion],
+ MinVersion: allTLSVersions[*tlsMinVersion],
+ MaxVersion: allTLSVersions[*tlsMaxVersion],
},
Zip: ZipServing{
ExpirationInterval: *zipCacheExpiration,
@@ -267,40 +277,48 @@ func loadConfig() (*Config, error) {
func LogConfig(config *Config) {
log.WithFields(log.Fields{
- "artifacts-server": *artifactsServer,
- "artifacts-server-timeout": *artifactsServerTimeout,
- "default-config-filename": flag.DefaultConfigFlagname,
- "disable-cross-origin-requests": *disableCrossOriginRequests,
- "domain": config.General.Domain,
- "insecure-ciphers": config.General.InsecureCiphers,
- "listen-http": listenHTTP,
- "listen-https": listenHTTPS,
- "listen-proxy": listenProxy,
- "listen-https-proxyv2": listenHTTPSProxyv2,
- "log-format": *logFormat,
- "metrics-address": *metricsAddress,
- "pages-domain": *pagesDomain,
- "pages-root": *pagesRoot,
- "pages-status": *pagesStatus,
- "propagate-correlation-id": *propagateCorrelationID,
- "redirect-http": config.General.RedirectHTTP,
- "root-cert": *pagesRootKey,
- "root-key": *pagesRootCert,
- "status_path": config.General.StatusPath,
- "tls-min-version": *tlsMinVersion,
- "tls-max-version": *tlsMaxVersion,
- "gitlab-server": config.GitLab.PublicServer,
- "internal-gitlab-server": config.GitLab.InternalServer,
- "api-secret-key": *gitLabAPISecretKey,
- "enable-disk": config.GitLab.EnableDisk,
- "auth-redirect-uri": config.Authentication.RedirectURI,
- "auth-scope": config.Authentication.Scope,
- "max-conns": config.General.MaxConns,
- "max-uri-length": config.General.MaxURILength,
- "zip-cache-expiration": config.Zip.ExpirationInterval,
- "zip-cache-cleanup": config.Zip.CleanupInterval,
- "zip-cache-refresh": config.Zip.RefreshInterval,
- "zip-open-timeout": config.Zip.OpenTimeout,
+ "artifacts-server": *artifactsServer,
+ "artifacts-server-timeout": *artifactsServerTimeout,
+ "default-config-filename": flag.DefaultConfigFlagname,
+ "disable-cross-origin-requests": *disableCrossOriginRequests,
+ "domain": config.General.Domain,
+ "insecure-ciphers": config.General.InsecureCiphers,
+ "listen-http": listenHTTP,
+ "listen-https": listenHTTPS,
+ "listen-proxy": listenProxy,
+ "listen-https-proxyv2": listenHTTPSProxyv2,
+ "log-format": *logFormat,
+ "metrics-address": *metricsAddress,
+ "pages-domain": *pagesDomain,
+ "pages-root": *pagesRoot,
+ "pages-status": *pagesStatus,
+ "propagate-correlation-id": *propagateCorrelationID,
+ "redirect-http": config.General.RedirectHTTP,
+ "root-cert": *pagesRootKey,
+ "root-key": *pagesRootCert,
+ "status_path": config.General.StatusPath,
+ "tls-min-version": *tlsMinVersion,
+ "tls-max-version": *tlsMaxVersion,
+ "gitlab-server": config.GitLab.PublicServer,
+ "internal-gitlab-server": config.GitLab.InternalServer,
+ "api-secret-key": *gitLabAPISecretKey,
+ "enable-disk": config.GitLab.EnableDisk,
+ "auth-redirect-uri": config.Authentication.RedirectURI,
+ "auth-scope": config.Authentication.Scope,
+ "max-conns": config.General.MaxConns,
+ "max-uri-length": config.General.MaxURILength,
+ "zip-cache-expiration": config.Zip.ExpirationInterval,
+ "zip-cache-cleanup": config.Zip.CleanupInterval,
+ "zip-cache-refresh": config.Zip.RefreshInterval,
+ "zip-open-timeout": config.Zip.OpenTimeout,
+ "rate-limit-source-ip": config.RateLimit.SourceIPLimitPerSecond,
+ "rate-limit-source-ip-burst": config.RateLimit.SourceIPBurst,
+ "rate-limit-domain": config.RateLimit.DomainLimitPerSecond,
+ "rate-limit-domain-burst": config.RateLimit.DomainBurst,
+ "rate-limit-tls-source-ip": config.RateLimit.TLSSourceIPLimitPerSecond,
+ "rate-limit-tls-source-ip-burst": config.RateLimit.TLSSourceIPBurst,
+ "rate-limit-tls-domain": config.RateLimit.TLSDomainLimitPerSecond,
+ "rate-limit-tls-domain-burst": config.RateLimit.TLSDomainBurst,
}).Debug("Start Pages with configuration")
}
diff --git a/internal/config/flags.go b/internal/config/flags.go
index 93228827..409ecdc7 100644
--- a/internal/config/flags.go
+++ b/internal/config/flags.go
@@ -1,24 +1,41 @@
package config
import (
+ "crypto/tls"
+ "fmt"
+ "sort"
+ "strings"
"time"
"github.com/namsral/flag"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
var (
- pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
- pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
- redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
- _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
- pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
- pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
- rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled")
- rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second")
- rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled")
- rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second")
+ // allTLSVersions has all supported flag values
+ allTLSVersions = map[string]uint16{
+ "": 0, // Default value in tls.Config
+ "tls1.2": tls.VersionTLS12,
+ "tls1.3": tls.VersionTLS13,
+ }
+
+ pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
+ pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
+ redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
+ _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
+ pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
+ pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
+
+ // HTTP rate limits
+ rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit HTTP requests per second from a single IP, 0 means is disabled")
+ rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit HTTP requests from a single IP, maximum burst allowed per second")
+ rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit HTTP requests per second to a single domain, 0 means is disabled")
+ rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit HTTP requests to a single domain, maximum burst allowed per second")
+ // TLS connections rate limits
+ rateLimitTLSSourceIP = flag.Float64("rate-limit-tls-source-ip", 0.0, "Rate limit new TLS connections per second from a single IP, 0 means is disabled")
+ rateLimitTLSSourceIPBurst = flag.Int("rate-limit-tls-source-ip-burst", 100, "Rate limit new TLS connections from a single IP, maximum burst allowed per second")
+ rateLimitTLSDomain = flag.Float64("rate-limit-tls-domain", 0.0, "Rate limit new TLS connections per second from to a single domain, 0 means is disabled")
+ rateLimitTLSDomainBurst = flag.Int("rate-limit-tls-domain-burst", 100, "Rate limit new TLS connections from a single domain, maximum burst allowed per second")
+
artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'")
artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server")
pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status")
@@ -55,8 +72,8 @@ var (
maxConns = flag.Int("max-conns", 0, "Limit on the number of concurrent connections to the HTTP, HTTPS or proxy listeners, 0 for no limit")
maxURILength = flag.Int("max-uri-length", 1024, "Limit the length of URI, 0 for unlimited.")
insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4")
- tlsMinVersion = flag.String("tls-min-version", "tls1.2", tls.FlagUsage("min"))
- tlsMaxVersion = flag.String("tls-max-version", "", tls.FlagUsage("max"))
+ tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min"))
+ tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max"))
zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval")
zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval")
zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval")
@@ -88,3 +105,17 @@ func initFlags() {
flag.Parse()
}
+
+// tlsVersionFlagUsage returns string with explanation how to use the tls version CLI flag
+func tlsVersionFlagUsage(minOrMax string) string {
+ versions := []string{}
+
+ for version := range allTLSVersions {
+ if version != "" {
+ versions = append(versions, fmt.Sprintf("%q", version))
+ }
+ }
+ sort.Strings(versions)
+
+ return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", "))
+}
diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go
deleted file mode 100644
index c76bbff7..00000000
--- a/internal/config/tls/tls.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package tls
-
-import (
- "crypto/tls"
- "fmt"
- "sort"
- "strings"
-)
-
-// GetCertificateFunc returns the certificate to be used for given domain
-type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
-
-var (
- preferredCipherSuites = []uint16{
- tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
- tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_AES_128_GCM_SHA256,
- tls.TLS_AES_256_GCM_SHA384,
- tls.TLS_CHACHA20_POLY1305_SHA256,
- }
-
- // AllTLSVersions has all supported flag values
- AllTLSVersions = map[string]uint16{
- "": 0, // Default value in tls.Config
- "tls1.2": tls.VersionTLS12,
- "tls1.3": tls.VersionTLS13,
- }
-)
-
-// FlagUsage returns string with explanation how to use the CLI flag
-func FlagUsage(minOrMax string) string {
- versions := []string{}
-
- for version := range AllTLSVersions {
- if version != "" {
- versions = append(versions, fmt.Sprintf("%q", version))
- }
- }
- sort.Strings(versions)
-
- return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", "))
-}
-
-// Create returns tls.Config for given app configuration
-func Create(cert, key []byte, getCertificate GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16) (*tls.Config, error) {
- // set MinVersion to fix gosec: G402
- tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}
-
- err := configureCertificate(tlsConfig, cert, key)
- if err != nil {
- return nil, err
- }
-
- if !insecureCiphers {
- configureTLSCiphers(tlsConfig)
- }
-
- tlsConfig.MinVersion = tlsMinVersion
- tlsConfig.MaxVersion = tlsMaxVersion
-
- return tlsConfig, nil
-}
-
-// ValidateTLSVersions returns error if the provided TLS versions config values are not valid
-func ValidateTLSVersions(min, max string) error {
- tlsMin, tlsMinOk := AllTLSVersions[min]
- tlsMax, tlsMaxOk := AllTLSVersions[max]
-
- if !tlsMinOk {
- return fmt.Errorf("invalid minimum TLS version: %s", min)
- }
- if !tlsMaxOk {
- return fmt.Errorf("invalid maximum TLS version: %s", max)
- }
- if tlsMin > tlsMax && tlsMax > 0 {
- return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min)
- }
-
- return nil
-}
-
-func configureCertificate(tlsConfig *tls.Config, cert, key []byte) error {
- certificate, err := tls.X509KeyPair(cert, key)
- if err != nil {
- return err
- }
-
- tlsConfig.Certificates = []tls.Certificate{certificate}
-
- return nil
-}
-
-func configureTLSCiphers(tlsConfig *tls.Config) {
- tlsConfig.PreferServerCipherSuites = true
- tlsConfig.CipherSuites = preferredCipherSuites
-}
diff --git a/internal/config/validate.go b/internal/config/validate.go
index a3dbcc3b..bb247287 100644
--- a/internal/config/validate.go
+++ b/internal/config/validate.go
@@ -6,8 +6,6 @@ import (
"net/url"
"github.com/hashicorp/go-multierror"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
var (
@@ -30,7 +28,7 @@ func Validate(config *Config) error {
validateListeners(config),
validateAuthConfig(config),
validateArtifactsServerConfig(config),
- tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion),
+ validateTLSVersions(*tlsMinVersion, *tlsMaxVersion),
)
return result.ErrorOrNil()
@@ -115,3 +113,21 @@ func validateArtifactsServerConfig(config *Config) error {
return result.ErrorOrNil()
}
+
+// validateTLSVersions returns error if the provided TLS versions config values are not valid
+func validateTLSVersions(min, max string) error {
+ tlsMin, tlsMinOk := allTLSVersions[min]
+ tlsMax, tlsMaxOk := allTLSVersions[max]
+
+ if !tlsMinOk {
+ return fmt.Errorf("invalid minimum TLS version: %s", min)
+ }
+ if !tlsMaxOk {
+ return fmt.Errorf("invalid maximum TLS version: %s", max)
+ }
+ if tlsMin > tlsMax && tlsMax > 0 {
+ return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min)
+ }
+
+ return nil
+}
diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go
index 60e37732..80e4ded3 100644
--- a/internal/config/validate_test.go
+++ b/internal/config/validate_test.go
@@ -159,3 +159,44 @@ func validConfig() Config {
return cfg
}
+
+func TestValidTLSVersions(t *testing.T) {
+ tests := map[string]struct {
+ tlsMin string
+ tlsMax string
+ }{
+ "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"},
+ "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"},
+ "tls 1.3 max": {tlsMax: "tls1.3"},
+ "tls 1.2 max": {tlsMax: "tls1.2"},
+ "tls 1.3+": {tlsMin: "tls1.3"},
+ "tls 1.2+": {tlsMin: "tls1.2"},
+ "default": {},
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ err := validateTLSVersions(tc.tlsMin, tc.tlsMax)
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestInvalidTLSVersions(t *testing.T) {
+ tests := map[string]struct {
+ tlsMin string
+ tlsMax string
+ err string
+ }{
+ "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"},
+ "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"},
+ "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"},
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ err := validateTLSVersions(tc.tlsMin, tc.tlsMax)
+ require.EqualError(t, err, tc.err)
+ })
+ }
+}
diff --git a/internal/feature/feature.go b/internal/feature/feature.go
index 81eef9a0..c98fe85d 100644
--- a/internal/feature/feature.go
+++ b/internal/feature/feature.go
@@ -8,17 +8,29 @@ type Feature struct {
}
// EnforceIPRateLimits enforces IP rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceIPRateLimits = Feature{
EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS",
}
// EnforceDomainRateLimits enforces domain rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceDomainRateLimits = Feature{
EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS",
}
+// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceDomainTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS",
+}
+
+// EnforceIPTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceIPTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_IP_TLS_RATE_LIMITS",
+}
+
// RedirectsPlaceholders enables support for placeholders in redirects file
// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620
var RedirectsPlaceholders = Feature{
diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go
index 52281f6e..a8eee005 100644
--- a/internal/handlers/ratelimiter.go
+++ b/internal/handlers/ratelimiter.go
@@ -14,11 +14,11 @@ import (
// TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done
func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
sourceIPLimiter := ratelimiter.New(
- "source_ip",
+ "http_requests_by_source_ip",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond),
ratelimiter.WithBurstSize(config.SourceIPBurst),
ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()),
@@ -27,12 +27,12 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
handler = sourceIPLimiter.Middleware(handler)
domainLimiter := ratelimiter.New(
- "domain",
+ "http_requests_by_domain",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
ratelimiter.WithKeyFunc(request.GetHostWithoutPort),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond),
ratelimiter.WithBurstSize(config.DomainBurst),
ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()),
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index af7b0881..1be6f642 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -8,7 +8,6 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
- "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
@@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler {
rl.logRateLimitedRequest(r)
if rl.blockedCount != nil {
- rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc()
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
}
if rl.enforce {
@@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) {
"x_forwarded_proto": r.Header.Get(headerXForwardedProto),
"x_forwarded_for": r.Header.Get(headerXForwardedFor),
"gitlab_real_ip": r.Header.Get(headerGitLabRealIP),
- "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(),
+ "enforced": rl.enforce,
"rate_limiter_limit_per_second": rl.limitPerSecond,
"rate_limiter_burst_size": rl.burstSize,
}). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index 1f753fc4..25ac08b5 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -107,7 +107,7 @@ func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) {
assertSourceIPLog(t, remoteAddr, hook)
}
- blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce)))
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter", strconv.FormatBool(tc.enforce)))
require.Equal(t, float64(4), blockedCount, "blocked count")
blocked.Reset()
@@ -226,7 +226,7 @@ func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *
prometheus.GaugeOpts{
Name: t.Name(),
},
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index feeb8cb4..b64cb4ce 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -1,6 +1,8 @@
package ratelimiter
import (
+ "crypto/tls"
+ "net"
"net/http"
"time"
@@ -27,6 +29,9 @@ type Option func(*RateLimiter)
// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
type KeyFunc func(*http.Request) string
+// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername)
+type TLSKeyFunc func(*tls.ClientHelloInfo) string
+
// RateLimiter holds an LRU cache of elements to be rate limited.
// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
@@ -35,6 +40,7 @@ type RateLimiter struct {
name string
now func() time.Time
keyFunc KeyFunc
+ tlsKeyFunc TLSKeyFunc
limitPerSecond float64
burstSize int
blockedCount *prometheus.GaugeVec
@@ -120,6 +126,26 @@ func WithKeyFunc(f KeyFunc) Option {
}
}
+func TLSHostnameKey(info *tls.ClientHelloInfo) string {
+ return info.ServerName
+}
+
+func TLSClientIPKey(info *tls.ClientHelloInfo) string {
+ remoteAddr := info.Conn.RemoteAddr().String()
+ remoteAddr, _, err := net.SplitHostPort(remoteAddr)
+ if err != nil {
+ return remoteAddr
+ }
+
+ return remoteAddr
+}
+
+func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option {
+ return func(rl *RateLimiter) {
+ rl.tlsKeyFunc = keyFunc
+ }
+}
+
// WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics
func WithEnforce(enforce bool) Option {
return func(rl *RateLimiter) {
@@ -138,6 +164,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter {
// requestAllowed checks if request is within the rate-limit
func (rl *RateLimiter) requestAllowed(r *http.Request) bool {
rateLimitedKey := rl.keyFunc(r)
+
+ return rl.allowed(rateLimitedKey)
+}
+
+func (rl *RateLimiter) allowed(rateLimitedKey string) bool {
limiter := rl.limiter(rateLimitedKey)
// AllowN allows us to use the rl.now function, so we can test this more easily.
diff --git a/internal/ratelimiter/stub_conn_test.go b/internal/ratelimiter/stub_conn_test.go
new file mode 100644
index 00000000..9d7a4a9a
--- /dev/null
+++ b/internal/ratelimiter/stub_conn_test.go
@@ -0,0 +1,14 @@
+package ratelimiter
+
+import (
+ "net"
+)
+
+type stubConn struct {
+ net.Conn
+ remoteAddr net.Addr
+}
+
+func (s stubConn) RemoteAddr() net.Addr {
+ return s.remoteAddr
+}
diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go
new file mode 100644
index 00000000..3bebbc38
--- /dev/null
+++ b/internal/ratelimiter/tls.go
@@ -0,0 +1,49 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "strconv"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var ErrTLSRateLimited = errors.New("too many connections, please retry later")
+
+type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+
+func (rl *RateLimiter) GetCertificateMiddleware(getCertificate GetCertificateFunc) GetCertificateFunc {
+ if rl.limitPerSecond <= 0.0 {
+ return getCertificate
+ }
+
+ return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if rl.allowed(rl.tlsKeyFunc(info)) {
+ return getCertificate(info)
+ }
+
+ rl.logRateLimitedTLS(info)
+
+ if rl.blockedCount != nil {
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
+ }
+
+ if !rl.enforce {
+ return getCertificate(info)
+ }
+
+ return nil, ErrTLSRateLimited
+ }
+}
+
+func (rl *RateLimiter) logRateLimitedTLS(info *tls.ClientHelloInfo) {
+ log.WithFields(logrus.Fields{
+ "rate_limiter_name": rl.name,
+ "source_ip": TLSClientIPKey(info),
+ "req_host": info.ServerName,
+ "rate_limiter_limit_per_second": rl.limitPerSecond,
+ "rate_limiter_burst_size": rl.burstSize,
+ "enforced": rl.enforce,
+ }).Info("TLS connection rate-limited")
+}
diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go
new file mode 100644
index 00000000..6763514b
--- /dev/null
+++ b/internal/ratelimiter/tls_test.go
@@ -0,0 +1,194 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ "github.com/sirupsen/logrus"
+ testlog "github.com/sirupsen/logrus/hooks/test"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTLSHostnameKey(t *testing.T) {
+ info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"}
+
+ require.Equal(t, "group.gitlab.io", TLSHostnameKey(info))
+}
+
+func TestTLSClientIPKey(t *testing.T) {
+ tests := []struct {
+ addr string
+ expected string
+ }{
+ {
+ "10.1.2.3:1234",
+ "10.1.2.3",
+ },
+ {
+ "[2001:db8:3333:4444:5555:6666:7777:8888]:1234",
+ "2001:db8:3333:4444:5555:6666:7777:8888",
+ },
+ }
+
+ for _, tt := range tests {
+ addr, err := net.ResolveTCPAddr("tcp", tt.addr)
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn}
+
+ require.Equal(t, tt.expected, TLSClientIPKey(info))
+ }
+}
+
+func TestGetCertificateMiddleware(t *testing.T) {
+ tests := map[string]struct {
+ useHostnameAsKey bool
+ enforced bool
+ limitPerSecond float64
+ burst int
+ successfulReqCnt int
+ }{
+ "ip_limiter": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "hostname_limiter": {
+ useHostnameAsKey: true,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "not_enforced": {
+ useHostnameAsKey: false,
+ enforced: false,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "disabled": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0,
+ burst: 5,
+ successfulReqCnt: 10,
+ },
+ "slowly_approach_limit": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.2,
+ burst: 5,
+ successfulReqCnt: 6, // 5 * 0.2 gives another 1 request
+ },
+ }
+
+ expectedCert := &tls.Certificate{}
+ expectedErr := errors.New("expected error")
+
+ getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return expectedCert, expectedErr
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ keyFunc := TLSClientIPKey
+ if tt.useHostnameAsKey {
+ keyFunc = TLSHostnameKey
+ }
+
+ rl := New("limit_name",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
+ WithBlockedCountMetric(blocked),
+ WithNow(stubNow()),
+ WithLimitPerSecond(tt.limitPerSecond),
+ WithBurstSize(tt.burst),
+ WithEnforce(tt.enforced),
+ WithTLSKeyFunc(keyFunc))
+
+ middlewareGetCert := rl.GetCertificateMiddleware(getCertificate)
+
+ addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345")
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"}
+
+ for i := 0; i < tt.successfulReqCnt; i++ {
+ cert, err := middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ // When rate-limiter disabled altogether
+ if tt.limitPerSecond <= 0 {
+ return
+ }
+
+ cert, err := middlewareGetCert(info)
+ if tt.enforced {
+ require.Nil(t, cert)
+ require.Equal(t, err, ErrTLSRateLimited)
+ } else {
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ require.NotNil(t, hook.LastEntry())
+ require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message)
+ expectedFields := logrus.Fields{
+ "rate_limiter_name": "limit_name",
+ "source_ip": "10.1.2.3",
+ "req_host": "group.gitlab.io",
+ "rate_limiter_limit_per_second": tt.limitPerSecond,
+ "rate_limiter_burst_size": tt.burst,
+ "enforced": tt.enforced,
+ }
+ require.Equal(t, expectedFields, hook.LastEntry().Data)
+
+ // make another request with different key and expect success
+ if tt.useHostnameAsKey {
+ info.ServerName = "another-group.gitlab.io"
+ } else {
+ addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345")
+ require.NoError(t, err)
+ conn = stubConn{remoteAddr: addr}
+ info.Conn = conn
+ }
+
+ cert, err = middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced)))
+ require.Equal(t, float64(1), blockedCount, "blocked count")
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name"))
+ require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss"))
+ require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit"))
+ require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count")
+ })
+ }
+}
+
+func stubNow() func() time.Time {
+ now := time.Now()
+ return func() time.Time {
+ now = now.Add(time.Second)
+
+ return now
+ }
+}
diff --git a/internal/tls/tls.go b/internal/tls/tls.go
new file mode 100644
index 00000000..6d7397af
--- /dev/null
+++ b/internal/tls/tls.go
@@ -0,0 +1,96 @@
+package tls
+
+import (
+ "crypto/tls"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/config"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter"
+ "gitlab.com/gitlab-org/gitlab-pages/metrics"
+)
+
+var preferredCipherSuites = []uint16{
+ tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
+ tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+}
+
+// GetCertificateFunc returns the certificate to be used for given domain
+type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+
+// GetTLSConfig initializes tls.Config based on config flags
+// getCertificateByServerName obtains certificate based on domain
+func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) {
+ wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey)
+ if err != nil {
+ return nil, err
+ }
+
+ getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ // Golang calls tls.Config.GetCertificate only if it's set and
+ // 1. ServerName != ""
+ // 2. Or tls.Config.Certificates is empty array
+ // tls.Config.Certificates contain wildcard certificate
+ // We want to implement rate limits via GetCertificate, so we need to call it every time
+ // So we don't set tls.Config.Certificates, but simulate the behavior of golang:
+ // 1. try to get certificate by name
+ // 2. if we can't, fallback to default(wildcard) certificate
+ customCertificate, err := getCertificateByServerName(info)
+
+ if customCertificate != nil || err != nil {
+ return customCertificate, err
+ }
+
+ return &wildcardCertificate, nil
+ }
+
+ TLSDomainRateLimiter := ratelimiter.New(
+ "tls_connections_by_domain",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond),
+ ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst),
+ ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()),
+ )
+
+ TLSSourceIPRateLimiter := ratelimiter.New(
+ "tls_connections_by_source_ip",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond),
+ ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst),
+ ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()),
+ )
+
+ getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate)
+ getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)
+
+ // set MinVersion to fix gosec: G402
+ tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}
+
+ if !cfg.General.InsecureCiphers {
+ configureTLSCiphers(tlsConfig)
+ }
+
+ tlsConfig.MinVersion = cfg.TLS.MinVersion
+ tlsConfig.MaxVersion = cfg.TLS.MaxVersion
+
+ return tlsConfig, nil
+}
+
+func configureTLSCiphers(tlsConfig *tls.Config) {
+ tlsConfig.PreferServerCipherSuites = true
+ tlsConfig.CipherSuites = preferredCipherSuites
+}
diff --git a/internal/config/tls/tls_test.go b/internal/tls/tls_test.go
index 7146d904..e65b2671 100644
--- a/internal/config/tls/tls_test.go
+++ b/internal/tls/tls_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/config"
)
var cert = []byte(`-----BEGIN CERTIFICATE-----
@@ -29,65 +31,46 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
}
-func TestInvalidTLSVersions(t *testing.T) {
- tests := map[string]struct {
- tlsMin string
- tlsMax string
- err string
- }{
- "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"},
- "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"},
- "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"},
- }
-
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax)
- require.EqualError(t, err, tc.err)
- })
- }
-}
-
-func TestValidTLSVersions(t *testing.T) {
- tests := map[string]struct {
- tlsMin string
- tlsMax string
- }{
- "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"},
- "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"},
- "tls 1.3 max": {tlsMax: "tls1.3"},
- "tls 1.2 max": {tlsMax: "tls1.2"},
- "tls 1.3+": {tlsMin: "tls1.3"},
- "tls 1.2+": {tlsMin: "tls1.2"},
- "default": {},
- }
-
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax)
- require.NoError(t, err)
- })
- }
-}
-
func TestInvalidKeyPair(t *testing.T) {
- _, err := Create([]byte(``), []byte(``), getCertificate, false, tls.VersionTLS11, tls.VersionTLS12)
+ cfg := &config.Config{}
+ _, err := GetTLSConfig(cfg, getCertificate)
require.EqualError(t, err, "tls: failed to find any PEM data in certificate input")
}
func TestInsecureCiphers(t *testing.T) {
- tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12)
+ cfg := &config.Config{
+ General: config.General{
+ RootCertificate: cert,
+ RootKey: key,
+ InsecureCiphers: true,
+ },
+ }
+ tlsConfig, err := GetTLSConfig(cfg, getCertificate)
require.NoError(t, err)
require.False(t, tlsConfig.PreferServerCipherSuites)
require.Empty(t, tlsConfig.CipherSuites)
}
-func TestCreate(t *testing.T) {
- tlsConfig, err := Create(cert, key, getCertificate, false, tls.VersionTLS11, tls.VersionTLS12)
+func TestGetTLSConfig(t *testing.T) {
+ cfg := &config.Config{
+ General: config.General{
+ RootCertificate: cert,
+ RootKey: key,
+ },
+ TLS: config.TLS{
+ MinVersion: tls.VersionTLS11,
+ MaxVersion: tls.VersionTLS12,
+ },
+ }
+ tlsConfig, err := GetTLSConfig(cfg, getCertificate)
require.NoError(t, err)
require.IsType(t, getCertificate, tlsConfig.GetCertificate)
require.True(t, tlsConfig.PreferServerCipherSuites)
require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites)
require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion)
require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion)
+
+ cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{})
+ require.NoError(t, err)
+ require.NotNil(t, cert)
}
diff --git a/metrics/metrics.go b/metrics/metrics.go
index e0e4ab29..26c7413f 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -193,60 +193,31 @@ var (
},
)
- // RateLimitSourceIPCacheRequests is the number of cache hits/misses
- RateLimitSourceIPCacheRequests = prometheus.NewCounterVec(
+ // RateLimitCacheRequests is the number of cache hits/misses
+ RateLimitCacheRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cache_requests",
+ Name: "gitlab_pages_rate_limit_cache_requests",
Help: "The number of source_ip cache hits/misses in the rate limiter",
},
[]string{"op", "cache"},
)
- // RateLimitSourceIPCachedEntries is the number of entries in the cache
- RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec(
+ // RateLimitCachedEntries is the number of entries in the cache
+ RateLimitCachedEntries = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cached_entries",
+ Name: "gitlab_pages_rate_limit_cached_entries",
Help: "The number of entries in the cache",
},
[]string{"op"},
)
- // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the
- // source IP rate limiter
- RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec(
+ // RateLimitBlockedCount is the number of requests that have been blocked
+ RateLimitBlockedCount = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_blocked_count",
- Help: "The number of requests that have been blocked by the IP rate limiter",
+ Name: "gitlab_pages_rate_limit_blocked_count",
+ Help: "The number of requests/connections that have been blocked by rate limiter",
},
- []string{"enforced"},
- )
-
- // RateLimitDomainCacheRequests is the number of cache hits/misses
- RateLimitDomainCacheRequests = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_domain_cache_requests",
- Help: "The number of source_ip cache hits/misses in the rate limiter",
- },
- []string{"op", "cache"},
- )
-
- // RateLimitDomainCachedEntries is the number of entries in the cache
- RateLimitDomainCachedEntries = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_cached_entries",
- Help: "The number of entries in the cache",
- },
- []string{"op"},
- )
-
- // RateLimitDomainBlockedCount is the number of requests that have been blocked by the
- // domain rate limiter
- RateLimitDomainBlockedCount = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_blocked_count",
- Help: "The number of requests addresses that have been blocked by the domain rate limiter",
- },
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
)
@@ -276,8 +247,8 @@ func MustRegister() {
LimitListenerConcurrentConns,
LimitListenerWaitingConns,
PanicRecoveredCount,
- RateLimitSourceIPCacheRequests,
- RateLimitSourceIPCachedEntries,
- RateLimitSourceIPBlockedCount,
+ RateLimitCacheRequests,
+ RateLimitCachedEntries,
+ RateLimitBlockedCount,
)
}
diff --git a/server.go b/server.go
index 0af582ff..b5aecc37 100644
--- a/server.go
+++ b/server.go
@@ -4,11 +4,13 @@ import (
"context"
"crypto/tls"
"fmt"
+ stdlog "log"
"net"
"net/http"
"time"
- proxyproto "github.com/pires/go-proxyproto"
+ "github.com/pires/go-proxyproto"
+ "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-pages/internal/netutil"
@@ -31,6 +33,7 @@ func (a *theApp) listenAndServe(server *http.Server, addr string, h http.Handler
server.Handler = h
server.TLSConfig = config.tlsConfig
+ server.ErrorLog = stdlog.New(logrus.StandardLogger().Writer(), "", 0)
// ensure http2 is enabled even if TLSConfig is not null
// See https://github.com/golang/go/blob/97cee43c93cfccded197cd281f0a5885cdb605b4/src/net/http/server.go#L2947-L2954
if server.TLSConfig != nil {
diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go
index c44058ba..1b514a85 100644
--- a/test/acceptance/helpers_test.go
+++ b/test/acceptance/helpers_test.go
@@ -602,3 +602,15 @@ func copyFile(dest, src string) error {
_, err = io.Copy(destFile, srcFile)
return err
}
+
+// RequireMetricEqual requests prometheus metrics and makes sure metric is there
+func RequireMetricEqual(t *testing.T, metricsAddress, metricWithValue string) {
+ resp, err := http.Get(fmt.Sprintf("http://%s/metrics", metricsAddress))
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ require.Contains(t, string(body), metricWithValue)
+}
diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go
index 02ba54f5..a97fdfb1 100644
--- a/test/acceptance/ratelimiter_test.go
+++ b/test/acceptance/ratelimiter_test.go
@@ -3,6 +3,7 @@ package acceptance_test
import (
"fmt"
"net/http"
+ "strconv"
"testing"
"time"
@@ -77,7 +78,7 @@ func TestIPRateLimits(t *testing.T) {
}
}
-func TestDomainateLimits(t *testing.T) {
+func TestDomainRateLimits(t *testing.T) {
testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true)
for name, tc := range ratelimitedListeners {
@@ -112,6 +113,131 @@ func TestDomainateLimits(t *testing.T) {
}
}
+func TestTLSRateLimits(t *testing.T) {
+ tests := map[string]struct {
+ spec ListenSpec
+ domainLimit bool
+ sourceIP string
+ enforceEnabled bool
+ }{
+ "https_with_domain_limit": {
+ spec: httpsListener,
+ domainLimit: true,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: true,
+ },
+ "https_with_domain_limit_not_enforced": {
+ spec: httpsListener,
+ domainLimit: true,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: false,
+ },
+ "https_with_ip_limit": {
+ spec: httpsListener,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: true,
+ },
+ "https_with_ip_limit_not_enforced": {
+ spec: httpsListener,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: false,
+ },
+ "proxyv2_with_domain_limit": {
+ spec: httpsProxyv2Listener,
+ domainLimit: true,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: true,
+ },
+ "proxyv2_with_domain_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ domainLimit: true,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: false,
+ },
+ "proxyv2_with_ip_limit": {
+ spec: httpsProxyv2Listener,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: true,
+ },
+ "proxyv2_with_ip_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: false,
+ },
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ rateLimit := 5
+
+ options := []processOption{
+ withListeners([]ListenSpec{tt.spec}),
+ withExtraArgument("metrics-address", ":42345"),
+ }
+
+ featureName := feature.EnforceIPTLSRateLimits.EnvVariable
+ limitName := "tls_connections_by_source_ip"
+
+ if tt.domainLimit {
+ options = append(options,
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)))
+
+ featureName = feature.EnforceDomainTLSRateLimits.EnvVariable
+ limitName = "tls_connections_by_domain"
+ } else {
+ options = append(options,
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)))
+ }
+
+ testhelpers.StubFeatureFlagValue(t, featureName, tt.enforceEnabled)
+ logBuf := RunPagesProcess(t, options...)
+
+ // when we start the process we make 1 requests to verify that process is up
+ // it gets counted in the rate limit for IP, but host is different
+ if !tt.domainLimit {
+ rateLimit--
+ }
+
+ for i := 0; i < 10; i++ {
+ rsp, err := makeTLSRequest(t, tt.spec)
+
+ if i >= rateLimit {
+ assertLogFound(t, logBuf, []string{
+ "TLS connection rate-limited",
+ "\"req_host\":\"group.gitlab-example.com\"",
+ fmt.Sprintf("\"source_ip\":\"%s\"", tt.sourceIP),
+ "\"enforced\":" + strconv.FormatBool(tt.enforceEnabled)})
+
+ if tt.enforceEnabled {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "remote error: tls: internal error")
+ }
+
+ continue
+ }
+
+ require.NoError(t, err, "request: %d failed", i)
+ require.NoError(t, rsp.Body.Close())
+ require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i)
+ }
+ expectedMetric := fmt.Sprintf(
+ "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} %v",
+ tt.enforceEnabled, limitName, 10-rateLimit)
+
+ RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric)
+ })
+ }
+}
+
+func makeTLSRequest(t *testing.T, spec ListenSpec) (*http.Response, error) {
+ req, err := http.NewRequest("GET", "https://group.gitlab-example.com/project", nil)
+ require.NoError(t, err)
+
+ return spec.Client().Do(req)
+}
+
func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) {
t.Helper()