diff options
author | Alessio Caiazza <acaiazza@gitlab.com> | 2019-08-23 15:54:16 +0300 |
---|---|---|
committer | Alessio Caiazza <acaiazza@gitlab.com> | 2019-08-23 15:54:16 +0300 |
commit | 119326917932e99dae8fceab3f86382ed4b61a07 (patch) | |
tree | 9e7a31d5ecaa3ea6dfc1ca60bd42552df9be6621 | |
parent | 654c183c4b06c7deb3b947c7f557fe4a48f2e218 (diff) | |
parent | c78ef2c684675b7b0685a78958860558149fae25 (diff) |
Merge branch 'an-use-middleware-handlers' into 'master'
Refactor to use pluggable http.Handler middlewares
See merge request gitlab-org/gitlab-pages!157
-rw-r--r-- | app.go | 181 | ||||
-rw-r--r-- | internal/request/request.go | 25 | ||||
-rw-r--r-- | internal/request/request_test.go | 52 | ||||
-rw-r--r-- | server.go | 4 |
4 files changed, 207 insertions, 55 deletions
@@ -168,46 +168,105 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht return false } -func (a *theApp) serveContent(ww http.ResponseWriter, r *http.Request, https bool) { - w := newLoggingResponseWriter(ww) - defer w.Log(r) +// observabilityMiddleware will provide observability (logging, metrics) +// for each request +func (a *theApp) observabilityMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) { + w := newLoggingResponseWriter(ww) + defer w.Log(r) - metrics.SessionsActive.Inc() - defer metrics.SessionsActive.Dec() + metrics.SessionsActive.Inc() + defer metrics.SessionsActive.Dec() - host, domain := a.getHostAndDomain(r) + handler.ServeHTTP(&w, r) - if a.AcmeMiddleware.ServeAcmeChallenges(&w, r, domain) { - return - } + metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc() + }) +} - if a.Auth.TryAuthenticate(&w, r, a.dm, &a.lock) { - return - } +// routingMiddleware will determine the host and domain for the request, for +// downstream middlewares to use +func (a *theApp) routingMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, domain := a.getHostAndDomain(r) - if a.tryAuxiliaryHandlers(&w, r, https, host, domain) { - return - } + r = request.WithHostAndDomain(r, host, domain) + + handler.ServeHTTP(w, r) + }) +} - // Only for projects that have access control enabled - if domain.IsAccessControlEnabled(r) { - if a.Auth.CheckAuthentication(&w, r, domain.GetID(r)) { +// customHeadersMiddleware will inject custom headers into the response +func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerConfig.AddCustomHeaders(w, a.CustomHeaders) + + handler.ServeHTTP(w, r) + }) +} + +// acmeMiddleware will handle ACME challenges +func (a *theApp) acmeMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + domain := request.GetDomain(r) + + if a.AcmeMiddleware.ServeAcmeChallenges(w, r, domain) { return } - } - // Serve static file, applying CORS headers if necessary - if a.DisableCrossOriginRequests { - a.serveFileOrNotFound(domain)(&w, r) - } else { - corsHandler.ServeHTTP(&w, r, a.serveFileOrNotFound(domain)) - } + handler.ServeHTTP(w, r) + }) +} - metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc() +// authMiddleware handles authentication requests +func (a *theApp) authMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.Auth.TryAuthenticate(w, r, a.dm, &a.lock) { + return + } + + handler.ServeHTTP(w, r) + }) } -func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +// auxiliaryMiddleware will handle status updates, not-ready requests and other +// not static-content responses +func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := request.GetHost(r) + domain := request.GetDomain(r) + https := request.IsHTTPS(r) + + if a.tryAuxiliaryHandlers(w, r, https, host, domain) { + return + } + + handler.ServeHTTP(w, r) + }) +} + +// accessControlMiddleware will handle authorization +func (a *theApp) accessControlMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + domain := request.GetDomain(r) + + // Only for projects that have access control enabled + if domain.IsAccessControlEnabled(r) { + // accessControlMiddleware + if a.Auth.CheckAuthentication(w, r, domain.GetID(r)) { + return + } + } + + handler.ServeHTTP(w, r) + }) +} + +// serveFileOrNotFoundHandler will serve static content or +// return a 404 Not Found response +func (a *theApp) serveFileOrNotFoundHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + domain := request.GetDomain(r) fileServed := domain.ServeFileHTTP(w, r) if !fileServed { @@ -226,29 +285,49 @@ func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc { domain.ServeNotFoundHTTP(w, r) } - } + }) } -func (a *theApp) ServeHTTP(ww http.ResponseWriter, r *http.Request) { - https := r.TLS != nil - r = request.WithHTTPSFlag(r, https) +// httpInitialMiddleware sets up HTTP requests +func (a *theApp) httpInitialMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + https := r.TLS != nil + r = request.WithHTTPSFlag(r, https) - headerConfig.AddCustomHeaders(ww, a.CustomHeaders) - - a.serveContent(ww, r, https) + handler.ServeHTTP(w, r) + }) } -func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) { - forwardedProto := r.Header.Get(xForwardedProto) - https := forwardedProto == xForwardedProtoHTTPS - r = request.WithHTTPSFlag(r, https) +// proxyInitialMiddleware sets up proxy requests +func (a *theApp) proxyInitialMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + forwardedProto := r.Header.Get(xForwardedProto) + https := forwardedProto == xForwardedProtoHTTPS - if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" { - r.Host = forwardedHost - } - headerConfig.AddCustomHeaders(ww, a.CustomHeaders) + r = request.WithHTTPSFlag(r, https) + if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" { + r.Host = forwardedHost + } - a.serveContent(ww, r, https) + handler.ServeHTTP(w, r) + }) +} + +func (a *theApp) buildHandlerPipeline() http.Handler { + // Handlers should be applied in reverse order + handler := a.serveFileOrNotFoundHandler() + if !a.DisableCrossOriginRequests { + handler = corsHandler.Handler(handler) + } + handler = a.accessControlMiddleware(handler) + handler = a.auxiliaryMiddleware(handler) + handler = a.authMiddleware(handler) + handler = a.acmeMiddleware(handler) + handler = a.customHeadersMiddleware(handler) + handler = a.observabilityMiddleware(handler) + handler = a.routingMiddleware(handler) + + return handler } func (a *theApp) UpdateDomains(dm domain.Map) { @@ -262,12 +341,18 @@ func (a *theApp) Run() { limiter := netutil.NewLimiter(a.MaxConns) + // Use a common pipeline to use a single instance of each handler, + // instead of making two nearly identical pipelines + commonHandlerPipeline := a.buildHandlerPipeline() + proxyHandler := a.proxyInitialMiddleware(commonHandlerPipeline) + httpHandler := a.httpInitialMiddleware(commonHandlerPipeline) + // Listen for HTTP for _, fd := range a.ListenHTTP { wg.Add(1) go func(fd uintptr) { defer wg.Done() - err := listenAndServe(fd, a.ServeHTTP, a.HTTP2, nil, limiter) + err := listenAndServe(fd, httpHandler, a.HTTP2, nil, limiter) if err != nil { capturingFatal(err, errortracking.WithField("listener", "http")) } @@ -279,7 +364,7 @@ func (a *theApp) Run() { wg.Add(1) go func(fd uintptr) { defer wg.Done() - err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, a.ServeHTTP, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter) + err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter) if err != nil { capturingFatal(err, errortracking.WithField("listener", "https")) } @@ -291,7 +376,7 @@ func (a *theApp) Run() { wg.Add(1) go func(fd uintptr) { defer wg.Done() - err := listenAndServe(fd, a.ServeProxy, a.HTTP2, nil, limiter) + err := listenAndServe(fd, proxyHandler, a.HTTP2, nil, limiter) if err != nil { capturingFatal(err, errortracking.WithField("listener", "http proxy")) } @@ -304,7 +389,7 @@ func (a *theApp) Run() { go func(fd uintptr) { defer wg.Done() - handler := promhttp.Handler().ServeHTTP + handler := promhttp.Handler() err := listenAndServe(fd, handler, false, nil, nil) if err != nil { capturingFatal(err, errortracking.WithField("listener", "metrics")) diff --git a/internal/request/request.go b/internal/request/request.go index dad6af3d..730eb527 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -3,12 +3,16 @@ package request import ( "context" "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) type ctxKey string const ( - ctxHTTPSKey ctxKey = "https" + ctxHTTPSKey ctxKey = "https" + ctxHostKey ctxKey = "host" + ctxDomainKey ctxKey = "domain" ) // WithHTTPSFlag saves https flag in request's context @@ -22,3 +26,22 @@ func WithHTTPSFlag(r *http.Request, https bool) *http.Request { func IsHTTPS(r *http.Request) bool { return r.Context().Value(ctxHTTPSKey).(bool) } + +// WithHostAndDomain saves host name and domain in the request's context +func WithHostAndDomain(r *http.Request, host string, domain *domain.D) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxHostKey, host) + ctx = context.WithValue(ctx, ctxDomainKey, domain) + + return r.WithContext(ctx) +} + +// GetHost extracts the host from request's context +func GetHost(r *http.Request) string { + return r.Context().Value(ctxHostKey).(string) +} + +// GetDomain extracts the domain from request's context +func GetDomain(r *http.Request) *domain.D { + return r.Context().Value(ctxDomainKey).(*domain.D) +} diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 1f47ee2e..97e40ee4 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -6,19 +6,63 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) func TestWithHTTPSFlag(t *testing.T) { r, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - assert.Panics(t, func() { - IsHTTPS(r) - }) - httpsRequest := WithHTTPSFlag(r, true) assert.True(t, IsHTTPS(httpsRequest)) httpRequest := WithHTTPSFlag(r, false) assert.False(t, IsHTTPS(httpRequest)) } + +func TestPanics(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + assert.Panics(t, func() { + IsHTTPS(r) + }) + + assert.Panics(t, func() { + GetHost(r) + }) + + assert.Panics(t, func() { + GetDomain(r) + }) +} + +func TestWithHostAndDomain(t *testing.T) { + tests := []struct { + name string + host string + domain *domain.D + }{ + { + name: "values", + host: "gitlab.com", + domain: &domain.D{}, + }, + { + name: "no_host", + host: "", + domain: &domain.D{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + r = WithHostAndDomain(r, tt.host, tt.domain) + assert.Exactly(t, tt.domain, GetDomain(r)) + assert.Equal(t, tt.host, GetHost(r)) + }) + } +} @@ -37,7 +37,7 @@ func (ln *keepAliveListener) Accept() (net.Conn, error) { return conn, nil } -func listenAndServe(fd uintptr, handler http.HandlerFunc, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter) error { +func listenAndServe(fd uintptr, handler http.Handler, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter) error { // create server server := &http.Server{Handler: context.ClearHandler(handler), TLSConfig: tlsConfig} @@ -64,7 +64,7 @@ func listenAndServe(fd uintptr, handler http.HandlerFunc, useHTTP2 bool, tlsConf return server.Serve(&keepAliveListener{l}) } -func listenAndServeTLS(fd uintptr, cert, key []byte, handler http.HandlerFunc, getCertificate tlsconfig.GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16, useHTTP2 bool, limiter *netutil.Limiter) error { +func listenAndServeTLS(fd uintptr, cert, key []byte, handler http.Handler, getCertificate tlsconfig.GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16, useHTTP2 bool, limiter *netutil.Limiter) error { tlsConfig, err := tlsconfig.Create(cert, key, getCertificate, insecureCiphers, tlsMinVersion, tlsMaxVersion) if err != nil { return err |