diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/auth/auth.go | 31 | ||||
-rw-r--r-- | internal/auth/auth_test.go | 33 | ||||
-rw-r--r-- | internal/domain/domain.go | 23 | ||||
-rw-r--r-- | internal/domain/domain_test.go | 12 | ||||
-rw-r--r-- | internal/source/disk/domain_test.go | 4 |
5 files changed, 71 insertions, 32 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 768e7f75..eaf3c25d 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -71,6 +71,10 @@ type errorResponse struct { Error string `json:"error"` ErrorDescription string `json:"error_description"` } +type domain interface { + GetProjectID(r *http.Request) uint64 + ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) +} func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) { session, err := a.store.Get(r, "gitlab-pages") @@ -436,12 +440,13 @@ func (a *Auth) IsAuthSupported() bool { return a != nil } -func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) { +func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool { session := a.checkSessionIsValid(w, r) if session == nil { - return true, false + return true } + projectID := domain.GetProjectID(r) // Access token exists, authorize request var url string if projectID > 0 { @@ -456,14 +461,14 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje errortracking.Capture(err, errortracking.WithRequest(req)) httperrors.Serve500(w) - return true, false + return true } req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string)) resp, err := a.apiClient.Do(req) if err == nil && checkResponseForInvalidToken(resp, session, w, r) { - return true, false + return true } if err != nil || resp.StatusCode != 200 { @@ -471,20 +476,22 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje logRequest(r).WithError(err).Error("Failed to retrieve info with token") } - return false, true + // call serve404 handler when auth fails + domain.ServeNotFoundAuthFailed(w, r) + return true } - return false, false + return false } // CheckAuthenticationWithoutProject checks if user is authenticated and has a valid token -func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request) (contentServed, authFailed bool) { +func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request, domain domain) bool { if a == nil { // No auth supported - return false, false + return false } - return a.checkAuthentication(w, r, 0) + return a.checkAuthentication(w, r, domain) } // GetTokenIfExists returns the token if it exists @@ -512,7 +519,7 @@ func (a *Auth) RequireAuth(w http.ResponseWriter, r *http.Request) bool { // CheckAuthentication checks if user is authenticated and has access to the project // will return contentServed = false when authFailed = true -func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) { +func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool { logRequest(r).Debug("Authenticate request") if a == nil { @@ -520,10 +527,10 @@ func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, proje errortracking.Capture(errAuthNotConfigured, errortracking.WithRequest(r)) httperrors.Serve500(w) - return true, false + return true } - return a.checkAuthentication(w, r, projectID) + return a.checkAuthentication(w, r, domain) } // CheckResponseForInvalidToken checks response for invalid token and destroys session if it was invalid diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 711dea78..ffaf5b6f 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -29,6 +29,18 @@ func defaultCookieStore() sessions.Store { return createCookieStore("something-very-secret") } +type domainMock struct { + projectID uint64 +} + +func (dm *domainMock) GetProjectID(r *http.Request) uint64 { + return dm.projectID +} + +func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) +} + // Gorilla's sessions use request context to save session // Which makes session sharable between test code and actually manipulating session // Which leads to negative side effects: we can't test encryption, and cookie params @@ -180,10 +192,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" session.Save(r, result) - - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.False(t, contentServed) - require.False(t, authFailed) + // content wasn't served so the default response from CheckAuthentication should be 200 require.Equal(t, 200, result.Code) } @@ -222,11 +233,10 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) - require.False(t, contentServed) - require.True(t, authFailed) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + require.True(t, contentServed) // content wasn't served so the default response from CheckAuthentication should be 200 - require.Equal(t, 200, result.Code) + require.Equal(t, 404, result.Code) } func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { @@ -263,9 +273,8 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthentication(result, r, 1000) + contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.True(t, contentServed) - require.False(t, authFailed) require.Equal(t, 302, result.Code) } @@ -303,9 +312,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r) + contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.False(t, contentServed) - require.False(t, authFailed) require.Equal(t, 200, result.Code) } @@ -342,9 +350,8 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { session.Values["access_token"] = "abc" session.Save(r, result) - contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r) + contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0}) require.True(t, contentServed) - require.False(t, authFailed) require.Equal(t, 302, result.Code) } diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 2a301121..9fa843a3 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -169,14 +169,14 @@ func (d *Domain) ServeNotFoundHTTP(w http.ResponseWriter, r *http.Request) { request.ServeNotFoundHTTP(w, r) } -// ServeNamespaceNotFound will try to find a parent namespace domain for a request +// serveNamespaceNotFound will try to find a parent namespace domain for a request // that failed authentication so that we serve the custom namespace error page for // public namespace domains -func (d *Domain) ServeNamespaceNotFound(w http.ResponseWriter, r *http.Request) { - // override the path nd try to resolve the domain name +func (d *Domain) serveNamespaceNotFound(w http.ResponseWriter, r *http.Request) { + // override the path and try to resolve the domain name r.URL.Path = "/" namespaceDomain, err := d.Resolver.Resolve(r) - if err != nil { + if err != nil || namespaceDomain.LookupPath == nil { httperrors.Serve404(w) return } @@ -189,3 +189,18 @@ func (d *Domain) ServeNamespaceNotFound(w http.ResponseWriter, r *http.Request) httperrors.Serve404(w) } + +// ServeNotFoundAuthFailed handler to be called when auth failed so the correct custom +// 404 page is served. +func (d *Domain) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) { + if d.isUnconfigured() || !d.HasLookupPath(r) { + httperrors.Serve404(w) + return + } + if d.IsNamespaceProject(r) { + d.ServeNotFoundHTTP(w, r) + return + } + + d.serveNamespaceNotFound(w, r) +} diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index e152b031..9e89f0e5 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -209,6 +209,16 @@ func TestDomain_ServeNamespaceNotFound(t *testing.T) { }, expectedResponse: "The page you're looking for could not be found.", }, + { + name: "no_parent_namespace_domain", + domain: "group.404.gitlab-example.com", + path: "/unknown", + resolver: &stubbedResolver{ + project: nil, + subpath: "/", + }, + expectedResponse: "The page you're looking for could not be found.", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -218,7 +228,7 @@ func TestDomain_ServeNamespaceNotFound(t *testing.T) { } w := httptest.NewRecorder() r := httptest.NewRequest("GET", fmt.Sprintf("http://%s%s", tt.domain, tt.path), nil) - d.ServeNamespaceNotFound(w, r) + d.serveNamespaceNotFound(w, r) resp := w.Result() require.Equal(t, http.StatusNotFound, resp.StatusCode) diff --git a/internal/source/disk/domain_test.go b/internal/source/disk/domain_test.go index d114ff8a..56297fd1 100644 --- a/internal/source/disk/domain_test.go +++ b/internal/source/disk/domain_test.go @@ -308,8 +308,8 @@ func TestDomain404ServeHTTP(t *testing.T) { }, } - testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom 404 group page") - testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/", nil, "Custom 404 group page") + testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom domain.404 page") + testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/", nil, "Custom domain.404 page") } func TestPredefined404ServeHTTP(t *testing.T) { |