Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/app.go
diff options
context:
space:
mode:
authorAndrew Newdigate <andrew@gitlab.com>2019-07-24 01:12:22 +0300
committerAndrew Newdigate <andrew@gitlab.com>2019-08-22 13:50:39 +0300
commitc78ef2c684675b7b0685a78958860558149fae25 (patch)
tree9e7a31d5ecaa3ea6dfc1ca60bd42552df9be6621 /app.go
parent654c183c4b06c7deb3b947c7f557fe4a48f2e218 (diff)
Refactor to use pluggable http.Handler middlewaresan-use-middleware-handlers
Diffstat (limited to 'app.go')
-rw-r--r--app.go181
1 files changed, 133 insertions, 48 deletions
diff --git a/app.go b/app.go
index 3efc22c2..8ff2444e 100644
--- a/app.go
+++ b/app.go
@@ -168,46 +168,105 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht
return false
}
-func (a *theApp) serveContent(ww http.ResponseWriter, r *http.Request, https bool) {
- w := newLoggingResponseWriter(ww)
- defer w.Log(r)
+// observabilityMiddleware will provide observability (logging, metrics)
+// for each request
+func (a *theApp) observabilityMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) {
+ w := newLoggingResponseWriter(ww)
+ defer w.Log(r)
- metrics.SessionsActive.Inc()
- defer metrics.SessionsActive.Dec()
+ metrics.SessionsActive.Inc()
+ defer metrics.SessionsActive.Dec()
- host, domain := a.getHostAndDomain(r)
+ handler.ServeHTTP(&w, r)
- if a.AcmeMiddleware.ServeAcmeChallenges(&w, r, domain) {
- return
- }
+ metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc()
+ })
+}
- if a.Auth.TryAuthenticate(&w, r, a.dm, &a.lock) {
- return
- }
+// routingMiddleware will determine the host and domain for the request, for
+// downstream middlewares to use
+func (a *theApp) routingMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host, domain := a.getHostAndDomain(r)
- if a.tryAuxiliaryHandlers(&w, r, https, host, domain) {
- return
- }
+ r = request.WithHostAndDomain(r, host, domain)
+
+ handler.ServeHTTP(w, r)
+ })
+}
- // Only for projects that have access control enabled
- if domain.IsAccessControlEnabled(r) {
- if a.Auth.CheckAuthentication(&w, r, domain.GetID(r)) {
+// customHeadersMiddleware will inject custom headers into the response
+func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ headerConfig.AddCustomHeaders(w, a.CustomHeaders)
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// acmeMiddleware will handle ACME challenges
+func (a *theApp) acmeMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
+
+ if a.AcmeMiddleware.ServeAcmeChallenges(w, r, domain) {
return
}
- }
- // Serve static file, applying CORS headers if necessary
- if a.DisableCrossOriginRequests {
- a.serveFileOrNotFound(domain)(&w, r)
- } else {
- corsHandler.ServeHTTP(&w, r, a.serveFileOrNotFound(domain))
- }
+ handler.ServeHTTP(w, r)
+ })
+}
- metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc()
+// authMiddleware handles authentication requests
+func (a *theApp) authMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if a.Auth.TryAuthenticate(w, r, a.dm, &a.lock) {
+ return
+ }
+
+ handler.ServeHTTP(w, r)
+ })
}
-func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
+// auxiliaryMiddleware will handle status updates, not-ready requests and other
+// not static-content responses
+func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host := request.GetHost(r)
+ domain := request.GetDomain(r)
+ https := request.IsHTTPS(r)
+
+ if a.tryAuxiliaryHandlers(w, r, https, host, domain) {
+ return
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// accessControlMiddleware will handle authorization
+func (a *theApp) accessControlMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
+
+ // Only for projects that have access control enabled
+ if domain.IsAccessControlEnabled(r) {
+ // accessControlMiddleware
+ if a.Auth.CheckAuthentication(w, r, domain.GetID(r)) {
+ return
+ }
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// serveFileOrNotFoundHandler will serve static content or
+// return a 404 Not Found response
+func (a *theApp) serveFileOrNotFoundHandler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
fileServed := domain.ServeFileHTTP(w, r)
if !fileServed {
@@ -226,29 +285,49 @@ func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc {
domain.ServeNotFoundHTTP(w, r)
}
- }
+ })
}
-func (a *theApp) ServeHTTP(ww http.ResponseWriter, r *http.Request) {
- https := r.TLS != nil
- r = request.WithHTTPSFlag(r, https)
+// httpInitialMiddleware sets up HTTP requests
+func (a *theApp) httpInitialMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ https := r.TLS != nil
+ r = request.WithHTTPSFlag(r, https)
- headerConfig.AddCustomHeaders(ww, a.CustomHeaders)
-
- a.serveContent(ww, r, https)
+ handler.ServeHTTP(w, r)
+ })
}
-func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) {
- forwardedProto := r.Header.Get(xForwardedProto)
- https := forwardedProto == xForwardedProtoHTTPS
- r = request.WithHTTPSFlag(r, https)
+// proxyInitialMiddleware sets up proxy requests
+func (a *theApp) proxyInitialMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ forwardedProto := r.Header.Get(xForwardedProto)
+ https := forwardedProto == xForwardedProtoHTTPS
- if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" {
- r.Host = forwardedHost
- }
- headerConfig.AddCustomHeaders(ww, a.CustomHeaders)
+ r = request.WithHTTPSFlag(r, https)
+ if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" {
+ r.Host = forwardedHost
+ }
- a.serveContent(ww, r, https)
+ handler.ServeHTTP(w, r)
+ })
+}
+
+func (a *theApp) buildHandlerPipeline() http.Handler {
+ // Handlers should be applied in reverse order
+ handler := a.serveFileOrNotFoundHandler()
+ if !a.DisableCrossOriginRequests {
+ handler = corsHandler.Handler(handler)
+ }
+ handler = a.accessControlMiddleware(handler)
+ handler = a.auxiliaryMiddleware(handler)
+ handler = a.authMiddleware(handler)
+ handler = a.acmeMiddleware(handler)
+ handler = a.customHeadersMiddleware(handler)
+ handler = a.observabilityMiddleware(handler)
+ handler = a.routingMiddleware(handler)
+
+ return handler
}
func (a *theApp) UpdateDomains(dm domain.Map) {
@@ -262,12 +341,18 @@ func (a *theApp) Run() {
limiter := netutil.NewLimiter(a.MaxConns)
+ // Use a common pipeline to use a single instance of each handler,
+ // instead of making two nearly identical pipelines
+ commonHandlerPipeline := a.buildHandlerPipeline()
+ proxyHandler := a.proxyInitialMiddleware(commonHandlerPipeline)
+ httpHandler := a.httpInitialMiddleware(commonHandlerPipeline)
+
// Listen for HTTP
for _, fd := range a.ListenHTTP {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServe(fd, a.ServeHTTP, a.HTTP2, nil, limiter)
+ err := listenAndServe(fd, httpHandler, a.HTTP2, nil, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "http"))
}
@@ -279,7 +364,7 @@ func (a *theApp) Run() {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, a.ServeHTTP, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter)
+ err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "https"))
}
@@ -291,7 +376,7 @@ func (a *theApp) Run() {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServe(fd, a.ServeProxy, a.HTTP2, nil, limiter)
+ err := listenAndServe(fd, proxyHandler, a.HTTP2, nil, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "http proxy"))
}
@@ -304,7 +389,7 @@ func (a *theApp) Run() {
go func(fd uintptr) {
defer wg.Done()
- handler := promhttp.Handler().ServeHTTP
+ handler := promhttp.Handler()
err := listenAndServe(fd, handler, false, nil, nil)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "metrics"))