Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--acceptance_test.go2
-rw-r--r--internal/auth/auth.go143
2 files changed, 88 insertions, 57 deletions
diff --git a/acceptance_test.go b/acceptance_test.go
index 98db0203..31f4e3e5 100644
--- a/acceptance_test.go
+++ b/acceptance_test.go
@@ -709,7 +709,7 @@ func TestAccessControlUnderCustomDomain(t *testing.T) {
require.NoError(t, err)
state := url.Query().Get("state")
- assert.Equal(t, url.Query().Get("domain"), "private.domain.com")
+ assert.Equal(t, url.Query().Get("domain"), "http://private.domain.com")
pagesrsp, err := GetRedirectPage(t, httpListener, url.Host, url.Path+"?"+url.RawQuery)
require.NoError(t, err)
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index f2f21537..f8524405 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -7,6 +7,7 @@ import (
"fmt"
"net"
"net/http"
+ "net/url"
"strings"
"sync"
"time"
@@ -69,24 +70,25 @@ func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) {
return store.Get(r, "gitlab-pages")
}
-func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) bool {
+func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) {
// Create or get session
session, err := a.getSessionFromStore(r)
if err != nil {
// Save cookie again
- session.Save(r, w)
+ err := session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return nil, err
+ }
+
http.Redirect(w, r, getRequestAddress(r), 302)
- return true
+ return nil, err
}
- return false
-}
-
-func (a *Auth) getSession(r *http.Request) *sessions.Session {
- session, _ := a.getSessionFromStore(r)
- return session
+ return session, nil
}
// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth
@@ -96,12 +98,11 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain
return false
}
- if a.checkSession(w, r) {
+ session, err := a.checkSession(w, r)
+ if err != nil {
return true
}
- session := a.getSession(r)
-
// Request is for auth
if r.URL.Path != callbackPath {
return false
@@ -123,39 +124,47 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain
}
if verifyCodeAndStateGiven(r) {
+ a.checkAuthenticationResponse(session, w, r)
+ return true
+ }
- if !validateState(r, session) {
- // State is NOT ok
- log.Debug("Authentication state did not match expected")
-
- httperrors.Serve401(w)
- return true
- }
+ return false
+}
- // Fetch access token with authorization code
- token, err := a.fetchAccessToken(r.URL.Query().Get("code"))
+func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.ResponseWriter, r *http.Request) {
- // Fetching token not OK
- if err != nil {
- log.WithError(err).Debug("Fetching access token failed")
+ if !validateState(r, session) {
+ // State is NOT ok
+ log.Debug("Authentication state did not match expected")
- httperrors.Serve503(w)
- return true
- }
+ httperrors.Serve401(w)
+ return
+ }
- // Store access token
- session.Values["access_token"] = token.AccessToken
- session.Save(r, w)
+ // Fetch access token with authorization code
+ token, err := a.fetchAccessToken(r.URL.Query().Get("code"))
- // Redirect back to requested URI
- log.Debug("Authentication was successful, redirecting user back to requested page")
+ // Fetching token not OK
+ if err != nil {
+ log.WithError(err).Debug("Fetching access token failed")
- http.Redirect(w, r, session.Values["uri"].(string), 302)
+ httperrors.Serve503(w)
+ return
+ }
- return true
+ // Store access token
+ session.Values["access_token"] = token.AccessToken
+ err = session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return
}
- return false
+ // Redirect back to requested URI
+ log.Debug("Authentication was successful, redirecting user back to requested page")
+
+ http.Redirect(w, r, session.Values["uri"].(string), 302)
}
func (a *Auth) domainAllowed(domain string, dm domain.Map, lock *sync.RWMutex) bool {
@@ -171,20 +180,33 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit
domain := r.URL.Query().Get("domain")
state := r.URL.Query().Get("state")
- if !a.domainAllowed(domain, dm, lock) {
- log.WithField("domain", domain).Debug("Domain is not configured")
+ proxyurl, err := url.Parse(domain)
+ if err != nil {
+ log.WithField("domain", domain).Error("Failed to parse domain query parameter")
+ httperrors.Serve500(w)
+ return true
+ }
+ host, _, err := net.SplitHostPort(proxyurl.Host)
+ if err != nil {
+ host = proxyurl.Host
+ }
+
+ if !a.domainAllowed(host, dm, lock) {
+ log.WithField("domain", host).Debug("Domain is not configured")
httperrors.Serve401(w)
return true
}
log.WithField("domain", domain).Debug("User is authenticating via domain")
- if r.TLS != nil {
- session.Values["proxy_auth_domain"] = "https://" + domain
- } else {
- session.Values["proxy_auth_domain"] = "http://" + domain
+ session.Values["proxy_auth_domain"] = domain
+
+ err = session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return true
}
- session.Save(r, w)
url := fmt.Sprintf(authorizeURLTemplate, a.gitLabServer, a.clientID, a.redirectURI, state)
http.Redirect(w, r, url, 302)
@@ -202,7 +224,12 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit
// Clear proxying from session
delete(session.Values, "proxy_auth_domain")
- session.Save(r, w)
+ err := session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return true
+ }
// Redirect pages under custom domain
http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302)
@@ -302,7 +329,12 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter
// Clear possible proxying
delete(session.Values, "proxy_auth_domain")
- session.Save(r, w)
+ err := session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return true
+ }
// Because the pages domain might be in public suffix list, we have to
// redirect to pages domain to trigger authorization flow
@@ -314,10 +346,7 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter
}
func (a *Auth) getProxyAddress(r *http.Request, state string) string {
- if r.TLS != nil {
- return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state)
- }
- return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state)
+ return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r), state)
}
func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Request) {
@@ -325,7 +354,12 @@ func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Re
// Invalidate access token and redirect back for refreshing and re-authenticating
delete(session.Values, "access_token")
- session.Save(r, w)
+ err := session.Save(r, w)
+ if err != nil {
+ log.WithError(err).Error("Failed to save the session")
+ httperrors.Serve500(w)
+ return
+ }
http.Redirect(w, r, getRequestAddress(r), 302)
}
@@ -346,12 +380,11 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.
return false
}
- if a.checkSession(w, r) {
+ session, err := a.checkSession(w, r)
+ if err != nil {
return true
}
- session := a.getSession(r)
-
if a.checkTokenExists(session, w, r) {
return true
}
@@ -394,16 +427,14 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http.
func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) bool {
if a == nil {
- log.Warn("Authentication is disabled, falling back to PUBLIC pages")
return false
}
- if a.checkSession(w, r) {
+ session, err := a.checkSession(w, r)
+ if err != nil {
return true
}
- session := a.getSession(r)
-
if a.checkTokenExists(session, w, r) {
return true
}