Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitlab/ci/test.yml5
-rw-r--r--.gitlab/issue_templates/release.md2
-rw-r--r--.golangci.yml18
-rw-r--r--.tool-versions2
-rw-r--r--CHANGELOG.md16
-rw-r--r--Makefile.build.mk12
-rw-r--r--Makefile.internal.mk4
-rw-r--r--Makefile.util.mk9
-rw-r--r--VERSION2
-rw-r--r--app.go32
-rw-r--r--internal/auth/auth_test.go27
-rw-r--r--internal/config/config.go94
-rw-r--r--internal/config/flags.go59
-rw-r--r--internal/config/tls/tls.go100
-rw-r--r--internal/config/validate.go22
-rw-r--r--internal/config/validate_test.go41
-rw-r--r--internal/customheaders/customheaders.go39
-rw-r--r--internal/customheaders/customheaders_test.go44
-rw-r--r--internal/domain/domain_test.go4
-rw-r--r--internal/domain/mock/resolver_mock.go (renamed from internal/mocks/resolver.go)4
-rw-r--r--internal/feature/feature.go16
-rw-r--r--internal/handlers/handlers_test.go24
-rw-r--r--internal/handlers/mock/handler_mock.go (renamed from internal/mocks/mocks.go)4
-rw-r--r--internal/handlers/ratelimiter.go16
-rw-r--r--internal/httpfs/http_fs_test.go3
-rw-r--r--internal/httptransport/metered_round_tripper_test.go4
-rw-r--r--internal/ratelimiter/middleware.go5
-rw-r--r--internal/ratelimiter/middleware_test.go4
-rw-r--r--internal/ratelimiter/ratelimiter.go31
-rw-r--r--internal/ratelimiter/stub_conn_test.go14
-rw-r--r--internal/ratelimiter/tls.go49
-rw-r--r--internal/ratelimiter/tls_test.go194
-rw-r--r--internal/serving/disk/local/serving_test.go2
-rw-r--r--internal/serving/disk/zip/serving_test.go2
-rw-r--r--internal/source/gitlab/cache/retriever.go5
-rw-r--r--internal/source/gitlab/gitlab_test.go6
-rw-r--r--internal/source/gitlab/mock/client_mock.go (renamed from internal/mocks/client.go)6
-rw-r--r--internal/source/gitlab/mock/client_stub.go (renamed from internal/mocks/api/client_stub.go)2
-rw-r--r--internal/source/mock/source_mock.go (renamed from internal/mocks/source.go)4
-rw-r--r--internal/testhelpers/testhelpers.go10
-rw-r--r--internal/tls/tls.go96
-rw-r--r--internal/tls/tls_test.go (renamed from internal/config/tls/tls_test.go)73
-rw-r--r--internal/urilimiter/urilimiter_test.go4
-rw-r--r--metrics/metrics.go57
-rw-r--r--server.go5
-rw-r--r--test/acceptance/acme_test.go6
-rw-r--r--test/acceptance/artifacts_test.go19
-rw-r--r--test/acceptance/auth_test.go70
-rw-r--r--test/acceptance/encodings_test.go6
-rw-r--r--test/acceptance/helpers_test.go167
-rw-r--r--test/acceptance/metrics_test.go6
-rw-r--r--test/acceptance/proxyv2_test.go4
-rw-r--r--test/acceptance/ratelimiter_test.go128
-rw-r--r--test/acceptance/redirects_test.go7
-rw-r--r--test/acceptance/rewrites_test.go3
-rw-r--r--test/acceptance/serving_test.go32
-rw-r--r--test/acceptance/status_test.go4
-rw-r--r--test/acceptance/unknown_http_method_test.go4
-rw-r--r--test/acceptance/zip_test.go13
59 files changed, 1108 insertions, 533 deletions
diff --git a/.gitlab/ci/test.yml b/.gitlab/ci/test.yml
index d20c144e..e9bca5d6 100644
--- a/.gitlab/ci/test.yml
+++ b/.gitlab/ci/test.yml
@@ -64,3 +64,8 @@ check deps:
- echo skipping
script:
- make deps-check
+
+check mocks:
+ extends: .tests-common
+ script:
+ - make mocks-check
diff --git a/.gitlab/issue_templates/release.md b/.gitlab/issue_templates/release.md
index 930508c8..3a4efc58 100644
--- a/.gitlab/issue_templates/release.md
+++ b/.gitlab/issue_templates/release.md
@@ -2,7 +2,7 @@
- Decide on the version number by reference to
the [Versioning](https://gitlab.com/gitlab-org/gitlab-pages/blob/master/PROCESS.md#versioning)
* Typically if you want to release code from current `master` branch you will update `MINOR` version, e.g. `1.12.0` -> `1.13.0`. In that case you **don't** need to create stable branch
- * If you want to backport some bug fix or security fix you will need to update stable branch `X-Y-stable`
+ * If you want to backport some bug fix or security fix you will need to create a stable branch `X-Y-stable` on the [security project](https://gitlab.com/gitlab-org/security/gitlab-pages). You will need maintainer access to create the stable branch.
- [ ] Create an MR for [gitlab-pages project](https://gitlab.com/gitlab-org/gitlab-pages).
You can use [this MR](https://gitlab.com/gitlab-org/gitlab-pages/merge_requests/217) as an example.
- [ ] Update `VERSION`, and push your branch
diff --git a/.golangci.yml b/.golangci.yml
index c9fc0444..3b472def 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -1,17 +1,17 @@
run:
- # which dirs to skip: issues from them won't be reported;
- # can use regexp here: generated.*, regexp is applied on full path;
- # default value is empty list, but default dirs are skipped independently
- # from this option's value (see skip-dirs-use-default).
- # "/" will be replaced by current OS file path separator to properly work
- # on Windows.
- skip-dirs:
- - internal/mocks
-
# default is true. Enables skipping of directories:
# vendor$, third_party$, testdata$, examples$, Godeps$, builtin$
skip-dirs-use-default: false
+ # Which files to skip: they will be analyzed, but issues from them won't be reported.
+ # Default value is empty list,
+ # but there is no need to include all autogenerated files,
+ # we confidently recognize autogenerated files.
+ # If it's not please let us know.
+ # "/" will be replaced by current OS file path separator to properly work on Windows.
+ skip-files:
+ - _mock\.go
+
# by default isn't set. If set we pass it to "go list -mod={option}". From "go help modules":
# If invoked with -mod=readonly, the go command is disallowed from the implicit
# automatic updating of go.mod described above. Instead, it fails when any changes
diff --git a/.tool-versions b/.tool-versions
index 29cc9a03..108bdd0f 100644
--- a/.tool-versions
+++ b/.tool-versions
@@ -1 +1 @@
-golang 1.17.6
+golang 1.17.7
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 56dabab5..f65bf9b6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,19 @@
+## 1.55.0 (2022-02-22)
+
+### Added (1 change)
+
+- [feat: Add TLS rate limits](gitlab-org/gitlab-pages@62a6491652aa6975d9ecf3b9e258766c886d49d4) ([merge request](gitlab-org/gitlab-pages!700))
+
+### Fixed (1 change)
+
+- [fix: do no retry resolving the domain if there's a ctx error](gitlab-org/gitlab-pages@970531c7f80db47d209196921043aabcdf7590ef) by @feistel ([merge request](gitlab-org/gitlab-pages!691))
+
+## 1.54.0 (2022-02-10)
+
+### Fixed (1 change)
+
+- [fix: ensure logging status codes field names are consistent](gitlab-org/gitlab-pages@6f23e35ffe9665ab17af54824a7de2b014829069) ([merge request](gitlab-org/gitlab-pages!679))
+
## 1.53.0 (2022-02-01)
### Fixed (2 changes)
diff --git a/Makefile.build.mk b/Makefile.build.mk
index 350a7b99..88d74dbf 100644
--- a/Makefile.build.mk
+++ b/Makefile.build.mk
@@ -19,13 +19,13 @@ setup: .GOPATH/.ok
cisetup: .GOPATH/.ok
mkdir -p bin/
# Installing dev tools defined in go.tools
- awk '/_/ {print $$2}' ./tools/main.go | grep -v -e mockgen -e golangci | xargs -tI % go install ${V:+-v -x} -modfile=tools/go.mod -mod=mod %
+ awk '/_/ {print $$2}' ./tools/main.go | grep -v -e golangci | xargs -tI % go install ${V:+-v -x} -modfile=tools/go.mod -mod=mod %
-generate-mocks: .GOPATH/.ok
- $Q bin/mockgen -source=internal/interface.go -destination=internal/mocks/mocks.go -package=mocks
- $Q bin/mockgen -source=internal/source/source.go -destination=internal/mocks/source.go -package=mocks
- $Q bin/mockgen -source=internal/mocks/api/client_stub.go -destination=internal/mocks/client.go -package=mocks
- $Q bin/mockgen -source=internal/domain/resolver.go -destination=internal/mocks/resolver.go -package=mocks
+generate-mocks: .GOPATH/.ok bin/mockgen
+ $Q bin/mockgen -source=internal/interface.go -destination=internal/handlers/mock/handler_mock.go -package=mock
+ $Q bin/mockgen -source=internal/source/source.go -destination=internal/source/mock/source_mock.go -package=mock
+ $Q bin/mockgen -source=internal/source/gitlab/mock/client_stub.go -destination=internal/source/gitlab/mock/client_mock.go -package=mock
+ $Q bin/mockgen -source=internal/domain/resolver.go -destination=internal/domain/mock/resolver_mock.go -package=mock
build: .GOPATH/.ok
$Q GOBIN=$(BINDIR) go install $(if $V,-v) -ldflags="$(VERSION_FLAGS)" -tags "${GO_BUILD_TAGS}" -buildmode exe $(IMPORT_PATH)
diff --git a/Makefile.internal.mk b/Makefile.internal.mk
index d2340855..a3e7eb7b 100644
--- a/Makefile.internal.mk
+++ b/Makefile.internal.mk
@@ -32,3 +32,7 @@ bin/golangci-lint: .GOPATH/.ok
bin/gotestsum: .GOPATH/.ok
@test -x $@ || \
{ echo "Vendored gotestsum not found, try running 'make setup'..."; exit 1; }
+
+bin/mockgen: .GOPATH/.ok
+ @test -x $@ || \
+ { echo "Vendored mockgen not found, try running 'make setup'..."; exit 1; }
diff --git a/Makefile.util.mk b/Makefile.util.mk
index 692848c7..e9abe846 100644
--- a/Makefile.util.mk
+++ b/Makefile.util.mk
@@ -49,6 +49,15 @@ deps-check: .GOPATH/.ok
exit 1; \
fi;
+mocks-check: .GOPATH/.ok generate-mocks
+ @if git diff --color=always --exit-code -- *_mock.go; then \
+ echo "mocks are ok"; \
+ else \
+ echo ""; \
+ echo "mocks needs to be regenerated, please run 'make generate-mocks' and commit them";\
+ exit 1; \
+ fi;
+
deps-download: .GOPATH/.ok
go mod download
diff --git a/VERSION b/VERSION
index 3f483015..094d6ad0 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-1.53.0
+1.55.0
diff --git a/app.go b/app.go
index ddafb0bf..ec928300 100644
--- a/app.go
+++ b/app.go
@@ -27,7 +27,6 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/artifact"
"gitlab.com/gitlab-org/gitlab-pages/internal/auth"
cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config"
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
"gitlab.com/gitlab-org/gitlab-pages/internal/customheaders"
"gitlab.com/gitlab-org/gitlab-pages/internal/domain"
"gitlab.com/gitlab-org/gitlab-pages/internal/handlers"
@@ -40,6 +39,7 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip"
"gitlab.com/gitlab-org/gitlab-pages/internal/source"
"gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/tls"
"gitlab.com/gitlab-org/gitlab-pages/internal/urilimiter"
"gitlab.com/gitlab-org/gitlab-pages/metrics"
)
@@ -51,6 +51,7 @@ var (
type theApp struct {
config *cfg.Config
source source.Source
+ tlsConfig *cryptotls.Config
Artifact *artifact.Artifact
Auth *auth.Auth
Handlers *handlers.Handlers
@@ -62,19 +63,33 @@ func (a *theApp) isReady() bool {
return true
}
-func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
+func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) {
if ch.ServerName == "" {
return nil, nil
}
if domain, _ := a.domain(context.Background(), ch.ServerName); domain != nil {
- tls, _ := domain.EnsureCertificate()
- return tls, nil
+ certificate, _ := domain.EnsureCertificate()
+ return certificate, nil
}
return nil, nil
}
+func (a *theApp) getTLSConfig() (*cryptotls.Config, error) {
+ // we call this function only when tls config is needed, and we ignore TLS related flags otherwise
+ // in theory you can configure both listen-https and listen-proxyv2,
+ // so this return is here to have a single TLS config
+ if a.tlsConfig != nil {
+ return a.tlsConfig, nil
+ }
+
+ var err error
+ a.tlsConfig, err = tls.GetTLSConfig(a.config, a.GetCertificate)
+
+ return a.tlsConfig, err
+}
+
func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) {
u := *r.URL
u.Scheme = request.SchemeHTTPS
@@ -306,7 +321,7 @@ func (a *theApp) Run() {
// Listen for HTTPS
for _, addr := range a.config.ListenHTTPSStrings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -334,7 +349,7 @@ func (a *theApp) Run() {
// Listen for HTTPS PROXYv2 requests
for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() {
- tlsConfig, err := a.TLSConfig()
+ tlsConfig, err := a.getTLSConfig()
if err != nil {
log.WithError(err).Fatal("Unable to retrieve tls config")
}
@@ -478,11 +493,6 @@ func fatal(err error, message string) {
log.WithError(err).Fatal(message)
}
-func (a *theApp) TLSConfig() (*cryptotls.Config, error) {
- return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS,
- a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion)
-}
-
// handlePanicMiddleware logs and captures the recover() information from any panic
func handlePanicMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index d55c5a46..61b4e88a 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -15,8 +15,9 @@ import (
"github.com/gorilla/sessions"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-pages/internal/mocks"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/source/mock"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func createTestAuth(t *testing.T, internalServer string, publicServer string) *Auth {
@@ -68,7 +69,7 @@ func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, value
session.Save(tmpRequest, result)
res := result.Result()
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
for _, cookie := range res.Cookies() {
r.AddCookie(cookie)
@@ -86,7 +87,7 @@ func TestTryAuthenticate(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.False(t, auth.TryAuthenticate(result, r, mockSource))
}
@@ -102,7 +103,7 @@ func TestTryAuthenticateWithError(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.True(t, auth.TryAuthenticate(result, r, mockSource))
require.Equal(t, http.StatusUnauthorized, result.Code)
}
@@ -124,7 +125,7 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.True(t, auth.TryAuthenticate(result, r, mockSource))
require.Equal(t, http.StatusUnauthorized, result.Code)
}
@@ -149,7 +150,7 @@ func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.True(t, auth.TryAuthenticate(result, r, mockSource))
require.Equal(t, http.StatusFound, result.Code)
@@ -168,7 +169,7 @@ func TestTryAuthenticateWithDomainAndState(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.True(t, auth.TryAuthenticate(result, r, mockSource))
require.Equal(t, http.StatusFound, result.Code)
redirect, err := url.Parse(result.Header().Get("Location"))
@@ -228,11 +229,11 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) {
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
require.True(t, auth.TryAuthenticate(result, r, mockSource))
res := result.Result()
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, http.StatusFound, result.Code)
require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location"))
@@ -318,7 +319,7 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) {
contentServed := auth.CheckAuthentication(w, r, &domainMock{projectID: 1000, notFoundContent: "Generic 404"})
require.True(t, contentServed)
res := w.Result()
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, http.StatusNotFound, res.StatusCode)
@@ -493,11 +494,11 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) {
r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"}
resp := &http.Response{StatusCode: http.StatusUnauthorized, Body: io.NopCloser(bytes.NewReader([]byte("{\"error\":\"invalid_token\"}")))}
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.True(t, auth.CheckResponseForInvalidToken(result, r, resp))
res := result.Result()
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, http.StatusFound, res.StatusCode)
require.Equal(t, "http://pages.gitlab-example.com/test", result.Header().Get("Location"))
}
@@ -518,7 +519,7 @@ func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) {
func TestDomainAllowed(t *testing.T) {
auth := createTestAuth(t, "", "")
mockCtrl := gomock.NewController(t)
- mockSource := mocks.NewMockSource(mockCtrl)
+ mockSource := mock.NewMockSource(mockCtrl)
testCases := []struct {
name string
diff --git a/internal/config/config.go b/internal/config/config.go
index 3bb7b126..24a811ec 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -9,8 +9,6 @@ import (
"github.com/namsral/flag"
"gitlab.com/gitlab-org/labkit/log"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
// Config stores all the config options relevant to GitLab Pages.
@@ -58,10 +56,17 @@ type General struct {
// RateLimit config struct
type RateLimit struct {
+ // HTTP limits
SourceIPLimitPerSecond float64
SourceIPBurst int
DomainLimitPerSecond float64
DomainBurst int
+
+ // TLS connections limits
+ TLSSourceIPLimitPerSecond float64
+ TLSSourceIPBurst int
+ TLSDomainLimitPerSecond float64
+ TLSDomainBurst int
}
// ArtifactsServer groups settings related to configuring Artifacts
@@ -183,6 +188,11 @@ func loadConfig() (*Config, error) {
SourceIPBurst: *rateLimitSourceIPBurst,
DomainLimitPerSecond: *rateLimitDomain,
DomainBurst: *rateLimitDomainBurst,
+
+ TLSSourceIPLimitPerSecond: *rateLimitTLSSourceIP,
+ TLSSourceIPBurst: *rateLimitTLSSourceIPBurst,
+ TLSDomainLimitPerSecond: *rateLimitTLSDomain,
+ TLSDomainBurst: *rateLimitTLSDomainBurst,
},
GitLab: GitLab{
ClientHTTPTimeout: *gitlabClientHTTPTimeout,
@@ -217,8 +227,8 @@ func loadConfig() (*Config, error) {
Environment: *sentryEnvironment,
},
TLS: TLS{
- MinVersion: tls.AllTLSVersions[*tlsMinVersion],
- MaxVersion: tls.AllTLSVersions[*tlsMaxVersion],
+ MinVersion: allTLSVersions[*tlsMinVersion],
+ MaxVersion: allTLSVersions[*tlsMaxVersion],
},
Zip: ZipServing{
ExpirationInterval: *zipCacheExpiration,
@@ -267,40 +277,48 @@ func loadConfig() (*Config, error) {
func LogConfig(config *Config) {
log.WithFields(log.Fields{
- "artifacts-server": *artifactsServer,
- "artifacts-server-timeout": *artifactsServerTimeout,
- "default-config-filename": flag.DefaultConfigFlagname,
- "disable-cross-origin-requests": *disableCrossOriginRequests,
- "domain": config.General.Domain,
- "insecure-ciphers": config.General.InsecureCiphers,
- "listen-http": listenHTTP,
- "listen-https": listenHTTPS,
- "listen-proxy": listenProxy,
- "listen-https-proxyv2": listenHTTPSProxyv2,
- "log-format": *logFormat,
- "metrics-address": *metricsAddress,
- "pages-domain": *pagesDomain,
- "pages-root": *pagesRoot,
- "pages-status": *pagesStatus,
- "propagate-correlation-id": *propagateCorrelationID,
- "redirect-http": config.General.RedirectHTTP,
- "root-cert": *pagesRootKey,
- "root-key": *pagesRootCert,
- "status_path": config.General.StatusPath,
- "tls-min-version": *tlsMinVersion,
- "tls-max-version": *tlsMaxVersion,
- "gitlab-server": config.GitLab.PublicServer,
- "internal-gitlab-server": config.GitLab.InternalServer,
- "api-secret-key": *gitLabAPISecretKey,
- "enable-disk": config.GitLab.EnableDisk,
- "auth-redirect-uri": config.Authentication.RedirectURI,
- "auth-scope": config.Authentication.Scope,
- "max-conns": config.General.MaxConns,
- "max-uri-length": config.General.MaxURILength,
- "zip-cache-expiration": config.Zip.ExpirationInterval,
- "zip-cache-cleanup": config.Zip.CleanupInterval,
- "zip-cache-refresh": config.Zip.RefreshInterval,
- "zip-open-timeout": config.Zip.OpenTimeout,
+ "artifacts-server": *artifactsServer,
+ "artifacts-server-timeout": *artifactsServerTimeout,
+ "default-config-filename": flag.DefaultConfigFlagname,
+ "disable-cross-origin-requests": *disableCrossOriginRequests,
+ "domain": config.General.Domain,
+ "insecure-ciphers": config.General.InsecureCiphers,
+ "listen-http": listenHTTP,
+ "listen-https": listenHTTPS,
+ "listen-proxy": listenProxy,
+ "listen-https-proxyv2": listenHTTPSProxyv2,
+ "log-format": *logFormat,
+ "metrics-address": *metricsAddress,
+ "pages-domain": *pagesDomain,
+ "pages-root": *pagesRoot,
+ "pages-status": *pagesStatus,
+ "propagate-correlation-id": *propagateCorrelationID,
+ "redirect-http": config.General.RedirectHTTP,
+ "root-cert": *pagesRootKey,
+ "root-key": *pagesRootCert,
+ "status_path": config.General.StatusPath,
+ "tls-min-version": *tlsMinVersion,
+ "tls-max-version": *tlsMaxVersion,
+ "gitlab-server": config.GitLab.PublicServer,
+ "internal-gitlab-server": config.GitLab.InternalServer,
+ "api-secret-key": *gitLabAPISecretKey,
+ "enable-disk": config.GitLab.EnableDisk,
+ "auth-redirect-uri": config.Authentication.RedirectURI,
+ "auth-scope": config.Authentication.Scope,
+ "max-conns": config.General.MaxConns,
+ "max-uri-length": config.General.MaxURILength,
+ "zip-cache-expiration": config.Zip.ExpirationInterval,
+ "zip-cache-cleanup": config.Zip.CleanupInterval,
+ "zip-cache-refresh": config.Zip.RefreshInterval,
+ "zip-open-timeout": config.Zip.OpenTimeout,
+ "rate-limit-source-ip": config.RateLimit.SourceIPLimitPerSecond,
+ "rate-limit-source-ip-burst": config.RateLimit.SourceIPBurst,
+ "rate-limit-domain": config.RateLimit.DomainLimitPerSecond,
+ "rate-limit-domain-burst": config.RateLimit.DomainBurst,
+ "rate-limit-tls-source-ip": config.RateLimit.TLSSourceIPLimitPerSecond,
+ "rate-limit-tls-source-ip-burst": config.RateLimit.TLSSourceIPBurst,
+ "rate-limit-tls-domain": config.RateLimit.TLSDomainLimitPerSecond,
+ "rate-limit-tls-domain-burst": config.RateLimit.TLSDomainBurst,
}).Debug("Start Pages with configuration")
}
diff --git a/internal/config/flags.go b/internal/config/flags.go
index 93228827..409ecdc7 100644
--- a/internal/config/flags.go
+++ b/internal/config/flags.go
@@ -1,24 +1,41 @@
package config
import (
+ "crypto/tls"
+ "fmt"
+ "sort"
+ "strings"
"time"
"github.com/namsral/flag"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
var (
- pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
- pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
- redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
- _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
- pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
- pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
- rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit per source IP in number of requests per second, 0 means is disabled")
- rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second")
- rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled")
- rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second")
+ // allTLSVersions has all supported flag values
+ allTLSVersions = map[string]uint16{
+ "": 0, // Default value in tls.Config
+ "tls1.2": tls.VersionTLS12,
+ "tls1.3": tls.VersionTLS13,
+ }
+
+ pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages")
+ pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages")
+ redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS")
+ _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages")
+ pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored")
+ pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages")
+
+ // HTTP rate limits
+ rateLimitSourceIP = flag.Float64("rate-limit-source-ip", 0.0, "Rate limit HTTP requests per second from a single IP, 0 means is disabled")
+ rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit HTTP requests from a single IP, maximum burst allowed per second")
+ rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit HTTP requests per second to a single domain, 0 means is disabled")
+ rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit HTTP requests to a single domain, maximum burst allowed per second")
+ // TLS connections rate limits
+ rateLimitTLSSourceIP = flag.Float64("rate-limit-tls-source-ip", 0.0, "Rate limit new TLS connections per second from a single IP, 0 means is disabled")
+ rateLimitTLSSourceIPBurst = flag.Int("rate-limit-tls-source-ip-burst", 100, "Rate limit new TLS connections from a single IP, maximum burst allowed per second")
+ rateLimitTLSDomain = flag.Float64("rate-limit-tls-domain", 0.0, "Rate limit new TLS connections per second from to a single domain, 0 means is disabled")
+ rateLimitTLSDomainBurst = flag.Int("rate-limit-tls-domain-burst", 100, "Rate limit new TLS connections from a single domain, maximum burst allowed per second")
+
artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'")
artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server")
pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status")
@@ -55,8 +72,8 @@ var (
maxConns = flag.Int("max-conns", 0, "Limit on the number of concurrent connections to the HTTP, HTTPS or proxy listeners, 0 for no limit")
maxURILength = flag.Int("max-uri-length", 1024, "Limit the length of URI, 0 for unlimited.")
insecureCiphers = flag.Bool("insecure-ciphers", false, "Use default list of cipher suites, may contain insecure ones like 3DES and RC4")
- tlsMinVersion = flag.String("tls-min-version", "tls1.2", tls.FlagUsage("min"))
- tlsMaxVersion = flag.String("tls-max-version", "", tls.FlagUsage("max"))
+ tlsMinVersion = flag.String("tls-min-version", "tls1.2", tlsVersionFlagUsage("min"))
+ tlsMaxVersion = flag.String("tls-max-version", "", tlsVersionFlagUsage("max"))
zipCacheExpiration = flag.Duration("zip-cache-expiration", 60*time.Second, "Zip serving archive cache expiration interval")
zipCacheCleanup = flag.Duration("zip-cache-cleanup", 30*time.Second, "Zip serving archive cache cleanup interval")
zipCacheRefresh = flag.Duration("zip-cache-refresh", 30*time.Second, "Zip serving archive cache refresh interval")
@@ -88,3 +105,17 @@ func initFlags() {
flag.Parse()
}
+
+// tlsVersionFlagUsage returns string with explanation how to use the tls version CLI flag
+func tlsVersionFlagUsage(minOrMax string) string {
+ versions := []string{}
+
+ for version := range allTLSVersions {
+ if version != "" {
+ versions = append(versions, fmt.Sprintf("%q", version))
+ }
+ }
+ sort.Strings(versions)
+
+ return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", "))
+}
diff --git a/internal/config/tls/tls.go b/internal/config/tls/tls.go
deleted file mode 100644
index c76bbff7..00000000
--- a/internal/config/tls/tls.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package tls
-
-import (
- "crypto/tls"
- "fmt"
- "sort"
- "strings"
-)
-
-// GetCertificateFunc returns the certificate to be used for given domain
-type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
-
-var (
- preferredCipherSuites = []uint16{
- tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
- tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
- tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- tls.TLS_AES_128_GCM_SHA256,
- tls.TLS_AES_256_GCM_SHA384,
- tls.TLS_CHACHA20_POLY1305_SHA256,
- }
-
- // AllTLSVersions has all supported flag values
- AllTLSVersions = map[string]uint16{
- "": 0, // Default value in tls.Config
- "tls1.2": tls.VersionTLS12,
- "tls1.3": tls.VersionTLS13,
- }
-)
-
-// FlagUsage returns string with explanation how to use the CLI flag
-func FlagUsage(minOrMax string) string {
- versions := []string{}
-
- for version := range AllTLSVersions {
- if version != "" {
- versions = append(versions, fmt.Sprintf("%q", version))
- }
- }
- sort.Strings(versions)
-
- return fmt.Sprintf("Specifies the "+minOrMax+"imum SSL/TLS version, supported values are %s", strings.Join(versions, ", "))
-}
-
-// Create returns tls.Config for given app configuration
-func Create(cert, key []byte, getCertificate GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16) (*tls.Config, error) {
- // set MinVersion to fix gosec: G402
- tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}
-
- err := configureCertificate(tlsConfig, cert, key)
- if err != nil {
- return nil, err
- }
-
- if !insecureCiphers {
- configureTLSCiphers(tlsConfig)
- }
-
- tlsConfig.MinVersion = tlsMinVersion
- tlsConfig.MaxVersion = tlsMaxVersion
-
- return tlsConfig, nil
-}
-
-// ValidateTLSVersions returns error if the provided TLS versions config values are not valid
-func ValidateTLSVersions(min, max string) error {
- tlsMin, tlsMinOk := AllTLSVersions[min]
- tlsMax, tlsMaxOk := AllTLSVersions[max]
-
- if !tlsMinOk {
- return fmt.Errorf("invalid minimum TLS version: %s", min)
- }
- if !tlsMaxOk {
- return fmt.Errorf("invalid maximum TLS version: %s", max)
- }
- if tlsMin > tlsMax && tlsMax > 0 {
- return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min)
- }
-
- return nil
-}
-
-func configureCertificate(tlsConfig *tls.Config, cert, key []byte) error {
- certificate, err := tls.X509KeyPair(cert, key)
- if err != nil {
- return err
- }
-
- tlsConfig.Certificates = []tls.Certificate{certificate}
-
- return nil
-}
-
-func configureTLSCiphers(tlsConfig *tls.Config) {
- tlsConfig.PreferServerCipherSuites = true
- tlsConfig.CipherSuites = preferredCipherSuites
-}
diff --git a/internal/config/validate.go b/internal/config/validate.go
index a3dbcc3b..bb247287 100644
--- a/internal/config/validate.go
+++ b/internal/config/validate.go
@@ -6,8 +6,6 @@ import (
"net/url"
"github.com/hashicorp/go-multierror"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls"
)
var (
@@ -30,7 +28,7 @@ func Validate(config *Config) error {
validateListeners(config),
validateAuthConfig(config),
validateArtifactsServerConfig(config),
- tls.ValidateTLSVersions(*tlsMinVersion, *tlsMaxVersion),
+ validateTLSVersions(*tlsMinVersion, *tlsMaxVersion),
)
return result.ErrorOrNil()
@@ -115,3 +113,21 @@ func validateArtifactsServerConfig(config *Config) error {
return result.ErrorOrNil()
}
+
+// validateTLSVersions returns error if the provided TLS versions config values are not valid
+func validateTLSVersions(min, max string) error {
+ tlsMin, tlsMinOk := allTLSVersions[min]
+ tlsMax, tlsMaxOk := allTLSVersions[max]
+
+ if !tlsMinOk {
+ return fmt.Errorf("invalid minimum TLS version: %s", min)
+ }
+ if !tlsMaxOk {
+ return fmt.Errorf("invalid maximum TLS version: %s", max)
+ }
+ if tlsMin > tlsMax && tlsMax > 0 {
+ return fmt.Errorf("invalid maximum TLS version: %s; should be at least %s", max, min)
+ }
+
+ return nil
+}
diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go
index 2118bb93..6e96d7d3 100644
--- a/internal/config/validate_test.go
+++ b/internal/config/validate_test.go
@@ -154,3 +154,44 @@ func validConfig() Config {
return cfg
}
+
+func TestValidTLSVersions(t *testing.T) {
+ tests := map[string]struct {
+ tlsMin string
+ tlsMax string
+ }{
+ "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"},
+ "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"},
+ "tls 1.3 max": {tlsMax: "tls1.3"},
+ "tls 1.2 max": {tlsMax: "tls1.2"},
+ "tls 1.3+": {tlsMin: "tls1.3"},
+ "tls 1.2+": {tlsMin: "tls1.2"},
+ "default": {},
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ err := validateTLSVersions(tc.tlsMin, tc.tlsMax)
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestInvalidTLSVersions(t *testing.T) {
+ tests := map[string]struct {
+ tlsMin string
+ tlsMax string
+ err string
+ }{
+ "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"},
+ "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"},
+ "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"},
+ }
+
+ for name, tc := range tests {
+ t.Run(name, func(t *testing.T) {
+ err := validateTLSVersions(tc.tlsMin, tc.tlsMax)
+ require.EqualError(t, err, tc.err)
+ })
+ }
+}
diff --git a/internal/customheaders/customheaders.go b/internal/customheaders/customheaders.go
index b54df585..92f50069 100644
--- a/internal/customheaders/customheaders.go
+++ b/internal/customheaders/customheaders.go
@@ -1,12 +1,20 @@
package customheaders
import (
+ "bufio"
"errors"
+ "fmt"
"net/http"
+ "net/textproto"
"strings"
+
+ "github.com/hashicorp/go-multierror"
)
-var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter")
+var (
+ errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter")
+ errDuplicateHeader = errors.New("duplicate header")
+)
// AddCustomHeaders adds a map of Headers to a Response
func AddCustomHeaders(w http.ResponseWriter, headers http.Header) {
@@ -19,17 +27,30 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) {
// ParseHeaderString parses a string of key values into a map
func ParseHeaderString(customHeaders []string) (http.Header, error) {
- headers := http.Header{}
- for _, keyValueString := range customHeaders {
- keyValue := strings.SplitN(keyValueString, ":", 2)
- if len(keyValue) != 2 {
- return nil, errInvalidHeaderParameter
+ headers := make(http.Header, len(customHeaders))
+
+ var result *multierror.Error
+ for _, h := range customHeaders {
+ h = h + "\n\n"
+ tp := textproto.NewReader(bufio.NewReader(strings.NewReader(h)))
+
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ result = multierror.Append(result, fmt.Errorf("parsing error %s: %w", h, errInvalidHeaderParameter))
}
- key := strings.TrimSpace(keyValue[0])
- value := strings.TrimSpace(keyValue[1])
+ for key, value := range mimeHeader {
+ if _, ok := headers[key]; ok {
+ result = multierror.Append(result, fmt.Errorf("%s already specified with value '%s': %w", key, value, errDuplicateHeader))
+ }
- headers[key] = append(headers[key], value)
+ headers[key] = value
+ }
+ }
+
+ if result.ErrorOrNil() != nil {
+ return nil, result
}
+
return headers, nil
}
diff --git a/internal/customheaders/customheaders_test.go b/internal/customheaders/customheaders_test.go
index a667f43a..857c45e0 100644
--- a/internal/customheaders/customheaders_test.go
+++ b/internal/customheaders/customheaders_test.go
@@ -23,18 +23,6 @@ func TestParseHeaderString(t *testing.T) {
expectedLen: 1,
},
{
- name: "Whitespace trim case",
- headerStrings: []string{" X-Test-String : Test "},
- valid: true,
- expectedLen: 1,
- },
- {
- name: "Whitespace in key, value case",
- headerStrings: []string{"My amazing header: This is a test"},
- valid: true,
- expectedLen: 1,
- },
- {
name: "Non-tracking header case",
headerStrings: []string{"Tk: N"},
valid: true,
@@ -63,6 +51,11 @@ func TestParseHeaderString(t *testing.T) {
valid: false,
},
{
+ name: "duplicate headers",
+ headerStrings: []string{"Tk: N", "Tk: M"},
+ valid: false,
+ },
+ {
name: "Not valid case",
headerStrings: []string{"X-Test-String Some-Test"},
valid: false,
@@ -99,22 +92,13 @@ func TestAddCustomHeaders(t *testing.T) {
name string
headerStrings []string
wantHeaders map[string]string
- }{{
- name: "Normal case",
- headerStrings: []string{"X-Test-String: Test"},
- wantHeaders: map[string]string{"X-Test-String": "Test"},
- },
+ }{
{
- name: "Whitespace trim case",
- headerStrings: []string{" X-Test-String : Test "},
+ name: "Normal case",
+ headerStrings: []string{"X-Test-String: Test"},
wantHeaders: map[string]string{"X-Test-String": "Test"},
},
{
- name: "Whitespace in key, value case",
- headerStrings: []string{"My amazing header: This is a test"},
- wantHeaders: map[string]string{"My amazing header": "This is a test"},
- },
- {
name: "Non-tracking header case",
headerStrings: []string{"Tk: N"},
wantHeaders: map[string]string{"Tk": "N"},
@@ -122,12 +106,12 @@ func TestAddCustomHeaders(t *testing.T) {
{
name: "Content security header case",
headerStrings: []string{"content-security-policy: default-src 'self'"},
- wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"},
+ wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'"},
},
{
name: "Multiple header strings",
- headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"},
- wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"},
+ headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header: Amazing"},
+ wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"},
},
}
@@ -139,7 +123,11 @@ func TestAddCustomHeaders(t *testing.T) {
customheaders.AddCustomHeaders(w, headers)
rsp := w.Result()
for k, v := range tt.wantHeaders {
- require.Equal(t, v, rsp.Header.Get(k), "Expected header %+v, got %+v", v, rsp.Header.Get(k))
+ require.Len(t, rsp.Header[k], 1)
+
+ // use the map directly to make sure ParseHeaderString is adding the canonical keys
+ got := rsp.Header[k][0]
+ require.Equal(t, v, got, "Expected header %+v, got %+v", v, got)
}
})
}
diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go
index 408ddf38..3496a6da 100644
--- a/internal/domain/domain_test.go
+++ b/internal/domain/domain_test.go
@@ -9,8 +9,8 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-pages/internal/domain"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/domain/mock"
"gitlab.com/gitlab-org/gitlab-pages/internal/fixture"
- "gitlab.com/gitlab-org/gitlab-pages/internal/mocks"
"gitlab.com/gitlab-org/gitlab-pages/internal/serving"
"gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/local"
"gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
@@ -208,7 +208,7 @@ func TestServeNamespaceNotFound(t *testing.T) {
func mockResolver(t *testing.T, project *serving.LookupPath, subpath string, err error) domain.Resolver {
mockCtrl := gomock.NewController(t)
- mockResolver := mocks.NewMockResolver(mockCtrl)
+ mockResolver := mock.NewMockResolver(mockCtrl)
mockResolver.EXPECT().
Resolve(gomock.Any()).
diff --git a/internal/mocks/resolver.go b/internal/domain/mock/resolver_mock.go
index 14b86a9c..8231b2b4 100644
--- a/internal/mocks/resolver.go
+++ b/internal/domain/mock/resolver_mock.go
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: internal/domain/resolver.go
-// Package mocks is a generated GoMock package.
-package mocks
+// Package mock is a generated GoMock package.
+package mock
import (
http "net/http"
diff --git a/internal/feature/feature.go b/internal/feature/feature.go
index 81eef9a0..c98fe85d 100644
--- a/internal/feature/feature.go
+++ b/internal/feature/feature.go
@@ -8,17 +8,29 @@ type Feature struct {
}
// EnforceIPRateLimits enforces IP rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceIPRateLimits = Feature{
EnvVariable: "FF_ENFORCE_IP_RATE_LIMITS",
}
// EnforceDomainRateLimits enforces domain rate limiter to drop requests
-// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
var EnforceDomainRateLimits = Feature{
EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS",
}
+// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceDomainTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS",
+}
+
+// EnforceIPTLSRateLimits enforces domain rate limits on establishing new TLS connections
+// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/706
+var EnforceIPTLSRateLimits = Feature{
+ EnvVariable: "FF_ENFORCE_IP_TLS_RATE_LIMITS",
+}
+
// RedirectsPlaceholders enables support for placeholders in redirects file
// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620
var RedirectsPlaceholders = Feature{
diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go
index dd940251..22027ad0 100644
--- a/internal/handlers/handlers_test.go
+++ b/internal/handlers/handlers_test.go
@@ -10,19 +10,19 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-pages/internal/mocks"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/handlers/mock"
)
func TestNotHandleArtifactRequestReturnsFalse(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockArtifact := mocks.NewMockArtifact(mockCtrl)
+ mockArtifact := mock.NewMockArtifact(mockCtrl)
mockArtifact.EXPECT().
TryMakeRequest(gomock.Any(), gomock.Any(), gomock.Any(), "", gomock.Any()).
Return(false).
Times(1)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().
GetTokenIfExists(gomock.Any(), gomock.Any()).
Return("", nil).
@@ -41,13 +41,13 @@ func TestNotHandleArtifactRequestReturnsFalse(t *testing.T) {
func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockArtifact := mocks.NewMockArtifact(mockCtrl)
+ mockArtifact := mock.NewMockArtifact(mockCtrl)
mockArtifact.EXPECT().
TryMakeRequest(gomock.Any(), gomock.Any(), gomock.Any(), "", gomock.Any()).
Return(true).
Times(1)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().
GetTokenIfExists(gomock.Any(), gomock.Any()).
Return("", nil).
@@ -64,7 +64,7 @@ func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
func TestNotFoundWithTokenIsNotHandled(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().CheckResponseForInvalidToken(gomock.Any(), gomock.Any(), gomock.Any()).
Return(false)
@@ -101,7 +101,7 @@ func TestForbiddenWithTokenIsNotHandled(t *testing.T) {
t.Run(tn, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
if tc.Token == "" {
mockAuth.EXPECT().IsAuthSupported().Return(true)
mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any()).Return(true)
@@ -125,7 +125,7 @@ func TestForbiddenWithTokenIsNotHandled(t *testing.T) {
func TestNotFoundWithoutTokenIsNotHandledWhenNotAuthSupport(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().IsAuthSupported().Return(false)
handlers := New(mockAuth, nil)
@@ -140,7 +140,7 @@ func TestNotFoundWithoutTokenIsNotHandledWhenNotAuthSupport(t *testing.T) {
func TestNotFoundWithoutTokenIsHandled(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().IsAuthSupported().Return(true)
mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any()).Times(1).Return(true)
@@ -156,7 +156,7 @@ func TestNotFoundWithoutTokenIsHandled(t *testing.T) {
func TestInvalidTokenResponseIsHandled(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().CheckResponseForInvalidToken(gomock.Any(), gomock.Any(), gomock.Any()).
Return(true)
@@ -173,12 +173,12 @@ func TestInvalidTokenResponseIsHandled(t *testing.T) {
func TestHandleArtifactRequestButGetTokenFails(t *testing.T) {
mockCtrl := gomock.NewController(t)
- mockArtifact := mocks.NewMockArtifact(mockCtrl)
+ mockArtifact := mock.NewMockArtifact(mockCtrl)
mockArtifact.EXPECT().
TryMakeRequest(gomock.Any(), gomock.Any(), gomock.Any(), "", gomock.Any()).
Times(0)
- mockAuth := mocks.NewMockAuth(mockCtrl)
+ mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().GetTokenIfExists(gomock.Any(), gomock.Any()).Return("", errors.New("error when retrieving token"))
handlers := New(mockAuth, mockArtifact)
diff --git a/internal/mocks/mocks.go b/internal/handlers/mock/handler_mock.go
index b18bede4..11548221 100644
--- a/internal/mocks/mocks.go
+++ b/internal/handlers/mock/handler_mock.go
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: internal/interface.go
-// Package mocks is a generated GoMock package.
-package mocks
+// Package mock is a generated GoMock package.
+package mock
import (
http "net/http"
diff --git a/internal/handlers/ratelimiter.go b/internal/handlers/ratelimiter.go
index 52281f6e..a8eee005 100644
--- a/internal/handlers/ratelimiter.go
+++ b/internal/handlers/ratelimiter.go
@@ -14,11 +14,11 @@ import (
// TODO: make this unexported once https://gitlab.com/gitlab-org/gitlab-pages/-/issues/670 is done
func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
sourceIPLimiter := ratelimiter.New(
- "source_ip",
+ "http_requests_by_source_ip",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitSourceIPCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitSourceIPCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitSourceIPBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.SourceIPLimitPerSecond),
ratelimiter.WithBurstSize(config.SourceIPBurst),
ratelimiter.WithEnforce(feature.EnforceIPRateLimits.Enabled()),
@@ -27,12 +27,12 @@ func Ratelimiter(handler http.Handler, config *config.RateLimit) http.Handler {
handler = sourceIPLimiter.Middleware(handler)
domainLimiter := ratelimiter.New(
- "domain",
+ "http_requests_by_domain",
ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
ratelimiter.WithKeyFunc(request.GetHostWithoutPort),
- ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainCachedEntries),
- ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainCacheRequests),
- ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainBlockedCount),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
ratelimiter.WithLimitPerSecond(config.DomainLimitPerSecond),
ratelimiter.WithBurstSize(config.DomainBurst),
ratelimiter.WithEnforce(feature.EnforceDomainRateLimits.Enabled()),
diff --git a/internal/httpfs/http_fs_test.go b/internal/httpfs/http_fs_test.go
index 01101984..10a1ef7b 100644
--- a/internal/httpfs/http_fs_test.go
+++ b/internal/httpfs/http_fs_test.go
@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-pages/internal/httptransport"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestFSOpen(t *testing.T) {
@@ -161,7 +162,7 @@ func TestFileSystemPathCanServeHTTP(t *testing.T) {
res, err := client.Do(req)
require.NoError(t, err)
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, test.expectedStatusCode, res.StatusCode)
content, err := io.ReadAll(res.Body)
diff --git a/internal/httptransport/metered_round_tripper_test.go b/internal/httptransport/metered_round_tripper_test.go
index 2b126760..2d05a1c3 100644
--- a/internal/httptransport/metered_round_tripper_test.go
+++ b/internal/httptransport/metered_round_tripper_test.go
@@ -10,6 +10,8 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestReconfigureMeteredRoundTripper(t *testing.T) {
@@ -23,7 +25,7 @@ func TestReconfigureMeteredRoundTripper(t *testing.T) {
res, err := mrt.RoundTrip(r)
require.NoError(t, err)
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, http.StatusOK, res.StatusCode)
body, err := io.ReadAll(res.Body)
diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go
index af7b0881..1be6f642 100644
--- a/internal/ratelimiter/middleware.go
+++ b/internal/ratelimiter/middleware.go
@@ -8,7 +8,6 @@ import (
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
- "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
@@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler {
rl.logRateLimitedRequest(r)
if rl.blockedCount != nil {
- rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc()
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
}
if rl.enforce {
@@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) {
"x_forwarded_proto": r.Header.Get(headerXForwardedProto),
"x_forwarded_for": r.Header.Get(headerXForwardedFor),
"gitlab_real_ip": r.Header.Get(headerGitLabRealIP),
- "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(),
+ "enforced": rl.enforce,
"rate_limiter_limit_per_second": rl.limitPerSecond,
"rate_limiter_burst_size": rl.burstSize,
}). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629
diff --git a/internal/ratelimiter/middleware_test.go b/internal/ratelimiter/middleware_test.go
index 1f753fc4..25ac08b5 100644
--- a/internal/ratelimiter/middleware_test.go
+++ b/internal/ratelimiter/middleware_test.go
@@ -107,7 +107,7 @@ func TestMiddlewareDenyRequestsAfterBurst(t *testing.T) {
assertSourceIPLog(t, remoteAddr, hook)
}
- blockedCount := testutil.ToFloat64(blocked.WithLabelValues(strconv.FormatBool(tc.enforce)))
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("rate_limiter", strconv.FormatBool(tc.enforce)))
require.Equal(t, float64(4), blockedCount, "blocked count")
blocked.Reset()
@@ -226,7 +226,7 @@ func newTestMetrics(t *testing.T) (*prometheus.GaugeVec, *prometheus.GaugeVec, *
prometheus.GaugeOpts{
Name: t.Name(),
},
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
cachedEntries := prometheus.NewGaugeVec(prometheus.GaugeOpts{
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go
index feeb8cb4..b64cb4ce 100644
--- a/internal/ratelimiter/ratelimiter.go
+++ b/internal/ratelimiter/ratelimiter.go
@@ -1,6 +1,8 @@
package ratelimiter
import (
+ "crypto/tls"
+ "net"
"net/http"
"time"
@@ -27,6 +29,9 @@ type Option func(*RateLimiter)
// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
type KeyFunc func(*http.Request) string
+// TLSKeyFunc is used by GetCertificateMiddleware to identify the subject of rate limit (client IP or SNI servername)
+type TLSKeyFunc func(*tls.ClientHelloInfo) string
+
// RateLimiter holds an LRU cache of elements to be rate limited.
// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
@@ -35,6 +40,7 @@ type RateLimiter struct {
name string
now func() time.Time
keyFunc KeyFunc
+ tlsKeyFunc TLSKeyFunc
limitPerSecond float64
burstSize int
blockedCount *prometheus.GaugeVec
@@ -120,6 +126,26 @@ func WithKeyFunc(f KeyFunc) Option {
}
}
+func TLSHostnameKey(info *tls.ClientHelloInfo) string {
+ return info.ServerName
+}
+
+func TLSClientIPKey(info *tls.ClientHelloInfo) string {
+ remoteAddr := info.Conn.RemoteAddr().String()
+ remoteAddr, _, err := net.SplitHostPort(remoteAddr)
+ if err != nil {
+ return remoteAddr
+ }
+
+ return remoteAddr
+}
+
+func WithTLSKeyFunc(keyFunc TLSKeyFunc) Option {
+ return func(rl *RateLimiter) {
+ rl.tlsKeyFunc = keyFunc
+ }
+}
+
// WithEnforce configures if requests are actually rejected, or we just report them as rejected in metrics
func WithEnforce(enforce bool) Option {
return func(rl *RateLimiter) {
@@ -138,6 +164,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter {
// requestAllowed checks if request is within the rate-limit
func (rl *RateLimiter) requestAllowed(r *http.Request) bool {
rateLimitedKey := rl.keyFunc(r)
+
+ return rl.allowed(rateLimitedKey)
+}
+
+func (rl *RateLimiter) allowed(rateLimitedKey string) bool {
limiter := rl.limiter(rateLimitedKey)
// AllowN allows us to use the rl.now function, so we can test this more easily.
diff --git a/internal/ratelimiter/stub_conn_test.go b/internal/ratelimiter/stub_conn_test.go
new file mode 100644
index 00000000..9d7a4a9a
--- /dev/null
+++ b/internal/ratelimiter/stub_conn_test.go
@@ -0,0 +1,14 @@
+package ratelimiter
+
+import (
+ "net"
+)
+
+type stubConn struct {
+ net.Conn
+ remoteAddr net.Addr
+}
+
+func (s stubConn) RemoteAddr() net.Addr {
+ return s.remoteAddr
+}
diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go
new file mode 100644
index 00000000..3bebbc38
--- /dev/null
+++ b/internal/ratelimiter/tls.go
@@ -0,0 +1,49 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "strconv"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var ErrTLSRateLimited = errors.New("too many connections, please retry later")
+
+type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+
+func (rl *RateLimiter) GetCertificateMiddleware(getCertificate GetCertificateFunc) GetCertificateFunc {
+ if rl.limitPerSecond <= 0.0 {
+ return getCertificate
+ }
+
+ return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if rl.allowed(rl.tlsKeyFunc(info)) {
+ return getCertificate(info)
+ }
+
+ rl.logRateLimitedTLS(info)
+
+ if rl.blockedCount != nil {
+ rl.blockedCount.WithLabelValues(rl.name, strconv.FormatBool(rl.enforce)).Inc()
+ }
+
+ if !rl.enforce {
+ return getCertificate(info)
+ }
+
+ return nil, ErrTLSRateLimited
+ }
+}
+
+func (rl *RateLimiter) logRateLimitedTLS(info *tls.ClientHelloInfo) {
+ log.WithFields(logrus.Fields{
+ "rate_limiter_name": rl.name,
+ "source_ip": TLSClientIPKey(info),
+ "req_host": info.ServerName,
+ "rate_limiter_limit_per_second": rl.limitPerSecond,
+ "rate_limiter_burst_size": rl.burstSize,
+ "enforced": rl.enforce,
+ }).Info("TLS connection rate-limited")
+}
diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go
new file mode 100644
index 00000000..6763514b
--- /dev/null
+++ b/internal/ratelimiter/tls_test.go
@@ -0,0 +1,194 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ "github.com/sirupsen/logrus"
+ testlog "github.com/sirupsen/logrus/hooks/test"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTLSHostnameKey(t *testing.T) {
+ info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"}
+
+ require.Equal(t, "group.gitlab.io", TLSHostnameKey(info))
+}
+
+func TestTLSClientIPKey(t *testing.T) {
+ tests := []struct {
+ addr string
+ expected string
+ }{
+ {
+ "10.1.2.3:1234",
+ "10.1.2.3",
+ },
+ {
+ "[2001:db8:3333:4444:5555:6666:7777:8888]:1234",
+ "2001:db8:3333:4444:5555:6666:7777:8888",
+ },
+ }
+
+ for _, tt := range tests {
+ addr, err := net.ResolveTCPAddr("tcp", tt.addr)
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn}
+
+ require.Equal(t, tt.expected, TLSClientIPKey(info))
+ }
+}
+
+func TestGetCertificateMiddleware(t *testing.T) {
+ tests := map[string]struct {
+ useHostnameAsKey bool
+ enforced bool
+ limitPerSecond float64
+ burst int
+ successfulReqCnt int
+ }{
+ "ip_limiter": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "hostname_limiter": {
+ useHostnameAsKey: true,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "not_enforced": {
+ useHostnameAsKey: false,
+ enforced: false,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "disabled": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0,
+ burst: 5,
+ successfulReqCnt: 10,
+ },
+ "slowly_approach_limit": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.2,
+ burst: 5,
+ successfulReqCnt: 6, // 5 * 0.2 gives another 1 request
+ },
+ }
+
+ expectedCert := &tls.Certificate{}
+ expectedErr := errors.New("expected error")
+
+ getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return expectedCert, expectedErr
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ keyFunc := TLSClientIPKey
+ if tt.useHostnameAsKey {
+ keyFunc = TLSHostnameKey
+ }
+
+ rl := New("limit_name",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
+ WithBlockedCountMetric(blocked),
+ WithNow(stubNow()),
+ WithLimitPerSecond(tt.limitPerSecond),
+ WithBurstSize(tt.burst),
+ WithEnforce(tt.enforced),
+ WithTLSKeyFunc(keyFunc))
+
+ middlewareGetCert := rl.GetCertificateMiddleware(getCertificate)
+
+ addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345")
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"}
+
+ for i := 0; i < tt.successfulReqCnt; i++ {
+ cert, err := middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ // When rate-limiter disabled altogether
+ if tt.limitPerSecond <= 0 {
+ return
+ }
+
+ cert, err := middlewareGetCert(info)
+ if tt.enforced {
+ require.Nil(t, cert)
+ require.Equal(t, err, ErrTLSRateLimited)
+ } else {
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ require.NotNil(t, hook.LastEntry())
+ require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message)
+ expectedFields := logrus.Fields{
+ "rate_limiter_name": "limit_name",
+ "source_ip": "10.1.2.3",
+ "req_host": "group.gitlab.io",
+ "rate_limiter_limit_per_second": tt.limitPerSecond,
+ "rate_limiter_burst_size": tt.burst,
+ "enforced": tt.enforced,
+ }
+ require.Equal(t, expectedFields, hook.LastEntry().Data)
+
+ // make another request with different key and expect success
+ if tt.useHostnameAsKey {
+ info.ServerName = "another-group.gitlab.io"
+ } else {
+ addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345")
+ require.NoError(t, err)
+ conn = stubConn{remoteAddr: addr}
+ info.Conn = conn
+ }
+
+ cert, err = middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced)))
+ require.Equal(t, float64(1), blockedCount, "blocked count")
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name"))
+ require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss"))
+ require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit"))
+ require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count")
+ })
+ }
+}
+
+func stubNow() func() time.Time {
+ now := time.Now()
+ return func() time.Time {
+ now = now.Add(time.Second)
+
+ return now
+ }
+}
diff --git a/internal/serving/disk/local/serving_test.go b/internal/serving/disk/local/serving_test.go
index 8352bfc9..dbb14138 100644
--- a/internal/serving/disk/local/serving_test.go
+++ b/internal/serving/disk/local/serving_test.go
@@ -79,7 +79,7 @@ func TestDisk_ServeFileHTTP(t *testing.T) {
require.True(t, s.ServeFileHTTP(handler))
resp := w.Result()
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, test.expectedStatus, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
diff --git a/internal/serving/disk/zip/serving_test.go b/internal/serving/disk/zip/serving_test.go
index 2b603b5d..c3b18f2b 100644
--- a/internal/serving/disk/zip/serving_test.go
+++ b/internal/serving/disk/zip/serving_test.go
@@ -209,7 +209,7 @@ func TestZip_ServeFileHTTP(t *testing.T) {
require.True(t, s.ServeFileHTTP(handler))
resp := w.Result()
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, test.expectedStatus, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
diff --git a/internal/source/gitlab/cache/retriever.go b/internal/source/gitlab/cache/retriever.go
index 65c7f1a2..68299f6c 100644
--- a/internal/source/gitlab/cache/retriever.go
+++ b/internal/source/gitlab/cache/retriever.go
@@ -78,6 +78,11 @@ func (r *Retriever) resolveWithBackoff(ctx context.Context, domainName string) <
break
}
+ if errors.Is(lookup.Error, context.Canceled) || errors.Is(lookup.Error, context.DeadlineExceeded) {
+ // do not retry if there's a context error to avoid leaking the goroutine
+ break
+ }
+
time.Sleep(r.maxRetrievalInterval)
}
diff --git a/internal/source/gitlab/gitlab_test.go b/internal/source/gitlab/gitlab_test.go
index d7fbf454..e7f80387 100644
--- a/internal/source/gitlab/gitlab_test.go
+++ b/internal/source/gitlab/gitlab_test.go
@@ -12,9 +12,9 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitlab-pages/internal/mocks"
"gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api"
"gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/client"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/mock"
)
func TestGetDomain(t *testing.T) {
@@ -171,10 +171,10 @@ func TestResolveLookupPathsOrderDoesNotMatter(t *testing.T) {
}
}
-func NewMockClient(t *testing.T, file string, mockedLookup *api.Lookup) *mocks.MockClientStub {
+func NewMockClient(t *testing.T, file string, mockedLookup *api.Lookup) *mock.MockClientStub {
mockCtrl := gomock.NewController(t)
- mockClient := mocks.NewMockClientStub(mockCtrl)
+ mockClient := mock.NewMockClientStub(mockCtrl)
mockClient.EXPECT().
Resolve(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, domain string) *api.Lookup {
diff --git a/internal/mocks/client.go b/internal/source/gitlab/mock/client_mock.go
index 22b0757e..6f78b01b 100644
--- a/internal/mocks/client.go
+++ b/internal/source/gitlab/mock/client_mock.go
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: internal/mocks/api/client_stub.go
+// Source: internal/source/gitlab/mock/client_stub.go
-// Package mocks is a generated GoMock package.
-package mocks
+// Package mock is a generated GoMock package.
+package mock
import (
context "context"
diff --git a/internal/mocks/api/client_stub.go b/internal/source/gitlab/mock/client_stub.go
index f1a20754..5b201da2 100644
--- a/internal/mocks/api/client_stub.go
+++ b/internal/source/gitlab/mock/client_stub.go
@@ -1,4 +1,4 @@
-package mocks
+package mock
import "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab/api"
diff --git a/internal/mocks/source.go b/internal/source/mock/source_mock.go
index 5412310f..6ec884f1 100644
--- a/internal/mocks/source.go
+++ b/internal/source/mock/source_mock.go
@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: internal/source/source.go
-// Package mocks is a generated GoMock package.
-package mocks
+// Package mock is a generated GoMock package.
+package mock
import (
context "context"
diff --git a/internal/testhelpers/testhelpers.go b/internal/testhelpers/testhelpers.go
index 11bf7e65..2f708743 100644
--- a/internal/testhelpers/testhelpers.go
+++ b/internal/testhelpers/testhelpers.go
@@ -95,3 +95,13 @@ func PerformRequest(t *testing.T, handler http.Handler, r *http.Request) (int, s
return res.StatusCode, string(b)
}
+
+// Close will call the close function on a closer as part
+// of the t.Cleanup function.
+func Close(t *testing.T, c io.Closer) {
+ t.Helper()
+
+ t.Cleanup(func() {
+ require.NoError(t, c.Close())
+ })
+}
diff --git a/internal/tls/tls.go b/internal/tls/tls.go
new file mode 100644
index 00000000..6d7397af
--- /dev/null
+++ b/internal/tls/tls.go
@@ -0,0 +1,96 @@
+package tls
+
+import (
+ "crypto/tls"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/config"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter"
+ "gitlab.com/gitlab-org/gitlab-pages/metrics"
+)
+
+var preferredCipherSuites = []uint16{
+ tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
+ tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+}
+
+// GetCertificateFunc returns the certificate to be used for given domain
+type GetCertificateFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error)
+
+// GetTLSConfig initializes tls.Config based on config flags
+// getCertificateByServerName obtains certificate based on domain
+func GetTLSConfig(cfg *config.Config, getCertificateByServerName GetCertificateFunc) (*tls.Config, error) {
+ wildcardCertificate, err := tls.X509KeyPair(cfg.General.RootCertificate, cfg.General.RootKey)
+ if err != nil {
+ return nil, err
+ }
+
+ getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ // Golang calls tls.Config.GetCertificate only if it's set and
+ // 1. ServerName != ""
+ // 2. Or tls.Config.Certificates is empty array
+ // tls.Config.Certificates contain wildcard certificate
+ // We want to implement rate limits via GetCertificate, so we need to call it every time
+ // So we don't set tls.Config.Certificates, but simulate the behavior of golang:
+ // 1. try to get certificate by name
+ // 2. if we can't, fallback to default(wildcard) certificate
+ customCertificate, err := getCertificateByServerName(info)
+
+ if customCertificate != nil || err != nil {
+ return customCertificate, err
+ }
+
+ return &wildcardCertificate, nil
+ }
+
+ TLSDomainRateLimiter := ratelimiter.New(
+ "tls_connections_by_domain",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSHostnameKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultDomainCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSDomainLimitPerSecond),
+ ratelimiter.WithBurstSize(cfg.RateLimit.TLSDomainBurst),
+ ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()),
+ )
+
+ TLSSourceIPRateLimiter := ratelimiter.New(
+ "tls_connections_by_source_ip",
+ ratelimiter.WithTLSKeyFunc(ratelimiter.TLSClientIPKey),
+ ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize),
+ ratelimiter.WithCachedEntriesMetric(metrics.RateLimitCachedEntries),
+ ratelimiter.WithCachedRequestsMetric(metrics.RateLimitCacheRequests),
+ ratelimiter.WithBlockedCountMetric(metrics.RateLimitBlockedCount),
+ ratelimiter.WithLimitPerSecond(cfg.RateLimit.TLSSourceIPLimitPerSecond),
+ ratelimiter.WithBurstSize(cfg.RateLimit.TLSSourceIPBurst),
+ ratelimiter.WithEnforce(feature.EnforceIPTLSRateLimits.Enabled()),
+ )
+
+ getCertificate = TLSDomainRateLimiter.GetCertificateMiddleware(getCertificate)
+ getCertificate = TLSSourceIPRateLimiter.GetCertificateMiddleware(getCertificate)
+
+ // set MinVersion to fix gosec: G402
+ tlsConfig := &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}
+
+ if !cfg.General.InsecureCiphers {
+ configureTLSCiphers(tlsConfig)
+ }
+
+ tlsConfig.MinVersion = cfg.TLS.MinVersion
+ tlsConfig.MaxVersion = cfg.TLS.MaxVersion
+
+ return tlsConfig, nil
+}
+
+func configureTLSCiphers(tlsConfig *tls.Config) {
+ tlsConfig.PreferServerCipherSuites = true
+ tlsConfig.CipherSuites = preferredCipherSuites
+}
diff --git a/internal/config/tls/tls_test.go b/internal/tls/tls_test.go
index 7146d904..e65b2671 100644
--- a/internal/config/tls/tls_test.go
+++ b/internal/tls/tls_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/config"
)
var cert = []byte(`-----BEGIN CERTIFICATE-----
@@ -29,65 +31,46 @@ var getCertificate = func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
}
-func TestInvalidTLSVersions(t *testing.T) {
- tests := map[string]struct {
- tlsMin string
- tlsMax string
- err string
- }{
- "invalid minimum TLS version": {tlsMin: "tls123", tlsMax: "", err: "invalid minimum TLS version: tls123"},
- "invalid maximum TLS version": {tlsMin: "", tlsMax: "tls123", err: "invalid maximum TLS version: tls123"},
- "TLS versions conflict": {tlsMin: "tls1.3", tlsMax: "tls1.2", err: "invalid maximum TLS version: tls1.2; should be at least tls1.3"},
- }
-
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax)
- require.EqualError(t, err, tc.err)
- })
- }
-}
-
-func TestValidTLSVersions(t *testing.T) {
- tests := map[string]struct {
- tlsMin string
- tlsMax string
- }{
- "tls 1.3 only": {tlsMin: "tls1.3", tlsMax: "tls1.3"},
- "tls 1.2 only": {tlsMin: "tls1.2", tlsMax: "tls1.2"},
- "tls 1.3 max": {tlsMax: "tls1.3"},
- "tls 1.2 max": {tlsMax: "tls1.2"},
- "tls 1.3+": {tlsMin: "tls1.3"},
- "tls 1.2+": {tlsMin: "tls1.2"},
- "default": {},
- }
-
- for name, tc := range tests {
- t.Run(name, func(t *testing.T) {
- err := ValidateTLSVersions(tc.tlsMin, tc.tlsMax)
- require.NoError(t, err)
- })
- }
-}
-
func TestInvalidKeyPair(t *testing.T) {
- _, err := Create([]byte(``), []byte(``), getCertificate, false, tls.VersionTLS11, tls.VersionTLS12)
+ cfg := &config.Config{}
+ _, err := GetTLSConfig(cfg, getCertificate)
require.EqualError(t, err, "tls: failed to find any PEM data in certificate input")
}
func TestInsecureCiphers(t *testing.T) {
- tlsConfig, err := Create(cert, key, getCertificate, true, tls.VersionTLS11, tls.VersionTLS12)
+ cfg := &config.Config{
+ General: config.General{
+ RootCertificate: cert,
+ RootKey: key,
+ InsecureCiphers: true,
+ },
+ }
+ tlsConfig, err := GetTLSConfig(cfg, getCertificate)
require.NoError(t, err)
require.False(t, tlsConfig.PreferServerCipherSuites)
require.Empty(t, tlsConfig.CipherSuites)
}
-func TestCreate(t *testing.T) {
- tlsConfig, err := Create(cert, key, getCertificate, false, tls.VersionTLS11, tls.VersionTLS12)
+func TestGetTLSConfig(t *testing.T) {
+ cfg := &config.Config{
+ General: config.General{
+ RootCertificate: cert,
+ RootKey: key,
+ },
+ TLS: config.TLS{
+ MinVersion: tls.VersionTLS11,
+ MaxVersion: tls.VersionTLS12,
+ },
+ }
+ tlsConfig, err := GetTLSConfig(cfg, getCertificate)
require.NoError(t, err)
require.IsType(t, getCertificate, tlsConfig.GetCertificate)
require.True(t, tlsConfig.PreferServerCipherSuites)
require.Equal(t, preferredCipherSuites, tlsConfig.CipherSuites)
require.Equal(t, uint16(tls.VersionTLS11), tlsConfig.MinVersion)
require.Equal(t, uint16(tls.VersionTLS12), tlsConfig.MaxVersion)
+
+ cert, err := tlsConfig.GetCertificate(&tls.ClientHelloInfo{})
+ require.NoError(t, err)
+ require.NotNil(t, cert)
}
diff --git a/internal/urilimiter/urilimiter_test.go b/internal/urilimiter/urilimiter_test.go
index 0ac89ea0..23b84102 100644
--- a/internal/urilimiter/urilimiter_test.go
+++ b/internal/urilimiter/urilimiter_test.go
@@ -8,6 +8,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestNewMiddleware(t *testing.T) {
@@ -56,7 +58,7 @@ func TestNewMiddleware(t *testing.T) {
middleware.ServeHTTP(ww, rr)
res := ww.Result()
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, tt.expectedStatus, res.StatusCode)
if tt.expectedStatus == http.StatusOK {
diff --git a/metrics/metrics.go b/metrics/metrics.go
index e0e4ab29..26c7413f 100644
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -193,60 +193,31 @@ var (
},
)
- // RateLimitSourceIPCacheRequests is the number of cache hits/misses
- RateLimitSourceIPCacheRequests = prometheus.NewCounterVec(
+ // RateLimitCacheRequests is the number of cache hits/misses
+ RateLimitCacheRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cache_requests",
+ Name: "gitlab_pages_rate_limit_cache_requests",
Help: "The number of source_ip cache hits/misses in the rate limiter",
},
[]string{"op", "cache"},
)
- // RateLimitSourceIPCachedEntries is the number of entries in the cache
- RateLimitSourceIPCachedEntries = prometheus.NewGaugeVec(
+ // RateLimitCachedEntries is the number of entries in the cache
+ RateLimitCachedEntries = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_cached_entries",
+ Name: "gitlab_pages_rate_limit_cached_entries",
Help: "The number of entries in the cache",
},
[]string{"op"},
)
- // RateLimitSourceIPBlockedCount is the number of requests that have been blocked by the
- // source IP rate limiter
- RateLimitSourceIPBlockedCount = prometheus.NewGaugeVec(
+ // RateLimitBlockedCount is the number of requests that have been blocked
+ RateLimitBlockedCount = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_source_ip_blocked_count",
- Help: "The number of requests that have been blocked by the IP rate limiter",
+ Name: "gitlab_pages_rate_limit_blocked_count",
+ Help: "The number of requests/connections that have been blocked by rate limiter",
},
- []string{"enforced"},
- )
-
- // RateLimitDomainCacheRequests is the number of cache hits/misses
- RateLimitDomainCacheRequests = prometheus.NewCounterVec(
- prometheus.CounterOpts{
- Name: "gitlab_pages_rate_limit_domain_cache_requests",
- Help: "The number of source_ip cache hits/misses in the rate limiter",
- },
- []string{"op", "cache"},
- )
-
- // RateLimitDomainCachedEntries is the number of entries in the cache
- RateLimitDomainCachedEntries = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_cached_entries",
- Help: "The number of entries in the cache",
- },
- []string{"op"},
- )
-
- // RateLimitDomainBlockedCount is the number of requests that have been blocked by the
- // domain rate limiter
- RateLimitDomainBlockedCount = prometheus.NewGaugeVec(
- prometheus.GaugeOpts{
- Name: "gitlab_pages_rate_limit_domain_blocked_count",
- Help: "The number of requests addresses that have been blocked by the domain rate limiter",
- },
- []string{"enforced"},
+ []string{"limit_name", "enforced"},
)
)
@@ -276,8 +247,8 @@ func MustRegister() {
LimitListenerConcurrentConns,
LimitListenerWaitingConns,
PanicRecoveredCount,
- RateLimitSourceIPCacheRequests,
- RateLimitSourceIPCachedEntries,
- RateLimitSourceIPBlockedCount,
+ RateLimitCacheRequests,
+ RateLimitCachedEntries,
+ RateLimitBlockedCount,
)
}
diff --git a/server.go b/server.go
index 0af582ff..b5aecc37 100644
--- a/server.go
+++ b/server.go
@@ -4,11 +4,13 @@ import (
"context"
"crypto/tls"
"fmt"
+ stdlog "log"
"net"
"net/http"
"time"
- proxyproto "github.com/pires/go-proxyproto"
+ "github.com/pires/go-proxyproto"
+ "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-pages/internal/netutil"
@@ -31,6 +33,7 @@ func (a *theApp) listenAndServe(server *http.Server, addr string, h http.Handler
server.Handler = h
server.TLSConfig = config.tlsConfig
+ server.ErrorLog = stdlog.New(logrus.StandardLogger().Writer(), "", 0)
// ensure http2 is enabled even if TLSConfig is not null
// See https://github.com/golang/go/blob/97cee43c93cfccded197cd281f0a5885cdb605b4/src/net/http/server.go#L2947-L2954
if server.TLSConfig != nil {
diff --git a/test/acceptance/acme_test.go b/test/acceptance/acme_test.go
index 77c4d6c0..d743a4e1 100644
--- a/test/acceptance/acme_test.go
+++ b/test/acceptance/acme_test.go
@@ -7,6 +7,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
const (
@@ -41,8 +43,8 @@ func TestAcmeChallengesWhenItIsNotConfigured(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com",
test.token)
+ testhelpers.Close(t, rsp.Body)
require.NoError(t, err)
- defer rsp.Body.Close()
require.Equal(t, test.expectedStatus, rsp.StatusCode)
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
@@ -82,8 +84,8 @@ func TestAcmeChallengesWhenItIsConfigured(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "withacmechallenge.domain.com",
test.token)
+ testhelpers.Close(t, rsp.Body)
require.NoError(t, err)
- defer rsp.Body.Close()
require.Equal(t, test.expectedStatus, rsp.StatusCode)
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go
index f087581c..65018c2c 100644
--- a/test/acceptance/artifacts_test.go
+++ b/test/acceptance/artifacts_test.go
@@ -11,12 +11,11 @@ import (
"time"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestArtifactProxyRequest(t *testing.T) {
- transport := (TestHTTPSClient.Transport).(*http.Transport).Clone()
- transport.ResponseHeaderTimeout = 5 * time.Second
-
content := "<!DOCTYPE html><html><head><title>Title of the document</title></head><body></body></html>"
contentLength := int64(len(content))
testServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -133,7 +132,7 @@ func TestArtifactProxyRequest(t *testing.T) {
resp, err := GetPageFromListener(t, httpListener, tt.host, tt.path)
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, tt.status, resp.StatusCode)
require.Equal(t, tt.contentType, resp.Header.Get("Content-Type"))
@@ -150,8 +149,6 @@ func TestArtifactProxyRequest(t *testing.T) {
}
func TestPrivateArtifactProxyRequest(t *testing.T) {
- setupTransport(t)
-
testServer := NewGitlabUnstartedServerStub(t, &stubOpts{})
keyFile, certFile := CreateHTTPSFixtureFiles(t)
@@ -229,7 +226,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) {
resp, err := GetRedirectPage(t, httpsListener, tt.host, tt.path)
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, http.StatusFound, resp.StatusCode)
@@ -245,7 +242,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) {
resp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery)
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, http.StatusFound, resp.StatusCode)
pagesDomainCookie := resp.Header.Get("Set-Cookie")
@@ -255,7 +252,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) {
state, pagesDomainCookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// Will redirect auth callback to correct host
url, err = url.Parse(authrsp.Header.Get("Location"))
@@ -266,7 +263,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) {
// Request auth callback in project domain
authrsp, err = GetRedirectPageWithCookie(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery, cookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// server returns the ticket, user will be redirected to the project page
require.Equal(t, http.StatusFound, authrsp.StatusCode)
@@ -276,7 +273,7 @@ func TestPrivateArtifactProxyRequest(t *testing.T) {
require.Equal(t, tt.status, resp.StatusCode)
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
})
}
}
diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go
index 18b73161..d7677622 100644
--- a/test/acceptance/auth_test.go
+++ b/test/acceptance/auth_test.go
@@ -8,6 +8,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestWhenAuthIsDisabledPrivateIsNotAccessible(t *testing.T) {
@@ -33,7 +35,7 @@ func TestWhenAuthIsEnabledPrivateWillRedirectToAuthorize(t *testing.T) {
rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusFound, rsp.StatusCode)
require.Equal(t, 1, len(rsp.Header["Location"]))
@@ -41,7 +43,7 @@ func TestWhenAuthIsEnabledPrivateWillRedirectToAuthorize(t *testing.T) {
require.NoError(t, err)
rsp, err = GetRedirectPage(t, httpsListener, url.Host, url.Path+"?"+url.RawQuery)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusFound, rsp.StatusCode)
require.Equal(t, 1, len(rsp.Header["Location"]))
@@ -69,7 +71,7 @@ func TestWhenAuthDeniedWillCauseUnauthorized(t *testing.T) {
rsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?error=access_denied")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusUnauthorized, rsp.StatusCode)
}
@@ -84,13 +86,13 @@ func TestWhenLoginCallbackWithWrongStateShouldFail(t *testing.T) {
rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
// Go to auth page with wrong state will cause failure
authrsp, err := GetPageFromListener(t, httpsListener, "projects.gitlab-example.com", "/auth?code=0&state=0")
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
require.Equal(t, http.StatusUnauthorized, authrsp.StatusCode)
}
@@ -106,7 +108,7 @@ func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) {
rsp, err := GetRedirectPage(t, httpsListener, "group.auth.gitlab-example.com", "private.project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
cookie := rsp.Header.Get("Set-Cookie")
@@ -122,7 +124,7 @@ func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) {
url.Query().Get("state"), header)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// Will cause 500 because the code is not encrypted
require.Equal(t, http.StatusInternalServerError, authrsp.StatusCode)
@@ -153,7 +155,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
t.Run(name, func(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
cookie := rsp.Header.Get("Set-Cookie")
@@ -165,7 +167,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery)
require.NoError(t, err)
- defer pagesrsp.Body.Close()
+ testhelpers.Close(t, pagesrsp.Body)
pagescookie := pagesrsp.Header.Get("Set-Cookie")
@@ -174,7 +176,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
state, pagescookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
url, err = url.Parse(authrsp.Header.Get("Location"))
require.NoError(t, err)
@@ -188,7 +190,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
state, cookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// Will redirect to the page
cookie = authrsp.Header.Get("Set-Cookie")
@@ -203,7 +205,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
// Fetch page in custom domain
authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, cookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
require.Equal(t, http.StatusOK, authrsp.StatusCode)
})
}
@@ -254,7 +256,7 @@ func TestCustomErrorPageWithAuth(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
cookie := rsp.Header.Get("Set-Cookie")
@@ -266,7 +268,7 @@ func TestCustomErrorPageWithAuth(t *testing.T) {
pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery)
require.NoError(t, err)
- defer pagesrsp.Body.Close()
+ testhelpers.Close(t, pagesrsp.Body)
pagescookie := pagesrsp.Header.Get("Set-Cookie")
@@ -275,7 +277,7 @@ func TestCustomErrorPageWithAuth(t *testing.T) {
state, pagescookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
url, err = url.Parse(authrsp.Header.Get("Location"))
require.NoError(t, err)
@@ -292,7 +294,7 @@ func TestCustomErrorPageWithAuth(t *testing.T) {
state, cookie)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// Will redirect to the page
groupCookie := authrsp.Header.Get("Set-Cookie")
@@ -307,7 +309,7 @@ func TestCustomErrorPageWithAuth(t *testing.T) {
// Fetch page in custom domain
anotherResp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, groupCookie)
require.NoError(t, err)
- defer anotherResp.Body.Close()
+ testhelpers.Close(t, anotherResp.Body)
require.Equal(t, http.StatusNotFound, anotherResp.StatusCode)
@@ -328,7 +330,7 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) {
rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/", "", true)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
cookie := rsp.Header.Get("Set-Cookie")
@@ -339,7 +341,7 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) {
require.Equal(t, url.Query().Get("domain"), "https://private.domain.com")
pagesrsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, url.Host, url.Path+"?"+url.RawQuery, "", true)
require.NoError(t, err)
- defer pagesrsp.Body.Close()
+ testhelpers.Close(t, pagesrsp.Body)
pagescookie := pagesrsp.Header.Get("Set-Cookie")
@@ -349,7 +351,7 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) {
pagescookie, true)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
url, err = url.Parse(authrsp.Header.Get("Location"))
require.NoError(t, err)
@@ -366,7 +368,7 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) {
"/auth?code="+code+"&state="+state, cookie, true)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
// Will redirect to the page
cookie = authrsp.Header.Get("Set-Cookie")
@@ -381,7 +383,7 @@ func TestAccessControlUnderCustomDomainWithHTTPSProxy(t *testing.T) {
authrsp, err = GetProxyRedirectPageWithCookie(t, proxyListener, "private.domain.com", "/",
cookie, true)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
require.Equal(t, http.StatusOK, authrsp.StatusCode)
}
@@ -395,7 +397,7 @@ func TestAccessControlGroupDomain404RedirectsAuth(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/nonexistent/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusFound, rsp.StatusCode)
// Redirects to the projects under gitlab pages domain for authentication flow
url, err := url.Parse(rsp.Header.Get("Location"))
@@ -414,15 +416,13 @@ func TestAccessControlProject404DoesNotRedirect(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "/project/nonexistent/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusNotFound, rsp.StatusCode)
}
type runPagesFunc func(t *testing.T, listeners []ListenSpec, sslCertFile string)
func testAccessControl(t *testing.T, runPages runPagesFunc) {
- setupTransport(t)
-
_, certFile := CreateHTTPSFixtureFiles(t)
tests := map[string]struct {
@@ -488,7 +488,7 @@ func testAccessControl(t *testing.T, runPages runPagesFunc) {
t.Run(tn, func(t *testing.T) {
rsp1, err1 := GetRedirectPage(t, httpsListener, tt.host, tt.path)
require.NoError(t, err1)
- defer rsp1.Body.Close()
+ testhelpers.Close(t, rsp1.Body)
require.Equal(t, http.StatusFound, rsp1.StatusCode)
cookie := rsp1.Header.Get("Set-Cookie")
@@ -502,7 +502,7 @@ func testAccessControl(t *testing.T, runPages runPagesFunc) {
rsp2, err2 := GetRedirectPage(t, httpsListener, loc1.Host, loc1.Path+"?"+loc1.RawQuery)
require.NoError(t, err2)
- defer rsp2.Body.Close()
+ testhelpers.Close(t, rsp2.Body)
require.Equal(t, http.StatusFound, rsp2.StatusCode)
pagesDomainCookie := rsp2.Header.Get("Set-Cookie")
@@ -511,7 +511,7 @@ func testAccessControl(t *testing.T, runPages runPagesFunc) {
authrsp1, err := GetRedirectPageWithCookie(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+
state, pagesDomainCookie)
require.NoError(t, err)
- defer authrsp1.Body.Close()
+ testhelpers.Close(t, authrsp1.Body)
// Will redirect auth callback to correct host
authLoc, err := url.Parse(authrsp1.Header.Get("Location"))
@@ -522,7 +522,7 @@ func testAccessControl(t *testing.T, runPages runPagesFunc) {
// Request auth callback in project domain
authrsp2, err := GetRedirectPageWithCookie(t, httpsListener, authLoc.Host, authLoc.Path+"?"+authLoc.RawQuery, cookie)
require.NoError(t, err)
- defer authrsp2.Body.Close()
+ testhelpers.Close(t, authrsp2.Body)
// server returns the ticket, user will be redirected to the project page
require.Equal(t, http.StatusFound, authrsp2.StatusCode)
@@ -530,7 +530,7 @@ func testAccessControl(t *testing.T, runPages runPagesFunc) {
rsp3, err3 := GetRedirectPageWithCookie(t, httpsListener, tt.host, tt.path, cookie)
require.NoError(t, err3)
- defer rsp3.Body.Close()
+ testhelpers.Close(t, rsp3.Body)
require.Equal(t, tt.status, rsp3.StatusCode)
@@ -580,7 +580,7 @@ func TestHijackedCode(t *testing.T) {
hackedURL := fmt.Sprintf("/auth?domain=http://%s&state=%s", attackersDomain, "irrelevant")
maliciousResp, err := GetProxyRedirectPageWithCookie(t, proxyListener, "projects.gitlab-example.com", hackedURL, "", true)
require.NoError(t, err)
- defer maliciousResp.Body.Close()
+ testhelpers.Close(t, maliciousResp.Body)
pagesCookie := maliciousResp.Header.Get("Set-Cookie")
@@ -597,7 +597,7 @@ func TestHijackedCode(t *testing.T) {
pagesCookie, true)
require.NoError(t, err)
- defer authrsp.Body.Close()
+ testhelpers.Close(t, authrsp.Body)
/****ATTACKER******/
// Target is redirected to attacker's domain and attacker receives the proper code
@@ -614,7 +614,7 @@ func TestHijackedCode(t *testing.T) {
impersonatingRes, err := GetProxyRedirectPageWithCookie(t, proxyListener, targetDomain,
"/auth?code="+hijackedCode+"&state="+attackerState, attackerCookie, true)
require.NoError(t, err)
- defer impersonatingRes.Body.Close()
+ testhelpers.Close(t, impersonatingRes.Body)
require.Equal(t, impersonatingRes.StatusCode, http.StatusInternalServerError, "should fail to decode code")
}
@@ -626,7 +626,7 @@ func getValidCookieAndState(t *testing.T, domain string) (string, string) {
// visit https://<domain>/
rsp, err := GetProxyRedirectPageWithCookie(t, proxyListener, domain, "/", "", true)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
cookie := rsp.Header.Get("Set-Cookie")
require.NotEmpty(t, cookie)
diff --git a/test/acceptance/encodings_test.go b/test/acceptance/encodings_test.go
index 18f2c492..c3532bd9 100644
--- a/test/acceptance/encodings_test.go
+++ b/test/acceptance/encodings_test.go
@@ -6,6 +6,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestMIMETypes(t *testing.T) {
@@ -27,7 +29,7 @@ func TestMIMETypes(t *testing.T) {
t.Run(name, func(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/"+tt.file)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
mt, _, err := mime.ParseMediaType(rsp.Header.Get("Content-Type"))
@@ -69,7 +71,7 @@ func TestCompressedEncoding(t *testing.T) {
}
rsp, err := GetPageFromListenerWithHeaders(t, httpListener, "group.gitlab-example.com", "index.html", header)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.Equal(t, tt.encoding, rsp.Header.Get("Content-Encoding"))
diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go
index 62a5e344..1b514a85 100644
--- a/test/acceptance/helpers_test.go
+++ b/test/acceptance/helpers_test.go
@@ -34,65 +34,10 @@ import (
// The HTTPS certificate isn't signed by anyone. This http client is set up
// so it can talk to servers using it.
var (
- // The HTTPS certificate isn't signed by anyone. This http client is set up
- // so it can talk to servers using it.
- TestHTTPSClient = &http.Client{
- Transport: &http.Transport{
- TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
- },
- }
-
- // Use HTTP with a very short timeout to repeatedly check for the server to be
- // up. Again, ignore HTTP
- QuickTimeoutHTTPSClient = &http.Client{
- Transport: &http.Transport{
- TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
- ResponseHeaderTimeout: 100 * time.Millisecond,
- },
- }
-
- // Proxyv2 client
- TestProxyv2Client = &http.Client{
- Transport: &http.Transport{
- DialContext: Proxyv2DialContext,
- TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
- },
- }
-
- QuickTimeoutProxyv2Client = &http.Client{
- Transport: &http.Transport{
- DialContext: Proxyv2DialContext,
- TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
- ResponseHeaderTimeout: 100 * time.Millisecond,
- },
- }
-
TestCertPool = x509.NewCertPool()
// Proxyv2 will create a dummy request with src 10.1.1.1:1000
// and dst 20.2.2.2:2000
- Proxyv2DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
- var d net.Dialer
-
- conn, err := d.DialContext(ctx, network, addr)
- if err != nil {
- return nil, err
- }
-
- header := &proxyproto.Header{
- Version: 2,
- Command: proxyproto.PROXY,
- TransportProtocol: proxyproto.TCPv4,
- SourceAddress: net.ParseIP("10.1.1.1"),
- SourcePort: 1000,
- DestinationAddress: net.ParseIP("20.2.2.2"),
- DestinationPort: 2000,
- }
-
- _, err = header.WriteTo(conn)
-
- return conn, err
- }
)
type tWriter struct {
@@ -151,15 +96,84 @@ func supportedListeners() []ListenSpec {
return listeners
}
-func (l ListenSpec) URL(suffix string) string {
- scheme := request.SchemeHTTP
+func (l ListenSpec) Scheme() string {
if l.Type == request.SchemeHTTPS || l.Type == "https-proxyv2" {
- scheme = request.SchemeHTTPS
+ return request.SchemeHTTPS
}
+ return request.SchemeHTTP
+}
+
+func (l ListenSpec) URL(suffix string) string {
suffix = strings.TrimPrefix(suffix, "/")
- return fmt.Sprintf("%s://%s/%s", scheme, l.JoinHostPort(), suffix)
+ return fmt.Sprintf("%s://%s/%s", l.Scheme(), l.JoinHostPort(), suffix)
+}
+
+type dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+func (l ListenSpec) proxyV2DialContext() dialContext {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+
+ // bypass DNS resolution by going directly to host and port
+ conn, err := d.DialContext(ctx, network, l.JoinHostPort())
+ if err != nil {
+ return nil, err
+ }
+
+ header := &proxyproto.Header{
+ Version: 2,
+ Command: proxyproto.PROXY,
+ TransportProtocol: proxyproto.TCPv4,
+ SourceAddress: net.ParseIP("10.1.1.1"),
+ SourcePort: 1000,
+ DestinationAddress: net.ParseIP("20.2.2.2"),
+ DestinationPort: 2000,
+ }
+
+ _, err = header.WriteTo(conn)
+
+ return conn, err
+ }
+}
+
+func (l ListenSpec) httpsDialContext() dialContext {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+
+ // bypass DNS resolution by going directly to host and port
+ return d.DialContext(ctx, network, l.JoinHostPort())
+ }
+}
+
+func (l ListenSpec) dialContext() dialContext {
+ if l.Type == "https-proxyv2" {
+ return l.proxyV2DialContext()
+ }
+
+ return l.httpsDialContext()
+}
+
+func (l ListenSpec) Client() *http.Client {
+ return &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
+ DialContext: l.dialContext(),
+ ResponseHeaderTimeout: 5 * time.Second,
+ },
+ }
+}
+
+// Use a very short timeout to repeatedly check for the server to be up.
+func (l ListenSpec) QuickTimeoutClient() *http.Client {
+ return &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{RootCAs: TestCertPool},
+ DialContext: l.dialContext(),
+ ResponseHeaderTimeout: 100 * time.Millisecond,
+ },
+ }
}
// Returns only once this spec points at a working TCP server
@@ -177,12 +191,7 @@ func (l ListenSpec) WaitUntilRequestSucceeds(done chan struct{}) error {
return err
}
- client := QuickTimeoutHTTPSClient
- if l.Type == "https-proxyv2" {
- client = QuickTimeoutProxyv2Client
- }
-
- response, err := client.Transport.RoundTrip(req)
+ response, err := l.QuickTimeoutClient().Transport.RoundTrip(req)
if err != nil {
time.Sleep(100 * time.Millisecond)
continue
@@ -380,11 +389,7 @@ func GetPageFromListenerWithHeaders(t *testing.T, spec ListenSpec, host, urlSuff
func DoPagesRequest(t *testing.T, spec ListenSpec, req *http.Request) (*http.Response, error) {
t.Logf("curl -X %s -H'Host: %s' %s", req.Method, req.Host, req.URL)
- if spec.Type == "https-proxyv2" {
- return TestProxyv2Client.Do(req)
- }
-
- return TestHTTPSClient.Do(req)
+ return spec.Client().Do(req)
}
func GetRedirectPage(t *testing.T, spec ListenSpec, host, urlsuffix string) (*http.Response, error) {
@@ -419,11 +424,7 @@ func GetRedirectPageWithHeaders(t *testing.T, spec ListenSpec, host, urlsuffix s
req.Host = host
- if spec.Type == "https-proxyv2" {
- return TestProxyv2Client.Transport.RoundTrip(req)
- }
-
- return TestHTTPSClient.Transport.RoundTrip(req)
+ return spec.Client().Transport.RoundTrip(req)
}
func ClientWithConfig(tlsConfig *tls.Config) (*http.Client, func()) {
@@ -602,12 +603,14 @@ func copyFile(dest, src string) error {
return err
}
-func setupTransport(t *testing.T) {
- t.Helper()
+// RequireMetricEqual requests prometheus metrics and makes sure metric is there
+func RequireMetricEqual(t *testing.T, metricsAddress, metricWithValue string) {
+ resp, err := http.Get(fmt.Sprintf("http://%s/metrics", metricsAddress))
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
- transport := (TestHTTPSClient.Transport).(*http.Transport)
- defer func(t time.Duration) {
- transport.ResponseHeaderTimeout = t
- }(transport.ResponseHeaderTimeout)
- transport.ResponseHeaderTimeout = 5 * time.Second
+ require.Contains(t, string(body), metricWithValue)
}
diff --git a/test/acceptance/metrics_test.go b/test/acceptance/metrics_test.go
index 32a36638..d1e8a1b6 100644
--- a/test/acceptance/metrics_test.go
+++ b/test/acceptance/metrics_test.go
@@ -6,6 +6,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestPrometheusMetricsCanBeScraped(t *testing.T) {
@@ -20,13 +22,13 @@ func TestPrometheusMetricsCanBeScraped(t *testing.T) {
res, err := GetPageFromListener(t, httpListener, "zip.gitlab.io",
"/symlink.html")
require.NoError(t, err)
- defer res.Body.Close()
+ testhelpers.Close(t, res.Body)
require.Equal(t, http.StatusOK, res.StatusCode)
resp, err := http.Get("http://127.0.0.1:42345/metrics")
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
diff --git a/test/acceptance/proxyv2_test.go b/test/acceptance/proxyv2_test.go
index 81a7ff94..45bdcb89 100644
--- a/test/acceptance/proxyv2_test.go
+++ b/test/acceptance/proxyv2_test.go
@@ -7,6 +7,8 @@ import (
"time"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestProxyv2(t *testing.T) {
@@ -37,7 +39,7 @@ func TestProxyv2(t *testing.T) {
response, err := GetPageFromListener(t, httpsProxyv2Listener, tt.host, tt.urlSuffix)
require.NoError(t, err)
- defer response.Body.Close()
+ testhelpers.Close(t, response.Body)
require.Equal(t, tt.expectedStatusCode, response.StatusCode)
diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go
index 02ba54f5..a97fdfb1 100644
--- a/test/acceptance/ratelimiter_test.go
+++ b/test/acceptance/ratelimiter_test.go
@@ -3,6 +3,7 @@ package acceptance_test
import (
"fmt"
"net/http"
+ "strconv"
"testing"
"time"
@@ -77,7 +78,7 @@ func TestIPRateLimits(t *testing.T) {
}
}
-func TestDomainateLimits(t *testing.T) {
+func TestDomainRateLimits(t *testing.T) {
testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true)
for name, tc := range ratelimitedListeners {
@@ -112,6 +113,131 @@ func TestDomainateLimits(t *testing.T) {
}
}
+func TestTLSRateLimits(t *testing.T) {
+ tests := map[string]struct {
+ spec ListenSpec
+ domainLimit bool
+ sourceIP string
+ enforceEnabled bool
+ }{
+ "https_with_domain_limit": {
+ spec: httpsListener,
+ domainLimit: true,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: true,
+ },
+ "https_with_domain_limit_not_enforced": {
+ spec: httpsListener,
+ domainLimit: true,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: false,
+ },
+ "https_with_ip_limit": {
+ spec: httpsListener,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: true,
+ },
+ "https_with_ip_limit_not_enforced": {
+ spec: httpsListener,
+ sourceIP: "127.0.0.1",
+ enforceEnabled: false,
+ },
+ "proxyv2_with_domain_limit": {
+ spec: httpsProxyv2Listener,
+ domainLimit: true,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: true,
+ },
+ "proxyv2_with_domain_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ domainLimit: true,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: false,
+ },
+ "proxyv2_with_ip_limit": {
+ spec: httpsProxyv2Listener,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: true,
+ },
+ "proxyv2_with_ip_limit_not_enforced": {
+ spec: httpsProxyv2Listener,
+ sourceIP: "10.1.1.1",
+ enforceEnabled: false,
+ },
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ rateLimit := 5
+
+ options := []processOption{
+ withListeners([]ListenSpec{tt.spec}),
+ withExtraArgument("metrics-address", ":42345"),
+ }
+
+ featureName := feature.EnforceIPTLSRateLimits.EnvVariable
+ limitName := "tls_connections_by_source_ip"
+
+ if tt.domainLimit {
+ options = append(options,
+ withExtraArgument("rate-limit-tls-domain", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-domain-burst", fmt.Sprint(rateLimit)))
+
+ featureName = feature.EnforceDomainTLSRateLimits.EnvVariable
+ limitName = "tls_connections_by_domain"
+ } else {
+ options = append(options,
+ withExtraArgument("rate-limit-tls-source-ip", fmt.Sprint(rateLimit)),
+ withExtraArgument("rate-limit-tls-source-ip-burst", fmt.Sprint(rateLimit)))
+ }
+
+ testhelpers.StubFeatureFlagValue(t, featureName, tt.enforceEnabled)
+ logBuf := RunPagesProcess(t, options...)
+
+ // when we start the process we make 1 requests to verify that process is up
+ // it gets counted in the rate limit for IP, but host is different
+ if !tt.domainLimit {
+ rateLimit--
+ }
+
+ for i := 0; i < 10; i++ {
+ rsp, err := makeTLSRequest(t, tt.spec)
+
+ if i >= rateLimit {
+ assertLogFound(t, logBuf, []string{
+ "TLS connection rate-limited",
+ "\"req_host\":\"group.gitlab-example.com\"",
+ fmt.Sprintf("\"source_ip\":\"%s\"", tt.sourceIP),
+ "\"enforced\":" + strconv.FormatBool(tt.enforceEnabled)})
+
+ if tt.enforceEnabled {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "remote error: tls: internal error")
+ }
+
+ continue
+ }
+
+ require.NoError(t, err, "request: %d failed", i)
+ require.NoError(t, rsp.Body.Close())
+ require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i)
+ }
+ expectedMetric := fmt.Sprintf(
+ "gitlab_pages_rate_limit_blocked_count{enforced=\"%t\",limit_name=\"%s\"} %v",
+ tt.enforceEnabled, limitName, 10-rateLimit)
+
+ RequireMetricEqual(t, "127.0.0.1:42345", expectedMetric)
+ })
+ }
+}
+
+func makeTLSRequest(t *testing.T, spec ListenSpec) (*http.Response, error) {
+ req, err := http.NewRequest("GET", "https://group.gitlab-example.com/project", nil)
+ require.NoError(t, err)
+
+ return spec.Client().Do(req)
+}
+
func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) {
t.Helper()
diff --git a/test/acceptance/redirects_test.go b/test/acceptance/redirects_test.go
index 6c1158a4..5846d2cd 100644
--- a/test/acceptance/redirects_test.go
+++ b/test/acceptance/redirects_test.go
@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-pages/internal/feature"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestRedirectStatusPage(t *testing.T) {
@@ -22,7 +23,7 @@ func TestRedirectStatusPage(t *testing.T) {
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Contains(t, string(body), "14 rules")
require.Equal(t, http.StatusOK, rsp.StatusCode)
@@ -37,7 +38,7 @@ func TestRedirect(t *testing.T) {
// Test that serving a file still works with redirects enabled
rsp, err := GetRedirectPage(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/index.html")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
@@ -95,7 +96,7 @@ func TestRedirect(t *testing.T) {
t.Run(fmt.Sprintf("%s%s -> %s (%d)", tt.host, tt.path, tt.expectedLocation, tt.expectedStatus), func(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, tt.host, tt.path)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, tt.expectedLocation, rsp.Header.Get("Location"))
require.Equal(t, tt.expectedStatus, rsp.StatusCode)
diff --git a/test/acceptance/rewrites_test.go b/test/acceptance/rewrites_test.go
index aa105789..eefb1e82 100644
--- a/test/acceptance/rewrites_test.go
+++ b/test/acceptance/rewrites_test.go
@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab-pages/internal/feature"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestRewrites(t *testing.T) {
@@ -47,7 +48,7 @@ func TestRewrites(t *testing.T) {
t.Run(name, func(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, tt.host, tt.path)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go
index 8b01f5b2..bab69357 100644
--- a/test/acceptance/serving_test.go
+++ b/test/acceptance/serving_test.go
@@ -10,6 +10,8 @@ import (
"time"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestUnknownHostReturnsNotFound(t *testing.T) {
@@ -29,7 +31,7 @@ func TestUnknownProjectReturnsNotFound(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/nonexistent/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusNotFound, rsp.StatusCode)
}
@@ -38,7 +40,7 @@ func TestGroupDomainReturns200(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
body, err := io.ReadAll(rsp.Body)
@@ -154,7 +156,7 @@ func TestCustom404(t *testing.T) {
rsp, err := GetPageFromListener(t, spec, test.host, test.path)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusNotFound, rsp.StatusCode)
page, err := io.ReadAll(rsp.Body)
@@ -171,7 +173,7 @@ func TestCORSWhenDisabled(t *testing.T) {
for _, spec := range supportedListeners() {
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
rsp := doCrossOriginRequest(t, spec, method, method, spec.URL("project/"))
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.Equal(t, "", rsp.Header.Get("Access-Control-Allow-Origin"))
@@ -219,7 +221,7 @@ func TestCORSAllowsMethod(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
for _, spec := range supportedListeners() {
rsp := doCrossOriginRequest(t, spec, tt.method, tt.method, spec.URL("project/"))
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, tt.expectedStatus, rsp.StatusCode)
require.Equal(t, tt.expectedOrigin, rsp.Header.Get("Access-Control-Allow-Origin"))
@@ -237,7 +239,7 @@ func TestCustomHeaders(t *testing.T) {
for _, spec := range supportedListeners() {
rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:", "project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
require.Equal(t, "Testing1", rsp.Header.Get("X-Test1"))
require.Equal(t, "Testing2", rsp.Header.Get("X-Test2"))
@@ -261,12 +263,12 @@ func TestHttpToHttpsRedirectDisabled(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
@@ -275,14 +277,14 @@ func TestHttpToHttpsRedirectEnabled(t *testing.T) {
rsp, err := GetRedirectPage(t, httpListener, "group.gitlab-example.com", "project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusTemporaryRedirect, rsp.StatusCode)
require.Equal(t, 1, len(rsp.Header["Location"]))
require.Equal(t, "https://group.gitlab-example.com/project/", rsp.Header.Get("Location"))
rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
@@ -408,7 +410,7 @@ func TestDomainResolverError(t *testing.T) {
response, err := GetPageFromListener(t, httpListener, domainName, "/my/pages/project/")
require.NoError(t, err)
- defer response.Body.Close()
+ testhelpers.Close(t, response.Body)
require.True(t, opts.getAPICalled(), "api must have been called")
@@ -450,7 +452,7 @@ func TestQueryStringPersistedInSlashRewrite(t *testing.T) {
rsp, err := GetRedirectPage(t, httpsListener, "group.gitlab-example.com", "project?q=test")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusFound, rsp.StatusCode)
require.Equal(t, 1, len(rsp.Header["Location"]))
@@ -458,7 +460,7 @@ func TestQueryStringPersistedInSlashRewrite(t *testing.T) {
rsp, err = GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/?q=test")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
@@ -488,7 +490,7 @@ func TestServerRepliesWithHeaders(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
@@ -540,7 +542,7 @@ func TestDiskDisabledFailsToServeFileAndLocalContent(t *testing.T) {
t.Run(host, func(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, host, suffix)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusInternalServerError, rsp.StatusCode)
})
diff --git a/test/acceptance/status_test.go b/test/acceptance/status_test.go
index ef01c692..c48aaff7 100644
--- a/test/acceptance/status_test.go
+++ b/test/acceptance/status_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestStatusPage(t *testing.T) {
@@ -15,6 +17,6 @@ func TestStatusPage(t *testing.T) {
rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck")
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go
index dfe9c82f..7e96c5e2 100644
--- a/test/acceptance/unknown_http_method_test.go
+++ b/test/acceptance/unknown_http_method_test.go
@@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestUnknownHTTPMethod(t *testing.T) {
@@ -18,7 +20,7 @@ func TestUnknownHTTPMethod(t *testing.T) {
resp, err := DoPagesRequest(t, httpListener, req)
require.NoError(t, err)
- defer resp.Body.Close()
+ testhelpers.Close(t, resp.Body)
require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
}
diff --git a/test/acceptance/zip_test.go b/test/acceptance/zip_test.go
index dcb831e7..5bb2a0d1 100644
--- a/test/acceptance/zip_test.go
+++ b/test/acceptance/zip_test.go
@@ -10,6 +10,8 @@ import (
"time"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
)
func TestZipServing(t *testing.T) {
@@ -85,8 +87,7 @@ func TestZipServing(t *testing.T) {
t.Run(name, func(t *testing.T) {
response, err := GetPageFromListener(t, httpListener, tt.host, tt.urlSuffix)
require.NoError(t, err)
- defer response.Body.Close()
-
+ testhelpers.Close(t, response.Body)
require.Equal(t, tt.expectedStatusCode, response.StatusCode)
if tt.expectedStatusCode == http.StatusOK {
@@ -217,7 +218,7 @@ func TestZipServingCache(t *testing.T) {
// send a request to get the ETag
response, err := GetPageFromListener(t, httpListener, tt.host, tt.urlSuffix)
require.NoError(t, err)
- defer response.Body.Close()
+ testhelpers.Close(t, response.Body)
require.Equal(t, http.StatusOK, response.StatusCode)
etag := response.Header.Get("ETag")
@@ -231,7 +232,7 @@ func TestZipServingCache(t *testing.T) {
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
- defer rsp.Body.Close()
+ testhelpers.Close(t, rsp.Body)
require.Equal(t, tt.expectedContent, string(body), "content mismatch")
})
}
@@ -308,7 +309,7 @@ func TestZipServingFromDisk(t *testing.T) {
t.Run(name, func(t *testing.T) {
response, err := GetPageFromListener(t, httpListener, tt.host, tt.urlSuffix)
require.NoError(t, err)
- defer response.Body.Close()
+ testhelpers.Close(t, response.Body)
require.Equal(t, tt.expectedStatusCode, response.StatusCode)
@@ -333,7 +334,7 @@ func TestZipServingConfigShortTimeout(t *testing.T) {
response, err := GetPageFromListener(t, httpListener, "zip.gitlab.io", "/")
require.NoError(t, err)
- defer response.Body.Close()
+ testhelpers.Close(t, response.Body)
require.Equal(t, http.StatusInternalServerError, response.StatusCode, "should fail to serve")
}