diff options
author | Av1o <django@dcas.dev> | 2023-10-17 03:40:19 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2023-10-17 03:40:19 +0300 |
commit | 82a9981e93e32a2ae7528a247a7065d98240b924 (patch) | |
tree | 88049b38d8f3f0f0926b106c4e545d470bf62749 /internal | |
parent | 4de55dca4d48748583a0d459e5795c00295aabc3 (diff) |
Support for Mutual TLS
Diffstat (limited to 'internal')
-rw-r--r-- | internal/config/config.go | 47 | ||||
-rw-r--r-- | internal/config/config_test.go | 65 | ||||
-rw-r--r-- | internal/config/flags.go | 4 | ||||
-rw-r--r-- | internal/domain/domain.go | 29 | ||||
-rw-r--r-- | internal/domain/domain_test.go | 8 | ||||
-rw-r--r-- | internal/logging/logging.go | 14 | ||||
-rw-r--r-- | internal/source/gitlab/api/virtual_domain.go | 5 | ||||
-rw-r--r-- | internal/source/gitlab/gitlab.go | 2 | ||||
-rw-r--r-- | internal/tls/testdata/cert.crt | 11 | ||||
-rw-r--r-- | internal/tls/tls.go | 64 | ||||
-rw-r--r-- | internal/tls/tls_test.go | 32 |
11 files changed, 256 insertions, 25 deletions
diff --git a/internal/config/config.go b/internal/config/config.go index afd7982e..d2016550 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -140,8 +140,11 @@ type Sentry struct { // TLS groups settings related to configuring TLS type TLS struct { - MinVersion uint16 - MaxVersion uint16 + MinVersion uint16 + MaxVersion uint16 + ClientAuth tls.ClientAuthType + ClientCert string + ClientAuthDomains []string } // ZipServing groups settings to be used by the zip VFS opening and caching @@ -246,6 +249,31 @@ func loadMetricsConfig() (metrics Metrics, err error) { return metrics, nil } +// parseClientAuthType converts the tls.ClientAuthType enum from names +// to the underlying value. Passing an empty string assumes +// tls.NoClientCert +func parseClientAuthType(clientAuth string) (tls.ClientAuthType, error) { + switch strings.ToLower(clientAuth) { + // if nothing is provided, assume that + // the user does not want to enable any form + // of client authentication + case "": + fallthrough + case "noclientcert": + return tls.NoClientCert, nil + case "requestclientcert": + return tls.RequestClientCert, nil + case "requireanyclientcert": + return tls.RequireAnyClientCert, nil + case "verifyclientcertifgiven": + return tls.VerifyClientCertIfGiven, nil + case "requireandverifyclientcert": + return tls.RequireAndVerifyClientCert, nil + default: + return -1, fmt.Errorf("unknown client auth type %s: supported values can be found at https://pkg.go.dev/crypto/tls#ClientAuthType", clientAuth) + } +} + func parseHeaderString(customHeaders []string) (http.Header, error) { headers := make(http.Header, len(customHeaders)) @@ -341,8 +369,10 @@ func loadConfig() (*Config, error) { Environment: *sentryEnvironment, }, TLS: TLS{ - MinVersion: allTLSVersions[*tlsMinVersion], - MaxVersion: allTLSVersions[*tlsMaxVersion], + MinVersion: allTLSVersions[*tlsMinVersion], + MaxVersion: allTLSVersions[*tlsMaxVersion], + ClientCert: *tlsClientCert, + ClientAuthDomains: tlsClientAuthDomains.value, }, Zip: ZipServing{ ExpirationInterval: *zipCacheExpiration, @@ -394,6 +424,12 @@ func loadConfig() (*Config, error) { return nil, fmt.Errorf("unable to parse header string: %w", err) } + clientAuthType, err := parseClientAuthType(*tlsClientAuth) + if err != nil { + return nil, err + } + config.TLS.ClientAuth = clientAuthType + config.General.CustomHeaders = customHeaders // Populating remaining GitLab settings @@ -437,6 +473,9 @@ func logFields(config *Config) map[string]any { "status_path": config.General.StatusPath, "tls-min-version": *tlsMinVersion, "tls-max-version": *tlsMaxVersion, + "tls-client-auth": *tlsClientAuth, + "tls-client-cert": *tlsClientCert, + "tls-client-auth-domains": tlsClientAuthDomains, "gitlab-server": config.GitLab.PublicServer, "internal-gitlab-server": config.GitLab.InternalServer, "api-secret-key": *gitLabAPISecretKey, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9db88acc..516d7472 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "crypto/tls" "os" "path/filepath" "testing" @@ -107,6 +108,70 @@ func setupHTTPSFixture(t *testing.T) (dir string, key string, cert string) { return tmpDir, keyfile.Name(), certfile.Name() } +func TestParseClientAuthType(t *testing.T) { + tests := []struct { + name string + clientAuth string + valid bool + expected tls.ClientAuthType + }{ + { + name: "empty string", + clientAuth: "", + valid: true, + expected: tls.NoClientCert, + }, + { + name: "unknown value", + clientAuth: "no cert", + valid: false, + expected: -1, + }, + { + name: "explicitly no cert", + clientAuth: "NoClientCert", + valid: true, + expected: tls.NoClientCert, + }, + { + name: "request cert", + clientAuth: "RequestClientCert", + valid: true, + expected: tls.RequestClientCert, + }, + { + name: "require any cert", + clientAuth: "RequireAnyClientCert", + valid: true, + expected: tls.RequireAnyClientCert, + }, + { + name: "verify cert if given", + clientAuth: "VerifyClientCertIfGiven", + valid: true, + expected: tls.VerifyClientCertIfGiven, + }, + { + name: "require and verify cert", + clientAuth: "RequireAndVerifyClientCert", + valid: true, + expected: tls.RequireAndVerifyClientCert, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authType, err := parseClientAuthType(tt.clientAuth) + if tt.valid { + require.NoError(t, err) + require.EqualValues(t, tt.expected, authType) + return + } + require.Error(t, err) + }) + } +} + func TestParseHeaderString(t *testing.T) { tests := []struct { name string diff --git a/internal/config/flags.go b/internal/config/flags.go index c05f2f32..bce3b07e 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -80,6 +80,9 @@ var ( insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4") tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min")) tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max")) + tlsClientAuth = flag.String("tls-client-auth", "noclientcert", "Determines the TLS servers policy for client authentication. Defaults to no client certificate. Values can be found at https://pkg.go.dev/crypto/tls#ClientAuthType") + tlsClientCert = flag.String("tls-client-cert", "", "Path to the certificate authority used to validate client certificates against") + tlsClientAuthDomains = MultiStringFlag{separator: ","} zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval") zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval") zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval") @@ -120,6 +123,7 @@ func initFlags() { flag.Var(&listenProxy, "listen-proxy", "The address(es) or unix socket paths to listen on for proxy requests") flag.Var(&listenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) or unix socket paths to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)") flag.Var(&header, "header", "The additional http header(s) that should be send to the client") + flag.Var(&tlsClientAuthDomains, "tls-client-auth-domains", "The domain(s) that require client certificate authentication") // read from -config=/path/to/gitlab-pages-config flag.String(flag.DefaultConfigFlagname, "", "path to config file") diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 1cfee100..9623ce51 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -3,6 +3,7 @@ package domain import ( "context" "crypto/tls" + "crypto/x509" "errors" "net/http" "sync" @@ -19,9 +20,10 @@ var ErrDomainDoesNotExist = errors.New("domain does not exist") // Domain is a domain that gitlab-pages can serve. type Domain struct { - Name string - CertificateCert string - CertificateKey string + Name string + CertificateCert string + CertificateKey string + ClientCertificateCert string Resolver Resolver @@ -31,12 +33,13 @@ type Domain struct { } // New creates a new domain with a resolver and existing certificates -func New(name, cert, key string, resolver Resolver) *Domain { +func New(name, cert, key, clientCert string, resolver Resolver) *Domain { return &Domain{ - Name: name, - CertificateCert: cert, - CertificateKey: key, - Resolver: resolver, + Name: name, + CertificateCert: cert, + CertificateKey: key, + ClientCertificateCert: clientCert, + Resolver: resolver, } } @@ -121,6 +124,16 @@ func (d *Domain) EnsureCertificate() (*tls.Certificate, error) { return d.certificate, d.certificateError } +func (d *Domain) EnsureClientCertPool() (*x509.CertPool, error) { + if d == nil || len(d.ClientCertificateCert) == 0 { + return nil, errors.New("tls client certificates can be loaded only for pages with configuration") + } + + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM([]byte(d.ClientCertificateCert)) + return certPool, nil +} + // ServeFileHTTP returns true if something was served, false if not. func (d *Domain) ServeFileHTTP(w http.ResponseWriter, r *http.Request) bool { request, err := d.resolve(r) diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index 5698aead..66a3a043 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -33,7 +33,7 @@ func TestIsHTTPSOnly(t *testing.T) { }{ { name: "Custom domain with HTTPS-only enabled", - domain: domain.New("custom-domain", "", "", + domain: domain.New("custom-domain", "", "", "", mockResolver(t, &serving.LookupPath{ Path: "group/project/public", @@ -48,7 +48,7 @@ func TestIsHTTPSOnly(t *testing.T) { }, { name: "Custom domain with HTTPS-only disabled", - domain: domain.New("custom-domain", "", "", + domain: domain.New("custom-domain", "", "", "", mockResolver(t, &serving.LookupPath{ Path: "group/project/public", @@ -63,7 +63,7 @@ func TestIsHTTPSOnly(t *testing.T) { }, { name: "Unknown project", - domain: domain.New("", "", "", mockResolver(t, nil, "", domain.ErrDomainDoesNotExist)), + domain: domain.New("", "", "", "", mockResolver(t, nil, "", domain.ErrDomainDoesNotExist)), url: "http://test-domain/project", expected: false, }, @@ -81,7 +81,7 @@ func TestPredefined404ServeHTTP(t *testing.T) { cleanup := setUpTests(t) defer cleanup() - testDomain := domain.New("", "", "", mockResolver(t, nil, "", domain.ErrDomainDoesNotExist)) + testDomain := domain.New("", "", "", "", mockResolver(t, nil, "", domain.ErrDomainDoesNotExist)) require.HTTPStatusCode(t, serveFileOrNotFound(testDomain), http.MethodGet, "http://group.test.io/not-existing-file", nil, http.StatusNotFound) require.HTTPBodyContains(t, serveFileOrNotFound(testDomain), http.MethodGet, "http://group.test.io/not-existing-file", nil, "The page you're looking for could not be found") diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 2faaddc5..47083d8e 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -1,6 +1,7 @@ package logging import ( + "fmt" "net/http" "github.com/sirupsen/logrus" @@ -66,9 +67,20 @@ func BasicAccessLogger(handler http.Handler, format string) (http.Handler, error } func extraFields(r *http.Request) log.Fields { - return log.Fields{ + fields := log.Fields{ "pages_https": request.IsHTTPS(r), } + // if there's no client cert, return early + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + return fields + } + + // log the client certificate information + for i := range r.TLS.PeerCertificates { + fields[fmt.Sprintf("x509_subject_%d", i)] = r.TLS.PeerCertificates[i].Subject.ToRDNSequence().String() + fields[fmt.Sprintf("x509_issuer_%d", i)] = r.TLS.PeerCertificates[i].Issuer.ToRDNSequence().String() + } + return fields } // LogRequest will inject request host and path to the logged messages diff --git a/internal/source/gitlab/api/virtual_domain.go b/internal/source/gitlab/api/virtual_domain.go index 200c06de..ba55db04 100644 --- a/internal/source/gitlab/api/virtual_domain.go +++ b/internal/source/gitlab/api/virtual_domain.go @@ -3,8 +3,9 @@ package api // VirtualDomain represents a GitLab Pages virtual domain that is being sent // from GitLab API type VirtualDomain struct { - Certificate string `json:"certificate,omitempty"` - Key string `json:"key,omitempty"` + Certificate string `json:"certificate,omitempty"` + Key string `json:"key,omitempty"` + ClientCertificate string `json:"client_certificate,omitempty"` LookupPaths []LookupPath `json:"lookup_paths"` } diff --git a/internal/source/gitlab/gitlab.go b/internal/source/gitlab/gitlab.go index ab5bc490..d3fbc8cc 100644 --- a/internal/source/gitlab/gitlab.go +++ b/internal/source/gitlab/gitlab.go @@ -56,7 +56,7 @@ func (g *Gitlab) GetDomain(ctx context.Context, name string) (*domain.Domain, er // TODO introduce a second-level cache for domains, invalidate using etags // from first-level cache - d := domain.New(name, lookup.Domain.Certificate, lookup.Domain.Key, g) + d := domain.New(name, lookup.Domain.Certificate, lookup.Domain.Key, lookup.Domain.ClientCertificate, g) return d, nil } diff --git a/internal/tls/testdata/cert.crt b/internal/tls/testdata/cert.crt new file mode 100644 index 00000000..4952c7ed --- /dev/null +++ b/internal/tls/testdata/cert.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----
\ No newline at end of file diff --git a/internal/tls/tls.go b/internal/tls/tls.go index eb8e4e64..a01cceb8 100644 --- a/internal/tls/tls.go +++ b/internal/tls/tls.go @@ -2,6 +2,8 @@ package tls import ( "crypto/tls" + "crypto/x509" + "os" "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" @@ -23,9 +25,13 @@ var preferredCipherSuites = []uint16{ // GetCertificateFunc returns the certificate to be used for given domain type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) +// GetConfigFunc returns a tls.Config with populated client +// auth values. +type GetConfigFunc func(*tls.ClientHelloInfo) (*tls.Config, error) + // GetTLSConfig initializes tls.Config based on config flags // getCertificateByServerName obtains certificate based on domain -func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) { +func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) { wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey) if err != nil { return nil, err @@ -74,8 +80,48 @@ func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateF getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate) getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate) + tlsConfig, err := getTLSConfig(cfg, getCertificate) + if err != nil { + return nil, err + } + + tlsConfig.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return getOptionalConfig(cfg, info, getCertificate, getConfigByServerName) + } + + return tlsConfig, nil +} + +func getOptionalConfig(cfg *config.Config, info *tls.ClientHelloInfo, getCertificate GetCertificateFunc, getConfigByServerName GetConfigFunc) (*tls.Config, error) { + customConfig, err := getConfigByServerName(info) + + if customConfig != nil || err != nil { + customConfig.GetCertificate = getCertificate + return customConfig, err + } + + if cfg.TLS.ClientAuth == tls.NoClientCert { + return nil, nil + } + + for _, i := range cfg.TLS.ClientAuthDomains { + if i != info.ServerName { + continue + } + tlsConfig, err := getTLSConfig(cfg, getCertificate) + if err != nil { + return nil, err + } + tlsConfig.ClientAuth = cfg.TLS.ClientAuth + return tlsConfig, nil + } + + return nil, nil +} + +func getTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) { // set MinVersion to fix gosec: G402 - tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12} + tlsConfig := &tls.Config{GetCertificate: getCertificateByServerName, MinVersion: tls.VersionTLS12} if !cfg.General.InsecureCiphers { tlsConfig.CipherSuites = preferredCipherSuites @@ -84,5 +130,19 @@ func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateF tlsConfig.MinVersion = cfg.TLS.MinVersion tlsConfig.MaxVersion = cfg.TLS.MaxVersion + if len(cfg.TLS.ClientAuthDomains) == 0 { + tlsConfig.ClientAuth = cfg.TLS.ClientAuth + } + + if cfg.TLS.ClientAuth > tls.RequestClientCert { + caCert, err := os.ReadFile(cfg.TLS.ClientCert) + if err != nil { + return nil, err + } + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(caCert) + tlsConfig.ClientCAs = certPool + } + return tlsConfig, nil } diff --git a/internal/tls/tls_test.go b/internal/tls/tls_test.go index 012d335b..358cc72a 100644 --- a/internal/tls/tls_test.go +++ b/internal/tls/tls_test.go @@ -31,9 +31,13 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { return nil, nil } +var getConfig = func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{MinVersion: tls.VersionTLS12}, nil +} + func TestInvalidKeyPair(t *testing.T) { cfg := &config.Config{} - _, err := GetTLSConfig(cfg, getCertificate) + _, err := GetTLSConfig(cfg, getCertificate, getConfig) require.EqualError(t, err, "tls: failed to find any PEM data in certificate input") } @@ -45,11 +49,30 @@ func TestInsecureCiphers(t *testing.T) { InsecureCiphers: true, }, } - tlsConfig, err := GetTLSConfig(cfg, getCertificate) + tlsConfig, err := GetTLSConfig(cfg, getCertificate, getConfig) require.NoError(t, err) require.Empty(t, tlsConfig.CipherSuites) } +func TestClientCert(t *testing.T) { + cfg := &config.Config{ + General: config.General{ + RootCertificate: cert, + RootKey: key, + }, + TLS: config.TLS{ + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCert: "./testdata/cert.crt", + }, + } + tlsConfig, err := GetTLSConfig(cfg, getCertificate, getConfig) + require.NoError(t, err) + require.IsType(t, getCertificate, tlsConfig.GetCertificate) + require.IsType(t, getConfig, tlsConfig.GetConfigForClient) + require.Equal(t, tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) + require.NotNil(t, tlsConfig.ClientCAs) +} + func TestGetTLSConfig(t *testing.T) { cfg := &config.Config{ General: config.General{ @@ -61,12 +84,15 @@ func TestGetTLSConfig(t *testing.T) { MaxVersion: tls.VersionTLS12, }, } - tlsConfig, err := GetTLSConfig(cfg, getCertificate) + tlsConfig, err := GetTLSConfig(cfg, getCertificate, getConfig) require.NoError(t, err) require.IsType(t, getCertificate, tlsConfig.GetCertificate) + require.IsType(t, getConfig, tlsConfig.GetConfigForClient) require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites) require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion) require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion) + require.Equal(t, tls.NoClientCert, tlsConfig.ClientAuth) + require.Nil(t, tlsConfig.ClientCAs) cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{}) require.NoError(t, err) |