Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMax Wittig <max.wittig@siemens.com>2019-06-28 13:35:29 +0300
committerMax Wittig <max.wittig@siemens.com>2019-07-12 15:37:16 +0300
commit5199c4c8b646f3e66b0f03dd51fbaa704d9fd94f (patch)
treec60baf57319ea68b7b7213c4bba07662ced06d86 /internal/config
parent7d39822ce2221156c3479ceae0e4f8e24c7373b2 (diff)
feat: add flag to define custom response headers
Diffstat (limited to 'internal/config')
-rw-r--r--internal/config/config.go32
-rw-r--r--internal/config/config_test.go126
2 files changed, 158 insertions, 0 deletions
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 00000000..d415b21d
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,32 @@
+package config
+
+import (
+ "errors"
+ "net/http"
+ "strings"
+)
+
+var errInvalidHeaderParameter = errors.New("invalid syntax specified as header parameter")
+
+// AddCustomHeaders adds a map of Headers to a Response
+func AddCustomHeaders(w http.ResponseWriter, headers http.Header) error {
+ for k, v := range headers {
+ for _, value := range v {
+ w.Header().Add(k, value)
+ }
+ }
+ return nil
+}
+
+// ParseHeaderString parses a string of key values into a map
+func ParseHeaderString(customHeaders []string) (http.Header, error) {
+ headers := http.Header{}
+ for _, keyValueString := range customHeaders {
+ keyValue := strings.SplitN(keyValueString, ":", 2)
+ if len(keyValue) != 2 {
+ return nil, errInvalidHeaderParameter
+ }
+ headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1]))
+ }
+ return headers, nil
+}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 00000000..44afd470
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,126 @@
+package config
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseHeaderString(t *testing.T) {
+ tests := []struct {
+ name string
+ headerStrings []string
+ valid bool
+ }{{
+ name: "Normal case",
+ headerStrings: []string{"X-Test-String: Test"},
+ valid: true,
+ },
+ {
+ name: "Whitespace trim case",
+ headerStrings: []string{" X-Test-String : Test "},
+ valid: true,
+ },
+ {
+ name: "Whitespace in key, value case",
+ headerStrings: []string{"My amazing header: This is a test"},
+ valid: true,
+ },
+ {
+ name: "Non-tracking header case",
+ headerStrings: []string{"Tk: N"},
+ valid: true,
+ },
+ {
+ name: "Content security header case",
+ headerStrings: []string{"content-security-policy: default-src 'self'"},
+ valid: true,
+ },
+ {
+ name: "Multiple header strings",
+ headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"},
+ valid: true,
+ },
+ {
+ name: "Multiple invalid cases",
+ headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"},
+ valid: false,
+ },
+ {
+ name: "Not valid case",
+ headerStrings: []string{"Tk= N"},
+ valid: false,
+ },
+ {
+ name: "Not valid case",
+ headerStrings: []string{"X-Test-String Some-Test"},
+ valid: false,
+ },
+ {
+ name: "Valid and not valid case",
+ headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"},
+ valid: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := ParseHeaderString(tt.headerStrings)
+ if tt.valid {
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ }
+ })
+ }
+}
+
+func TestAddCustomHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ headerStrings []string
+ wantHeaders map[string]string
+ }{{
+ name: "Normal case",
+ headerStrings: []string{"X-Test-String: Test"},
+ wantHeaders: map[string]string{"X-Test-String": "Test"},
+ },
+ {
+ name: "Whitespace trim case",
+ headerStrings: []string{" X-Test-String : Test "},
+ wantHeaders: map[string]string{"X-Test-String": "Test"},
+ },
+ {
+ name: "Whitespace in key, value case",
+ headerStrings: []string{"My amazing header: This is a test"},
+ wantHeaders: map[string]string{"My amazing header": "This is a test"},
+ },
+ {
+ name: "Non-tracking header case",
+ headerStrings: []string{"Tk: N"},
+ wantHeaders: map[string]string{"Tk": "N"},
+ },
+ {
+ name: "Content security header case",
+ headerStrings: []string{"content-security-policy: default-src 'self'"},
+ wantHeaders: map[string]string{"content-security-policy": "default-src 'self'"},
+ },
+ {
+ name: "Multiple header strings",
+ headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"},
+ wantHeaders: map[string]string{"content-security-policy": "default-src 'self'", "X-Test-String": "Test", "My amazing header": "Amazing"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ headers, _ := ParseHeaderString(tt.headerStrings)
+ w := httptest.NewRecorder()
+ AddCustomHeaders(w, headers)
+ for k, v := range tt.wantHeaders {
+ require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k))
+ }
+ })
+ }
+}