diff options
author | Vladimir Shushlin <vshushlin@gitlab.com> | 2019-08-21 19:00:52 +0300 |
---|---|---|
committer | Vladimir Shushlin <v.shushlin@gmail.com> | 2019-08-26 12:19:12 +0300 |
commit | 87ff3b1653589fa69a4c934c464b68718db74892 (patch) | |
tree | 56debc2c057303af86a63ca3dc3b08b48443d1ee /internal | |
parent | 1fa5c7b079831a73b55bb874b84a0b53fd4c0f23 (diff) |
Fix https downgrade for pages behind proxy
We can't rely on r.TLS when pages are served behind proxy
So we save https flag to a context for later usage
Right now I'm trying to keep changes to a minimum since
I'm planning to backport this to older versions
That's why https flag is not refactored throughout the codebase
The alternative way would be to use gorilla's proxy headers
I'm planning to refactor to that version later
Diffstat (limited to 'internal')
-rw-r--r-- | internal/auth/auth.go | 5 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 3 | ||||
-rw-r--r-- | internal/request/request.go | 24 | ||||
-rw-r--r-- | internal/request/request_test.go | 24 |
4 files changed, 54 insertions, 2 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index d6cbdff1..920c5d12 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -22,6 +22,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" "golang.org/x/crypto/hkdf" ) @@ -285,14 +286,14 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit } func getRequestAddress(r *http.Request) string { - if r.TLS != nil { + if request.IsHTTPS(r) { return "https://" + r.Host + r.RequestURI } return "http://" + r.Host + r.RequestURI } func getRequestDomain(r *http.Request) string { - if r.TLS != nil { + if request.IsHTTPS(r) { return "https://" + r.Host } return "http://" + r.Host diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 2fbbb938..00cdbd5b 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) func createAuth(t *testing.T) *Auth { @@ -214,6 +215,7 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} + r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" @@ -289,6 +291,7 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) r := &http.Request{URL: reqURL} + r = request.WithHTTPSFlag(r, false) session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" diff --git a/internal/request/request.go b/internal/request/request.go new file mode 100644 index 00000000..dad6af3d --- /dev/null +++ b/internal/request/request.go @@ -0,0 +1,24 @@ +package request + +import ( + "context" + "net/http" +) + +type ctxKey string + +const ( + ctxHTTPSKey ctxKey = "https" +) + +// WithHTTPSFlag saves https flag in request's context +func WithHTTPSFlag(r *http.Request, https bool) *http.Request { + ctx := context.WithValue(r.Context(), ctxHTTPSKey, https) + + return r.WithContext(ctx) +} + +// IsHTTPS restores https flag from request's context +func IsHTTPS(r *http.Request) bool { + return r.Context().Value(ctxHTTPSKey).(bool) +} diff --git a/internal/request/request_test.go b/internal/request/request_test.go new file mode 100644 index 00000000..1f47ee2e --- /dev/null +++ b/internal/request/request_test.go @@ -0,0 +1,24 @@ +package request + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithHTTPSFlag(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + assert.Panics(t, func() { + IsHTTPS(r) + }) + + httpsRequest := WithHTTPSFlag(r, true) + assert.True(t, IsHTTPS(httpsRequest)) + + httpRequest := WithHTTPSFlag(r, false) + assert.False(t, IsHTTPS(httpRequest)) +} |