Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2020-06-09 09:20:31 +0300
committerJaime Martinez <jmartinez@gitlab.com>2020-07-06 02:13:51 +0300
commit2a23f2fb9bca74302dcdc40def50c748da4a5e06 (patch)
tree31698c64ca1b9b8dc370aa42d5015c63f5ca7fcb /internal
parent8e4dff76f1015bf10bdaedc295f726e80958bba1 (diff)
Move serving 404 logic to domain package
Simplify responsibilities of auth package and reduce complexity of app.go deciding which content to serve.
Diffstat (limited to 'internal')
-rw-r--r--internal/auth/auth.go31
-rw-r--r--internal/auth/auth_test.go33
-rw-r--r--internal/domain/domain.go23
-rw-r--r--internal/domain/domain_test.go12
-rw-r--r--internal/source/disk/domain_test.go4
5 files changed, 71 insertions, 32 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 768e7f75..eaf3c25d 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -71,6 +71,10 @@ type errorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
+type domain interface {
+ GetProjectID(r *http.Request) uint64
+ ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request)
+}
func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) {
session, err := a.store.Get(r, "gitlab-pages")
@@ -436,12 +440,13 @@ func (a *Auth) IsAuthSupported() bool {
return a != nil
}
-func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) {
+func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool {
session := a.checkSessionIsValid(w, r)
if session == nil {
- return true, false
+ return true
}
+ projectID := domain.GetProjectID(r)
// Access token exists, authorize request
var url string
if projectID > 0 {
@@ -456,14 +461,14 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje
errortracking.Capture(err, errortracking.WithRequest(req))
httperrors.Serve500(w)
- return true, false
+ return true
}
req.Header.Add("Authorization", "Bearer "+session.Values["access_token"].(string))
resp, err := a.apiClient.Do(req)
if err == nil && checkResponseForInvalidToken(resp, session, w, r) {
- return true, false
+ return true
}
if err != nil || resp.StatusCode != 200 {
@@ -471,20 +476,22 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, proje
logRequest(r).WithError(err).Error("Failed to retrieve info with token")
}
- return false, true
+ // call serve404 handler when auth fails
+ domain.ServeNotFoundAuthFailed(w, r)
+ return true
}
- return false, false
+ return false
}
// CheckAuthenticationWithoutProject checks if user is authenticated and has a valid token
-func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request) (contentServed, authFailed bool) {
+func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request, domain domain) bool {
if a == nil {
// No auth supported
- return false, false
+ return false
}
- return a.checkAuthentication(w, r, 0)
+ return a.checkAuthentication(w, r, domain)
}
// GetTokenIfExists returns the token if it exists
@@ -512,7 +519,7 @@ func (a *Auth) RequireAuth(w http.ResponseWriter, r *http.Request) bool {
// CheckAuthentication checks if user is authenticated and has access to the project
// will return contentServed = false when authFailed = true
-func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) (contentServed, authFailed bool) {
+func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool {
logRequest(r).Debug("Authenticate request")
if a == nil {
@@ -520,10 +527,10 @@ func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, proje
errortracking.Capture(errAuthNotConfigured, errortracking.WithRequest(r))
httperrors.Serve500(w)
- return true, false
+ return true
}
- return a.checkAuthentication(w, r, projectID)
+ return a.checkAuthentication(w, r, domain)
}
// CheckResponseForInvalidToken checks response for invalid token and destroys session if it was invalid
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 711dea78..ffaf5b6f 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -29,6 +29,18 @@ func defaultCookieStore() sessions.Store {
return createCookieStore("something-very-secret")
}
+type domainMock struct {
+ projectID uint64
+}
+
+func (dm *domainMock) GetProjectID(r *http.Request) uint64 {
+ return dm.projectID
+}
+
+func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+}
+
// Gorilla's sessions use request context to save session
// Which makes session sharable between test code and actually manipulating session
// Which leads to negative side effects: we can't test encryption, and cookie params
@@ -180,10 +192,9 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) {
session, _ := store.Get(r, "gitlab-pages")
session.Values["access_token"] = "abc"
session.Save(r, result)
-
- contentServed, authFailed := auth.CheckAuthentication(result, r, 1000)
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000})
require.False(t, contentServed)
- require.False(t, authFailed)
+
// content wasn't served so the default response from CheckAuthentication should be 200
require.Equal(t, 200, result.Code)
}
@@ -222,11 +233,10 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) {
session.Values["access_token"] = "abc"
session.Save(r, result)
- contentServed, authFailed := auth.CheckAuthentication(result, r, 1000)
- require.False(t, contentServed)
- require.True(t, authFailed)
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000})
+ require.True(t, contentServed)
// content wasn't served so the default response from CheckAuthentication should be 200
- require.Equal(t, 200, result.Code)
+ require.Equal(t, 404, result.Code)
}
func TestCheckAuthenticationWhenInvalidToken(t *testing.T) {
@@ -263,9 +273,8 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) {
session.Values["access_token"] = "abc"
session.Save(r, result)
- contentServed, authFailed := auth.CheckAuthentication(result, r, 1000)
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000})
require.True(t, contentServed)
- require.False(t, authFailed)
require.Equal(t, 302, result.Code)
}
@@ -303,9 +312,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) {
session.Values["access_token"] = "abc"
session.Save(r, result)
- contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r)
+ contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0})
require.False(t, contentServed)
- require.False(t, authFailed)
require.Equal(t, 200, result.Code)
}
@@ -342,9 +350,8 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) {
session.Values["access_token"] = "abc"
session.Save(r, result)
- contentServed, authFailed := auth.CheckAuthenticationWithoutProject(result, r)
+ contentServed := auth.CheckAuthenticationWithoutProject(result, r, &domainMock{projectID: 0})
require.True(t, contentServed)
- require.False(t, authFailed)
require.Equal(t, 302, result.Code)
}
diff --git a/internal/domain/domain.go b/internal/domain/domain.go
index 2a301121..9fa843a3 100644
--- a/internal/domain/domain.go
+++ b/internal/domain/domain.go
@@ -169,14 +169,14 @@ func (d *Domain) ServeNotFoundHTTP(w http.ResponseWriter, r *http.Request) {
request.ServeNotFoundHTTP(w, r)
}
-// ServeNamespaceNotFound will try to find a parent namespace domain for a request
+// serveNamespaceNotFound will try to find a parent namespace domain for a request
// that failed authentication so that we serve the custom namespace error page for
// public namespace domains
-func (d *Domain) ServeNamespaceNotFound(w http.ResponseWriter, r *http.Request) {
- // override the path nd try to resolve the domain name
+func (d *Domain) serveNamespaceNotFound(w http.ResponseWriter, r *http.Request) {
+ // override the path and try to resolve the domain name
r.URL.Path = "/"
namespaceDomain, err := d.Resolver.Resolve(r)
- if err != nil {
+ if err != nil || namespaceDomain.LookupPath == nil {
httperrors.Serve404(w)
return
}
@@ -189,3 +189,18 @@ func (d *Domain) ServeNamespaceNotFound(w http.ResponseWriter, r *http.Request)
httperrors.Serve404(w)
}
+
+// ServeNotFoundAuthFailed handler to be called when auth failed so the correct custom
+// 404 page is served.
+func (d *Domain) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) {
+ if d.isUnconfigured() || !d.HasLookupPath(r) {
+ httperrors.Serve404(w)
+ return
+ }
+ if d.IsNamespaceProject(r) {
+ d.ServeNotFoundHTTP(w, r)
+ return
+ }
+
+ d.serveNamespaceNotFound(w, r)
+}
diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go
index e152b031..9e89f0e5 100644
--- a/internal/domain/domain_test.go
+++ b/internal/domain/domain_test.go
@@ -209,6 +209,16 @@ func TestDomain_ServeNamespaceNotFound(t *testing.T) {
},
expectedResponse: "The page you're looking for could not be found.",
},
+ {
+ name: "no_parent_namespace_domain",
+ domain: "group.404.gitlab-example.com",
+ path: "/unknown",
+ resolver: &stubbedResolver{
+ project: nil,
+ subpath: "/",
+ },
+ expectedResponse: "The page you're looking for could not be found.",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -218,7 +228,7 @@ func TestDomain_ServeNamespaceNotFound(t *testing.T) {
}
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", fmt.Sprintf("http://%s%s", tt.domain, tt.path), nil)
- d.ServeNamespaceNotFound(w, r)
+ d.serveNamespaceNotFound(w, r)
resp := w.Result()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
diff --git a/internal/source/disk/domain_test.go b/internal/source/disk/domain_test.go
index d114ff8a..56297fd1 100644
--- a/internal/source/disk/domain_test.go
+++ b/internal/source/disk/domain_test.go
@@ -308,8 +308,8 @@ func TestDomain404ServeHTTP(t *testing.T) {
},
}
- testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom 404 group page")
- testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/", nil, "Custom 404 group page")
+ testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom domain.404 page")
+ testhelpers.AssertHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.404.test.io/", nil, "Custom domain.404 page")
}
func TestPredefined404ServeHTTP(t *testing.T) {