diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-01-28 07:02:37 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-01-28 07:02:37 +0300 |
commit | 37d15f3a1d1114db84da12920138b9ee89e7596a (patch) | |
tree | b31cee15e614dcbaf968fbed9d9b1c170dfb63f7 | |
parent | 09f840543a841b6e9d8e8b17ded6b24e3fb22590 (diff) |
Restore splitting and update tests
-rw-r--r-- | internal/middleware/headers.go | 2 | ||||
-rw-r--r-- | internal/middleware/headers_test.go | 31 | ||||
-rw-r--r-- | multi_string_flag.go | 20 |
3 files changed, 32 insertions, 21 deletions
diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go index 246b1aa9..837dbe3b 100644 --- a/internal/middleware/headers.go +++ b/internal/middleware/headers.go @@ -21,7 +21,7 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { func ParseHeaderString(customHeaders []string) (http.Header, error) { headers := http.Header{} for _, keyValueString := range customHeaders { - keyValue := strings.SplitN(keyValueString, ":", -1) + keyValue := strings.SplitN(keyValueString, ":", 2) if len(keyValue) != 2 { return nil, errInvalidHeaderParameter } diff --git a/internal/middleware/headers_test.go b/internal/middleware/headers_test.go index 6c6078ee..1f3d98c6 100644 --- a/internal/middleware/headers_test.go +++ b/internal/middleware/headers_test.go @@ -12,35 +12,43 @@ func TestParseHeaderString(t *testing.T) { name string headerStrings []string valid bool - }{{ - name: "Normal case", - headerStrings: []string{"X-Test-String: Test"}, - valid: true, - }, + expectedLen int + }{ + { + name: "Normal case", + headerStrings: []string{"X-Test-String: Test"}, + valid: true, + expectedLen: 1, + }, { name: "Whitespace trim case", headerStrings: []string{" X-Test-String : Test "}, valid: true, + expectedLen: 1, }, { name: "Whitespace in key, value case", headerStrings: []string{"My amazing header: This is a test"}, valid: true, + expectedLen: 1, }, { name: "Non-tracking header case", headerStrings: []string{"Tk: N"}, valid: true, + expectedLen: 1, }, { name: "Content security header case", headerStrings: []string{"content-security-policy: default-src 'self'"}, valid: true, + expectedLen: 1, }, { name: "Multiple header strings", headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"}, valid: true, + expectedLen: 3, }, { name: "Multiple invalid cases", @@ -63,20 +71,23 @@ func TestParseHeaderString(t *testing.T) { valid: false, }, { - name: "Multiple headers in single string", + name: "Multiple headers in single string parsed as one header", headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, - valid: false, + valid: true, + expectedLen: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := ParseHeaderString(tt.headerStrings) + got, err := ParseHeaderString(tt.headerStrings) if tt.valid { require.NoError(t, err) - } else { - require.Error(t, err) + require.Len(t, got, tt.expectedLen) + return } + + require.Error(t, err) }) } } diff --git a/multi_string_flag.go b/multi_string_flag.go index ab22b14d..1be02ef1 100644 --- a/multi_string_flag.go +++ b/multi_string_flag.go @@ -20,11 +20,7 @@ type MultiStringFlag struct { // String returns the list of parameters joined with a commas (",") func (s *MultiStringFlag) String() string { - if s.separator == "" { - s.separator = defaultSeparator - } - - return strings.Join(s.value, s.separator) + return strings.Join(s.value, s.sep()) } // Set appends the value to the list of parameters @@ -39,13 +35,17 @@ func (s *MultiStringFlag) Set(value string) error { // Split each flag func (s *MultiStringFlag) Split() (result []string) { - if s.separator == "" { - s.separator = defaultSeparator - } - for _, str := range s.value { - result = append(result, strings.Split(str, s.separator)...) + result = append(result, strings.Split(str, s.sep())...) } return } + +func (s *MultiStringFlag) sep() string { + if s.separator == "" { + return defaultSeparator + } + + return s.separator +} |