diff options
Diffstat (limited to 'internal/auth/auth_test.go')
-rw-r--r-- | internal/auth/auth_test.go | 158 |
1 files changed, 81 insertions, 77 deletions
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 4b035132..9f38877a 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -8,12 +8,10 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "testing" "time" "github.com/golang/mock/gomock" - "github.com/gorilla/sessions" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/request" @@ -57,17 +55,18 @@ func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Req // Which leads to negative side effects: we can't test encryption, and cookie params // like max-age and secure are not being properly set // To avoid that we use fake request, and set only session cookie without copying context -func setSessionValues(t *testing.T, r *http.Request, store sessions.Store, values map[interface{}]interface{}) { +func setSessionValues(t *testing.T, r *http.Request, auth *Auth, values map[interface{}]interface{}) { t.Helper() - tmpRequest, err := http.NewRequest("GET", "/", nil) + tmpRequest, err := http.NewRequest("GET", "http://"+r.Host, nil) require.NoError(t, err) result := httptest.NewRecorder() - session, _ := store.Get(tmpRequest, "gitlab-pages") + session, _ := auth.getSessionFromStore(tmpRequest) session.Values = values - session.Save(tmpRequest, result) + err = session.Save(tmpRequest, result) + require.NoError(t, err) res := result.Result() testhelpers.Close(t, res.Body) @@ -113,16 +112,13 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=invalid") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=invalid", nil) require.NoError(t, err) - session.Values["state"] = "state" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "state": "state", + }) mockCtrl := gomock.NewController(t) @@ -135,19 +131,15 @@ func TestTryAuthenticateRemoveTokenFromRedirect(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state&token=secret") - require.NoError(t, err) - require.Equal(t, reqURL.Query().Get("token"), "secret", "token is present before redirecting") - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state&token=secret", nil) + require.Equal(t, r.URL.Query().Get("token"), "secret", "token is present before redirecting") require.NoError(t, err) - session.Values["state"] = "state" - session.Values["proxy_auth_domain"] = "https://domain.com" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "state": "state", + "proxy_auth_domain": "https://domain.com", + }) mockCtrl := gomock.NewController(t) @@ -203,15 +195,15 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { auth := createTestAuth(t, apiServer.URL, "") - domain := apiServer.URL + host := "http://example.com" if https { - domain = strings.Replace(apiServer.URL, "http://", "https://", -1) + host = "https://example.com" } - code, err := auth.EncryptAndSignCode(domain, "1") + code, err := auth.EncryptAndSignCode(host, "1") require.NoError(t, err) - r, err := http.NewRequest("GET", "/auth?code="+code+"&state=state", nil) + r, err := http.NewRequest("GET", host+"/auth?code="+code+"&state=state", nil) require.NoError(t, err) if https { r.URL.Scheme = request.SchemeHTTPS @@ -219,9 +211,7 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { r.URL.Scheme = request.SchemeHTTP } - r.Host = strings.TrimPrefix(apiServer.URL, "http://") - - setSessionValues(t, r, auth.store, map[interface{}]interface{}{ + setSessionValues(t, r, auth, map[interface{}]interface{}{ "uri": "https://pages.gitlab-example.com/project/", "state": "state", }) @@ -269,16 +259,10 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"}) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.False(t, contentServed) @@ -306,16 +290,12 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { w := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, w) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthentication(w, r, &domainMock{projectID: 1000, notFoundContent: "Generic 404"}) require.True(t, contentServed) @@ -329,6 +309,43 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { require.Equal(t, string(body), "Generic 404") } +func TestCheckAuthenticationWithSessionFromDifferentHost(t *testing.T) { + apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v4/projects/1000/pages_access": + require.Equal(t, "Bearer abc", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + default: + t.Logf("Unexpected r.URL.RawPath: %q", r.URL.Path) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + } + })) + + apiServer.Start() + defer apiServer.Close() + + auth := createTestAuth(t, apiServer.URL, "") + + result := httptest.NewRecorder() + r, err := http.NewRequest("Get", "https://different.com/", nil) + require.NoError(t, err) + setSessionValues(t, r, auth, map[interface{}]interface{}{"access_token": "abc"}) + + r, err = http.NewRequest("Get", "https://example.com/", nil) + require.NoError(t, err) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + require.True(t, contentServed) + + // should redirect to auth + require.Equal(t, http.StatusFound, result.Code) + redirectURL, err := url.Parse(result.Header().Get("Location")) + require.NoError(t, err) + require.Equal(t, "pages.gitlab-example.com", redirectURL.Host) + require.Equal(t, "/auth", redirectURL.Path) + require.Equal(t, "https://example.com", redirectURL.Query().Get("domain")) +} + func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -349,16 +366,13 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - err = session.Save(r, result) - require.NoError(t, err) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) @@ -384,16 +398,13 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - reqURL.Scheme = request.SchemeHTTPS - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/auth?code=1&state=state", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.False(t, contentServed) @@ -420,15 +431,13 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { auth := createTestAuth(t, apiServer.URL, "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/auth?code=1&state=state") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com/", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.True(t, contentServed) @@ -453,15 +462,13 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("/") - require.NoError(t, err) - r := &http.Request{URL: reqURL} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Values["access_token"] = "abc" - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{ + "access_token": "abc", + }) token, err := auth.GetTokenIfExists(result, r) require.NoError(t, err) @@ -472,14 +479,11 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { auth := createTestAuth(t, "", "") result := httptest.NewRecorder() - reqURL, err := url.Parse("http://pages.gitlab-example.com/test") - require.NoError(t, err) - r := &http.Request{URL: reqURL, Host: "pages.gitlab-example.com", RequestURI: "/test"} - session, err := auth.store.Get(r, "gitlab-pages") + r, err := http.NewRequest("Get", "https://example.com", nil) require.NoError(t, err) - session.Save(r, result) + setSessionValues(t, r, auth, map[interface{}]interface{}{}) token, err := auth.GetTokenIfExists(result, r) require.Equal(t, "", token) |