diff options
-rw-r--r-- | README.md | 11 | ||||
-rw-r--r-- | acceptance_test.go | 13 | ||||
-rw-r--r-- | app.go | 13 | ||||
-rw-r--r-- | app_config.go | 12 | ||||
-rw-r--r-- | internal/config/config.go | 32 | ||||
-rw-r--r-- | internal/config/config_test.go | 126 | ||||
-rw-r--r-- | main.go | 47 |
7 files changed, 226 insertions, 28 deletions
@@ -234,6 +234,17 @@ values are `ssl3`, `tls1.0`, `tls1.1`, `tls1.2`, and `tls1.3` (if supported). Wh is used GitLab Pages will add `tls13=1` to `GODEBUG` to enable TLS 1.3. See https://golang.org/src/crypto/tls/tls.go for more. +### Custom headers + +To specify custom headers that should be send with every request on GitLab pages use the `-header` argument. + +You can add as many headers as you like. + +Example: +```sh +./gitlab-pages -header "Content-Security-Policy: default-src 'self' *.example.com" -header "X-Test: Testing" ... +``` + ### Configuration The daemon can be configured with any combination of these methods: diff --git a/acceptance_test.go b/acceptance_test.go index eaa32318..28ebdeb8 100644 --- a/acceptance_test.go +++ b/acceptance_test.go @@ -241,6 +241,19 @@ func TestCORSForbidsPOST(t *testing.T) { } } +func TestCustomHeaders(t *testing.T) { + teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-header", "X-Test1:Testing1", "-header", "X-Test2:Testing2") + defer teardown() + + for _, spec := range listeners { + rsp, err := GetPageFromListener(t, spec, "group.gitlab-example.com:", "project/") + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rsp.StatusCode) + assert.Equal(t, "Testing1", rsp.Header.Get("X-Test1")) + assert.Equal(t, "Testing2", rsp.Header.Get("X-Test2")) + } +} + func doCrossOriginRequest(t *testing.T, method, reqMethod, url string) *http.Response { req, err := http.NewRequest(method, url, nil) require.NoError(t, err) @@ -23,6 +23,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/admin" "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" + headerConfig "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" @@ -51,6 +52,7 @@ type theApp struct { Artifact *artifact.Artifact Auth *auth.Auth AcmeMiddleware *acme.Middleware + CustomHeaders http.Header } func (a *theApp) isReady() bool { @@ -229,6 +231,8 @@ func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc { func (a *theApp) ServeHTTP(ww http.ResponseWriter, r *http.Request) { https := r.TLS != nil + headerConfig.AddCustomHeaders(ww, a.CustomHeaders) + a.serveContent(ww, r, https) } @@ -239,6 +243,7 @@ func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) { if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" { r.Host = forwardedHost } + headerConfig.AddCustomHeaders(ww, a.CustomHeaders) a.serveContent(ww, r, https) } @@ -380,6 +385,14 @@ func runApp(config appConfig) { a.AcmeMiddleware = &acme.Middleware{GitlabURL: config.GitLabServer} } + if len(config.CustomHeaders) != 0 { + customHeaders, err := headerConfig.ParseHeaderString(config.CustomHeaders) + if err != nil { + log.Fatal(err) + } + a.CustomHeaders = customHeaders + } + configureLogging(config.LogFormat, config.LogVerbose) if err := mimedb.LoadTypes(); err != nil { diff --git a/app_config.go b/app_config.go index 8b9eb303..00469db6 100644 --- a/app_config.go +++ b/app_config.go @@ -30,12 +30,12 @@ type appConfig struct { LogFormat string LogVerbose bool - StoreSecret string - GitLabServer string - ClientID string - ClientSecret string - RedirectURI string - + StoreSecret string + GitLabServer string + ClientID string + ClientSecret string + RedirectURI string SentryDSN string SentryEnvironment string + CustomHeaders []string } diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..d415b21d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,32 @@ +package config + +import ( + "errors" + "net/http" + "strings" +) + +var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + +// AddCustomHeaders adds a map of Headers to a Response +func AddCustomHeaders(w http.ResponseWriter, headers http.Header) error { + for k, v := range headers { + for _, value := range v { + w.Header().Add(k, value) + } + } + return nil +} + +// ParseHeaderString parses a string of key values into a map +func ParseHeaderString(customHeaders []string) (http.Header, error) { + headers := http.Header{} + for _, keyValueString := range customHeaders { + keyValue := strings.SplitN(keyValueString, ":", 2) + if len(keyValue) != 2 { + return nil, errInvalidHeaderParameter + } + headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) + } + return headers, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..44afd470 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,126 @@ +package config + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseHeaderString(t *testing.T) { + tests := []struct { + name string + headerStrings []string + valid bool + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + valid: true, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + valid: true, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + valid: true, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + valid: true, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + valid: true, + }, + { + name: "Multiple invalid cases", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"Tk= N"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"X-Test-String Some-Test"}, + valid: false, + }, + { + name: "Valid and not valid case", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseHeaderString(tt.headerStrings) + if tt.valid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestAddCustomHeaders(t *testing.T) { + tests := []struct { + name string + headerStrings []string + wantHeaders map[string]string + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + wantHeaders: map[string]string{"My amazing header": "This is a test"}, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + wantHeaders: map[string]string{"Tk": "N"}, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"}, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers, _ := ParseHeaderString(tt.headerStrings) + w := httptest.NewRecorder() + AddCustomHeaders(w, headers) + for k, v := range tt.wantHeaders { + require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k)) + } + }) + } +} @@ -26,6 +26,7 @@ func init() { flag.Var(&listenHTTP, "listen-http", "The address(es) to listen on for HTTP requests") flag.Var(&listenHTTPS, "listen-https", "The address(es) to listen on for HTTPS requests") flag.Var(&listenProxy, "listen-proxy", "The address(es) to listen on for proxy requests") + flag.Var(&header, "header", "The additional http header(s) that should be send to the client") } var ( @@ -68,6 +69,8 @@ var ( listenHTTP MultiStringFlag listenHTTPS MultiStringFlag listenProxy MultiStringFlag + + header MultiStringFlag ) var ( @@ -95,6 +98,26 @@ func gitlabServerFromFlags() string { return host.FromString(url.Host) } +func setArtifactsServer(artifactsServer string, artifactsServerTimeout int, config *appConfig) { + u, err := url.Parse(artifactsServer) + if err != nil { + log.Fatal(err) + } + // url.Parse ensures that the Scheme arttribute is always lower case. + if u.Scheme != "http" && u.Scheme != "https" { + errortracking.Capture(err) + log.Fatal(errArtifactSchemaUnsupported) + } + + if artifactsServerTimeout < 1 { + errortracking.Capture(err) + log.Fatal(errArtifactsServerTimeoutValue) + } + + config.ArtifactsServerTimeout = artifactsServerTimeout + config.ArtifactsServer = artifactsServer +} + func configFromFlags() appConfig { var config appConfig @@ -110,6 +133,7 @@ func configFromFlags() appConfig { // tlsMinVersion and tlsMaxVersion are validated in appMain config.TLSMinVersion = tlsconfig.AllTLSVersions[*tlsMinVersion] config.TLSMaxVersion = tlsconfig.AllTLSVersions[*tlsMaxVersion] + config.CustomHeaders = header for _, file := range []struct { contents *[]byte @@ -126,29 +150,8 @@ func configFromFlags() appConfig { } } - if *artifactsServerTimeout < 1 { - errortracking.Capture(errArtifactsServerTimeoutValue) - log.Fatal(errArtifactsServerTimeoutValue) - } - if *artifactsServer != "" { - u, err := url.Parse(*artifactsServer) - if err != nil { - log.Fatal(err) - } - // url.Parse ensures that the Scheme arttribute is always lower case. - if u.Scheme != "http" && u.Scheme != "https" { - errortracking.Capture(err) - log.Fatal(errArtifactSchemaUnsupported) - } - - if *artifactsServerTimeout < 1 { - errortracking.Capture(err) - log.Fatal(errArtifactsServerTimeoutValue) - } - - config.ArtifactsServerTimeout = *artifactsServerTimeout - config.ArtifactsServer = *artifactsServer + setArtifactsServer(*artifactsServer, *artifactsServerTimeout, &config) } config.GitLabServer = gitlabServerFromFlags() |