diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/auth/auth_test.go | 23 | ||||
-rw-r--r-- | internal/domain/domain.go | 13 | ||||
-rw-r--r-- | internal/domain/domain_test.go | 7 |
3 files changed, 27 insertions, 16 deletions
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index ffaf5b6f..39a533b3 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -30,7 +30,8 @@ func defaultCookieStore() sessions.Store { } type domainMock struct { - projectID uint64 + projectID uint64 + notFoundContent string } func (dm *domainMock) GetProjectID(r *http.Request) uint64 { @@ -39,6 +40,7 @@ func (dm *domainMock) GetProjectID(r *http.Request) uint64 { func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) + w.Write([]byte(dm.notFoundContent)) } // Gorilla's sessions use request context to save session @@ -195,7 +197,7 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) require.False(t, contentServed) - // content wasn't served so the default response from CheckAuthentication should be 200 + // notFoundContent wasn't served so the default response from CheckAuthentication should be 200 require.Equal(t, 200, result.Code) } @@ -223,7 +225,8 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { "http://pages.gitlab-example.com/auth", apiServer.URL) - result := httptest.NewRecorder() + w := httptest.NewRecorder() + reqURL, err := url.Parse("/auth?code=1&state=state") require.NoError(t, err) reqURL.Scheme = request.SchemeHTTPS @@ -231,12 +234,18 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { session, _ := store.Get(r, "gitlab-pages") session.Values["access_token"] = "abc" - session.Save(r, result) + session.Save(r, w) - contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000}) + contentServed := auth.CheckAuthentication(w, r, &domainMock{projectID: 1000, notFoundContent: "Generic 404"}) require.True(t, contentServed) - // content wasn't served so the default response from CheckAuthentication should be 200 - require.Equal(t, 404, result.Code) + res := w.Result() + defer res.Body.Close() + + require.Equal(t, 404, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "Generic 404") } func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 9fa843a3..7c1639a3 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -1,6 +1,7 @@ package domain import ( + "context" "crypto/tls" "errors" "net/http" @@ -173,16 +174,18 @@ func (d *Domain) ServeNotFoundHTTP(w http.ResponseWriter, r *http.Request) { // that failed authentication so that we serve the custom namespace error page for // public namespace domains func (d *Domain) serveNamespaceNotFound(w http.ResponseWriter, r *http.Request) { - // override the path and try to resolve the domain name - r.URL.Path = "/" - namespaceDomain, err := d.Resolver.Resolve(r) + // clone r and override the path and try to resolve the domain name + clonedReq := r.Clone(context.Background()) + clonedReq.URL.Path = "/" + + namespaceDomain, err := d.Resolver.Resolve(clonedReq) if err != nil || namespaceDomain.LookupPath == nil { httperrors.Serve404(w) return } // for namespace domains that have no access control enabled - if namespaceDomain.LookupPath.IsNamespaceProject && !namespaceDomain.LookupPath.HasAccessControl { + if !namespaceDomain.LookupPath.HasAccessControl { namespaceDomain.ServeNotFoundHTTP(w, r) return } @@ -197,7 +200,7 @@ func (d *Domain) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) httperrors.Serve404(w) return } - if d.IsNamespaceProject(r) { + if d.IsNamespaceProject(r) && !d.GetLookupPath(r).HasAccessControl { d.ServeNotFoundHTTP(w, r) return } diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index 9e89f0e5..fc5611ba 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -157,10 +157,7 @@ func chdirInPath(t require.TestingT, path string) func() { } } -func TestDomain_ServeNamespaceNotFound(t *testing.T) { - // defaultNotFound := "The page you're looking for could not be found." - // customNotFound := "Custom error page" - +func TestServeNamespaceNotFound(t *testing.T) { tests := []struct { name string domain string @@ -231,6 +228,8 @@ func TestDomain_ServeNamespaceNotFound(t *testing.T) { d.serveNamespaceNotFound(w, r) resp := w.Result() + defer resp.Body.Close() + require.Equal(t, http.StatusNotFound, resp.StatusCode) body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) |