From 6dcd9b539eb211034c00a46818cdc32820a95d9a Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Wed, 19 Feb 2020 12:27:47 +1100 Subject: Remove request.WithHTTPSFlag and set directly in tests --- internal/auth/auth_test.go | 34 +++++++++++++---------- internal/logging/logging_test.go | 12 ++++---- internal/request/request.go | 14 ---------- internal/request/request_test.go | 60 ++++++++-------------------------------- 4 files changed, 37 insertions(+), 83 deletions(-) (limited to 'internal') diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 92e1e8c7..4a5d63fa 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -54,7 +54,8 @@ func TestTryAuthenticate(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/something/else") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} require.Equal(t, false, auth.TryAuthenticate(result, r, source.NewMockSource())) } @@ -65,7 +66,8 @@ func TestTryAuthenticateWithError(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?error=access_denied") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) require.Equal(t, 401, result.Code) @@ -78,7 +80,8 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=invalid") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["state"] = "state" @@ -115,8 +118,13 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { "http://pages.gitlab-example.com/auth", apiServer.URL) - r, _ := http.NewRequest("GET", "/auth?code=1&state=state", nil) - r = request.WithHTTPSFlag(r, https) + r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) + require.NoError(t, err) + if https { + r.URL.Scheme = request.SchemeHTTPS + } else { + r.URL.Scheme = request.SchemeHTTP + } setSessionValues(r, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", @@ -166,7 +174,8 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -203,7 +212,8 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -242,7 +252,6 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -279,7 +288,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) - r := request.WithHTTPSFlag(&http.Request{URL: reqURL}, true) + reqURL.Scheme = request.SchemeHTTPS + r := &http.Request{URL: reqURL} session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -318,8 +328,6 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) - session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" session.Save(r, result) @@ -348,7 +356,6 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { reqURL, err := url.Parse("/") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -371,7 +378,6 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Save(r, result) @@ -393,7 +399,6 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("http://pages.gitlab-example.com/test") require.NoError(t, err) r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - r = request.WithHTTPSFlag(r, false) resp := &http.Response{StatusCode: http.StatusUnauthorized, Body: ioutil.NopCloser(bytes.NewReader([]byte("{\"error\":\"invalid_token\"}")))} @@ -414,7 +419,6 @@ func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { reqURL, err := url.Parse("/something") require.NoError(t, err) r := &http.Request{URL: reqURL} - r = request.WithHTTPSFlag(r, false) resp := &http.Response{StatusCode: 200, Body: ioutil.NopCloser(bytes.NewReader([]byte("ok")))} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index 32bd4603..ec8837b6 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -13,25 +13,25 @@ import ( func TestGetExtraLogFields(t *testing.T) { tests := []struct { name string - https bool + scheme string host string domain *domain.Domain }{ { name: "https", - https: true, + scheme: request.SchemeHTTPS, host: "githost.io", domain: &domain.Domain{}, }, { name: "http", - https: false, + scheme: request.SchemeHTTP, host: "githost.io", domain: &domain.Domain{}, }, { name: "no_domain", - https: false, + scheme: request.SchemeHTTP, host: "githost.io", domain: nil, }, @@ -42,11 +42,11 @@ func TestGetExtraLogFields(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - req = request.WithHTTPSFlag(req, tt.https) + req.URL.Scheme = tt.scheme req = request.WithHostAndDomain(req, tt.host, tt.domain) got := getExtraLogFields(req) - require.Equal(t, got["pages_https"], tt.https) + require.Equal(t, got["pages_https"], tt.scheme == request.SchemeHTTPS) require.Equal(t, got["pages_host"], tt.host) require.Equal(t, got["pages_project_id"], uint64(0x0)) }) diff --git a/internal/request/request.go b/internal/request/request.go index 4e5c553d..cbda16e5 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -20,20 +20,6 @@ const ( SchemeHTTPS = "https" ) -// WithHTTPSFlag saves https flag in request's context -func WithHTTPSFlag(r *http.Request, https bool) *http.Request { - // scheme should already be set but leaving this for testing scenarios that set this value explicitly - if r.URL.Scheme == "" { - if https { - r.URL.Scheme = SchemeHTTPS - } else { - r.URL.Scheme = SchemeHTTP - } - } - - return r -} - // IsHTTPS checks whether the request originated from HTTP or HTTPS. // It checks the value from r.URL.Scheme func IsHTTPS(r *http.Request) bool { diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 86601554..a9ffb223 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -5,61 +5,25 @@ import ( "net/http/httptest" "testing" - "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) -func TestWithHTTPSFlag(t *testing.T) { - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - - httpsRequest := WithHTTPSFlag(r, true) - httpsRequest.URL.Scheme = SchemeHTTPS - require.True(t, IsHTTPS(httpsRequest)) - - httpRequest := WithHTTPSFlag(r, false) - httpsRequest.URL.Scheme = SchemeHTTP - require.False(t, IsHTTPS(httpRequest)) - -} - func TestIsHTTPS(t *testing.T) { - hook := test.NewGlobal() - - tests := []struct { - name string - flag bool - scheme string - }{ - { - name: "https", - flag: true, - scheme: "https", - }, - { - name: "http", - flag: false, - scheme: "http", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - hook.Reset() - - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - r.URL.Scheme = tt.scheme - - httpsRequest := WithHTTPSFlag(r, tt.flag) - - got := IsHTTPS(httpsRequest) - require.Equal(t, tt.flag, got) - }) - } + t.Run("when scheme is http", func(t *testing.T) { + httpRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + httpRequest.URL.Scheme = SchemeHTTP + require.False(t, IsHTTPS(httpRequest)) + }) + t.Run("when scheme is https", func(t *testing.T) { + httpsRequest, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + httpsRequest.URL.Scheme = SchemeHTTPS + require.True(t, IsHTTPS(httpsRequest)) + }) } func TestPanics(t *testing.T) { -- cgit v1.2.3