diff options
-rw-r--r-- | acceptance_test.go | 2 | ||||
-rw-r--r-- | internal/auth/auth.go | 143 |
2 files changed, 88 insertions, 57 deletions
diff --git a/acceptance_test.go b/acceptance_test.go index 98db0203..31f4e3e5 100644 --- a/acceptance_test.go +++ b/acceptance_test.go @@ -709,7 +709,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) { require.NoError(t, err) state := url.Query().Get("state") - assert.Equal(t, url.Query().Get("domain"), "private.domain.com") + assert.Equal(t, url.Query().Get("domain"), "http://private.domain.com") pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery) require.NoError(t, err) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index f2f21537..f8524405 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "net/url" "strings" "sync" "time" @@ -69,24 +70,25 @@ func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) { return store.Get(r, "gitlab-pages") } -func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) bool { +func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { // Create or get session session, err := a.getSessionFromStore(r) if err != nil { // Save cookie again - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return nil, err + } + http.Redirect(w, r, getRequestAddress(r), 302) - return true + return nil, err } - return false -} - -func (a *Auth) getSession(r *http.Request) *sessions.Session { - session, _ := a.getSessionFromStore(r) - return session + return session, nil } // TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth @@ -96,12 +98,11 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - // Request is for auth if r.URL.Path != callbackPath { return false @@ -123,39 +124,47 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain } if verifyCodeAndStateGiven(r) { + a.checkAuthenticationResponse(session, w, r) + return true + } - if !validateState(r, session) { - // State is NOT ok - log.Debug("Authentication state did not match expected") - - httperrors.Serve401(w) - return true - } + return false +} - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) +func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.ResponseWriter, r *http.Request) { - // Fetching token not OK - if err != nil { - log.WithError(err).Debug("Fetching access token failed") + if !validateState(r, session) { + // State is NOT ok + log.Debug("Authentication state did not match expected") - httperrors.Serve503(w) - return true - } + httperrors.Serve401(w) + return + } - // Store access token - session.Values["access_token"] = token.AccessToken - session.Save(r, w) + // Fetch access token with authorization code + token, err := a.fetchAccessToken(r.URL.Query().Get("code")) - // Redirect back to requested URI - log.Debug("Authentication was successful, redirecting user back to requested page") + // Fetching token not OK + if err != nil { + log.WithError(err).Debug("Fetching access token failed") - http.Redirect(w, r, session.Values["uri"].(string), 302) + httperrors.Serve503(w) + return + } - return true + // Store access token + session.Values["access_token"] = token.AccessToken + err = session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return } - return false + // Redirect back to requested URI + log.Debug("Authentication was successful, redirecting user back to requested page") + + http.Redirect(w, r, session.Values["uri"].(string), 302) } func (a *Auth) domainAllowed(domain string, dm domain.Map, lock *sync.RWMutex) bool { @@ -171,20 +180,33 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") - if !a.domainAllowed(domain, dm, lock) { - log.WithField("domain", domain).Debug("Domain is not configured") + proxyurl, err := url.Parse(domain) + if err != nil { + log.WithField("domain", domain).Error("Failed to parse domain query parameter") + httperrors.Serve500(w) + return true + } + host, _, err := net.SplitHostPort(proxyurl.Host) + if err != nil { + host = proxyurl.Host + } + + if !a.domainAllowed(host, dm, lock) { + log.WithField("domain", host).Debug("Domain is not configured") httperrors.Serve401(w) return true } log.WithField("domain", domain).Debug("User is authenticating via domain") - if r.TLS != nil { - session.Values["proxy_auth_domain"] = "https://" + domain - } else { - session.Values["proxy_auth_domain"] = "http://" + domain + session.Values["proxy_auth_domain"] = domain + + err = session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true } - session.Save(r, w) url := fmt.Sprintf(authorizeURLTemplate, a.gitLabServer, a.clientID, a.redirectURI, state) http.Redirect(w, r, url, 302) @@ -202,7 +224,12 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit // Clear proxying from session delete(session.Values, "proxy_auth_domain") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true + } // Redirect pages under custom domain http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) @@ -302,7 +329,12 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter // Clear possible proxying delete(session.Values, "proxy_auth_domain") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true + } // Because the pages domain might be in public suffix list, we have to // redirect to pages domain to trigger authorization flow @@ -314,10 +346,7 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter } func (a *Auth) getProxyAddress(r *http.Request, state string) string { - if r.TLS != nil { - return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state) - } - return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state) + return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r), state) } func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Request) { @@ -325,7 +354,12 @@ func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Re // Invalidate access token and redirect back for refreshing and re-authenticating delete(session.Values, "access_token") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return + } http.Redirect(w, r, getRequestAddress(r), 302) } @@ -346,12 +380,11 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http. return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - if a.checkTokenExists(session, w, r) { return true } @@ -394,16 +427,14 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http. func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) bool { if a == nil { - log.Warn("Authentication is disabled, falling back to PUBLIC pages") return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - if a.checkTokenExists(session, w, r) { return true } |