diff options
Diffstat (limited to 'internal/auth')
-rw-r--r-- | internal/auth/auth.go | 31 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 33 |
2 files changed, 39 insertions, 25 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 768e7f75..eaf3c25d 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -71,6 +71,10 @@ type errorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } +type domain interface { + GetProjectID(r *http.Request) uint64 + ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) +} func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) { session, err := a.store.Get(r, "gitlab-pages") @@ -436,12 +440,13 @@ func (a *Auth) IsAuthSupported() bool { return a != nil } -func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) { +func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool { session := a.checkSessionIsValid(w, r) if session == nil { - return true, false + return true } + projectID := domain.GetProjectID(r) // Access token exists, authorize request var url string if projectID > 0 { @@ -456,14 +461,14 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje errortracking.Capture(err, errortracking.WithRequest(req)) httperrors.Serve500(w) - return true, false + return true } req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string)) resp, err := a.apiClient.Do(req) if err == nil && checkResponseForInvalidToken(resp, session, w, r) { - return true, false + return true } if err != nil || resp.StatusCode != 200 { @@ -471,20 +476,22 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje logRequest(r).WithError(err).Error("Failed to retrieve info with token") } - return false, true + // call serve404 handler when auth fails + domain.ServeNotFoundAuthFailed(w, r) + return true } - return false, false + return false } // CheckAuthenticationWithoutProject checks if user is authenticated and has a valid token -func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request) (contentServed, authFailed bool) { +func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request, domain domain) bool { if a == nil { // No auth supported - return false, false + return false } - return a.checkAuthentication(w, r, 0) + return a.checkAuthentication(w, r, domain) } // GetTokenIfExists returns the token if it exists @@ -512,7 +519,7 @@ func (a *Auth) RequireAuth(w http.ResponseWriter, r *http.Request) bool { // CheckAuthentication checks if user is authenticated and has access to the project // will return contentServed = false when authFailed = true -func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) { +func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool { logRequest(r).Debug("Authenticate request") if a == nil { @@ -520,10 +527,10 @@ func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, proje errortracking.Capture(errAuthNotConfigured, errortracking.WithRequest(r)) httperrors.Serve500(w) - return true, false + return true } - return a.checkAuthentication(w, r, projectID) + return a.checkAuthentication(w, r, domain) } // CheckResponseForInvalidToken checks response for invalid token and destroys session if it was invalid diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 711dea78..ffaf5b6f 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -29,6 +29,18 @@ func defaultCookieStore() sessions.Store { return createCookieStore("something-very-secret") } +type domainMock struct { + projectID uint64 +} + +func (dm *domainMock) GetProjectID(r *http.Request) uint64 { + return dm.projectID +} + +func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) +} + // Gorilla's sessions use request context to save session // Which makes session sharable between test code and actually manipulating session // Which leads to negative side effects: we can't test encryption, and cookie params @@ -180,10 +192,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" session.Save(r, result) - - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.False(t, contentServed) - require.False(t, authFailed) + // content wasn't served so the default response from CheckAuthentication should be 200 require.Equal(t, 200, result.Code) } @@ -222,11 +233,10 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) - require.False(t, contentServed) - require.True(t, authFailed) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + require.True(t, contentServed) // content wasn't served so the default response from CheckAuthentication should be 200 - require.Equal(t, 200, result.Code) + require.Equal(t, 404, result.Code) } func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { @@ -263,9 +273,8 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) - require.False(t, authFailed) require.Equal(t, 302, result.Code) } @@ -303,9 +312,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r) + contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.False(t, contentServed) - require.False(t, authFailed) require.Equal(t, 200, result.Code) } @@ -342,9 +350,8 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r) + contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.True(t, contentServed) - require.False(t, authFailed) require.Equal(t, 302, result.Code) } |