Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVladimir Shushlin <vshushlin@gitlab.com>2021-01-28 14:21:55 +0300
committerVladimir Shushlin <vshushlin@gitlab.com>2021-01-28 14:21:55 +0300
commit2cb80834597f8e0d818bb28b60e3338e0a3e6acb (patch)
treeb31cee15e614dcbaf968fbed9d9b1c170dfb63f7
parentbc757b304ff6c958dfd771f87959b3dad8418c92 (diff)
parent37d15f3a1d1114db84da12920138b9ee89e7596a (diff)
Merge branch '531-fix-header-via-config-file' into 'master'
Define separator for MultiStringFlag Closes #531 See merge request gitlab-org/gitlab-pages!417
-rw-r--r--internal/middleware/headers.go6
-rw-r--r--internal/middleware/headers_test.go35
-rw-r--r--main.go24
-rw-r--r--multi_string_flag.go24
-rw-r--r--multi_string_flag_test.go18
-rw-r--r--test/acceptance/helpers_test.go19
-rw-r--r--test/acceptance/serving_test.go74
7 files changed, 160 insertions, 40 deletions
diff --git a/internal/middleware/headers.go b/internal/middleware/headers.go
index 77b008f3..837dbe3b 100644
--- a/internal/middleware/headers.go
+++ b/internal/middleware/headers.go
@@ -25,7 +25,11 @@ func ParseHeaderString(customHeaders []string) (http.Header, error) {
if len(keyValue) != 2 {
return nil, errInvalidHeaderParameter
}
- headers[strings.TrimSpace(keyValue[0])] = append(headers[strings.TrimSpace(keyValue[0])], strings.TrimSpace(keyValue[1]))
+
+ key := strings.TrimSpace(keyValue[0])
+ value := strings.TrimSpace(keyValue[1])
+
+ headers[key] = append(headers[key], value)
}
return headers, nil
}
diff --git a/internal/middleware/headers_test.go b/internal/middleware/headers_test.go
index 17d31b50..1f3d98c6 100644
--- a/internal/middleware/headers_test.go
+++ b/internal/middleware/headers_test.go
@@ -12,35 +12,43 @@ func TestParseHeaderString(t *testing.T) {
name string
headerStrings []string
valid bool
- }{{
- name: "Normal case",
- headerStrings: []string{"X-Test-String: Test"},
- valid: true,
- },
+ expectedLen int
+ }{
+ {
+ name: "Normal case",
+ headerStrings: []string{"X-Test-String: Test"},
+ valid: true,
+ expectedLen: 1,
+ },
{
name: "Whitespace trim case",
headerStrings: []string{" X-Test-String : Test "},
valid: true,
+ expectedLen: 1,
},
{
name: "Whitespace in key, value case",
headerStrings: []string{"My amazing header: This is a test"},
valid: true,
+ expectedLen: 1,
},
{
name: "Non-tracking header case",
headerStrings: []string{"Tk: N"},
valid: true,
+ expectedLen: 1,
},
{
name: "Content security header case",
headerStrings: []string{"content-security-policy: default-src 'self'"},
valid: true,
+ expectedLen: 1,
},
{
name: "Multiple header strings",
headerStrings: []string{"content-security-policy: default-src 'self'", "X-Test-String: Test", "My amazing header : Amazing"},
valid: true,
+ expectedLen: 3,
},
{
name: "Multiple invalid cases",
@@ -62,16 +70,24 @@ func TestParseHeaderString(t *testing.T) {
headerStrings: []string{"content-security-policy: default-src 'self'", "test-case"},
valid: false,
},
+ {
+ name: "Multiple headers in single string parsed as one header",
+ headerStrings: []string{"content-security-policy: default-src 'self',X-Test-String: Test,My amazing header : Amazing"},
+ valid: true,
+ expectedLen: 1,
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- _, err := ParseHeaderString(tt.headerStrings)
+ got, err := ParseHeaderString(tt.headerStrings)
if tt.valid {
require.NoError(t, err)
- } else {
- require.Error(t, err)
+ require.Len(t, got, tt.expectedLen)
+ return
}
+
+ require.Error(t, err)
})
}
}
@@ -115,7 +131,8 @@ func TestAddCustomHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- headers, _ := ParseHeaderString(tt.headerStrings)
+ headers, err := ParseHeaderString(tt.headerStrings)
+ require.NoError(t, err)
w := httptest.NewRecorder()
AddCustomHeaders(w, headers)
for k, v := range tt.wantHeaders {
diff --git a/main.go b/main.go
index ae6941ce..04afaccf 100644
--- a/main.go
+++ b/main.go
@@ -33,7 +33,7 @@ func init() {
flag.Var(&listenHTTP, "listen-http", "The address(es) to listen on for HTTP requests")
flag.Var(&listenHTTPS, "listen-https", "The address(es) to listen on for HTTPS requests")
flag.Var(&listenProxy, "listen-proxy", "The address(es) to listen on for proxy requests")
- flag.Var(&ListenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)")
+ flag.Var(&listenHTTPSProxyv2, "listen-https-proxyv2", "The address(es) to listen on for HTTPS PROXYv2 requests (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)")
flag.Var(&header, "header", "The additional http header(s) that should be send to the client")
}
@@ -86,12 +86,12 @@ var (
disableCrossOriginRequests = flag.Bool("disable-cross-origin-requests", false, "Disable cross-origin requests")
// See init()
- listenHTTP MultiStringFlag
- listenHTTPS MultiStringFlag
- listenProxy MultiStringFlag
- ListenHTTPSProxyv2 MultiStringFlag
+ listenHTTP = MultiStringFlag{separator: ","}
+ listenHTTPS = MultiStringFlag{separator: ","}
+ listenProxy = MultiStringFlag{separator: ","}
+ listenHTTPSProxyv2 = MultiStringFlag{separator: ","}
- header MultiStringFlag
+ header = MultiStringFlag{separator: ";;"}
)
func gitlabServerFromFlags() string {
@@ -172,7 +172,7 @@ func configFromFlags() appConfig {
// tlsMinVersion and tlsMaxVersion are validated in appMain
config.TLSMinVersion = tlsconfig.AllTLSVersions[*tlsMinVersion]
config.TLSMaxVersion = tlsconfig.AllTLSVersions[*tlsMaxVersion]
- config.CustomHeaders = header
+ config.CustomHeaders = header.Split()
for _, file := range []struct {
contents *[]byte
@@ -274,10 +274,10 @@ func loadConfig() appConfig {
"disable-cross-origin-requests": *disableCrossOriginRequests,
"domain": config.Domain,
"insecure-ciphers": config.InsecureCiphers,
- "listen-http": strings.Join(listenHTTP, ","),
- "listen-https": strings.Join(listenHTTPS, ","),
- "listen-proxy": strings.Join(listenProxy, ","),
- "listen-https-proxyv2": strings.Join(ListenHTTPSProxyv2, ","),
+ "listen-http": listenHTTP,
+ "listen-https": listenHTTPS,
+ "listen-proxy": listenProxy,
+ "listen-https-proxyv2": listenHTTPSProxyv2,
"log-format": *logFormat,
"metrics-address": *metricsAddress,
"pages-domain": *pagesDomain,
@@ -399,7 +399,7 @@ func createAppListeners(config *appConfig) []io.Closer {
config.ListenProxy = append(config.ListenProxy, f.Fd())
}
- for _, addr := range ListenHTTPSProxyv2.Split() {
+ for _, addr := range listenHTTPSProxyv2.Split() {
l, f := createSocket(addr)
closers = append(closers, l, f)
diff --git a/multi_string_flag.go b/multi_string_flag.go
index 699529a0..1be02ef1 100644
--- a/multi_string_flag.go
+++ b/multi_string_flag.go
@@ -7,15 +7,20 @@ import (
var errMultiStringSetEmptyValue = errors.New("value cannot be empty")
+const defaultSeparator = ","
+
// MultiStringFlag implements the flag.Value interface and allows a string flag
// to be specified multiple times on the command line.
//
// e.g.: -listen-http 127.0.0.1:80 -listen-http [::1]:80
-type MultiStringFlag []string
+type MultiStringFlag struct {
+ value []string
+ separator string
+}
// String returns the list of parameters joined with a commas (",")
func (s *MultiStringFlag) String() string {
- return strings.Join(*s, ",")
+ return strings.Join(s.value, s.sep())
}
// Set appends the value to the list of parameters
@@ -23,15 +28,24 @@ func (s *MultiStringFlag) Set(value string) error {
if value == "" {
return errMultiStringSetEmptyValue
}
- *s = append(*s, value)
+
+ s.value = append(s.value, value)
return nil
}
// Split each flag
func (s *MultiStringFlag) Split() (result []string) {
- for _, str := range *s {
- result = append(result, strings.Split(str, ",")...)
+ for _, str := range s.value {
+ result = append(result, strings.Split(str, s.sep())...)
}
return
}
+
+func (s *MultiStringFlag) sep() string {
+ if s.separator == "" {
+ return defaultSeparator
+ }
+
+ return s.separator
+}
diff --git a/multi_string_flag_test.go b/multi_string_flag_test.go
index c09f7225..9c9c7d48 100644
--- a/multi_string_flag_test.go
+++ b/multi_string_flag_test.go
@@ -1,6 +1,7 @@
package main
import (
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -15,7 +16,7 @@ func TestMultiStringFlagAppendsOnSet(t *testing.T) {
require.EqualError(t, iface.Set(""), "value cannot be empty")
- require.Equal(t, MultiStringFlag{"foo", "bar"}, concrete)
+ require.Equal(t, MultiStringFlag{value: []string{"foo", "bar"}}, concrete)
}
func TestMultiStringFlag_Split(t *testing.T) {
@@ -31,19 +32,30 @@ func TestMultiStringFlag_Split(t *testing.T) {
},
{
name: "one_value",
- s: &MultiStringFlag{"value1"}, // -flag "value1"
+ s: &MultiStringFlag{value: []string{"value1"}}, // -flag "value1"
wantResult: []string{"value1"},
},
{
name: "multiple_values",
- s: &MultiStringFlag{"value1", "", "value3"}, // -flag "value1,,value3"
+ s: &MultiStringFlag{value: []string{"value1", "", "value3"}}, // -flag "value1,,value3"
wantResult: []string{"value1", "", "value3"},
},
+ {
+ name: "multiple_values_in_one_string",
+ s: &MultiStringFlag{value: []string{"value1,value2"}}, // -flag "value1,value2"
+ wantResult: []string{"value1", "value2"},
+ },
+ {
+ name: "different_separator",
+ s: &MultiStringFlag{value: []string{"value1", "value2"}, separator: ";"}, // -flag "value1;value2"
+ wantResult: []string{"value1", "value2"},
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotResult := tt.s.Split()
require.ElementsMatch(t, tt.wantResult, gotResult)
+ require.Equal(t, strings.Join(gotResult, tt.s.separator), strings.Join(tt.wantResult, tt.s.separator))
})
}
}
diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go
index aa209240..8d8ca8d6 100644
--- a/test/acceptance/helpers_test.go
+++ b/test/acceptance/helpers_test.go
@@ -594,30 +594,29 @@ func NewGitlabDomainsSourceStub(t *testing.T, opts *stubOpts) *httptest.Server {
return httptest.NewServer(mux)
}
-func newConfigFile(configs ...string) (string, error) {
+func newConfigFile(t *testing.T, configs ...string) string {
+ t.Helper()
+
f, err := ioutil.TempFile(os.TempDir(), "gitlab-pages-config")
- if err != nil {
- return "", err
- }
+ require.NoError(t, err)
defer f.Close()
for _, config := range configs {
_, err := fmt.Fprintf(f, "%s\n", config)
- if err != nil {
- return "", err
- }
+ require.NoError(t, err)
}
- return f.Name(), nil
+ return f.Name()
}
func defaultConfigFileWith(t *testing.T, configs ...string) (string, func()) {
+ t.Helper()
+
configs = append(configs, "auth-client-id=clientID",
"auth-client-secret=clientSecret",
"auth-secret=authSecret")
- name, err := newConfigFile(configs...)
- require.NoError(t, err)
+ name := newConfigFile(t, configs...)
cleanup := func() {
err := os.Remove(name)
diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go
index 31588046..6afa9560 100644
--- a/test/acceptance/serving_test.go
+++ b/test/acceptance/serving_test.go
@@ -4,8 +4,10 @@ import (
"fmt"
"io/ioutil"
"net/http"
+ "net/textproto"
"os"
"path"
+ "strings"
"testing"
"time"
@@ -645,3 +647,75 @@ func TestQueryStringPersistedInSlashRewrite(t *testing.T) {
defer rsp.Body.Close()
require.Equal(t, http.StatusOK, rsp.StatusCode)
}
+
+func TestServerRepliesWithHeaders(t *testing.T) {
+ skipUnlessEnabled(t)
+
+ tests := map[string]struct {
+ flags []string
+ expectedHeaders map[string][]string
+ }{
+ "single_header": {
+ flags: []string{"X-testing-1: y-value"},
+ expectedHeaders: http.Header{"X-testing-1": {"y-value"}},
+ },
+ "multiple_header": {
+ flags: []string{"X: 1,2", "Y: 3,4"},
+ expectedHeaders: http.Header{"X": {"1,2"}, "Y": {"3,4"}},
+ },
+ }
+
+ for name, test := range tests {
+ testFn := func(envArgs, headerArgs []string) func(*testing.T) {
+ return func(t *testing.T) {
+ teardown := RunPagesProcessWithEnvs(t, true, *pagesBinary, []ListenSpec{httpListener}, "", envArgs, headerArgs...)
+
+ defer teardown()
+
+ rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "/")
+ require.NoError(t, err)
+ defer rsp.Body.Close()
+
+ require.Equal(t, http.StatusOK, rsp.StatusCode)
+
+ for key, value := range test.expectedHeaders {
+ got := headerValues(rsp.Header, key)
+ require.Equal(t, value, got)
+ }
+ }
+ }
+
+ t.Run(name+"/from_single_flag", func(t *testing.T) {
+ args := []string{"-header", strings.Join(test.flags, ";;")}
+ testFn([]string{}, args)
+ })
+
+ t.Run(name+"/from_multiple_flags", func(t *testing.T) {
+ args := make([]string, 0, 2*len(test.flags))
+ for _, arg := range test.flags {
+ args = append(args, "-header", arg)
+ }
+
+ testFn([]string{}, args)
+ })
+
+ t.Run(name+"/from_config_file", func(t *testing.T) {
+ file := newConfigFile(t, "-header="+strings.Join(test.flags, ";;"))
+
+ testFn([]string{}, []string{"-config", file})
+ })
+
+ t.Run(name+"/from_env", func(t *testing.T) {
+ args := []string{"header", strings.Join(test.flags, ";;")}
+ testFn(args, []string{})
+ })
+ }
+}
+
+func headerValues(header http.Header, key string) []string {
+ h := textproto.MIMEHeader(header)
+
+ // NOTE: cannot use header.Values() in Go 1.13 or lower, this is the implementation
+ // from Go 1.15 https://github.com/golang/go/blob/release-branch.go1.15/src/net/textproto/header.go#L46
+ return h[textproto.CanonicalMIMEHeaderKey(key)]
+}