diff options
author | Alessio Caiazza <acaiazza@gitlab.com> | 2020-02-20 18:28:23 +0300 |
---|---|---|
committer | Alessio Caiazza <acaiazza@gitlab.com> | 2020-02-20 18:28:23 +0300 |
commit | 6f4fcff22c468de4102d908455d5c01aa27b8760 (patch) | |
tree | 0b37d78a0250b6279f5bd786c2455dcd30a56146 | |
parent | 1fe3f552049912c98063d67ca6a0c2da626cb324 (diff) | |
parent | 6dcd9b539eb211034c00a46818cdc32820a95d9a (diff) |
Merge branch 'remove-ctx-https-key' into 'master'
Remove ctxHTTPSKey from the context completely
Closes #219
See merge request gitlab-org/gitlab-pages!245
-rw-r--r-- | app.go | 8 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 34 | ||||
-rw-r--r-- | internal/logging/logging_test.go | 12 | ||||
-rw-r--r-- | internal/request/request.go | 27 | ||||
-rw-r--r-- | internal/request/request_test.go | 80 |
5 files changed, 40 insertions, 121 deletions
@@ -280,10 +280,6 @@ func (a *theApp) httpInitialMiddleware(handler http.Handler) http.Handler { // proxyInitialMiddleware sets up proxy requests func (a *theApp) proxyInitialMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - forwardedProto := r.Header.Get(xForwardedProto) - https := forwardedProto == xForwardedProtoHTTPS - - r = request.WithHTTPSFlag(r, https) if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" { r.Host = forwardedHost } @@ -294,16 +290,14 @@ func (a *theApp) proxyInitialMiddleware(handler http.Handler) http.Handler { // setRequestScheme will update r.URL.Scheme if empty based on r.TLS func setRequestScheme(r *http.Request) *http.Request { - https := false if r.URL.Scheme == request.SchemeHTTPS || r.TLS != nil { // make sure is set for non-proxy requests r.URL.Scheme = request.SchemeHTTPS - https = true } else { r.URL.Scheme = request.SchemeHTTP } - return request.WithHTTPSFlag(r, https) + return r } func (a *theApp) buildHandlerPipeline() (http.Handler, error) { diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 92e1e8c7..4a5d63fa 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -54,7 +54,8 @@ func TestTryAuthenticate(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} require.Equal(t, false, auth.TryAuthenticate(result, r, source.NewMockSource())) } @@ -65,7 +66,8 @@ func TestTryAuthenticateWithError(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) require.Equal(t, 401, result.Code) @@ -78,7 +80,8 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["state"] = "state" @@ -115,8 +118,13 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { "http://pages.gitlab-example.com/auth", apiServer.URL) - r, _ := http.NewRequest("GET", "/auth?code=1&state=state", nil) - r = request.WithHTTPSFlag(r, https) + r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + require.NoError(t, err) + if https { + r.URL.Scheme = request.SchemeHTTPS + } else { + r.URL.Scheme = request.SchemeHTTP + } setSessionValues(r, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", @@ -166,7 +174,8 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -203,7 +212,8 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -242,7 +252,6 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -279,7 +288,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -318,8 +328,6 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) - session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" session.Save(r, result) @@ -348,7 +356,6 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -371,7 +378,6 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Save(r, result) @@ -393,7 +399,6 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - r = request.WithHTTPSFlag(r, false) resp := &http.Response{StatusCode: http.StatusUnauthorized, Body: ioutil.NopCloser(bytes.NewReader([]byte("{\"error\":\"invalid_token\"}")))} @@ -414,7 +419,6 @@ func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { reqURL, err := url.Parse("/something") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) resp := &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewReader([]byte("ok")))} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index 32bd4603..ec8837b6 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -13,25 +13,25 @@ import ( func TestGetExtraLogFields(t *testing.T) { tests := []struct { name string - https bool + scheme string host string domain *domain.Domain }{ { name: "https", - https: true, + scheme: request.SchemeHTTPS, host: "githost.io", domain: &domain.Domain{}, }, { name: "http", - https: false, + scheme: request.SchemeHTTP, host: "githost.io", domain: &domain.Domain{}, }, { name: "no_domain", - https: false, + scheme: request.SchemeHTTP, host: "githost.io", domain: nil, }, @@ -42,11 +42,11 @@ func TestGetExtraLogFields(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - req = request.WithHTTPSFlag(req, tt.https) + req.URL.Scheme = tt.scheme req = request.WithHostAndDomain(req, tt.host, tt.domain) got := getExtraLogFields(req) - require.Equal(t, got["pages_https"], tt.https) + require.Equal(t, got["pages_https"], tt.scheme == request.SchemeHTTPS) require.Equal(t, got["pages_host"], tt.host) require.Equal(t, got["pages_project_id"], uint64(0x0)) }) diff --git a/internal/request/request.go b/internal/request/request.go index ba56f45b..cbda16e5 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -5,15 +5,12 @@ import ( "net" "net/http" - log "github.com/sirupsen/logrus" - "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) type ctxKey string const ( - ctxHTTPSKey ctxKey = "https" ctxHostKey ctxKey = "host" ctxDomainKey ctxKey = "domain" @@ -23,30 +20,10 @@ const ( SchemeHTTPS = "https" ) -// WithHTTPSFlag saves https flag in request's context -func WithHTTPSFlag(r *http.Request, https bool) *http.Request { - ctx := context.WithValue(r.Context(), ctxHTTPSKey, https) - - return r.WithContext(ctx) -} - // IsHTTPS checks whether the request originated from HTTP or HTTPS. -// It reads the ctxHTTPSKey from the context and returns its value -// It also checks that r.URL.Scheme matches the value in ctxHTTPSKey for HTTPS requests -// TODO: remove the ctxHTTPSKey from the context https://gitlab.com/gitlab-org/gitlab-pages/issues/219 +// It checks the value from r.URL.Scheme func IsHTTPS(r *http.Request) bool { - https := r.Context().Value(ctxHTTPSKey).(bool) - - if https != (r.URL.Scheme == SchemeHTTPS) { - log.WithFields(log.Fields{ - "ctxHTTPSKey": https, - "scheme": r.URL.Scheme, - }).Warn("request: r.URL.Scheme does not match value in ctxHTTPSKey") - } - - // Returning the value of ctxHTTPSKey for now, can just switch to r.URL.Scheme == SchemeHTTPS later - // and later can remove IsHTTPS method altogether - return https + return r.URL.Scheme == SchemeHTTPS } // WithHostAndDomain saves host name and domain in the request's context diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 3c0f970c..a9ffb223 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -5,77 +5,25 @@ import ( "net/http/httptest" "testing" - "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" - "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" ) -func TestWithHTTPSFlag(t *testing.T) { - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - - httpsRequest := WithHTTPSFlag(r, true) - httpsRequest.URL.Scheme = SchemeHTTPS - require.True(t, IsHTTPS(httpsRequest)) - - httpRequest := WithHTTPSFlag(r, false) - httpsRequest.URL.Scheme = SchemeHTTP - require.False(t, IsHTTPS(httpRequest)) - -} - func TestIsHTTPS(t *testing.T) { - hook := test.NewGlobal() - - tests := []struct { - name string - flag bool - scheme string - wantLogEntries string - }{ - { - name: "https", - flag: true, - scheme: "https", - }, - { - name: "http", - flag: false, - scheme: "http", - }, - { - name: "scheme true but flag is false", - flag: false, - scheme: "https", - wantLogEntries: "request: r.URL.Scheme does not match value in ctxHTTPSKey", - }, - { - name: "scheme false but flag is true", - flag: true, - scheme: "http", - wantLogEntries: "request: r.URL.Scheme does not match value in ctxHTTPSKey", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - hook.Reset() - - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - r.URL.Scheme = tt.scheme - - httpsRequest := WithHTTPSFlag(r, tt.flag) - - got := IsHTTPS(httpsRequest) - require.Equal(t, tt.flag, got) - - testhelpers.AssertLogContains(t, tt.wantLogEntries, hook.AllEntries()) - }) - } + t.Run("when scheme is http", func(t *testing.T) { + httpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + httpRequest.URL.Scheme = SchemeHTTP + require.False(t, IsHTTPS(httpRequest)) + }) + t.Run("when scheme is https", func(t *testing.T) { + httpsRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + httpsRequest.URL.Scheme = SchemeHTTPS + require.True(t, IsHTTPS(httpsRequest)) + }) } func TestPanics(t *testing.T) { @@ -83,10 +31,6 @@ func TestPanics(t *testing.T) { require.NoError(t, err) require.Panics(t, func() { - IsHTTPS(r) - }) - - require.Panics(t, func() { GetHost(r) }) |