From b78029c60fe8138fd49aa560b9c4e26279049c5f Mon Sep 17 00:00:00 2001 From: feistel <6742251-feistel@users.noreply.gitlab.com> Date: Sat, 25 Jun 2022 14:06:59 +0200 Subject: Move custom headers parsing into config loading --- internal/config/config.go | 49 +++++++++++- internal/config/config_test.go | 78 ++++++++++++++++++ internal/customheaders/customheaders.go | 42 ---------- internal/customheaders/customheaders_test.go | 113 ++++----------------------- internal/customheaders/middleware.go | 4 + 5 files changed, 144 insertions(+), 142 deletions(-) (limited to 'internal') diff --git a/internal/config/config.go b/internal/config/config.go index 18fe33eb..27581d83 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,14 +1,18 @@ package config import ( + "bufio" "crypto/tls" "encoding/base64" "errors" "fmt" + "net/http" + "net/textproto" "os" "strings" "time" + "github.com/hashicorp/go-multierror" "github.com/namsral/flag" "gitlab.com/gitlab-org/labkit/log" ) @@ -55,7 +59,7 @@ type General struct { ShowVersion bool - CustomHeaders []string + CustomHeaders http.Header } // RateLimit config struct @@ -154,8 +158,10 @@ type Metrics struct { } var ( - errMetricsNoCertificate = errors.New("metrics certificate path must not be empty") - errMetricsNoKey = errors.New("metrics private key path must not be empty") + errDuplicateHeader = errors.New("duplicate header") + errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + errMetricsNoCertificate = errors.New("metrics certificate path must not be empty") + errMetricsNoKey = errors.New("metrics private key path must not be empty") ) func internalGitlabServerFromFlags() string { @@ -231,6 +237,35 @@ func loadMetricsConfig() (metrics Metrics, err error) { return metrics, nil } +func parseHeaderString(customHeaders []string) (http.Header, error) { + headers := make(http.Header, len(customHeaders)) + + var result *multierror.Error + for _, h := range customHeaders { + h = h + "\n\n" + tp := textproto.NewReader(bufio.NewReader(strings.NewReader(h))) + + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + result = multierror.Append(result, fmt.Errorf("parsing error %s: %w", h, errInvalidHeaderParameter)) + } + + for key, value := range mimeHeader { + if _, ok := headers[key]; ok { + result = multierror.Append(result, fmt.Errorf("%s already specified with value '%s': %w", key, value, errDuplicateHeader)) + } + + headers[key] = value + } + } + + if result.ErrorOrNil() != nil { + return nil, result + } + + return headers, nil +} + func loadConfig() (*Config, error) { config := &Config{ General: General{ @@ -244,7 +279,6 @@ func loadConfig() (*Config, error) { DisableCrossOriginRequests: *disableCrossOriginRequests, InsecureCiphers: *insecureCiphers, PropagateCorrelationID: *propagateCorrelationID, - CustomHeaders: header.Split(), ShowVersion: *showVersion, }, RateLimit: RateLimit{ @@ -340,6 +374,13 @@ func loadConfig() (*Config, error) { } } + customHeaders, err := parseHeaderString(header.Split()) + if err != nil { + return nil, fmt.Errorf("unable to parse header string: %w", err) + } + + config.General.CustomHeaders = customHeaders + // Populating remaining GitLab settings config.GitLab.PublicServer = *publicGitLabServer diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f120d17d..19dfb957 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -83,3 +83,81 @@ func setupHTTPSFixture(t *testing.T) (dir string, key string, cert string) { return tmpDir, keyfile.Name(), certfile.Name() } + +func TestParseHeaderString(t *testing.T) { + tests := []struct { + name string + headerStrings []string + valid bool + expectedLen int + }{ + { + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + expectedLen: 1, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + valid: true, + expectedLen: 1, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + valid: true, + expectedLen: 1, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + valid: true, + expectedLen: 3, + }, + { + name: "Multiple invalid cases", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"Tk= N"}, + valid: false, + }, + { + name: "duplicate headers", + headerStrings: []string{"Tk: N", "Tk: M"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"X-Test-String Some-Test"}, + valid: false, + }, + { + name: "Valid and not valid case", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Multiple headers in single string parsed as one header", + headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, + valid: true, + expectedLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHeaderString(tt.headerStrings) + if tt.valid { + require.NoError(t, err) + require.Len(t, got, tt.expectedLen) + return + } + + require.Error(t, err) + }) + } +} diff --git a/internal/customheaders/customheaders.go b/internal/customheaders/customheaders.go index 92f50069..e7f8cb91 100644 --- a/internal/customheaders/customheaders.go +++ b/internal/customheaders/customheaders.go @@ -1,19 +1,7 @@ package customheaders import ( - "bufio" - "errors" - "fmt" "net/http" - "net/textproto" - "strings" - - "github.com/hashicorp/go-multierror" -) - -var ( - errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") - errDuplicateHeader = errors.New("duplicate header") ) // AddCustomHeaders adds a map of Headers to a Response @@ -24,33 +12,3 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { } } } - -// ParseHeaderString parses a string of key values into a map -func ParseHeaderString(customHeaders []string) (http.Header, error) { - headers := make(http.Header, len(customHeaders)) - - var result *multierror.Error - for _, h := range customHeaders { - h = h + "\n\n" - tp := textproto.NewReader(bufio.NewReader(strings.NewReader(h))) - - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil { - result = multierror.Append(result, fmt.Errorf("parsing error %s: %w", h, errInvalidHeaderParameter)) - } - - for key, value := range mimeHeader { - if _, ok := headers[key]; ok { - result = multierror.Append(result, fmt.Errorf("%s already specified with value '%s': %w", key, value, errDuplicateHeader)) - } - - headers[key] = value - } - } - - if result.ErrorOrNil() != nil { - return nil, result - } - - return headers, nil -} diff --git a/internal/customheaders/customheaders_test.go b/internal/customheaders/customheaders_test.go index 857c45e0..f0252137 100644 --- a/internal/customheaders/customheaders_test.go +++ b/internal/customheaders/customheaders_test.go @@ -1,6 +1,7 @@ package customheaders_test import ( + "net/http" "net/http/httptest" "testing" @@ -9,118 +10,38 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" ) -func TestParseHeaderString(t *testing.T) { - tests := []struct { - name string - headerStrings []string - valid bool - expectedLen int - }{ - { - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - valid: true, - expectedLen: 1, - }, - { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - valid: true, - expectedLen: 1, - }, - { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - valid: true, - expectedLen: 1, - }, - { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, - valid: true, - expectedLen: 3, - }, - { - name: "Multiple invalid cases", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"Tk= N"}, - valid: false, - }, - { - name: "duplicate headers", - headerStrings: []string{"Tk: N", "Tk: M"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"X-Test-String Some-Test"}, - valid: false, - }, - { - name: "Valid and not valid case", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - { - name: "Multiple headers in single string parsed as one header", - headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, - valid: true, - expectedLen: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := customheaders.ParseHeaderString(tt.headerStrings) - if tt.valid { - require.NoError(t, err) - require.Len(t, got, tt.expectedLen) - return - } - - require.Error(t, err) - }) - } -} - func TestAddCustomHeaders(t *testing.T) { tests := []struct { - name string - headerStrings []string - wantHeaders map[string]string + name string + headers http.Header + wantHeaders map[string]string }{ { - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - wantHeaders: map[string]string{"X-Test-String": "Test"}, + name: "Normal case", + headers: http.Header{"X-Test-String": []string{"Test"}}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, }, { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - wantHeaders: map[string]string{"Tk": "N"}, + name: "Non-tracking header case", + headers: http.Header{"Tk": []string{"N"}}, + wantHeaders: map[string]string{"Tk": "N"}, }, { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'"}, + name: "Content security header case", + headers: http.Header{"content-security-policy": []string{"default-src 'self'"}}, + wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'"}, }, { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header: Amazing"}, - wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, + name: "Multiple header strings", + headers: http.Header{"content-security-policy": []string{"default-src 'self'"}, "X-Test-String": []string{"Test"}, "My amazing header": []string{"Amazing"}}, + wantHeaders: map[string]string{"Content-Security-Policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - headers, err := customheaders.ParseHeaderString(tt.headerStrings) - require.NoError(t, err) w := httptest.NewRecorder() - customheaders.AddCustomHeaders(w, headers) + customheaders.AddCustomHeaders(w, tt.headers) rsp := w.Result() for k, v := range tt.wantHeaders { require.Len(t, rsp.Header[k], 1) diff --git a/internal/customheaders/middleware.go b/internal/customheaders/middleware.go index e0e11ad6..964b822a 100644 --- a/internal/customheaders/middleware.go +++ b/internal/customheaders/middleware.go @@ -6,6 +6,10 @@ import ( // NewMiddleware returns middleware which inject custom headers into the response func NewMiddleware(handler http.Handler, headers http.Header) http.Handler { + if len(headers) == 0 { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { AddCustomHeaders(w, headers) -- cgit v1.2.3