Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2021-01-28 07:02:37 +0300
committerJaime Martinez <jmartinez@gitlab.com>2021-01-28 07:02:37 +0300
commit37d15f3a1d1114db84da12920138b9ee89e7596a (patch)
treeb31cee15e614dcbaf968fbed9d9b1c170dfb63f7
parent09f840543a841b6e9d8e8b17ded6b24e3fb22590 (diff)
Restore splitting and update tests
-rw-r--r--internal/middleware/headers.go2
-rw-r--r--internal/middleware/headers_test.go31
-rw-r--r--multi_string_flag.go20
3 files changed, 32 insertions, 21 deletions
diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go
index 246b1aa9..837dbe3b 100644
--- a/internal/middleware/headers.go
+++ b/internal/middleware/headers.go
@@ -21,7 +21,7 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) {
func ParseHeaderString(customHeaders []string) (http.Header, error) {
headers := http.Header{}
for _, keyValueString := range customHeaders {
- keyValue := strings.SplitN(keyValueString, ":", -1)
+ keyValue := strings.SplitN(keyValueString, ":", 2)
if len(keyValue) != 2 {
return nil, errInvalidHeaderParameter
}
diff --git a/internal/middleware/headers_test.go b/internal/middleware/headers_test.go
index 6c6078ee..1f3d98c6 100644
--- a/internal/middleware/headers_test.go
+++ b/internal/middleware/headers_test.go
@@ -12,35 +12,43 @@ func TestParseHeaderString(t *testing.T) {
name string
headerStrings []string
valid bool
- }{{
- name: "Normal case",
- headerStrings: []string{"X-Test-String: Test"},
- valid: true,
- },
+ expectedLen int
+ }{
+ {
+ name: "Normal case",
+ headerStrings: []string{"X-Test-String: Test"},
+ valid: true,
+ expectedLen: 1,
+ },
{
name: "Whitespace trim case",
headerStrings: []string{" X-Test-String : Test "},
valid: true,
+ expectedLen: 1,
},
{
name: "Whitespace in key, value case",
headerStrings: []string{"My amazing header: This is a test"},
valid: true,
+ expectedLen: 1,
},
{
name: "Non-tracking header case",
headerStrings: []string{"Tk: N"},
valid: true,
+ expectedLen: 1,
},
{
name: "Content security header case",
headerStrings: []string{"content-security-policy: default-src 'self'"},
valid: true,
+ expectedLen: 1,
},
{
name: "Multiple header strings",
headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"},
valid: true,
+ expectedLen: 3,
},
{
name: "Multiple invalid cases",
@@ -63,20 +71,23 @@ func TestParseHeaderString(t *testing.T) {
valid: false,
},
{
- name: "Multiple headers in single string",
+ name: "Multiple headers in single string parsed as one header",
headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"},
- valid: false,
+ valid: true,
+ expectedLen: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- _, err := ParseHeaderString(tt.headerStrings)
+ got, err := ParseHeaderString(tt.headerStrings)
if tt.valid {
require.NoError(t, err)
- } else {
- require.Error(t, err)
+ require.Len(t, got, tt.expectedLen)
+ return
}
+
+ require.Error(t, err)
})
}
}
diff --git a/multi_string_flag.go b/multi_string_flag.go
index ab22b14d..1be02ef1 100644
--- a/multi_string_flag.go
+++ b/multi_string_flag.go
@@ -20,11 +20,7 @@ type MultiStringFlag struct {
// String returns the list of parameters joined with a commas (",")
func (s *MultiStringFlag) String() string {
- if s.separator == "" {
- s.separator = defaultSeparator
- }
-
- return strings.Join(s.value, s.separator)
+ return strings.Join(s.value, s.sep())
}
// Set appends the value to the list of parameters
@@ -39,13 +35,17 @@ func (s *MultiStringFlag) Set(value string) error {
// Split each flag
func (s *MultiStringFlag) Split() (result []string) {
- if s.separator == "" {
- s.separator = defaultSeparator
- }
-
for _, str := range s.value {
- result = append(result, strings.Split(str, s.separator)...)
+ result = append(result, strings.Split(str, s.sep())...)
}
return
}
+
+func (s *MultiStringFlag) sep() string {
+ if s.separator == "" {
+ return defaultSeparator
+ }
+
+ return s.separator
+}