diff options
author | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-21 20:00:46 +0300 |
---|---|---|
committer | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-22 12:38:49 +0300 |
commit | 58581c5a2ff3e95e1dc3acc69913412477a37557 (patch) | |
tree | f6f89f7ebe87075601c7cb94c608701b37d3d40a | |
parent | 62a6491652aa6975d9ecf3b9e258766c886d49d4 (diff) |
feat: Always apply TLS limits even without ServerName
-rw-r--r-- | app.go | 42 | ||||
-rw-r--r-- | internal/config/config.go | 6 | ||||
-rw-r--r-- | internal/config/flags.go | 31 | ||||
-rw-r--r-- | internal/config/tls/tls.go | 100 | ||||
-rw-r--r-- | internal/config/validate.go | 22 | ||||
-rw-r--r-- | internal/config/validate_test.go | 41 | ||||
-rw-r--r-- | internal/ratelimiter/tls.go | 6 | ||||
-rw-r--r-- | internal/tls/tls.go | 96 | ||||
-rw-r--r-- | internal/tls/tls_test.go (renamed from internal/config/tls/tls_test.go) | 73 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 119 |
10 files changed, 267 insertions, 269 deletions
@@ -27,21 +27,19 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config" - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" - "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/routing" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab" + "gitlab.com/gitlab-org/gitlab-pages/internal/tls" "gitlab.com/gitlab-org/gitlab-pages/internal/urilimiter" "gitlab.com/gitlab-org/gitlab-pages/metrics" ) @@ -66,7 +64,6 @@ func (a *theApp) isReady() bool { } func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { - log.Info("GetCertificate called") if ch.ServerName == "" { return nil, nil } @@ -79,42 +76,15 @@ func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certi return nil, nil } -// TODO: find a better place than app.go for all the TLS logic https://gitlab.com/gitlab-org/gitlab-pages/-/issues/707 -// right now we have config/tls, but I think does more than config -// related logic should func (a *theApp) getTLSConfig() (*cryptotls.Config, error) { + // we call this function only when tls config is needed, and we ignore TLS related flags otherwise + // in theory you can configure both listen-https and listen-proxyv2, + // so this return is here to have a single TLS config if a.tlsConfig != nil { return a.tlsConfig, nil } - TLSDomainRateLimiter := ratelimiter.New( - "tls_connections_by_domain", - ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey), - ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), - ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSDomainLimitPerSecond), - ratelimiter.WithBurstSize(a.config.RateLimit.TLSDomainBurst), - ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), - ) - - TLSSourceIPRateLimiter := ratelimiter.New( - "tls_connections_by_source_ip", - ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey), - ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), - ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), - ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), - ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), - ratelimiter.WithLimitPerSecond(a.config.RateLimit.TLSSourceIPLimitPerSecond), - ratelimiter.WithBurstSize(a.config.RateLimit.TLSSourceIPBurst), - ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()), - ) - - getCertificate := TLSDomainRateLimiter.GetCertificateMiddleware(a.GetCertificate) - getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) - - tlsConfig, err := tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, getCertificate, - a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) + + tlsConfig, err := tls.GetTLSConfig(a.config, a.GetCertificate) a.tlsConfig = tlsConfig diff --git a/internal/config/config.go b/internal/config/config.go index 3dd4ecb3..2e31612a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,8 +9,6 @@ import ( "github.com/namsral/flag" "gitlab.com/gitlab-org/labkit/log" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) // Config stores all the config options relevant to GitLab Pages. @@ -229,8 +227,8 @@ func loadConfig() (*Config, error) { Environment: *sentryEnvironment, }, TLS: TLS{ - MinVersion: tls.AllTLSVersions[*tlsMinVersion], - MaxVersion: tls.AllTLSVersions[*tlsMaxVersion], + MinVersion: allTLSVersions[*tlsMinVersion], + MaxVersion: allTLSVersions[*tlsMaxVersion], }, Zip: ZipServing{ ExpirationInterval: *zipCacheExpiration, diff --git a/internal/config/flags.go b/internal/config/flags.go index 3778a677..409ecdc7 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -1,14 +1,23 @@ package config import ( + "crypto/tls" + "fmt" + "sort" + "strings" "time" "github.com/namsral/flag" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( + // allTLSVersions has all supported flag values + allTLSVersions = map[string]uint16{ + "": 0, // Default value in tls.Config + "tls1.2": tls.VersionTLS12, + "tls1.3": tls.VersionTLS13, + } + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") @@ -63,8 +72,8 @@ var ( maxConns = flag.Int("max-conns", 0, "Limit on the number of concurrent connections to the HTTP, HTTPS or proxy listeners, 0 for no limit") maxURILength = flag.Int("max-uri-length", 1024, "Limit the length of URI, 0 for unlimited.") insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4") - tlsMinVersion = flag.String("tls-min-version", "tls1.2", tls.FlagUsage("min")) - tlsMaxVersion = flag.String("tls-max-version", "", tls.FlagUsage("max")) + tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min")) + tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max")) zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval") zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval") zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval") @@ -96,3 +105,17 @@ func initFlags() { flag.Parse() } + +// tlsVersionFlagUsage returns string with explanation how to use the tls version CLI flag +func tlsVersionFlagUsage(minOrMax string) string { + versions := []string{} + + for version := range allTLSVersions { + if version != "" { + versions = append(versions, fmt.Sprintf("%q", version)) + } + } + sort.Strings(versions) + + return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) +} diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go deleted file mode 100644 index c76bbff7..00000000 --- a/internal/config/tls/tls.go +++ /dev/null @@ -1,100 +0,0 @@ -package tls - -import ( - "crypto/tls" - "fmt" - "sort" - "strings" -) - -// GetCertificateFunc returns the certificate to be used for given domain -type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) - -var ( - preferredCipherSuites = []uint16{ - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - } - - // AllTLSVersions has all supported flag values - AllTLSVersions = map[string]uint16{ - "": 0, // Default value in tls.Config - "tls1.2": tls.VersionTLS12, - "tls1.3": tls.VersionTLS13, - } -) - -// FlagUsage returns string with explanation how to use the CLI flag -func FlagUsage(minOrMax string) string { - versions := []string{} - - for version := range AllTLSVersions { - if version != "" { - versions = append(versions, fmt.Sprintf("%q", version)) - } - } - sort.Strings(versions) - - return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) -} - -// Create returns tls.Config for given app configuration -func Create(cert, key []byte, getCertificate GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16) (*tls.Config, error) { - // set MinVersion to fix gosec: G402 - tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} - - err := configureCertificate(tlsConfig, cert, key) - if err != nil { - return nil, err - } - - if !insecureCiphers { - configureTLSCiphers(tlsConfig) - } - - tlsConfig.MinVersion = tlsMinVersion - tlsConfig.MaxVersion = tlsMaxVersion - - return tlsConfig, nil -} - -// ValidateTLSVersions returns error if the provided TLS versions config values are not valid -func ValidateTLSVersions(min, max string) error { - tlsMin, tlsMinOk := AllTLSVersions[min] - tlsMax, tlsMaxOk := AllTLSVersions[max] - - if !tlsMinOk { - return fmt.Errorf("invalid minimum TLS version: %s", min) - } - if !tlsMaxOk { - return fmt.Errorf("invalid maximum TLS version: %s", max) - } - if tlsMin > tlsMax && tlsMax > 0 { - return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) - } - - return nil -} - -func configureCertificate(tlsConfig *tls.Config, cert, key []byte) error { - certificate, err := tls.X509KeyPair(cert, key) - if err != nil { - return err - } - - tlsConfig.Certificates = []tls.Certificate{certificate} - - return nil -} - -func configureTLSCiphers(tlsConfig *tls.Config) { - tlsConfig.PreferServerCipherSuites = true - tlsConfig.CipherSuites = preferredCipherSuites -} diff --git a/internal/config/validate.go b/internal/config/validate.go index a3dbcc3b..bb247287 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -6,8 +6,6 @@ import ( "net/url" "github.com/hashicorp/go-multierror" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( @@ -30,7 +28,7 @@ func Validate(config *Config) error { validateListeners(config), validateAuthConfig(config), validateArtifactsServerConfig(config), - tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion), + validateTLSVersions(*tlsMinVersion, *tlsMaxVersion), ) return result.ErrorOrNil() @@ -115,3 +113,21 @@ func validateArtifactsServerConfig(config *Config) error { return result.ErrorOrNil() } + +// validateTLSVersions returns error if the provided TLS versions config values are not valid +func validateTLSVersions(min, max string) error { + tlsMin, tlsMinOk := allTLSVersions[min] + tlsMax, tlsMaxOk := allTLSVersions[max] + + if !tlsMinOk { + return fmt.Errorf("invalid minimum TLS version: %s", min) + } + if !tlsMaxOk { + return fmt.Errorf("invalid maximum TLS version: %s", max) + } + if tlsMin > tlsMax && tlsMax > 0 { + return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) + } + + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 60e37732..80e4ded3 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -159,3 +159,44 @@ func validConfig() Config { return cfg } + +func TestValidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + }{ + "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, + "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, + "tls 1.3 max": {tlsMax: "tls1.3"}, + "tls 1.2 max": {tlsMax: "tls1.2"}, + "tls 1.3+": {tlsMin: "tls1.3"}, + "tls 1.2+": {tlsMin: "tls1.2"}, + "default": {}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.NoError(t, err) + }) + } +} + +func TestInvalidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + err string + }{ + "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, + "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, + "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.EqualError(t, err, tc.err) + }) + } +} diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go index 15b10cb4..3bebbc38 100644 --- a/internal/ratelimiter/tls.go +++ b/internal/ratelimiter/tls.go @@ -7,13 +7,13 @@ import ( "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/log" - - tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ErrTLSRateLimited = errors.New("too many connections, please retry later") -func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate GetCertificateFunc) GetCertificateFunc { if rl.limitPerSecond <= 0.0 { return getCertificate } diff --git a/internal/tls/tls.go b/internal/tls/tls.go new file mode 100644 index 00000000..c222fcce --- /dev/null +++ b/internal/tls/tls.go @@ -0,0 +1,96 @@ +package tls + +import ( + "crypto/tls" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var preferredCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, +} + +// GetCertificateFunc returns the certificate to be used for given domain +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +// GetTLSConfig initializes tls.Config based on config flags +// getCertificateByServerName obtains certificate based on domain +func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) { + certificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey) + if err != nil { + return nil, err + } + + getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Golang calls tls.Config.GetCertificate only if it's set and + // 1. ServerName != "" + // 2. Or tls.Config.Certificates is empty array + // tls.Config.Certificates contain wildcard certificate + // We want to implement rate limits via GetCertificate, so we need to call it every time + // So we don't set tls.Config.Certificates, but simulate the behavior of golang: + // 1. try to get certificate by name + // 2. if we can't, fallback to default(wildcard) certificate + cert, err := getCertificateByServerName(info) + + if cert != nil || err != nil { + return cert, err + } + + return &certificate, nil + } + + TLSDomainRateLimiter := ratelimiter.New( + "tls_connections_by_domain", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + TLSSourceIPRateLimiter := ratelimiter.New( + "tls_connections_by_source_ip", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst), + ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()), + ) + + getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate) + getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) + + // set MinVersion to fix gosec: G402 + tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} + + if !cfg.General.InsecureCiphers { + configureTLSCiphers(tlsConfig) + } + + tlsConfig.MinVersion = cfg.TLS.MinVersion + tlsConfig.MaxVersion = cfg.TLS.MaxVersion + + return tlsConfig, nil +} + +func configureTLSCiphers(tlsConfig *tls.Config) { + tlsConfig.PreferServerCipherSuites = true + tlsConfig.CipherSuites = preferredCipherSuites +} diff --git a/internal/config/tls/tls_test.go b/internal/tls/tls_test.go index 7146d904..e65b2671 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/tls/tls_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" ) var cert = []byte(`-----BEGIN CERTIFICATE----- @@ -29,65 +31,46 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { return nil, nil } -func TestInvalidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - err string - }{ - "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, - "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, - "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.EqualError(t, err, tc.err) - }) - } -} - -func TestValidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - }{ - "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, - "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, - "tls 1.3 max": {tlsMax: "tls1.3"}, - "tls 1.2 max": {tlsMax: "tls1.2"}, - "tls 1.3+": {tlsMin: "tls1.3"}, - "tls 1.2+": {tlsMin: "tls1.2"}, - "default": {}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.NoError(t, err) - }) - } -} - func TestInvalidKeyPair(t *testing.T) { - _, err := Create([]byte(``), []byte(``), getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{} + _, err := GetTLSConfig(cfg, getCertificate) require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") } func TestInsecureCiphers(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + InsecureCiphers: true, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.False(t, tlsConfig.PreferServerCipherSuites) require.Empty(t, tlsConfig.CipherSuites) } -func TestCreate(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) +func TestGetTLSConfig(t *testing.T) { + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + }, + TLS: config.TLS{ + MinVersion: tls.VersionTLS11, + MaxVersion: tls.VersionTLS12, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.IsType(t, getCertificate, tlsConfig.GetCertificate) require.True(t, tlsConfig.PreferServerCipherSuites) require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites) require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion) require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion) + + cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) } diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index 365ba4cd..a97fdfb1 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -114,121 +114,92 @@ func TestDomainRateLimits(t *testing.T) { } func TestTLSRateLimits(t *testing.T) { - rateLimit := 5 - tests := map[string]struct { spec ListenSpec - options []processOption + domainLimit bool sourceIP string - featureName string enforceEnabled bool - limitName string }{ "https_with_domain_limit": { - spec: httpsListener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsListener, + domainLimit: true, sourceIP: "127.0.0.1", - featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, enforceEnabled: true, - limitName: "tls_connections_by_domain", }, "https_with_domain_limit_not_enforced": { - spec: httpsListener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsListener, + domainLimit: true, sourceIP: "127.0.0.1", - featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, enforceEnabled: false, - limitName: "tls_connections_by_domain", }, "https_with_ip_limit": { - spec: httpsListener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsListener, sourceIP: "127.0.0.1", - featureName: feature.EnforceIPTLSRateLimits.EnvVariable, enforceEnabled: true, - limitName: "tls_connections_by_source_ip", }, "https_with_ip_limit_not_enforced": { - spec: httpsListener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsListener, sourceIP: "127.0.0.1", - featureName: feature.EnforceIPTLSRateLimits.EnvVariable, enforceEnabled: false, - limitName: "tls_connections_by_source_ip", }, "proxyv2_with_domain_limit": { - spec: httpsProxyv2Listener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsProxyv2Listener, + domainLimit: true, sourceIP: "10.1.1.1", - featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, enforceEnabled: true, - limitName: "tls_connections_by_domain", }, "proxyv2_with_domain_limit_not_enforced": { - spec: httpsProxyv2Listener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsProxyv2Listener, + domainLimit: true, sourceIP: "10.1.1.1", - featureName: feature.EnforceDomainTLSRateLimits.EnvVariable, enforceEnabled: false, - limitName: "tls_connections_by_domain", }, "proxyv2_with_ip_limit": { - spec: httpsProxyv2Listener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsProxyv2Listener, sourceIP: "10.1.1.1", - featureName: feature.EnforceIPTLSRateLimits.EnvVariable, enforceEnabled: true, - limitName: "tls_connections_by_source_ip", }, "proxyv2_with_ip_limit_not_enforced": { - spec: httpsProxyv2Listener, - options: []processOption{ - withExtraArgument("metrics-address", ":42345"), - withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), - withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)), - }, + spec: httpsProxyv2Listener, sourceIP: "10.1.1.1", - featureName: feature.EnforceIPTLSRateLimits.EnvVariable, enforceEnabled: false, - limitName: "tls_connections_by_source_ip", }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { - testhelpers.StubFeatureFlagValue(t, tt.featureName, tt.enforceEnabled) + rateLimit := 5 - options := append(tt.options, withListeners([]ListenSpec{tt.spec})) + options := []processOption{ + withListeners([]ListenSpec{tt.spec}), + withExtraArgument("metrics-address", ":42345"), + } + + featureName := feature.EnforceIPTLSRateLimits.EnvVariable + limitName := "tls_connections_by_source_ip" + + if tt.domainLimit { + options = append(options, + withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit))) + + featureName = feature.EnforceDomainTLSRateLimits.EnvVariable + limitName = "tls_connections_by_domain" + } else { + options = append(options, + withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit))) + } + + testhelpers.StubFeatureFlagValue(t, featureName, tt.enforceEnabled) logBuf := RunPagesProcess(t, options...) + // when we start the process we make 1 requests to verify that process is up + // it gets counted in the rate limit for IP, but host is different + if !tt.domainLimit { + rateLimit-- + } + for i := 0; i < 10; i++ { rsp, err := makeTLSRequest(t, tt.spec) @@ -247,13 +218,13 @@ func TestTLSRateLimits(t *testing.T) { continue } - require.NoError(t, err) + require.NoError(t, err, "request: %d failed", i) require.NoError(t, rsp.Body.Close()) require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) } expectedMetric := fmt.Sprintf( - "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} 5", - tt.enforceEnabled, tt.limitName) + "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} %v", + tt.enforceEnabled, limitName, 10-rateLimit) RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric) }) |