Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlessio Caiazza <acaiazza@gitlab.com>2019-08-23 15:54:16 +0300
committerAlessio Caiazza <acaiazza@gitlab.com>2019-08-23 15:54:16 +0300
commit119326917932e99dae8fceab3f86382ed4b61a07 (patch)
tree9e7a31d5ecaa3ea6dfc1ca60bd42552df9be6621
parent654c183c4b06c7deb3b947c7f557fe4a48f2e218 (diff)
parentc78ef2c684675b7b0685a78958860558149fae25 (diff)
Merge branch 'an-use-middleware-handlers' into 'master'
Refactor to use pluggable http.Handler middlewares See merge request gitlab-org/gitlab-pages!157
-rw-r--r--app.go181
-rw-r--r--internal/request/request.go25
-rw-r--r--internal/request/request_test.go52
-rw-r--r--server.go4
4 files changed, 207 insertions, 55 deletions
diff --git a/app.go b/app.go
index 3efc22c2..8ff2444e 100644
--- a/app.go
+++ b/app.go
@@ -168,46 +168,105 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht
return false
}
-func (a *theApp) serveContent(ww http.ResponseWriter, r *http.Request, https bool) {
- w := newLoggingResponseWriter(ww)
- defer w.Log(r)
+// observabilityMiddleware will provide observability (logging, metrics)
+// for each request
+func (a *theApp) observabilityMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) {
+ w := newLoggingResponseWriter(ww)
+ defer w.Log(r)
- metrics.SessionsActive.Inc()
- defer metrics.SessionsActive.Dec()
+ metrics.SessionsActive.Inc()
+ defer metrics.SessionsActive.Dec()
- host, domain := a.getHostAndDomain(r)
+ handler.ServeHTTP(&w, r)
- if a.AcmeMiddleware.ServeAcmeChallenges(&w, r, domain) {
- return
- }
+ metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc()
+ })
+}
- if a.Auth.TryAuthenticate(&w, r, a.dm, &a.lock) {
- return
- }
+// routingMiddleware will determine the host and domain for the request, for
+// downstream middlewares to use
+func (a *theApp) routingMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host, domain := a.getHostAndDomain(r)
- if a.tryAuxiliaryHandlers(&w, r, https, host, domain) {
- return
- }
+ r = request.WithHostAndDomain(r, host, domain)
+
+ handler.ServeHTTP(w, r)
+ })
+}
- // Only for projects that have access control enabled
- if domain.IsAccessControlEnabled(r) {
- if a.Auth.CheckAuthentication(&w, r, domain.GetID(r)) {
+// customHeadersMiddleware will inject custom headers into the response
+func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ headerConfig.AddCustomHeaders(w, a.CustomHeaders)
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// acmeMiddleware will handle ACME challenges
+func (a *theApp) acmeMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
+
+ if a.AcmeMiddleware.ServeAcmeChallenges(w, r, domain) {
return
}
- }
- // Serve static file, applying CORS headers if necessary
- if a.DisableCrossOriginRequests {
- a.serveFileOrNotFound(domain)(&w, r)
- } else {
- corsHandler.ServeHTTP(&w, r, a.serveFileOrNotFound(domain))
- }
+ handler.ServeHTTP(w, r)
+ })
+}
- metrics.ProcessedRequests.WithLabelValues(strconv.Itoa(w.status), r.Method).Inc()
+// authMiddleware handles authentication requests
+func (a *theApp) authMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if a.Auth.TryAuthenticate(w, r, a.dm, &a.lock) {
+ return
+ }
+
+ handler.ServeHTTP(w, r)
+ })
}
-func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
+// auxiliaryMiddleware will handle status updates, not-ready requests and other
+// not static-content responses
+func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ host := request.GetHost(r)
+ domain := request.GetDomain(r)
+ https := request.IsHTTPS(r)
+
+ if a.tryAuxiliaryHandlers(w, r, https, host, domain) {
+ return
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// accessControlMiddleware will handle authorization
+func (a *theApp) accessControlMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
+
+ // Only for projects that have access control enabled
+ if domain.IsAccessControlEnabled(r) {
+ // accessControlMiddleware
+ if a.Auth.CheckAuthentication(w, r, domain.GetID(r)) {
+ return
+ }
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
+
+// serveFileOrNotFoundHandler will serve static content or
+// return a 404 Not Found response
+func (a *theApp) serveFileOrNotFoundHandler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ domain := request.GetDomain(r)
fileServed := domain.ServeFileHTTP(w, r)
if !fileServed {
@@ -226,29 +285,49 @@ func (a *theApp) serveFileOrNotFound(domain *domain.D) http.HandlerFunc {
domain.ServeNotFoundHTTP(w, r)
}
- }
+ })
}
-func (a *theApp) ServeHTTP(ww http.ResponseWriter, r *http.Request) {
- https := r.TLS != nil
- r = request.WithHTTPSFlag(r, https)
+// httpInitialMiddleware sets up HTTP requests
+func (a *theApp) httpInitialMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ https := r.TLS != nil
+ r = request.WithHTTPSFlag(r, https)
- headerConfig.AddCustomHeaders(ww, a.CustomHeaders)
-
- a.serveContent(ww, r, https)
+ handler.ServeHTTP(w, r)
+ })
}
-func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) {
- forwardedProto := r.Header.Get(xForwardedProto)
- https := forwardedProto == xForwardedProtoHTTPS
- r = request.WithHTTPSFlag(r, https)
+// proxyInitialMiddleware sets up proxy requests
+func (a *theApp) proxyInitialMiddleware(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ forwardedProto := r.Header.Get(xForwardedProto)
+ https := forwardedProto == xForwardedProtoHTTPS
- if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" {
- r.Host = forwardedHost
- }
- headerConfig.AddCustomHeaders(ww, a.CustomHeaders)
+ r = request.WithHTTPSFlag(r, https)
+ if forwardedHost := r.Header.Get(xForwardedHost); forwardedHost != "" {
+ r.Host = forwardedHost
+ }
- a.serveContent(ww, r, https)
+ handler.ServeHTTP(w, r)
+ })
+}
+
+func (a *theApp) buildHandlerPipeline() http.Handler {
+ // Handlers should be applied in reverse order
+ handler := a.serveFileOrNotFoundHandler()
+ if !a.DisableCrossOriginRequests {
+ handler = corsHandler.Handler(handler)
+ }
+ handler = a.accessControlMiddleware(handler)
+ handler = a.auxiliaryMiddleware(handler)
+ handler = a.authMiddleware(handler)
+ handler = a.acmeMiddleware(handler)
+ handler = a.customHeadersMiddleware(handler)
+ handler = a.observabilityMiddleware(handler)
+ handler = a.routingMiddleware(handler)
+
+ return handler
}
func (a *theApp) UpdateDomains(dm domain.Map) {
@@ -262,12 +341,18 @@ func (a *theApp) Run() {
limiter := netutil.NewLimiter(a.MaxConns)
+ // Use a common pipeline to use a single instance of each handler,
+ // instead of making two nearly identical pipelines
+ commonHandlerPipeline := a.buildHandlerPipeline()
+ proxyHandler := a.proxyInitialMiddleware(commonHandlerPipeline)
+ httpHandler := a.httpInitialMiddleware(commonHandlerPipeline)
+
// Listen for HTTP
for _, fd := range a.ListenHTTP {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServe(fd, a.ServeHTTP, a.HTTP2, nil, limiter)
+ err := listenAndServe(fd, httpHandler, a.HTTP2, nil, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "http"))
}
@@ -279,7 +364,7 @@ func (a *theApp) Run() {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, a.ServeHTTP, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter)
+ err := listenAndServeTLS(fd, a.RootCertificate, a.RootKey, httpHandler, a.ServeTLS, a.InsecureCiphers, a.TLSMinVersion, a.TLSMaxVersion, a.HTTP2, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "https"))
}
@@ -291,7 +376,7 @@ func (a *theApp) Run() {
wg.Add(1)
go func(fd uintptr) {
defer wg.Done()
- err := listenAndServe(fd, a.ServeProxy, a.HTTP2, nil, limiter)
+ err := listenAndServe(fd, proxyHandler, a.HTTP2, nil, limiter)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "http proxy"))
}
@@ -304,7 +389,7 @@ func (a *theApp) Run() {
go func(fd uintptr) {
defer wg.Done()
- handler := promhttp.Handler().ServeHTTP
+ handler := promhttp.Handler()
err := listenAndServe(fd, handler, false, nil, nil)
if err != nil {
capturingFatal(err, errortracking.WithField("listener", "metrics"))
diff --git a/internal/request/request.go b/internal/request/request.go
index dad6af3d..730eb527 100644
--- a/internal/request/request.go
+++ b/internal/request/request.go
@@ -3,12 +3,16 @@ package request
import (
"context"
"net/http"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)
type ctxKey string
const (
- ctxHTTPSKey ctxKey = "https"
+ ctxHTTPSKey ctxKey = "https"
+ ctxHostKey ctxKey = "host"
+ ctxDomainKey ctxKey = "domain"
)
// WithHTTPSFlag saves https flag in request's context
@@ -22,3 +26,22 @@ func WithHTTPSFlag(r *http.Request, https bool) *http.Request {
func IsHTTPS(r *http.Request) bool {
return r.Context().Value(ctxHTTPSKey).(bool)
}
+
+// WithHostAndDomain saves host name and domain in the request's context
+func WithHostAndDomain(r *http.Request, host string, domain *domain.D) *http.Request {
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ctxHostKey, host)
+ ctx = context.WithValue(ctx, ctxDomainKey, domain)
+
+ return r.WithContext(ctx)
+}
+
+// GetHost extracts the host from request's context
+func GetHost(r *http.Request) string {
+ return r.Context().Value(ctxHostKey).(string)
+}
+
+// GetDomain extracts the domain from request's context
+func GetDomain(r *http.Request) *domain.D {
+ return r.Context().Value(ctxDomainKey).(*domain.D)
+}
diff --git a/internal/request/request_test.go b/internal/request/request_test.go
index 1f47ee2e..97e40ee4 100644
--- a/internal/request/request_test.go
+++ b/internal/request/request_test.go
@@ -6,19 +6,63 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)
func TestWithHTTPSFlag(t *testing.T) {
r, err := http.NewRequest("GET", "/", nil)
require.NoError(t, err)
- assert.Panics(t, func() {
- IsHTTPS(r)
- })
-
httpsRequest := WithHTTPSFlag(r, true)
assert.True(t, IsHTTPS(httpsRequest))
httpRequest := WithHTTPSFlag(r, false)
assert.False(t, IsHTTPS(httpRequest))
}
+
+func TestPanics(t *testing.T) {
+ r, err := http.NewRequest("GET", "/", nil)
+ require.NoError(t, err)
+
+ assert.Panics(t, func() {
+ IsHTTPS(r)
+ })
+
+ assert.Panics(t, func() {
+ GetHost(r)
+ })
+
+ assert.Panics(t, func() {
+ GetDomain(r)
+ })
+}
+
+func TestWithHostAndDomain(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ domain *domain.D
+ }{
+ {
+ name: "values",
+ host: "gitlab.com",
+ domain: &domain.D{},
+ },
+ {
+ name: "no_host",
+ host: "",
+ domain: &domain.D{},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r, err := http.NewRequest("GET", "/", nil)
+ require.NoError(t, err)
+
+ r = WithHostAndDomain(r, tt.host, tt.domain)
+ assert.Exactly(t, tt.domain, GetDomain(r))
+ assert.Equal(t, tt.host, GetHost(r))
+ })
+ }
+}
diff --git a/server.go b/server.go
index d42fd18f..64f8f5f9 100644
--- a/server.go
+++ b/server.go
@@ -37,7 +37,7 @@ func (ln *keepAliveListener) Accept() (net.Conn, error) {
return conn, nil
}
-func listenAndServe(fd uintptr, handler http.HandlerFunc, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter) error {
+func listenAndServe(fd uintptr, handler http.Handler, useHTTP2 bool, tlsConfig *tls.Config, limiter *netutil.Limiter) error {
// create server
server := &http.Server{Handler: context.ClearHandler(handler), TLSConfig: tlsConfig}
@@ -64,7 +64,7 @@ func listenAndServe(fd uintptr, handler http.HandlerFunc, useHTTP2 bool, tlsConf
return server.Serve(&keepAliveListener{l})
}
-func listenAndServeTLS(fd uintptr, cert, key []byte, handler http.HandlerFunc, getCertificate tlsconfig.GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16, useHTTP2 bool, limiter *netutil.Limiter) error {
+func listenAndServeTLS(fd uintptr, cert, key []byte, handler http.Handler, getCertificate tlsconfig.GetCertificateFunc, insecureCiphers bool, tlsMinVersion uint16, tlsMaxVersion uint16, useHTTP2 bool, limiter *netutil.Limiter) error {
tlsConfig, err := tlsconfig.Create(cert, key, getCertificate, insecureCiphers, tlsMinVersion, tlsMaxVersion)
if err != nil {
return err