diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-01-19 02:54:41 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-01-28 05:57:44 +0300 |
commit | 8627963a4cf1874e2550889cd687bc30ecdab0c6 (patch) | |
tree | cf92b87bc9ef68aa2b3daca0836677914719ac52 | |
parent | bc757b304ff6c958dfd771f87959b3dad8418c92 (diff) |
Define separator for MultiStringFlag
Allows initializing each MultiStringFlag using its own separator and
defaults to `,` when not specified.
This change makes the `-header` flag use a `;;` separator so that it can
be defined inside a config file.
Fixes https://gitlab.com/gitlab-org/gitlab-pages/-/issues/531.
-rw-r--r-- | internal/middleware/headers.go | 8 | ||||
-rw-r--r-- | internal/middleware/headers_test.go | 8 | ||||
-rw-r--r-- | main.go | 24 | ||||
-rw-r--r-- | multi_string_flag.go | 24 | ||||
-rw-r--r-- | multi_string_flag_test.go | 18 |
5 files changed, 59 insertions, 23 deletions
diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go index 77b008f3..246b1aa9 100644 --- a/internal/middleware/headers.go +++ b/internal/middleware/headers.go @@ -21,11 +21,15 @@ func AddCustomHeaders(w http.ResponseWriter, headers http.Header) { func ParseHeaderString(customHeaders []string) (http.Header, error) { headers := http.Header{} for _, keyValueString := range customHeaders { - keyValue := strings.SplitN(keyValueString, ":", 2) + keyValue := strings.SplitN(keyValueString, ":", -1) if len(keyValue) != 2 { return nil, errInvalidHeaderParameter } - headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1])) + + key := strings.TrimSpace(keyValue[0]) + value := strings.TrimSpace(keyValue[1]) + + headers[key] = append(headers[key], value) } return headers, nil } diff --git a/internal/middleware/headers_test.go b/internal/middleware/headers_test.go index 17d31b50..6c6078ee 100644 --- a/internal/middleware/headers_test.go +++ b/internal/middleware/headers_test.go @@ -62,6 +62,11 @@ func TestParseHeaderString(t *testing.T) { headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"}, valid: false, }, + { + name: "Multiple headers in single string", + headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"}, + valid: false, + }, } for _, tt := range tests { @@ -115,7 +120,8 @@ func TestAddCustomHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - headers, _ := ParseHeaderString(tt.headerStrings) + headers, err := ParseHeaderString(tt.headerStrings) + require.NoError(t, err) w := httptest.NewRecorder() AddCustomHeaders(w, headers) for k, v := range tt.wantHeaders { @@ -33,7 +33,7 @@ func init() { flag.Var(&listenHTTP, "listen-http", "The address(es) to listen on for HTTP requests") flag.Var(&listenHTTPS, "listen-https", "The address(es) to listen on for HTTPS requests") flag.Var(&listenProxy, "listen-proxy", "The address(es) to listen on for proxy requests") - flag.Var(&ListenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)") + flag.Var(&listenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)") flag.Var(&header, "header", "The additional http header(s) that should be send to the client") } @@ -86,12 +86,12 @@ var ( disableCrossOriginRequests = flag.Bool("disable-cross-origin-requests", false, "Disable cross-origin requests") // See init() - listenHTTP MultiStringFlag - listenHTTPS MultiStringFlag - listenProxy MultiStringFlag - ListenHTTPSProxyv2 MultiStringFlag + listenHTTP = MultiStringFlag{separator: ","} + listenHTTPS = MultiStringFlag{separator: ","} + listenProxy = MultiStringFlag{separator: ","} + listenHTTPSProxyv2 = MultiStringFlag{separator: ","} - header MultiStringFlag + header = MultiStringFlag{separator: ";;"} ) func gitlabServerFromFlags() string { @@ -172,7 +172,7 @@ func configFromFlags() appConfig { // tlsMinVersion and tlsMaxVersion are validated in appMain config.TLSMinVersion = tlsconfig.AllTLSVersions[*tlsMinVersion] config.TLSMaxVersion = tlsconfig.AllTLSVersions[*tlsMaxVersion] - config.CustomHeaders = header + config.CustomHeaders = header.Split() for _, file := range []struct { contents *[]byte @@ -274,10 +274,10 @@ func loadConfig() appConfig { "disable-cross-origin-requests": *disableCrossOriginRequests, "domain": config.Domain, "insecure-ciphers": config.InsecureCiphers, - "listen-http": strings.Join(listenHTTP, ","), - "listen-https": strings.Join(listenHTTPS, ","), - "listen-proxy": strings.Join(listenProxy, ","), - "listen-https-proxyv2": strings.Join(ListenHTTPSProxyv2, ","), + "listen-http": listenHTTP, + "listen-https": listenHTTPS, + "listen-proxy": listenProxy, + "listen-https-proxyv2": listenHTTPSProxyv2, "log-format": *logFormat, "metrics-address": *metricsAddress, "pages-domain": *pagesDomain, @@ -399,7 +399,7 @@ func createAppListeners(config *appConfig) []io.Closer { config.ListenProxy = append(config.ListenProxy, f.Fd()) } - for _, addr := range ListenHTTPSProxyv2.Split() { + for _, addr := range listenHTTPSProxyv2.Split() { l, f := createSocket(addr) closers = append(closers, l, f) diff --git a/multi_string_flag.go b/multi_string_flag.go index 699529a0..ab22b14d 100644 --- a/multi_string_flag.go +++ b/multi_string_flag.go @@ -7,15 +7,24 @@ import ( var errMultiStringSetEmptyValue = errors.New("value cannot be empty") +const defaultSeparator = "," + // MultiStringFlag implements the flag.Value interface and allows a string flag // to be specified multiple times on the command line. // // e.g.: -listen-http 127.0.0.1:80 -listen-http [::1]:80 -type MultiStringFlag []string +type MultiStringFlag struct { + value []string + separator string +} // String returns the list of parameters joined with a commas (",") func (s *MultiStringFlag) String() string { - return strings.Join(*s, ",") + if s.separator == "" { + s.separator = defaultSeparator + } + + return strings.Join(s.value, s.separator) } // Set appends the value to the list of parameters @@ -23,14 +32,19 @@ func (s *MultiStringFlag) Set(value string) error { if value == "" { return errMultiStringSetEmptyValue } - *s = append(*s, value) + + s.value = append(s.value, value) return nil } // Split each flag func (s *MultiStringFlag) Split() (result []string) { - for _, str := range *s { - result = append(result, strings.Split(str, ",")...) + if s.separator == "" { + s.separator = defaultSeparator + } + + for _, str := range s.value { + result = append(result, strings.Split(str, s.separator)...) } return diff --git a/multi_string_flag_test.go b/multi_string_flag_test.go index c09f7225..9c9c7d48 100644 --- a/multi_string_flag_test.go +++ b/multi_string_flag_test.go @@ -1,6 +1,7 @@ package main import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -15,7 +16,7 @@ func TestMultiStringFlagAppendsOnSet(t *testing.T) { require.EqualError(t, iface.Set(""), "value cannot be empty") - require.Equal(t, MultiStringFlag{"foo", "bar"}, concrete) + require.Equal(t, MultiStringFlag{value: []string{"foo", "bar"}}, concrete) } func TestMultiStringFlag_Split(t *testing.T) { @@ -31,19 +32,30 @@ func TestMultiStringFlag_Split(t *testing.T) { }, { name: "one_value", - s: &MultiStringFlag{"value1"}, // -flag "value1" + s: &MultiStringFlag{value: []string{"value1"}}, // -flag "value1" wantResult: []string{"value1"}, }, { name: "multiple_values", - s: &MultiStringFlag{"value1", "", "value3"}, // -flag "value1,,value3" + s: &MultiStringFlag{value: []string{"value1", "", "value3"}}, // -flag "value1,,value3" wantResult: []string{"value1", "", "value3"}, }, + { + name: "multiple_values_in_one_string", + s: &MultiStringFlag{value: []string{"value1,value2"}}, // -flag "value1,value2" + wantResult: []string{"value1", "value2"}, + }, + { + name: "different_separator", + s: &MultiStringFlag{value: []string{"value1", "value2"}, separator: ";"}, // -flag "value1;value2" + wantResult: []string{"value1", "value2"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotResult := tt.s.Split() require.ElementsMatch(t, tt.wantResult, gotResult) + require.Equal(t, strings.Join(gotResult, tt.s.separator), strings.Join(tt.wantResult, tt.s.separator)) }) } } |