diff options
author | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-21 20:00:46 +0300 |
---|---|---|
committer | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-22 12:38:49 +0300 |
commit | 58581c5a2ff3e95e1dc3acc69913412477a37557 (patch) | |
tree | f6f89f7ebe87075601c7cb94c608701b37d3d40a /internal | |
parent | 62a6491652aa6975d9ecf3b9e258766c886d49d4 (diff) |
feat: Always apply TLS limits even without ServerName
Diffstat (limited to 'internal')
-rw-r--r-- | internal/config/config.go | 6 | ||||
-rw-r--r-- | internal/config/flags.go | 31 | ||||
-rw-r--r-- | internal/config/tls/tls.go | 100 | ||||
-rw-r--r-- | internal/config/validate.go | 22 | ||||
-rw-r--r-- | internal/config/validate_test.go | 41 | ||||
-rw-r--r-- | internal/ratelimiter/tls.go | 6 | ||||
-rw-r--r-- | internal/tls/tls.go | 96 | ||||
-rw-r--r-- | internal/tls/tls_test.go (renamed from internal/config/tls/tls_test.go) | 73 |
8 files changed, 216 insertions, 159 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index 3dd4ecb3..2e31612a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,8 +9,6 @@ import ( "github.com/namsral/flag" "gitlab.com/gitlab-org/labkit/log" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) // Config stores all the config options relevant to GitLab Pages. @@ -229,8 +227,8 @@ func loadConfig() (*Config, error) { Environment: *sentryEnvironment, }, TLS: TLS{ - MinVersion: tls.AllTLSVersions[*tlsMinVersion], - MaxVersion: tls.AllTLSVersions[*tlsMaxVersion], + MinVersion: allTLSVersions[*tlsMinVersion], + MaxVersion: allTLSVersions[*tlsMaxVersion], }, Zip: ZipServing{ ExpirationInterval: *zipCacheExpiration, diff --git a/internal/config/flags.go b/internal/config/flags.go index 3778a677..409ecdc7 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -1,14 +1,23 @@ package config import ( + "crypto/tls" + "fmt" + "sort" + "strings" "time" "github.com/namsral/flag" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( + // allTLSVersions has all supported flag values + allTLSVersions = map[string]uint16{ + "": 0, // Default value in tls.Config + "tls1.2": tls.VersionTLS12, + "tls1.3": tls.VersionTLS13, + } + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") @@ -63,8 +72,8 @@ var ( maxConns = flag.Int("max-conns", 0, "Limit on the number of concurrent connections to the HTTP, HTTPS or proxy listeners, 0 for no limit") maxURILength = flag.Int("max-uri-length", 1024, "Limit the length of URI, 0 for unlimited.") insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4") - tlsMinVersion = flag.String("tls-min-version", "tls1.2", tls.FlagUsage("min")) - tlsMaxVersion = flag.String("tls-max-version", "", tls.FlagUsage("max")) + tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min")) + tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max")) zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval") zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval") zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval") @@ -96,3 +105,17 @@ func initFlags() { flag.Parse() } + +// tlsVersionFlagUsage returns string with explanation how to use the tls version CLI flag +func tlsVersionFlagUsage(minOrMax string) string { + versions := []string{} + + for version := range allTLSVersions { + if version != "" { + versions = append(versions, fmt.Sprintf("%q", version)) + } + } + sort.Strings(versions) + + return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) +} diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go deleted file mode 100644 index c76bbff7..00000000 --- a/internal/config/tls/tls.go +++ /dev/null @@ -1,100 +0,0 @@ -package tls - -import ( - "crypto/tls" - "fmt" - "sort" - "strings" -) - -// GetCertificateFunc returns the certificate to be used for given domain -type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) - -var ( - preferredCipherSuites = []uint16{ - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - } - - // AllTLSVersions has all supported flag values - AllTLSVersions = map[string]uint16{ - "": 0, // Default value in tls.Config - "tls1.2": tls.VersionTLS12, - "tls1.3": tls.VersionTLS13, - } -) - -// FlagUsage returns string with explanation how to use the CLI flag -func FlagUsage(minOrMax string) string { - versions := []string{} - - for version := range AllTLSVersions { - if version != "" { - versions = append(versions, fmt.Sprintf("%q", version)) - } - } - sort.Strings(versions) - - return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", ")) -} - -// Create returns tls.Config for given app configuration -func Create(cert, key []byte, getCertificate GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16) (*tls.Config, error) { - // set MinVersion to fix gosec: G402 - tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} - - err := configureCertificate(tlsConfig, cert, key) - if err != nil { - return nil, err - } - - if !insecureCiphers { - configureTLSCiphers(tlsConfig) - } - - tlsConfig.MinVersion = tlsMinVersion - tlsConfig.MaxVersion = tlsMaxVersion - - return tlsConfig, nil -} - -// ValidateTLSVersions returns error if the provided TLS versions config values are not valid -func ValidateTLSVersions(min, max string) error { - tlsMin, tlsMinOk := AllTLSVersions[min] - tlsMax, tlsMaxOk := AllTLSVersions[max] - - if !tlsMinOk { - return fmt.Errorf("invalid minimum TLS version: %s", min) - } - if !tlsMaxOk { - return fmt.Errorf("invalid maximum TLS version: %s", max) - } - if tlsMin > tlsMax && tlsMax > 0 { - return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) - } - - return nil -} - -func configureCertificate(tlsConfig *tls.Config, cert, key []byte) error { - certificate, err := tls.X509KeyPair(cert, key) - if err != nil { - return err - } - - tlsConfig.Certificates = []tls.Certificate{certificate} - - return nil -} - -func configureTLSCiphers(tlsConfig *tls.Config) { - tlsConfig.PreferServerCipherSuites = true - tlsConfig.CipherSuites = preferredCipherSuites -} diff --git a/internal/config/validate.go b/internal/config/validate.go index a3dbcc3b..bb247287 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -6,8 +6,6 @@ import ( "net/url" "github.com/hashicorp/go-multierror" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( @@ -30,7 +28,7 @@ func Validate(config *Config) error { validateListeners(config), validateAuthConfig(config), validateArtifactsServerConfig(config), - tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion), + validateTLSVersions(*tlsMinVersion, *tlsMaxVersion), ) return result.ErrorOrNil() @@ -115,3 +113,21 @@ func validateArtifactsServerConfig(config *Config) error { return result.ErrorOrNil() } + +// validateTLSVersions returns error if the provided TLS versions config values are not valid +func validateTLSVersions(min, max string) error { + tlsMin, tlsMinOk := allTLSVersions[min] + tlsMax, tlsMaxOk := allTLSVersions[max] + + if !tlsMinOk { + return fmt.Errorf("invalid minimum TLS version: %s", min) + } + if !tlsMaxOk { + return fmt.Errorf("invalid maximum TLS version: %s", max) + } + if tlsMin > tlsMax && tlsMax > 0 { + return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min) + } + + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 60e37732..80e4ded3 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -159,3 +159,44 @@ func validConfig() Config { return cfg } + +func TestValidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + }{ + "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, + "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, + "tls 1.3 max": {tlsMax: "tls1.3"}, + "tls 1.2 max": {tlsMax: "tls1.2"}, + "tls 1.3+": {tlsMin: "tls1.3"}, + "tls 1.2+": {tlsMin: "tls1.2"}, + "default": {}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.NoError(t, err) + }) + } +} + +func TestInvalidTLSVersions(t *testing.T) { + tests := map[string]struct { + tlsMin string + tlsMax string + err string + }{ + "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, + "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, + "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + err := validateTLSVersions(tc.tlsMin, tc.tlsMax) + require.EqualError(t, err, tc.err) + }) + } +} diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go index 15b10cb4..3bebbc38 100644 --- a/internal/ratelimiter/tls.go +++ b/internal/ratelimiter/tls.go @@ -7,13 +7,13 @@ import ( "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/log" - - tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ErrTLSRateLimited = errors.New("too many connections, please retry later") -func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate GetCertificateFunc) GetCertificateFunc { if rl.limitPerSecond <= 0.0 { return getCertificate } diff --git a/internal/tls/tls.go b/internal/tls/tls.go new file mode 100644 index 00000000..c222fcce --- /dev/null +++ b/internal/tls/tls.go @@ -0,0 +1,96 @@ +package tls + +import ( + "crypto/tls" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +var preferredCipherSuites = []uint16{ + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, +} + +// GetCertificateFunc returns the certificate to be used for given domain +type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) + +// GetTLSConfig initializes tls.Config based on config flags +// getCertificateByServerName obtains certificate based on domain +func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) { + certificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey) + if err != nil { + return nil, err + } + + getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + // Golang calls tls.Config.GetCertificate only if it's set and + // 1. ServerName != "" + // 2. Or tls.Config.Certificates is empty array + // tls.Config.Certificates contain wildcard certificate + // We want to implement rate limits via GetCertificate, so we need to call it every time + // So we don't set tls.Config.Certificates, but simulate the behavior of golang: + // 1. try to get certificate by name + // 2. if we can't, fallback to default(wildcard) certificate + cert, err := getCertificateByServerName(info) + + if cert != nil || err != nil { + return cert, err + } + + return &certificate, nil + } + + TLSDomainRateLimiter := ratelimiter.New( + "tls_connections_by_domain", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + TLSSourceIPRateLimiter := ratelimiter.New( + "tls_connections_by_source_ip", + ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey), + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount), + ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond), + ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst), + ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()), + ) + + getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate) + getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) + + // set MinVersion to fix gosec: G402 + tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} + + if !cfg.General.InsecureCiphers { + configureTLSCiphers(tlsConfig) + } + + tlsConfig.MinVersion = cfg.TLS.MinVersion + tlsConfig.MaxVersion = cfg.TLS.MaxVersion + + return tlsConfig, nil +} + +func configureTLSCiphers(tlsConfig *tls.Config) { + tlsConfig.PreferServerCipherSuites = true + tlsConfig.CipherSuites = preferredCipherSuites +} diff --git a/internal/config/tls/tls_test.go b/internal/tls/tls_test.go index 7146d904..e65b2671 100644 --- a/internal/config/tls/tls_test.go +++ b/internal/tls/tls_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/config" ) var cert = []byte(`-----BEGIN CERTIFICATE----- @@ -29,65 +31,46 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { return nil, nil } -func TestInvalidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - err string - }{ - "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"}, - "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"}, - "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.EqualError(t, err, tc.err) - }) - } -} - -func TestValidTLSVersions(t *testing.T) { - tests := map[string]struct { - tlsMin string - tlsMax string - }{ - "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"}, - "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"}, - "tls 1.3 max": {tlsMax: "tls1.3"}, - "tls 1.2 max": {tlsMax: "tls1.2"}, - "tls 1.3+": {tlsMin: "tls1.3"}, - "tls 1.2+": {tlsMin: "tls1.2"}, - "default": {}, - } - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax) - require.NoError(t, err) - }) - } -} - func TestInvalidKeyPair(t *testing.T) { - _, err := Create([]byte(``), []byte(``), getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{} + _, err := GetTLSConfig(cfg, getCertificate) require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") } func TestInsecureCiphers(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12) + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + InsecureCiphers: true, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.False(t, tlsConfig.PreferServerCipherSuites) require.Empty(t, tlsConfig.CipherSuites) } -func TestCreate(t *testing.T) { - tlsConfig, err := Create(cert, key, getCertificate, false, tls.VersionTLS11, tls.VersionTLS12) +func TestGetTLSConfig(t *testing.T) { + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + }, + TLS: config.TLS{ + MinVersion: tls.VersionTLS11, + MaxVersion: tls.VersionTLS12, + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate) require.NoError(t, err) require.IsType(t, getCertificate, tlsConfig.GetCertificate) require.True(t, tlsConfig.PreferServerCipherSuites) require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites) require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion) require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion) + + cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{}) + require.NoError(t, err) + require.NotNil(t, cert) } |