diff options
-rw-r--r-- | internal/auth/auth.go | 51 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 158 | ||||
-rw-r--r-- | internal/auth/session.go | 62 | ||||
-rw-r--r-- | test/acceptance/auth_test.go | 66 |
4 files changed, 192 insertions, 145 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a6e1f7e7..21398e56 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -85,41 +85,6 @@ type domain interface { ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) } -func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) { - session, err := a.store.Get(r, "gitlab-pages") - - if session != nil { - // Cookie just for this domain - session.Options.Path = "/" - session.Options.HttpOnly = true - session.Options.Secure = request.IsHTTPS(r) - session.Options.MaxAge = authSessionMaxAge - } - - return session, err -} - -func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { - // Create or get session - session, errsession := a.getSessionFromStore(r) - - if errsession != nil { - // Save cookie again - errsave := session.Save(r, w) - if errsave != nil { - logRequest(r).WithError(errsave).Error(saveSessionErrMsg) - errortracking.CaptureErrWithReqAndStackTrace(errsave, r) - httperrors.Serve500(w) - return nil, errsave - } - - http.Redirect(w, r, getRequestAddress(r), http.StatusFound) - return nil, errsession - } - - return session, nil -} - // TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth? func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { if a == nil { @@ -159,7 +124,7 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains s return false } -func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.ResponseWriter, r *http.Request) { +func (a *Auth) checkAuthenticationResponse(session *hostSession, w http.ResponseWriter, r *http.Request) { if !validateState(r, session) { // State is NOT ok logRequest(r).Warn("Authentication state did not match expected") @@ -228,7 +193,7 @@ func (a *Auth) domainAllowed(ctx context.Context, name string, domains source.So return (domain != nil && err == nil) } -func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { +func (a *Auth) handleProxyingAuth(session *hostSession, w http.ResponseWriter, r *http.Request, domains source.Source) bool { // handle auth callback e.g. https://gitlab.io/auth?domain=domain&state=state if shouldProxyAuthToGitlab(r) { domain := r.URL.Query().Get("domain") @@ -345,11 +310,11 @@ func shouldProxyAuthToGitlab(r *http.Request) bool { return r.URL.Query().Get("domain") != "" && r.URL.Query().Get("state") != "" } -func shouldProxyCallbackToCustomDomain(session *sessions.Session) bool { +func shouldProxyCallbackToCustomDomain(session *hostSession) bool { return session.Values["proxy_auth_domain"] != nil } -func validateState(r *http.Request, session *sessions.Session) bool { +func validateState(r *http.Request, session *hostSession) bool { state := r.URL.Query().Get("state") if state == "" { // No state param @@ -414,7 +379,7 @@ func (a *Auth) fetchAccessToken(ctx context.Context, code string) (tokenResponse return token, nil } -func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sessions.Session { +func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *hostSession { session, err := a.checkSession(w, r) if err != nil { return nil @@ -428,7 +393,7 @@ func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *sess return session } -func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter, r *http.Request) bool { +func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r *http.Request) bool { // If no access token redirect to OAuth login page if session.Values["access_token"] == nil { logRequest(r).Debug("No access token exists, redirecting user to OAuth2 login") @@ -463,7 +428,7 @@ func (a *Auth) getProxyAddress(r *http.Request, state string) string { return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r), state) } -func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Request) { +func destroySession(session *hostSession, w http.ResponseWriter, r *http.Request) { logRequest(r).Debug("Destroying session") // Invalidate access token and redirect back for refreshing and re-authenticating @@ -609,7 +574,7 @@ func (a *Auth) CheckResponseForInvalidToken(w http.ResponseWriter, r *http.Reque return false } -func checkResponseForInvalidToken(resp *http.Response, session *sessions.Session, w http.ResponseWriter, r *http.Request) bool { +func checkResponseForInvalidToken(resp *http.Response, session *hostSession, w http.ResponseWriter, r *http.Request) bool { if resp.StatusCode == http.StatusUnauthorized { errResp := errorResponse{} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 4b035132..9f38877a 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -8,12 +8,10 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "testing" "time" "github.com/golang/mock/gomock" - "github.com/gorilla/sessions" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/request" @@ -57,17 +55,18 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { +func setSessionValues(t *testing.T, r *http.Request, auth *Auth, values map[interface{}]interface{}) { t.Helper() - tmpRequest, err := http.NewRequest("GET", "/", nil) + tmpRequest, err := http.NewRequest("GET", "http://"+r.Host, nil) require.NoError(t, err) result := httptest.NewRecorder() - session, _ := store.Get(tmpRequest, "gitlab-pages") + session, _ := auth.getSessionFromStore(tmpRequest) session.Values = values - session.Save(tmpRequest, result) + err = session.Save(tmpRequest, result) + require.NoError(t, err) res := result.Result() testhelpers.Close(t, res.Body) @@ -113,16 +112,13 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=invalid") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=invalid", nil) require.NoError(t, err) - session.Values["state"] = "state" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "state": "state", + }) mockCtrl := gomock.NewController(t) @@ -135,19 +131,15 @@ func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") - require.NoError(t, err) - require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state&token=secret", nil) + require.Equal(t, r.URL.Query().Get("token"), "secret", "token is present before redirecting") require.NoError(t, err) - session.Values["state"] = "state" - session.Values["proxy_auth_domain"] = "https://domain.com" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "state": "state", + "proxy_auth_domain": "https://domain.com", + }) mockCtrl := gomock.NewController(t) @@ -203,15 +195,15 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { auth := createTestAuth(t, apiServer.URL, "") - domain := apiServer.URL + host := "http://example.com" if https { - domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + host = "https://example.com" } - code, err := auth.EncryptAndSignCode(domain, "1") + code, err := auth.EncryptAndSignCode(host, "1") require.NoError(t, err) - r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) + r, err := http.NewRequest("GET", host+"/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -219,9 +211,7 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - r.Host = strings.TrimPrefix(apiServer.URL, "http://") - - setSessionValues(t, r, auth.store, map[interface{}]interface{}{ + setSessionValues(t, r, auth, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) @@ -269,16 +259,10 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"}) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.False(t, contentServed) @@ -306,16 +290,12 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { w := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, w) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthentication(w, r, &domainMock{projectID: 1000, notFoundContent: "Generic 404"}) require.True(t, contentServed) @@ -329,6 +309,43 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { require.Equal(t, string(body), "Generic 404") } +func TestCheckAuthenticationWithSessionFromDifferentHost(t *testing.T) { + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v4/projects/1000/pages_access": + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + default: + t.Logf("Unexpected r.URL.RawPath: %q", r.URL.Path) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + } + })) + + apiServer.Start() + defer apiServer.Close() + + auth := createTestAuth(t, apiServer.URL, "") + + result := httptest.NewRecorder() + r, err := http.NewRequest("Get", "https://different.com/", nil) + require.NoError(t, err) + setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"}) + + r, err = http.NewRequest("Get", "https://example.com/", nil) + require.NoError(t, err) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + require.True(t, contentServed) + + // should redirect to auth + require.Equal(t, http.StatusFound, result.Code) + redirectURL, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + require.Equal(t, "pages.gitlab-example.com", redirectURL.Host) + require.Equal(t, "/auth", redirectURL.Path) + require.Equal(t, "https://example.com", redirectURL.Query().Get("domain")) +} + func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -349,16 +366,13 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - err = session.Save(r, result) - require.NoError(t, err) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -384,16 +398,13 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.False(t, contentServed) @@ -420,15 +431,13 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.True(t, contentServed) @@ -453,15 +462,13 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) token, err := auth.GetTokenIfExists(result, r) require.NoError(t, err) @@ -472,14 +479,11 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("http://pages.gitlab-example.com/test") - require.NoError(t, err) - r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{}) token, err := auth.GetTokenIfExists(result, r) require.Equal(t, "", token) diff --git a/internal/auth/session.go b/internal/auth/session.go new file mode 100644 index 00000000..d6402bf9 --- /dev/null +++ b/internal/auth/session.go @@ -0,0 +1,62 @@ +package auth + +import ( + "net/http" + + "github.com/gorilla/sessions" + + "gitlab.com/gitlab-org/gitlab-pages/internal/errortracking" + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" +) + +type hostSession struct { + *sessions.Session +} + +const sessionHostKey = "_session_host" + +func (s *hostSession) Save(r *http.Request, w http.ResponseWriter) error { + s.Session.Values[sessionHostKey] = r.Host + + return s.Session.Save(r, w) +} + +func (a *Auth) getSessionFromStore(r *http.Request) (*hostSession, error) { + session, err := a.store.Get(r, "gitlab-pages") + + if session != nil { + // Cookie just for this domain + session.Options.Path = "/" + session.Options.HttpOnly = true + session.Options.Secure = request.IsHTTPS(r) + session.Options.MaxAge = authSessionMaxAge + + if session.Values[sessionHostKey] == nil || session.Values[sessionHostKey] != r.Host { + session.Values = make(map[interface{}]interface{}) + } + } + + return &hostSession{session}, err +} + +func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*hostSession, error) { + // Create or get session + session, errsession := a.getSessionFromStore(r) + + if errsession != nil { + // Save cookie again + errsave := session.Save(r, w) + if errsave != nil { + logRequest(r).WithError(errsave).Error(saveSessionErrMsg) + errortracking.CaptureErrWithReqAndStackTrace(errsave, r) + httperrors.Serve500(w) + return nil, errsave + } + + http.Redirect(w, r, getRequestAddress(r), http.StatusFound) + return nil, errsession + } + + return session, nil +} diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index d7677622..dbc7b900 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -120,7 +120,7 @@ func TestWhenLoginCallbackWithUnencryptedCode(t *testing.T) { } // Go to auth page with correct state will cause fetching the token - authrsp, err := GetPageFromListenerWithHeaders(t, httpsListener, "projects.gitlab-example.com", "/auth?code=1&state="+ + authrsp, err := GetPageFromListenerWithHeaders(t, httpsListener, "group.auth.gitlab-example.com", "/auth?code=1&state="+ url.Query().Get("state"), header) require.NoError(t, err) @@ -153,60 +153,76 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { + // visit to custom domain rsp, err := GetRedirectPage(t, httpListener, tt.domain, tt.path) require.NoError(t, err) testhelpers.Close(t, rsp.Body) - cookie := rsp.Header.Get("Set-Cookie") + domainCookie := rsp.Header.Get("Set-Cookie") - url, err := url.Parse(rsp.Header.Get("Location")) + projectProxyURL, err := url.Parse(rsp.Header.Get("Location")) require.NoError(t, err) - state := url.Query().Get("state") - require.Equal(t, "http://"+tt.domain, url.Query().Get("domain")) + state := projectProxyURL.Query().Get("state") + require.Equal(t, "http://"+tt.domain, projectProxyURL.Query().Get("domain")) - pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) + // visit projects.gitlab-example.com?state=something + projectsProxyRsp, err := GetRedirectPage(t, httpListener, + projectProxyURL.Host, projectProxyURL.Path+"?"+projectProxyURL.RawQuery) require.NoError(t, err) - testhelpers.Close(t, pagesrsp.Body) + testhelpers.Close(t, projectsProxyRsp.Body) - pagescookie := pagesrsp.Header.Get("Set-Cookie") + projectsCookie := projectsProxyRsp.Header.Get("Set-Cookie") - // Go to auth page with correct state will cause fetching the token - authrsp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code=1&state="+ - state, pagescookie) + // visit projects.gitlab-example.com?state=something&code=1 + authRsp, err := GetRedirectPageWithCookie(t, httpListener, projectProxyURL.Host, "/auth?code=1&state="+ + state, projectsCookie) require.NoError(t, err) - testhelpers.Close(t, authrsp.Body) + testhelpers.Close(t, authRsp.Body) - url, err = url.Parse(authrsp.Header.Get("Location")) + backDomainURL, err := projectProxyURL.Parse(authRsp.Header.Get("Location")) require.NoError(t, err) // Will redirect to custom domain - require.Equal(t, tt.domain, url.Host) - code := url.Query().Get("code") + require.Equal(t, tt.domain, backDomainURL.Host) + code := backDomainURL.Query().Get("code") require.NotEqual(t, "1", code) - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ - state, cookie) + // visit domain.com/auth?code&state will set the cookie and redirect back to original page + selfRedirectRsp, err := GetRedirectPageWithCookie(t, httpListener, tt.domain, "/auth?code="+code+"&state="+ + state, domainCookie) require.NoError(t, err) - testhelpers.Close(t, authrsp.Body) + testhelpers.Close(t, selfRedirectRsp.Body) // Will redirect to the page - cookie = authrsp.Header.Get("Set-Cookie") - require.Equal(t, http.StatusFound, authrsp.StatusCode) + domainCookie = selfRedirectRsp.Header.Get("Set-Cookie") + require.Equal(t, http.StatusFound, selfRedirectRsp.StatusCode) - url, err = url.Parse(authrsp.Header.Get("Location")) + selfRedirectURL, err := projectProxyURL.Parse(selfRedirectRsp.Header.Get("Location")) require.NoError(t, err) // Will redirect to custom domain - require.Equal(t, "http://"+tt.domain+"/"+tt.path, url.String()) + require.Equal(t, "http://"+tt.domain+"/"+tt.path, selfRedirectURL.String()) // Fetch page in custom domain - authrsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, cookie) + authRsp, err = GetRedirectPageWithCookie(t, httpListener, tt.domain, tt.path, domainCookie) require.NoError(t, err) - testhelpers.Close(t, authrsp.Body) - require.Equal(t, http.StatusOK, authrsp.StatusCode) + testhelpers.Close(t, authRsp.Body) + require.Equal(t, http.StatusOK, authRsp.StatusCode) + + // Try to fetch page from another domain + // it should restart the auth process ignoring already existing cookie + secondAuthRsp, err := GetRedirectPageWithCookie(t, httpListener, "group.auth.gitlab-example.com", "/private.project/", domainCookie) + require.NoError(t, err) + testhelpers.Close(t, authRsp.Body) + + secondAuthURL, err := url.Parse(secondAuthRsp.Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, "projects.gitlab-example.com", secondAuthURL.Host) + require.Equal(t, "/auth", secondAuthURL.Path) + require.Equal(t, "http://group.auth.gitlab-example.com", secondAuthURL.Query().Get("domain")) }) } } |