diff options
author | Tuomo Ala-Vannesluoma <tuomoav@gmail.com> | 2018-08-18 21:49:47 +0300 |
---|---|---|
committer | Tuomo Ala-Vannesluoma <tuomoav@gmail.com> | 2018-08-18 21:55:06 +0300 |
commit | f6edf4e90517c8ba0ffa3190f0b9db537f5f0e1b (patch) | |
tree | 9178fbdfeaef4b6f170754e6cf0229afb828cca0 /internal | |
parent | af8b9cd5df9bf6331b9494149d2e402d30bcea81 (diff) |
Added checks for errors, refactored a bit to avoid method complexity increasing, fixed to work with custom ports and TLS enabled or not
Diffstat (limited to 'internal')
-rw-r--r-- | internal/auth/auth.go | 143 |
1 files changed, 87 insertions, 56 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index f2f21537..f8524405 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "net/url" "strings" "sync" "time" @@ -69,24 +70,25 @@ func (a *Auth) getSessionFromStore(r *http.Request) (*sessions.Session, error) { return store.Get(r, "gitlab-pages") } -func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) bool { +func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.Session, error) { // Create or get session session, err := a.getSessionFromStore(r) if err != nil { // Save cookie again - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return nil, err + } + http.Redirect(w, r, getRequestAddress(r), 302) - return true + return nil, err } - return false -} - -func (a *Auth) getSession(r *http.Request) *sessions.Session { - session, _ := a.getSessionFromStore(r) - return session + return session, nil } // TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth @@ -96,12 +98,11 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - // Request is for auth if r.URL.Path != callbackPath { return false @@ -123,39 +124,47 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain } if verifyCodeAndStateGiven(r) { + a.checkAuthenticationResponse(session, w, r) + return true + } - if !validateState(r, session) { - // State is NOT ok - log.Debug("Authentication state did not match expected") - - httperrors.Serve401(w) - return true - } + return false +} - // Fetch access token with authorization code - token, err := a.fetchAccessToken(r.URL.Query().Get("code")) +func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.ResponseWriter, r *http.Request) { - // Fetching token not OK - if err != nil { - log.WithError(err).Debug("Fetching access token failed") + if !validateState(r, session) { + // State is NOT ok + log.Debug("Authentication state did not match expected") - httperrors.Serve503(w) - return true - } + httperrors.Serve401(w) + return + } - // Store access token - session.Values["access_token"] = token.AccessToken - session.Save(r, w) + // Fetch access token with authorization code + token, err := a.fetchAccessToken(r.URL.Query().Get("code")) - // Redirect back to requested URI - log.Debug("Authentication was successful, redirecting user back to requested page") + // Fetching token not OK + if err != nil { + log.WithError(err).Debug("Fetching access token failed") - http.Redirect(w, r, session.Values["uri"].(string), 302) + httperrors.Serve503(w) + return + } - return true + // Store access token + session.Values["access_token"] = token.AccessToken + err = session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return } - return false + // Redirect back to requested URI + log.Debug("Authentication was successful, redirecting user back to requested page") + + http.Redirect(w, r, session.Values["uri"].(string), 302) } func (a *Auth) domainAllowed(domain string, dm domain.Map, lock *sync.RWMutex) bool { @@ -171,20 +180,33 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit domain := r.URL.Query().Get("domain") state := r.URL.Query().Get("state") - if !a.domainAllowed(domain, dm, lock) { - log.WithField("domain", domain).Debug("Domain is not configured") + proxyurl, err := url.Parse(domain) + if err != nil { + log.WithField("domain", domain).Error("Failed to parse domain query parameter") + httperrors.Serve500(w) + return true + } + host, _, err := net.SplitHostPort(proxyurl.Host) + if err != nil { + host = proxyurl.Host + } + + if !a.domainAllowed(host, dm, lock) { + log.WithField("domain", host).Debug("Domain is not configured") httperrors.Serve401(w) return true } log.WithField("domain", domain).Debug("User is authenticating via domain") - if r.TLS != nil { - session.Values["proxy_auth_domain"] = "https://" + domain - } else { - session.Values["proxy_auth_domain"] = "http://" + domain + session.Values["proxy_auth_domain"] = domain + + err = session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true } - session.Save(r, w) url := fmt.Sprintf(authorizeURLTemplate, a.gitLabServer, a.clientID, a.redirectURI, state) http.Redirect(w, r, url, 302) @@ -202,7 +224,12 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit // Clear proxying from session delete(session.Values, "proxy_auth_domain") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true + } // Redirect pages under custom domain http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+r.URL.RawQuery, 302) @@ -302,7 +329,12 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter // Clear possible proxying delete(session.Values, "proxy_auth_domain") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return true + } // Because the pages domain might be in public suffix list, we have to // redirect to pages domain to trigger authorization flow @@ -314,10 +346,7 @@ func (a *Auth) checkTokenExists(session *sessions.Session, w http.ResponseWriter } func (a *Auth) getProxyAddress(r *http.Request, state string) string { - if r.TLS != nil { - return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state) - } - return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, r.Host, state) + return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r), state) } func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Request) { @@ -325,7 +354,12 @@ func destroySession(session *sessions.Session, w http.ResponseWriter, r *http.Re // Invalidate access token and redirect back for refreshing and re-authenticating delete(session.Values, "access_token") - session.Save(r, w) + err := session.Save(r, w) + if err != nil { + log.WithError(err).Error("Failed to save the session") + httperrors.Serve500(w) + return + } http.Redirect(w, r, getRequestAddress(r), 302) } @@ -346,12 +380,11 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http. return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - if a.checkTokenExists(session, w, r) { return true } @@ -394,16 +427,14 @@ func (a *Auth) CheckAuthenticationWithoutProject(w http.ResponseWriter, r *http. func (a *Auth) CheckAuthentication(w http.ResponseWriter, r *http.Request, projectID uint64) bool { if a == nil { - log.Warn("Authentication is disabled, falling back to PUBLIC pages") return false } - if a.checkSession(w, r) { + session, err := a.checkSession(w, r) + if err != nil { return true } - session := a.getSession(r) - if a.checkTokenExists(session, w, r) { return true } |