diff options
-rw-r--r-- | app.go | 23 | ||||
-rw-r--r-- | internal/config/config.go | 49 | ||||
-rw-r--r-- | internal/config/config_test.go | 78 | ||||
-rw-r--r-- | internal/customheaders/customheaders.go | 42 | ||||
-rw-r--r-- | internal/customheaders/customheaders_test.go | 113 | ||||
-rw-r--r-- | internal/customheaders/middleware.go | 4 | ||||
-rw-r--r-- | internal/vfs/zip/archive.go | 15 | ||||
-rw-r--r-- | test/acceptance/artifacts_test.go | 37 | ||||
-rw-r--r-- | test/acceptance/config_test.go | 9 | ||||
-rw-r--r-- | test/acceptance/helpers_test.go | 64 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 22 | ||||
-rw-r--r-- | test/acceptance/redirects_test.go | 6 | ||||
-rw-r--r-- | test/acceptance/rewrites_test.go | 3 | ||||
-rw-r--r-- | test/acceptance/serving_test.go | 16 | ||||
-rw-r--r-- | test/acceptance/status_test.go | 3 | ||||
-rw-r--r-- | test/acceptance/stub_test.go | 14 | ||||
-rw-r--r-- | test/gitlabstub/cmd/server/main.go | 25 | ||||
-rw-r--r-- | test/gitlabstub/option.go | 17 | ||||
-rw-r--r-- | test/gitlabstub/server.go | 5 |
19 files changed, 286 insertions, 259 deletions
@@ -49,13 +49,12 @@ var ( ) type theApp struct { - config *cfg.Config - source source.Source - tlsConfig *cryptotls.Config - Artifact *artifact.Artifact - Auth *auth.Auth - Handlers *handlers.Handlers - CustomHeaders http.Header + config *cfg.Config + source source.Source + tlsConfig *cryptotls.Config + Artifact *artifact.Artifact + Auth *auth.Auth + Handlers *handlers.Handlers } func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { @@ -154,7 +153,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = health.NewMiddleware(handler, a.config.General.StatusPath) // Custom response headers - handler = customheaders.NewMiddleware(handler, a.CustomHeaders) + handler = customheaders.NewMiddleware(handler, a.config.General.CustomHeaders) // Correlation ID injection middleware var correlationOpts []correlation.InboundHandlerOption @@ -372,14 +371,6 @@ func runApp(config *cfg.Config) error { a.Handlers = handlers.New(a.Auth, a.Artifact) - if len(config.General.CustomHeaders) != 0 { - customHeaders, err := customheaders.ParseHeaderString(config.General.CustomHeaders) - if err != nil { - return fmt.Errorf("unable to parse header string: %w", err) - } - a.CustomHeaders = customHeaders - } - if err := mimedb.LoadTypes(); err != nil { log.WithError(err).Warn("Loading extended MIME database failed") } diff --git a/internal/config/config.go b/internal/config/config.go index 354dcc98..50146011 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,14 +1,18 @@ package config import ( + "bufio" "crypto/tls" "encoding/base64" "errors" "fmt" + "net/http" + "net/textproto" "os" "strings" "time" + "github.com/hashicorp/go-multierror" "github.com/namsral/flag" "gitlab.com/gitlab-org/labkit/log" ) @@ -56,7 +60,7 @@ type General struct { ShowVersion bool - CustomHeaders []string + CustomHeaders http.Header } // RateLimit config struct @@ -162,8 +166,10 @@ type Metrics struct { } var ( - errMetricsNoCertificate = errors.New("metrics certificate path must not be empty") - errMetricsNoKey = errors.New("metrics private key path must not be empty") + errDuplicateHeader = errors.New("duplicate header") + errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + errMetricsNoCertificate = errors.New("metrics certificate path must not be empty") + errMetricsNoKey = errors.New("metrics private key path must not be empty") ) func internalGitlabServerFromFlags() string { @@ -239,6 +245,35 @@ func loadMetricsConfig() (metrics Metrics, err error) { return metrics, nil } +func parseHeaderString(customHeaders []string) (http.Header, error) { + headers := make(http.Header, len(customHeaders)) + + var result *multierror.Error + for _, h := range customHeaders { + h = h + "\n\n" + tp := textproto.NewReader(bufio.NewReader(strings.NewReader(h))) + + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + result = multierror.Append(result, fmt.Errorf("parsing error %s: %w", h, errInvalidHeaderParameter)) + } + + for key, value := range mimeHeader { + if _, ok := headers[key]; ok { + result = multierror.Append(result, fmt.Errorf("%s already specified with value '%s': %w", key, value, errDuplicateHeader)) + } + + headers[key] = value + } + } + + if result.ErrorOrNil() != nil { + return nil, result + } + + return headers, nil +} + func loadConfig() (*Config, error) { config := &Config{ General: General{ @@ -252,7 +287,6 @@ func loadConfig() (*Config, error) { DisableCrossOriginRequests: *disableCrossOriginRequests, InsecureCiphers: *insecureCiphers, PropagateCorrelationID: *propagateCorrelationID, - CustomHeaders: header.Split(), ShowVersion: *showVersion, }, RateLimit: RateLimit{ @@ -353,6 +387,13 @@ func loadConfig() (*Config, error) { } } + customHeaders, err := parseHeaderString(header.Split()) + if err != nil { + return nil, fmt.Errorf("unable to parse header string: %w", err) + } + + config.General.CustomHeaders = customHeaders + // Populating remaining GitLab settings config.GitLab.PublicServer = *publicGitLabServer diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f120d17d..19dfb957 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -83,3 +83,81 @@ func setupHTTPSFixture(t *testing.T) (dir string, key string, cert string) { return tmpDir, keyfile.Name(), certfile.Name() } + +func TestParseHeaderString(t *testing.T) { + tests := []struct { + name string + headerStrings []string + valid bool + expectedLen int + }{ + { + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + expectedLen: 1, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + valid: true, + expectedLen: 1, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + valid: true, + expectedLen: 1, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + valid: true, + expectedLen: 3, + }, + { + name: "Multiple invalid cases", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"Tk= N"}, + valid: false, + }, + { + name: "duplicate headers", + headerStrings: []string{"Tk: N", "Tk: M"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"X-Test-String Some-Test"}, + valid: false, + }, + { + name: "Valid and not valid case", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Multiple headers in single string parsed as one header", + headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, + valid: true, + expectedLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHeaderString(tt.headerStrings) + if tt.valid { + require.NoError(t, err) + require.Len(t, got, tt.expectedLen) + return + } + + require.Error(t, err) + }) + } +} diff --git a/internal/customheaders/customheaders.go b/internal/customheaders/customheaders.go index 92f50069..e7f8cb91 100644 --- a/internal/customheaders/customheaders.go +++ b/internal/customheaders/customheaders.go @@ -1,19 +1,7 @@ package customheaders import ( - "bufio" - "errors" - "fmt" "net/http" - "net/textproto" - "strings" - - "github.com/hashicorp/go-multierror" -) - -var ( - errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") - errDuplicateHeader = errors.New("duplicate header") ) // AddCustomHeaders adds a map of Headers to a Response @@ -24,33 +12,3 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { } } } - -// ParseHeaderString parses a string of key values into a map -func ParseHeaderString(customHeaders []string) (http.Header, error) { - headers := make(http.Header, len(customHeaders)) - - var result *multierror.Error - for _, h := range customHeaders { - h = h + "\n\n" - tp := textproto.NewReader(bufio.NewReader(strings.NewReader(h))) - - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil { - result = multierror.Append(result, fmt.Errorf("parsing error %s: %w", h, errInvalidHeaderParameter)) - } - - for key, value := range mimeHeader { - if _, ok := headers[key]; ok { - result = multierror.Append(result, fmt.Errorf("%s already specified with value '%s': %w", key, value, errDuplicateHeader)) - } - - headers[key] = value - } - } - - if result.ErrorOrNil() != nil { - return nil, result - } - - return headers, nil -} diff --git a/internal/customheaders/customheaders_test.go b/internal/customheaders/customheaders_test.go index 857c45e0..f0252137 100644 --- a/internal/customheaders/customheaders_test.go +++ b/internal/customheaders/customheaders_test.go @@ -1,6 +1,7 @@ package customheaders_test import ( + "net/http" "net/http/httptest" "testing" @@ -9,118 +10,38 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" ) -func TestParseHeaderString(t *testing.T) { - tests := []struct { - name string - headerStrings []string - valid bool - expectedLen int - }{ - { - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - valid: true, - expectedLen: 1, - }, - { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - valid: true, - expectedLen: 1, - }, - { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - valid: true, - expectedLen: 1, - }, - { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, - valid: true, - expectedLen: 3, - }, - { - name: "Multiple invalid cases", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"Tk= N"}, - valid: false, - }, - { - name: "duplicate headers", - headerStrings: []string{"Tk: N", "Tk: M"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"X-Test-String Some-Test"}, - valid: false, - }, - { - name: "Valid and not valid case", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - { - name: "Multiple headers in single string parsed as one header", - headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, - valid: true, - expectedLen: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := customheaders.ParseHeaderString(tt.headerStrings) - if tt.valid { - require.NoError(t, err) - require.Len(t, got, tt.expectedLen) - return - } - - require.Error(t, err) - }) - } -} - func TestAddCustomHeaders(t *testing.T) { tests := []struct { - name string - headerStrings []string - wantHeaders map[string]string + name string + headers http.Header + wantHeaders map[string]string }{ { - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - wantHeaders: map[string]string{"X-Test-String": "Test"}, + name: "Normal case", + headers: http.Header{"X-Test-String": []string{"Test"}}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, }, { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - wantHeaders: map[string]string{"Tk": "N"}, + name: "Non-tracking header case", + headers: http.Header{"Tk": []string{"N"}}, + wantHeaders: map[string]string{"Tk": "N"}, }, { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'"}, + name: "Content security header case", + headers: http.Header{"content-security-policy": []string{"default-src 'self'"}}, + wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'"}, }, { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header: Amazing"}, - wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, + name: "Multiple header strings", + headers: http.Header{"content-security-policy": []string{"default-src 'self'"}, "X-Test-String": []string{"Test"}, "My amazing header": []string{"Amazing"}}, + wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - headers, err := customheaders.ParseHeaderString(tt.headerStrings) - require.NoError(t, err) w := httptest.NewRecorder() - customheaders.AddCustomHeaders(w, headers) + customheaders.AddCustomHeaders(w, tt.headers) rsp := w.Result() for k, v := range tt.wantHeaders { require.Len(t, rsp.Header[k], 1) diff --git a/internal/customheaders/middleware.go b/internal/customheaders/middleware.go index e0e11ad6..964b822a 100644 --- a/internal/customheaders/middleware.go +++ b/internal/customheaders/middleware.go @@ -6,6 +6,10 @@ import ( // NewMiddleware returns middleware which inject custom headers into the response func NewMiddleware(handler http.Handler, headers http.Header) http.Handler { + if len(headers) == 0 { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { AddCustomHeaders(w, headers) diff --git a/internal/vfs/zip/archive.go b/internal/vfs/zip/archive.go index 9f01794b..2b04310d 100644 --- a/internal/vfs/zip/archive.go +++ b/internal/vfs/zip/archive.go @@ -56,7 +56,6 @@ type zipArchive struct { resource *httprange.Resource reader *httprange.RangedReader - archive *zip.Reader err error files map[string]*zip.File @@ -128,13 +127,15 @@ func (a *zipArchive) readArchive(url string) { return } + var archive *zip.Reader + // load all archive files into memory using a cached ranged reader a.reader = httprange.NewRangedReader(a.resource) a.reader.WithCachedReader(ctx, func() { - a.archive, a.err = zip.NewReader(a.reader, a.resource.Size) + archive, a.err = zip.NewReader(a.reader, a.resource.Size) }) - if a.archive == nil || a.err != nil { + if archive == nil || a.err != nil { log.WithFields(log.Fields{ "archive_url": url, }).WithError(a.err).Infoln("loading zip archive files into memory failed") @@ -143,7 +144,7 @@ func (a *zipArchive) readArchive(url string) { } // TODO: Improve preprocessing of zip archives https://gitlab.com/gitlab-org/gitlab-pages/-/issues/432 - for _, file := range a.archive.File { + for _, file := range archive.File { if !strings.HasPrefix(file.Name, dirPrefix) { continue } @@ -164,8 +165,10 @@ func (a *zipArchive) readArchive(url string) { a.addPathDirectory(file.Name) } - // recycle memory - a.archive.File = nil + // Each file stores a pointer to the zip.reader. + // The file slice is not used so we null it out + // to reduce memory consumption. + archive.File = nil fileCount := float64(len(a.files)) metrics.ZipOpened.WithLabelValues("ok").Inc() diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index e56a1390..c88f48bf 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -120,10 +120,11 @@ func TestArtifactProxyRequest(t *testing.T) { args := []string{"-artifacts-server=" + artifactServerURL, "-artifacts-server-timeout=1"} + t.Setenv("SSL_CERT_FILE", certFile) + RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), withArguments(args), - withEnv([]string{"SSL_CERT_FILE=" + certFile}), ) for _, tt := range tests { @@ -150,20 +151,6 @@ func TestArtifactProxyRequest(t *testing.T) { } func TestPrivateArtifactProxyRequest(t *testing.T) { - testServer, err := gitlabstub.NewUnstartedServer() - require.NoError(t, err) - - keyFile, certFile := CreateHTTPSFixtureFiles(t) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - require.NoError(t, err) - - testServer.TLS = &tls.Config{Certificates: []tls.Certificate{cert}} - testServer.StartTLS() - - t.Cleanup(func() { - testServer.Close() - }) - tests := []struct { name string host string @@ -202,23 +189,23 @@ func TestPrivateArtifactProxyRequest(t *testing.T) { }, } - // Ensure the IP address is used in the URL, as we're relying on IP SANs to - // validate - artifactServerURL := testServer.URL + "/api/v4" - t.Log("Artifact server URL", artifactServerURL) + configFile := defaultConfigFileWith(t) + + keyFile, certFile := CreateHTTPSFixtureFiles(t) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) - configFile := defaultConfigFileWith(t, - "gitlab-server="+testServer.URL, - "artifacts-server="+artifactServerURL, - "auth-redirect-uri=https://projects.gitlab-example.com/auth", - "artifacts-server-timeout=1") + t.Setenv("SSL_CERT_FILE", certFile) RunPagesProcess(t, withListeners([]ListenSpec{httpsListener}), withArguments([]string{ "-config=" + configFile, }), - withEnv([]string{"SSL_CERT_FILE=" + certFile}), + withPublicServer, + withExtraArgument("auth-redirect-uri", "https://projects.gitlab-example.com/auth"), + withExtraArgument("artifacts-server-timeout", "1"), + withStubOptions(gitlabstub.WithCertificate(cert)), ) for _, tt := range tests { diff --git a/test/acceptance/config_test.go b/test/acceptance/config_test.go index 95be6e17..baa35f6e 100644 --- a/test/acceptance/config_test.go +++ b/test/acceptance/config_test.go @@ -11,12 +11,11 @@ import ( ) func TestEnvironmentVariablesConfig(t *testing.T) { - envVarValue := "LISTEN_HTTP=" + net.JoinHostPort(httpListener.Host, httpListener.Port) + t.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) RunPagesProcess(t, withoutWait, withListeners([]ListenSpec{}), // explicitly disable listeners for this test - withEnv([]string{envVarValue}), ) require.NoError(t, httpListener.WaitUntilRequestSucceeds(nil)) @@ -28,12 +27,11 @@ func TestEnvironmentVariablesConfig(t *testing.T) { } func TestMixedConfigSources(t *testing.T) { - envVarValue := "LISTEN_HTTP=" + net.JoinHostPort(httpListener.Host, httpListener.Port) + t.Setenv("LISTEN_HTTP", net.JoinHostPort(httpListener.Host, httpListener.Port)) RunPagesProcess(t, withoutWait, withListeners([]ListenSpec{httpsListener}), - withEnv([]string{envVarValue}), ) for _, listener := range []ListenSpec{httpListener, httpsListener} { @@ -48,12 +46,11 @@ func TestMixedConfigSources(t *testing.T) { func TestMultipleListenersFromEnvironmentVariables(t *testing.T) { listenSpecs := []ListenSpec{{"http", "127.0.0.1", "37001"}, {"http", "127.0.0.1", "37002"}} - envVarValue := fmt.Sprintf("LISTEN_HTTP=%s,%s", net.JoinHostPort("127.0.0.1", "37001"), net.JoinHostPort("127.0.0.1", "37002")) + t.Setenv("LISTEN_HTTP", fmt.Sprintf("%s,%s", net.JoinHostPort("127.0.0.1", "37001"), net.JoinHostPort("127.0.0.1", "37002"))) RunPagesProcess(t, withoutWait, withListeners([]ListenSpec{}), // explicitly disable listeners for this test - withEnv([]string{envVarValue}), ) for _, listener := range listenSpecs { diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go index 4c365c5e..f8638e36 100644 --- a/test/acceptance/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -190,33 +190,36 @@ func (l ListenSpec) QuickTimeoutClient() *http.Client { // Returns only once this spec points at a working TCP server func (l ListenSpec) WaitUntilRequestSucceeds(done chan struct{}) error { timeout := 5 * time.Second - for start := time.Now(); time.Since(start) < timeout; { + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { select { case <-done: return fmt.Errorf("server has shut down already") - default: - } + case <-ctx.Done(): + return fmt.Errorf("ctx done: %w for listener %v", ctx.Err(), l) + case <-ticker.C: + req, err := http.NewRequestWithContext(ctx, http.MethodGet, l.URL("/@healthcheck"), nil) + if err != nil { + return err + } - req, err := http.NewRequest("GET", l.URL("/"), nil) - if err != nil { - return err - } + response, err := l.QuickTimeoutClient().Transport.RoundTrip(req) - response, err := l.QuickTimeoutClient().Transport.RoundTrip(req) - if err != nil { - time.Sleep(100 * time.Millisecond) - continue - } - response.Body.Close() + if err == nil { + response.Body.Close() - if code := response.StatusCode; code >= 200 && code < 500 { - return nil + if code := response.StatusCode; code >= 200 && code < 500 { + return nil + } + } } - - time.Sleep(100 * time.Millisecond) } - - return fmt.Errorf("timed out after %v waiting for listener %v", timeout, l) } func (l ListenSpec) JoinHostPort() string { @@ -247,7 +250,12 @@ func RunPagesProcess(t *testing.T, opts ...processOption) *LogCaptureBuffer { source, err := gitlabstub.NewUnstartedServer(processCfg.gitlabStubOpts...) require.NoError(t, err) - source.Start() + + if source.TLS != nil { + source.StartTLS() + } else { + source.Start() + } gitLabAPISecretKey := CreateGitLabAPISecretKeyFixtureFile(t) processCfg.extraArgs = append( @@ -257,7 +265,11 @@ func RunPagesProcess(t *testing.T, opts ...processOption) *LogCaptureBuffer { "-api-secret-key", gitLabAPISecretKey, ) - logBuf, cleanup := runPagesProcess(t, processCfg.wait, processCfg.pagesBinary, processCfg.listeners, "", processCfg.envs, processCfg.extraArgs...) + if processCfg.publicServer { + processCfg.extraArgs = append(processCfg.extraArgs, "-gitlab-server", source.URL) + } + + logBuf, cleanup := runPagesProcess(t, processCfg.wait, processCfg.pagesBinary, processCfg.listeners, "", processCfg.extraArgs...) t.Cleanup(func() { source.Close() @@ -269,12 +281,13 @@ func RunPagesProcess(t *testing.T, opts ...processOption) *LogCaptureBuffer { } func RunPagesProcessWithSSLCertFile(t *testing.T, listeners []ListenSpec, sslCertFile string) { + t.Setenv("SSL_CERT_FILE", sslCertFile) + RunPagesProcess(t, withListeners(listeners), withArguments([]string{ "-config=" + defaultAuthConfig(t), }), - withEnv([]string{"SSL_CERT_FILE=" + sslCertFile}), ) } @@ -282,6 +295,8 @@ func RunPagesProcessWithSSLCertDir(t *testing.T, listeners []ListenSpec, sslCert // Create temporary cert dir sslCertDir := t.TempDir() + t.Setenv("SSL_CERT_DIR", sslCertDir) + // Copy sslCertFile into temp cert dir err := copyFile(sslCertDir+"/"+path.Base(sslCertFile), sslCertFile) require.NoError(t, err) @@ -291,11 +306,10 @@ func RunPagesProcessWithSSLCertDir(t *testing.T, listeners []ListenSpec, sslCert withArguments([]string{ "-config=" + defaultAuthConfig(t), }), - withEnv([]string{"SSL_CERT_DIR=" + sslCertDir}), ) } -func runPagesProcess(t *testing.T, wait bool, pagesBinary string, listeners []ListenSpec, promPort string, extraEnv []string, extraArgs ...string) (*LogCaptureBuffer, func()) { +func runPagesProcess(t *testing.T, wait bool, pagesBinary string, listeners []ListenSpec, promPort string, extraArgs ...string) (*LogCaptureBuffer, func()) { t.Helper() _, err := os.Stat(pagesBinary) @@ -306,7 +320,6 @@ func runPagesProcess(t *testing.T, wait bool, pagesBinary string, listeners []Li args := getPagesArgs(t, listeners, promPort, extraArgs) cmd := exec.Command(pagesBinary, args...) - cmd.Env = append(os.Environ(), extraEnv...) cmd.Stdout = out cmd.Stderr = out require.NoError(t, cmd.Start()) @@ -339,6 +352,7 @@ func runPagesProcess(t *testing.T, wait bool, pagesBinary string, listeners []Li func getPagesArgs(t *testing.T, listeners []ListenSpec, promPort string, extraArgs []string) (args []string) { var hasHTTPS bool args = append(args, "-log-verbose=true") + args = append(args, "-pages-status=/@healthcheck") for _, spec := range listeners { if spec.Type == "unix" { diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index ef49aa19..85e0f6aa 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -16,19 +16,14 @@ var ratelimitedListeners = map[string]struct { listener ListenSpec header http.Header clientIP string - // We perform requests to server while we're waiting for it to boot up, - // successful request gets counted in IP rate limit - includeWaitRequest bool }{ "http_listener": { - listener: httpListener, - clientIP: "127.0.0.1", - includeWaitRequest: true, + listener: httpListener, + clientIP: "127.0.0.1", }, "https_listener": { - listener: httpsListener, - clientIP: "127.0.0.1", - includeWaitRequest: true, + listener: httpsListener, + clientIP: "127.0.0.1", }, "proxy_listener": { listener: proxyListener, @@ -39,9 +34,8 @@ var ratelimitedListeners = map[string]struct { clientIP: "172.16.123.1", }, "proxyv2_listener": { - listener: httpsProxyv2Listener, - clientIP: "10.1.1.1", - includeWaitRequest: true, + listener: httpsProxyv2Listener, + clientIP: "10.1.1.1", }, } @@ -57,10 +51,6 @@ func TestIPRateLimits(t *testing.T) { withExtraArgument("rate-limit-source-ip-burst", fmt.Sprint(rateLimit)), ) - if tc.includeWaitRequest { - rateLimit-- // we've already used one of requests while checking if server is up - } - for i := 0; i < 10; i++ { rsp, err := GetPageFromListenerWithHeaders(t, tc.listener, "group.gitlab-example.com", "project/", tc.header) require.NoError(t, err) diff --git a/test/acceptance/redirects_test.go b/test/acceptance/redirects_test.go index 5846d2cd..a2bdde53 100644 --- a/test/acceptance/redirects_test.go +++ b/test/acceptance/redirects_test.go @@ -13,9 +13,10 @@ import ( ) func TestRedirectStatusPage(t *testing.T) { + t.Setenv(feature.RedirectsPlaceholders.EnvVariable, "true") + RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - withEnv([]string{feature.RedirectsPlaceholders.EnvVariable + "=true"}), ) rsp, err := GetPageFromListener(t, httpListener, "group.redirects.gitlab-example.com", "/project-redirects/_redirects") @@ -30,9 +31,10 @@ func TestRedirectStatusPage(t *testing.T) { } func TestRedirect(t *testing.T) { + t.Setenv(feature.RedirectsPlaceholders.EnvVariable, "true") + RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - withEnv([]string{feature.RedirectsPlaceholders.EnvVariable + "=true"}), ) // Test that serving a file still works with redirects enabled diff --git a/test/acceptance/rewrites_test.go b/test/acceptance/rewrites_test.go index eefb1e82..cb11c470 100644 --- a/test/acceptance/rewrites_test.go +++ b/test/acceptance/rewrites_test.go @@ -12,9 +12,10 @@ import ( ) func TestRewrites(t *testing.T) { + t.Setenv(feature.RedirectsPlaceholders.EnvVariable, "true") + RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - withEnv([]string{feature.RedirectsPlaceholders.EnvVariable + "=true"}), ) tests := map[string]struct { diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 410e5ca0..00289c58 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -486,11 +486,14 @@ func TestServerRepliesWithHeaders(t *testing.T) { } for name, test := range tests { - testFn := func(envArgs, headerArgs []string) func(*testing.T) { + testFn := func(headerEnv string, headerArgs []string) func(*testing.T) { return func(t *testing.T) { + if headerEnv != "" { + t.Setenv("HEADER", headerEnv) + } + RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - withEnv(envArgs), withArguments(headerArgs), ) @@ -509,7 +512,7 @@ func TestServerRepliesWithHeaders(t *testing.T) { t.Run(name+"/from_single_flag", func(t *testing.T) { args := []string{"-header", strings.Join(test.flags, ";;")} - testFn([]string{}, args) + testFn("", args) }) t.Run(name+"/from_multiple_flags", func(t *testing.T) { @@ -518,18 +521,17 @@ func TestServerRepliesWithHeaders(t *testing.T) { args = append(args, "-header", arg) } - testFn([]string{}, args) + testFn("", args) }) t.Run(name+"/from_config_file", func(t *testing.T) { file := newConfigFile(t, "-header="+strings.Join(test.flags, ";;")) - testFn([]string{}, []string{"-config", file}) + testFn("", []string{"-config", file}) }) t.Run(name+"/from_env", func(t *testing.T) { - args := []string{"header", strings.Join(test.flags, ";;")} - testFn(args, []string{}) + testFn(strings.Join(test.flags, ";;"), []string{}) }) } } diff --git a/test/acceptance/status_test.go b/test/acceptance/status_test.go index c48aaff7..a57ae043 100644 --- a/test/acceptance/status_test.go +++ b/test/acceptance/status_test.go @@ -12,10 +12,9 @@ import ( func TestStatusPage(t *testing.T) { RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - withExtraArgument("pages-status", "/@statuscheck"), ) - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") + rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@healthcheck") require.NoError(t, err) testhelpers.Close(t, rsp.Body) require.Equal(t, http.StatusOK, rsp.StatusCode) diff --git a/test/acceptance/stub_test.go b/test/acceptance/stub_test.go index 3d54b4d4..3ab6def7 100644 --- a/test/acceptance/stub_test.go +++ b/test/acceptance/stub_test.go @@ -15,7 +15,6 @@ var defaultProcessConfig = processConfig{ wait: true, pagesBinary: *pagesBinary, listeners: supportedListeners(), - envs: []string{}, extraArgs: []string{}, gitlabStubOpts: []gitlabstub.Option{}, } @@ -24,9 +23,9 @@ type processConfig struct { wait bool pagesBinary string listeners []ListenSpec - envs []string extraArgs []string gitlabStubOpts []gitlabstub.Option + publicServer bool } type processOption func(*processConfig) @@ -41,23 +40,22 @@ func withListeners(listeners []ListenSpec) processOption { } } -func withEnv(envs []string) processOption { - return func(config *processConfig) { - config.envs = append(config.envs, envs...) - } -} - func withExtraArgument(key, value string) processOption { return func(config *processConfig) { config.extraArgs = append(config.extraArgs, fmt.Sprintf("-%s=%s", key, value)) } } + func withArguments(args []string) processOption { return func(config *processConfig) { config.extraArgs = append(config.extraArgs, args...) } } +func withPublicServer(config *processConfig) { + config.publicServer = true +} + func withStubOptions(opts ...gitlabstub.Option) processOption { return func(config *processConfig) { config.gitlabStubOpts = opts diff --git a/test/gitlabstub/cmd/server/main.go b/test/gitlabstub/cmd/server/main.go index 3e33daaa..9820b722 100644 --- a/test/gitlabstub/cmd/server/main.go +++ b/test/gitlabstub/cmd/server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "errors" "flag" "log" @@ -16,11 +17,25 @@ import ( var ( pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") + keyFile = flag.String("key-file", "", "Path to file certificate") + certFile = flag.String("cert-file", "", "Path to file certificate") ) func main() { flag.Parse() + var opts []gitlabstub.Option + + if *keyFile != "" && *certFile != "" { + log.Printf("Loading key pair: (%s) - (%s)", *certFile, *keyFile) + cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) + if err != nil { + log.Fatalf("error loading certificate: %v", err) + } + + opts = append(opts, gitlabstub.WithCertificate(cert)) + } + if err := os.Chdir(*pagesRoot); err != nil { log.Fatalf("error chdir in %s: %v", *pagesRoot, err) } @@ -30,12 +45,18 @@ func main() { log.Fatalf("error getting current dir: %v", err) } - server, err := gitlabstub.NewUnstartedServer(gitlabstub.WithPagesRoot(wd)) + opts = append(opts, gitlabstub.WithPagesRoot(wd)) + + server, err := gitlabstub.NewUnstartedServer(opts...) if err != nil { log.Fatalf("error starting the server: %v", err) } - server.Start() + if server.TLS != nil { + server.StartTLS() + } else { + server.Start() + } log.Printf("listening on %s\n", server.URL) diff --git a/test/gitlabstub/option.go b/test/gitlabstub/option.go index d55abec2..366aeb5d 100644 --- a/test/gitlabstub/option.go +++ b/test/gitlabstub/option.go @@ -1,6 +1,7 @@ package gitlabstub import ( + "crypto/tls" "net/http" "time" ) @@ -9,10 +10,17 @@ type config struct { pagesHandler http.HandlerFunc pagesRoot string delay time.Duration + tlsConfig *tls.Config } type Option func(*config) +func defaultTLSConfig() *tls.Config { + return &tls.Config{ + MinVersion: tls.VersionTLS12, + } +} + func WithPagesHandler(ph http.HandlerFunc) Option { return func(sc *config) { sc.pagesHandler = ph @@ -30,3 +38,12 @@ func WithDelay(delay time.Duration) Option { sc.delay = delay } } + +func WithCertificate(cert tls.Certificate) Option { + return func(c *config) { + if c.tlsConfig == nil { + c.tlsConfig = defaultTLSConfig() + } + c.tlsConfig.Certificates = append(c.tlsConfig.Certificates, cert) + } +} diff --git a/test/gitlabstub/server.go b/test/gitlabstub/server.go index 5cf3dacf..74c75067 100644 --- a/test/gitlabstub/server.go +++ b/test/gitlabstub/server.go @@ -39,5 +39,8 @@ func NewUnstartedServer(opts ...Option) (*httptest.Server, error) { router.PathPrefix("/").HandlerFunc(handleAccessControlArtifactRequests) - return httptest.NewUnstartedServer(router), nil + s := httptest.NewUnstartedServer(router) + s.TLS = conf.tlsConfig + + return s, nil } |