Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNaman Jagdish Gala <ngala@gitlab.com>2023-12-22 13:13:18 +0300
committerAlessio Caiazza <acaiazza@gitlab.com>2023-12-22 13:13:18 +0300
commitfb14102a5a5cfa14581c326f50c4034eea50f31a (patch)
tree985fb54a4b10cf2c711e9c6061479041e2dea5f3
parent5e29f05994cf60590c3195216a96443a3a9b6367 (diff)
Add project prefix in session cookie path
-rw-r--r--app.go2
-rw-r--r--internal/auth/auth.go48
-rw-r--r--internal/auth/auth_test.go56
-rw-r--r--internal/auth/session.go25
-rw-r--r--internal/domain/domain.go9
-rw-r--r--internal/feature/feature.go6
-rw-r--r--internal/handlers/artifact.go10
-rw-r--r--internal/handlers/handlers.go8
-rw-r--r--internal/handlers/handlers_test.go22
-rw-r--r--internal/handlers/mock/handler_mock.go72
-rw-r--r--internal/interface.go8
11 files changed, 220 insertions, 46 deletions
diff --git a/app.go b/app.go
index 086e4493..bda50e89 100644
--- a/app.go
+++ b/app.go
@@ -155,10 +155,10 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
handler := a.serveFileOrNotFoundHandler()
handler = uniquedomain.NewMiddleware(handler)
handler = a.Auth.AuthorizationMiddleware(handler)
- handler = routing.NewMiddleware(handler, a.source)
handler = handlers.ArtifactMiddleware(handler, a.Handlers)
handler = a.Auth.AuthenticationMiddleware(handler, a.source)
+ handler = routing.NewMiddleware(handler, a.source)
handler = handlers.AcmeMiddleware(handler, a.source, a.config.GitLab.PublicServer)
if !a.config.General.DisableCrossOriginRequests {
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index fe092ae9..7c21cd6b 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -20,7 +20,9 @@ import (
"gitlab.com/gitlab-org/labkit/log"
"golang.org/x/crypto/hkdf"
+ "gitlab.com/gitlab-org/gitlab-pages/internal"
"gitlab.com/gitlab-org/gitlab-pages/internal/errortracking"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/httptransport"
"gitlab.com/gitlab-org/gitlab-pages/internal/logging"
@@ -43,6 +45,7 @@ const (
queryParameterErrMsg = "failed to parse domain query parameter"
saveSessionErrMsg = "failed to save the session"
domainQueryParameterErrMsg = "domain query parameter only supports http/https protocol"
+ projectPrefix = "_project_prefix"
)
var (
@@ -81,10 +84,6 @@ type errorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
-type domain interface {
- GetProjectID(r *http.Request) uint64
- ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request)
-}
// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to /auth?
func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool {
@@ -170,6 +169,20 @@ func (a *Auth) checkAuthenticationResponse(session *hostSession, w http.Response
// Store access token
session.Values["access_token"] = token.AccessToken
+
+ // In final /auth call, updating session path with project prefix.
+ // This will prevent leaking restricted and private projects/subgroups pages under the same top level group
+ // https://gitlab.com/gitlab-org/gitlab-pages/-/issues/1088
+ if feature.ProjectPrefixCookiePath.Enabled() && session.Values[projectPrefix] != nil {
+ session.appendPath(session.Values[projectPrefix].(string))
+
+ logRequest(r).WithField("Prefix Path", session.Values[projectPrefix].(string)).
+ Info("Appending project prefix in session cookie path")
+
+ // If project prefix is useful anywhere, we can avoid deleting it from session.
+ delete(session.Values, projectPrefix)
+ }
+
err = session.Save(r, w)
if err != nil {
logRequest(r).WithError(err).Error(saveSessionErrMsg)
@@ -404,21 +417,21 @@ func (a *Auth) fetchAccessToken(ctx context.Context, code string) (tokenResponse
return token, nil
}
-func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request) *hostSession {
+func (a *Auth) checkSessionIsValid(w http.ResponseWriter, r *http.Request, domain internal.Domain) *hostSession {
session, err := a.checkSession(w, r)
if err != nil {
return nil
}
// redirect to /auth?domain=%s&state=%s
- if a.checkTokenExists(session, w, r) {
+ if a.checkTokenExists(session, w, r, domain) {
return nil
}
return session
}
-func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r *http.Request) bool {
+func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
// If no access token redirect to OAuth login page
if session.Values["access_token"] == nil {
logRequest(r).Debug("No access token exists, redirecting user to OAuth2 login")
@@ -436,6 +449,15 @@ func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r *
// Clear possible proxying
delete(session.Values, "proxy_auth_domain")
+ if feature.ProjectPrefixCookiePath.Enabled() {
+ if prefix := domain.GetProjectPrefix(r); len(prefix) > 1 {
+ session.Values[projectPrefix] = prefix
+ }
+ // After successful authentication, user is redirected to /auth url
+ // To utilise same session, appended /auth in session path
+ session.appendPath("/auth")
+ }
+
err := session.Save(r, w)
if err != nil {
logRequest(r).WithError(err).Error(saveSessionErrMsg)
@@ -483,8 +505,8 @@ func (a *Auth) IsAuthSupported() bool {
return a != nil
}
-func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool {
- session := a.checkSessionIsValid(w, r)
+func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
+ session := a.checkSessionIsValid(w, r, domain)
if session == nil {
return true
}
@@ -543,7 +565,7 @@ func (a *Auth) checkAuthentication(w http.ResponseWriter, r *http.Request, domai
}
// CheckAuthenticationWithoutProject checks if user is authenticated and has a valid token
-func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request, domain domain) bool {
+func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
if a == nil {
// No auth supported
return false
@@ -571,13 +593,13 @@ func (a *Auth) GetTokenIfExists(w http.ResponseWriter, r *http.Request) (string,
}
// RequireAuth will trigger authentication flow if no token exists
-func (a *Auth) RequireAuth(w http.ResponseWriter, r *http.Request) bool {
- return a.checkSessionIsValid(w, r) == nil
+func (a *Auth) RequireAuth(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
+ return a.checkSessionIsValid(w, r, domain) == nil
}
// CheckAuthentication checks if user is authenticated and has access to the project
// will return contentServed = false when authFailed = true
-func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, domain domain) bool {
+func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
logRequest(r).Debug("Authenticate request")
if a == nil {
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 8582a973..4d32b82a 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -14,6 +14,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
"gitlab.com/gitlab-org/gitlab-pages/internal/source/mock"
"gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers"
@@ -43,6 +44,7 @@ func createTestAuth(t *testing.T, internalServer string, publicServer string) *A
type domainMock struct {
projectID uint64
+ projectPrefix string
notFoundContent string
}
@@ -50,6 +52,10 @@ func (dm *domainMock) GetProjectID(r *http.Request) uint64 {
return dm.projectID
}
+func (dm *domainMock) GetProjectPrefix(r *http.Request) string {
+ return dm.projectPrefix
+}
+
func (dm *domainMock) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(dm.notFoundContent))
@@ -215,7 +221,26 @@ func TestTryAuthenticateWithNonHttpDomainAndState(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, result.Code)
}
-func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) {
+func TestProjectPrefixInSessionValues(t *testing.T) {
+ t.Setenv(feature.ProjectPrefixCookiePath.EnvVariable, "true")
+ auth := createTestAuth(t, "", "")
+
+ result := httptest.NewRecorder()
+
+ r, err := http.NewRequest("Get", "https://example.com/test", nil)
+ require.NoError(t, err)
+
+ contentServed := auth.CheckAuthentication(result, r, &domainMock{projectID: 1000, projectPrefix: "/test/"})
+ require.True(t, contentServed)
+
+ cookieString := result.Header().Get("Set-Cookie")
+ require.Contains(t, cookieString, "Path=/auth;", "did not set cookie Path=/auth;")
+
+ session, _ := auth.getSessionFromStore(r)
+ require.Equal(t, "/test/", session.Values["_project_prefix"], "did not set project prefix in session values")
+}
+
+func testTryAuthenticateWithCodeAndState(t *testing.T, https bool, projectPrefix string) *httptest.ResponseRecorder {
t.Helper()
apiServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -256,8 +281,9 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) {
}
setSessionValues(t, r, auth, map[interface{}]interface{}{
- "uri": "https://pages.gitlab-example.com/project/",
- "state": "state",
+ "uri": "https://pages.gitlab-example.com/project/",
+ "state": "state",
+ "_project_prefix": projectPrefix,
})
result := httptest.NewRecorder()
@@ -274,14 +300,34 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) {
require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location"))
require.Equal(t, 600, res.Cookies()[0].MaxAge)
require.Equal(t, https, res.Cookies()[0].Secure)
+
+ return result
}
func TestTryAuthenticateWithCodeAndStateOverHTTP(t *testing.T) {
- testTryAuthenticateWithCodeAndState(t, false)
+ testTryAuthenticateWithCodeAndState(t, false, "/")
}
func TestTryAuthenticateWithCodeAndStateOverHTTPS(t *testing.T) {
- testTryAuthenticateWithCodeAndState(t, true)
+ testTryAuthenticateWithCodeAndState(t, true, "/")
+}
+
+func TestTryAuthenticateOverHTTPSWithProjectPrefixCookiePathEnabled(t *testing.T) {
+ t.Setenv(feature.ProjectPrefixCookiePath.EnvVariable, "true")
+ prefix := "/project"
+ result := testTryAuthenticateWithCodeAndState(t, true, prefix)
+
+ cookieString := result.Header().Get("Set-Cookie")
+ require.Contains(t, cookieString, "Path="+prefix+";", "did not set Cookie Path="+prefix+";")
+}
+
+func TestTryAuthenticateOverHTTPSWithProjectPrefixCookiePathDisabled(t *testing.T) {
+ t.Setenv(feature.ProjectPrefixCookiePath.EnvVariable, "false")
+ prefix := "/project"
+ result := testTryAuthenticateWithCodeAndState(t, true, prefix)
+
+ cookieString := result.Header().Get("Set-Cookie")
+ require.Contains(t, cookieString, "Path=/;", "did not set Cookie Path=/;")
}
func TestCheckAuthenticationWhenAccess(t *testing.T) {
diff --git a/internal/auth/session.go b/internal/auth/session.go
index c9b4ee8d..c0a3d370 100644
--- a/internal/auth/session.go
+++ b/internal/auth/session.go
@@ -2,21 +2,26 @@ package auth
import (
"net/http"
+ "strings"
"github.com/gorilla/sessions"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab-pages/internal/errortracking"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/feature"
"gitlab.com/gitlab-org/gitlab-pages/internal/httperrors"
"gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
+const (
+ sessionHostKey = "_session_host"
+ namespaceInPathKey = "_namespace_in_path"
+)
+
type hostSession struct {
*sessions.Session
}
-const sessionHostKey = "_session_host"
-
func (s *hostSession) Save(r *http.Request, w http.ResponseWriter) error {
s.Session.Values[sessionHostKey] = r.Host
@@ -24,11 +29,16 @@ func (s *hostSession) Save(r *http.Request, w http.ResponseWriter) error {
}
func (s *hostSession) getNamespaceInPathFromSession() string {
- namespaceInPath := ""
- if len(s.Options.Path) > 1 && s.Options.Path[0] == '/' {
- namespaceInPath = s.Options.Path[1:]
+ if s.Values[namespaceInPathKey] != nil {
+ return s.Values[namespaceInPathKey].(string)
+ }
+ return ""
+}
+
+func (s *hostSession) appendPath(path string) {
+ if feature.ProjectPrefixCookiePath.Enabled() && len(path) > 0 {
+ s.Options.Path = strings.TrimSuffix(s.Options.Path, "/") + "/" + strings.Trim(path, "/")
}
- return namespaceInPath
}
func (a *Auth) getSessionFromStore(r *http.Request) (*hostSession, error) {
@@ -52,6 +62,9 @@ func (a *Auth) getSessionFromStore(r *http.Request) (*hostSession, error) {
session.Values = make(map[interface{}]interface{})
}
+ if len(namespaceInPath) > 0 {
+ session.Values[namespaceInPathKey] = namespaceInPathKey
+ }
}
return &hostSession{session}, err
diff --git a/internal/domain/domain.go b/internal/domain/domain.go
index 9623ce51..db6e81e6 100644
--- a/internal/domain/domain.go
+++ b/internal/domain/domain.go
@@ -104,6 +104,15 @@ func (d *Domain) GetProjectID(r *http.Request) uint64 {
return 0
}
+// GetProjectPrefix figures out what is the prefix (Ex. /subgroup/project/) of the project user tries to access
+func (d *Domain) GetProjectPrefix(r *http.Request) string {
+ if lookupPath, _ := d.GetLookupPath(r); lookupPath != nil {
+ return lookupPath.Prefix
+ }
+
+ return ""
+}
+
// EnsureCertificate parses the PEM-encoded certificate for the domain
func (d *Domain) EnsureCertificate() (*tls.Certificate, error) {
if d == nil || len(d.CertificateKey) == 0 || len(d.CertificateCert) == 0 {
diff --git a/internal/feature/feature.go b/internal/feature/feature.go
index d5c340ff..daf5f7c9 100644
--- a/internal/feature/feature.go
+++ b/internal/feature/feature.go
@@ -19,6 +19,12 @@ var HandleReadErrors = Feature{
EnvVariable: "FF_HANDLE_READ_ERRORS",
}
+// ProjectPrefixCookiePath enables support for path in session cookie
+var ProjectPrefixCookiePath = Feature{
+ EnvVariable: "FF_ENABLE_PROJECT_PREFIX_COOKIE_PATH",
+ defaultEnabled: false,
+}
+
// Enabled reads the environment variable responsible for the feature flag
// if FF is disabled by default, the environment variable needs to be "true" to explicitly enable it
// if FF is enabled by default, variable needs to be "false" to explicitly disable it
diff --git a/internal/handlers/artifact.go b/internal/handlers/artifact.go
index 994fde33..add12e89 100644
--- a/internal/handlers/artifact.go
+++ b/internal/handlers/artifact.go
@@ -1,10 +1,16 @@
package handlers
-import "net/http"
+import (
+ "net/http"
+
+ domainCfg "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
+)
func ArtifactMiddleware(handler http.Handler, h *Handlers) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if h.HandleArtifactRequest(w, r) {
+ domain := domainCfg.FromRequest(r)
+
+ if h.HandleArtifactRequest(w, r, domain) {
return
}
diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go
index 5b15fba1..6c3b25ab 100644
--- a/internal/handlers/handlers.go
+++ b/internal/handlers/handlers.go
@@ -22,7 +22,7 @@ func New(auth internal.Auth, artifact internal.Artifact) *Handlers {
}
}
-func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r *http.Request, token string) func(*http.Response) bool {
+func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r *http.Request, token string, domain internal.Domain) func(*http.Response) bool {
return func(resp *http.Response) bool {
// API will return 403 if the project does not have public pipelines (public_builds flag)
if resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusForbidden {
@@ -35,7 +35,7 @@ func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r *
logging.LogRequest(r).Debugf("Artifact API response was %d without token, try with authentication", resp.StatusCode)
// Authenticate user
- if a.Auth.RequireAuth(w, r) {
+ if a.Auth.RequireAuth(w, r, domain) {
return true
}
} else {
@@ -52,7 +52,7 @@ func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r *
}
// HandleArtifactRequest handles all artifact related requests, will return true if request was handled here
-func (a *Handlers) HandleArtifactRequest(w http.ResponseWriter, r *http.Request) bool {
+func (a *Handlers) HandleArtifactRequest(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
// In the event a host is prefixed with the artifact prefix an artifact
// value is created, and an attempt to proxy the request is made
@@ -65,5 +65,5 @@ func (a *Handlers) HandleArtifactRequest(w http.ResponseWriter, r *http.Request)
//nolint: bodyclose // false positive
// a.checkIfLoginRequiredOrInvalidToken returns a response.Body, closing this body is responsibility
// of the TryMakeRequest implementation
- return a.Artifact.TryMakeRequest(w, r, token, a.checkIfLoginRequiredOrInvalidToken(w, r, token))
+ return a.Artifact.TryMakeRequest(w, r, token, a.checkIfLoginRequiredOrInvalidToken(w, r, token, domain))
}
diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go
index 99ba8ca9..a1fa3246 100644
--- a/internal/handlers/handlers_test.go
+++ b/internal/handlers/handlers_test.go
@@ -35,7 +35,7 @@ func TestNotHandleArtifactRequestReturnsFalse(t *testing.T) {
require.NoError(t, err)
r := &http.Request{URL: reqURL}
- require.False(t, handlers.HandleArtifactRequest(result, r))
+ require.False(t, handlers.HandleArtifactRequest(result, r, mock.NewMockDomain(mockCtrl)))
}
func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
@@ -58,7 +58,7 @@ func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
result := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/something", nil)
- require.True(t, handlers.HandleArtifactRequest(result, r))
+ require.True(t, handlers.HandleArtifactRequest(result, r, mock.NewMockDomain(mockCtrl)))
}
func TestNotFoundWithTokenIsNotHandled(t *testing.T) {
@@ -74,7 +74,7 @@ func TestNotFoundWithTokenIsNotHandled(t *testing.T) {
reqURL, _ := url.Parse("/")
r := &http.Request{URL: reqURL}
response := &http.Response{StatusCode: http.StatusNotFound}
- handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "token")(response)
+ handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "token", mock.NewMockDomain(mockCtrl))(response)
require.False(t, handled)
}
@@ -104,7 +104,7 @@ func TestForbiddenWithTokenIsNotHandled(t *testing.T) {
mockAuth := mock.NewMockAuth(mockCtrl)
if tc.Token == "" {
mockAuth.EXPECT().IsAuthSupported().Return(true)
- mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any()).Return(true)
+ mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
} else {
mockAuth.EXPECT().CheckResponseForInvalidToken(gomock.Any(), gomock.Any(), gomock.Any()).
Return(false)
@@ -115,7 +115,7 @@ func TestForbiddenWithTokenIsNotHandled(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
response := &http.Response{StatusCode: tc.StatusCode}
- handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, tc.Token)(response)
+ handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, tc.Token, mock.NewMockDomain(mockCtrl))(response)
require.Equal(t, tc.Handled, handled)
})
@@ -133,26 +133,28 @@ func TestNotFoundWithoutTokenIsNotHandledWhenNotAuthSupport(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
response := &http.Response{StatusCode: http.StatusNotFound}
- handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "")(response)
+ handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "", mock.NewMockDomain(mockCtrl))(response)
require.False(t, handled)
}
+
func TestNotFoundWithoutTokenIsHandled(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockAuth := mock.NewMockAuth(mockCtrl)
mockAuth.EXPECT().IsAuthSupported().Return(true)
- mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any()).Times(1).Return(true)
+ mockAuth.EXPECT().RequireAuth(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(true)
handlers := New(mockAuth, nil)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
response := &http.Response{StatusCode: http.StatusNotFound}
- handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "")(response)
+ handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "", mock.NewMockDomain(mockCtrl))(response)
require.True(t, handled)
}
+
func TestInvalidTokenResponseIsHandled(t *testing.T) {
mockCtrl := gomock.NewController(t)
@@ -165,7 +167,7 @@ func TestInvalidTokenResponseIsHandled(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
response := &http.Response{StatusCode: http.StatusUnauthorized}
- handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "token")(response)
+ handled := handlers.checkIfLoginRequiredOrInvalidToken(w, r, "token", mock.NewMockDomain(mockCtrl))(response)
require.True(t, handled)
}
@@ -186,5 +188,5 @@ func TestHandleArtifactRequestButGetTokenFails(t *testing.T) {
result := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/something", nil)
- require.True(t, handlers.HandleArtifactRequest(result, r))
+ require.True(t, handlers.HandleArtifactRequest(result, r, mock.NewMockDomain(mockCtrl)))
}
diff --git a/internal/handlers/mock/handler_mock.go b/internal/handlers/mock/handler_mock.go
index 089b567c..2906fe72 100644
--- a/internal/handlers/mock/handler_mock.go
+++ b/internal/handlers/mock/handler_mock.go
@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
+ internal "gitlab.com/gitlab-org/gitlab-pages/internal"
)
// MockArtifact is a mock of Artifact interface.
@@ -115,15 +116,78 @@ func (mr *MockAuthMockRecorder) IsAuthSupported() *gomock.Call {
}
// RequireAuth mocks base method.
-func (m *MockAuth) RequireAuth(w http.ResponseWriter, r *http.Request) bool {
+func (m *MockAuth) RequireAuth(w http.ResponseWriter, r *http.Request, domain internal.Domain) bool {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "RequireAuth", w, r)
+ ret := m.ctrl.Call(m, "RequireAuth", w, r, domain)
ret0, _ := ret[0].(bool)
return ret0
}
// RequireAuth indicates an expected call of RequireAuth.
-func (mr *MockAuthMockRecorder) RequireAuth(w, r interface{}) *gomock.Call {
+func (mr *MockAuthMockRecorder) RequireAuth(w, r, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequireAuth", reflect.TypeOf((*MockAuth)(nil).RequireAuth), w, r)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequireAuth", reflect.TypeOf((*MockAuth)(nil).RequireAuth), w, r, domain)
+}
+
+// MockDomain is a mock of Domain interface.
+type MockDomain struct {
+ ctrl *gomock.Controller
+ recorder *MockDomainMockRecorder
+}
+
+// MockDomainMockRecorder is the mock recorder for MockDomain.
+type MockDomainMockRecorder struct {
+ mock *MockDomain
+}
+
+// NewMockDomain creates a new mock instance.
+func NewMockDomain(ctrl *gomock.Controller) *MockDomain {
+ mock := &MockDomain{ctrl: ctrl}
+ mock.recorder = &MockDomainMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockDomain) EXPECT() *MockDomainMockRecorder {
+ return m.recorder
+}
+
+// GetProjectID mocks base method.
+func (m *MockDomain) GetProjectID(r *http.Request) uint64 {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetProjectID", r)
+ ret0, _ := ret[0].(uint64)
+ return ret0
+}
+
+// GetProjectID indicates an expected call of GetProjectID.
+func (mr *MockDomainMockRecorder) GetProjectID(r interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProjectID", reflect.TypeOf((*MockDomain)(nil).GetProjectID), r)
+}
+
+// GetProjectPrefix mocks base method.
+func (m *MockDomain) GetProjectPrefix(r *http.Request) string {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetProjectPrefix", r)
+ ret0, _ := ret[0].(string)
+ return ret0
+}
+
+// GetProjectPrefix indicates an expected call of GetProjectPrefix.
+func (mr *MockDomainMockRecorder) GetProjectPrefix(r interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProjectPrefix", reflect.TypeOf((*MockDomain)(nil).GetProjectPrefix), r)
+}
+
+// ServeNotFoundAuthFailed mocks base method.
+func (m *MockDomain) ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request) {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "ServeNotFoundAuthFailed", w, r)
+}
+
+// ServeNotFoundAuthFailed indicates an expected call of ServeNotFoundAuthFailed.
+func (mr *MockDomainMockRecorder) ServeNotFoundAuthFailed(w, r interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServeNotFoundAuthFailed", reflect.TypeOf((*MockDomain)(nil).ServeNotFoundAuthFailed), w, r)
}
diff --git a/internal/interface.go b/internal/interface.go
index 2a77e569..55ed72fa 100644
--- a/internal/interface.go
+++ b/internal/interface.go
@@ -12,7 +12,13 @@ type Artifact interface {
// Auth handles the authentication logic
type Auth interface {
IsAuthSupported() bool
- RequireAuth(w http.ResponseWriter, r *http.Request) bool
+ RequireAuth(w http.ResponseWriter, r *http.Request, domain Domain) bool
GetTokenIfExists(w http.ResponseWriter, r *http.Request) (string, error)
CheckResponseForInvalidToken(w http.ResponseWriter, r *http.Request, resp *http.Response) bool
}
+
+type Domain interface {
+ GetProjectID(r *http.Request) uint64
+ GetProjectPrefix(r *http.Request) string
+ ServeNotFoundAuthFailed(w http.ResponseWriter, r *http.Request)
+}