diff options
author | Max Wittig <max.wittig@siemens.com> | 2019-06-28 13:35:29 +0300 |
---|---|---|
committer | Max Wittig <max.wittig@siemens.com> | 2019-07-12 15:37:16 +0300 |
commit | 5199c4c8b646f3e66b0f03dd51fbaa704d9fd94f (patch) | |
tree | c60baf57319ea68b7b7213c4bba07662ced06d86 /internal/config | |
parent | 7d39822ce2221156c3479ceae0e4f8e24c7373b2 (diff) |
feat: add flag to define custom response headers
Diffstat (limited to 'internal/config')
-rw-r--r-- | internal/config/config.go | 32 | ||||
-rw-r--r-- | internal/config/config_test.go | 126 |
2 files changed, 158 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..d415b21d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,32 @@ +package config + +import ( + "errors" + "net/http" + "strings" +) + +var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter") + +// AddCustomHeaders adds a map of Headers to a Response +func AddCustomHeaders(w http.ResponseWriter, headers http.Header) error { + for k, v := range headers { + for _, value := range v { + w.Header().Add(k, value) + } + } + return nil +} + +// ParseHeaderString parses a string of key values into a map +func ParseHeaderString(customHeaders []string) (http.Header, error) { + headers := http.Header{} + for _, keyValueString := range customHeaders { + keyValue := strings.SplitN(keyValueString, ":", 2) + if len(keyValue) != 2 { + return nil, errInvalidHeaderParameter + } + headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) + } + return headers, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..44afd470 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,126 @@ +package config + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseHeaderString(t *testing.T) { + tests := []struct { + name string + headerStrings []string + valid bool + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + valid: true, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + valid: true, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + valid: true, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + valid: true, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + valid: true, + }, + { + name: "Multiple invalid cases", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"Tk= N"}, + valid: false, + }, + { + name: "Not valid case", + headerStrings: []string{"X-Test-String Some-Test"}, + valid: false, + }, + { + name: "Valid and not valid case", + headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseHeaderString(tt.headerStrings) + if tt.valid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestAddCustomHeaders(t *testing.T) { + tests := []struct { + name string + headerStrings []string + wantHeaders map[string]string + }{{ + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace trim case", + headerStrings: []string{" X-Test-String : Test "}, + wantHeaders: map[string]string{"X-Test-String": "Test"}, + }, + { + name: "Whitespace in key, value case", + headerStrings: []string{"My amazing header: This is a test"}, + wantHeaders: map[string]string{"My amazing header": "This is a test"}, + }, + { + name: "Non-tracking header case", + headerStrings: []string{"Tk: N"}, + wantHeaders: map[string]string{"Tk": "N"}, + }, + { + name: "Content security header case", + headerStrings: []string{"content-security-policy: default-src 'self'"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"}, + }, + { + name: "Multiple header strings", + headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, + wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers, _ := ParseHeaderString(tt.headerStrings) + w := httptest.NewRecorder() + AddCustomHeaders(w, headers) + for k, v := range tt.wantHeaders { + require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k)) + } + }) + } +} |