From e59ec12bcc06b10860c811bc3810a6320b9855d2 Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Mon, 16 Nov 2020 13:56:22 +1100 Subject: Move headers to middleware --- app.go | 6 +- internal/config/headers.go | 32 --------- internal/config/headers_test.go | 126 ------------------------------------ internal/middleware/headers.go | 31 +++++++++ internal/middleware/headers_test.go | 126 ++++++++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 161 deletions(-) delete mode 100644 internal/config/headers.go delete mode 100644 internal/config/headers_test.go create mode 100644 internal/middleware/headers.go create mode 100644 internal/middleware/headers_test.go diff --git a/app.go b/app.go index ca495073..a802f96d 100644 --- a/app.go +++ b/app.go @@ -21,11 +21,11 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/acme" "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" - headerConfig "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" + "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" "gitlab.com/gitlab-org/gitlab-pages/internal/request" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -186,7 +186,7 @@ func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, erro // customHeadersMiddleware will inject custom headers into the response func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - headerConfig.AddCustomHeaders(w, a.CustomHeaders) + middleware.AddCustomHeaders(w, a.CustomHeaders) handler.ServeHTTP(w, r) }) @@ -493,7 +493,7 @@ func runApp(config appConfig) { } if len(config.CustomHeaders) != 0 { - customHeaders, err := headerConfig.ParseHeaderString(config.CustomHeaders) + customHeaders, err := middleware.ParseHeaderString(config.CustomHeaders) if err != nil { log.WithError(err).Fatal("Unable to parse header string") } diff --git a/internal/config/headers.go b/internal/config/headers.go deleted file mode 100644 index d415b21d..00000000 --- a/internal/config/headers.go +++ /dev/null @@ -1,32 +0,0 @@ -package config - -import ( - "errors" - "net/http" - "strings" -) - -var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") - -// AddCustomHeaders adds a map of Headers to a Response -func AddCustomHeaders(w http.ResponseWriter, headers http.Header) error { - for k, v := range headers { - for _, value := range v { - w.Header().Add(k, value) - } - } - return nil -} - -// ParseHeaderString parses a string of key values into a map -func ParseHeaderString(customHeaders []string) (http.Header, error) { - headers := http.Header{} - for _, keyValueString := range customHeaders { - keyValue := strings.SplitN(keyValueString, ":", 2) - if len(keyValue) != 2 { - return nil, errInvalidHeaderParameter - } - headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) - } - return headers, nil -} diff --git a/internal/config/headers_test.go b/internal/config/headers_test.go deleted file mode 100644 index 44afd470..00000000 --- a/internal/config/headers_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package config - -import ( - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseHeaderString(t *testing.T) { - tests := []struct { - name string - headerStrings []string - valid bool - }{{ - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - valid: true, - }, - { - name: "Whitespace trim case", - headerStrings: []string{" X-Test-String : Test "}, - valid: true, - }, - { - name: "Whitespace in key, value case", - headerStrings: []string{"My amazing header: This is a test"}, - valid: true, - }, - { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - valid: true, - }, - { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - valid: true, - }, - { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, - valid: true, - }, - { - name: "Multiple invalid cases", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"Tk= N"}, - valid: false, - }, - { - name: "Not valid case", - headerStrings: []string{"X-Test-String Some-Test"}, - valid: false, - }, - { - name: "Valid and not valid case", - headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, - valid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ParseHeaderString(tt.headerStrings) - if tt.valid { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) - } -} - -func TestAddCustomHeaders(t *testing.T) { - tests := []struct { - name string - headerStrings []string - wantHeaders map[string]string - }{{ - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - wantHeaders: map[string]string{"X-Test-String": "Test"}, - }, - { - name: "Whitespace trim case", - headerStrings: []string{" X-Test-String : Test "}, - wantHeaders: map[string]string{"X-Test-String": "Test"}, - }, - { - name: "Whitespace in key, value case", - headerStrings: []string{"My amazing header: This is a test"}, - wantHeaders: map[string]string{"My amazing header": "This is a test"}, - }, - { - name: "Non-tracking header case", - headerStrings: []string{"Tk: N"}, - wantHeaders: map[string]string{"Tk": "N"}, - }, - { - name: "Content security header case", - headerStrings: []string{"content-security-policy: default-src 'self'"}, - wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"}, - }, - { - name: "Multiple header strings", - headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, - wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - headers, _ := ParseHeaderString(tt.headerStrings) - w := httptest.NewRecorder() - AddCustomHeaders(w, headers) - for k, v := range tt.wantHeaders { - require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k)) - } - }) - } -} diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go new file mode 100644 index 00000000..77b008f3 --- /dev/null +++ b/internal/middleware/headers.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "errors" + "net/http" + "strings" +) + +var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + +// AddCustomHeaders adds a map of Headers to a Response +func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { + for k, v := range headers { + for _, value := range v { + w.Header().Add(k, value) + } + } +} + +// ParseHeaderString parses a string of key values into a map +func ParseHeaderString(customHeaders []string) (http.Header, error) { + headers := http.Header{} + for _, keyValueString := range customHeaders { + keyValue := strings.SplitN(keyValueString, ":", 2) + if len(keyValue) != 2 { + return nil, errInvalidHeaderParameter + } + headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) + } + return headers, nil +} diff --git a/internal/middleware/headers_test.go b/internal/middleware/headers_test.go new file mode 100644 index 00000000..17d31b50 --- /dev/null +++ b/internal/middleware/headers_test.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseHeaderString(t *testing.T) { + tests := []struct { + name string + headerStrings []string + valid bool + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + valid: true, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + valid: true, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + valid: true, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + valid: true, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + valid: true, + }, + { + name: "Multiple invalid cases", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"Tk= N"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"X-Test-String Some-Test"}, + valid: false, + }, + { + name: "Valid and not valid case", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseHeaderString(tt.headerStrings) + if tt.valid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestAddCustomHeaders(t *testing.T) { + tests := []struct { + name string + headerStrings []string + wantHeaders map[string]string + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + wantHeaders: map[string]string{"My amazing header": "This is a test"}, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + wantHeaders: map[string]string{"Tk": "N"}, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"}, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers, _ := ParseHeaderString(tt.headerStrings) + w := httptest.NewRecorder() + AddCustomHeaders(w, headers) + for k, v := range tt.wantHeaders { + require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k)) + } + }) + } +} -- cgit v1.2.3