diff options
author | feistel <6742251-feistel@users.noreply.gitlab.com> | 2021-09-16 19:18:54 +0300 |
---|---|---|
committer | feistel <6742251-feistel@users.noreply.gitlab.com> | 2021-09-16 19:19:13 +0300 |
commit | 08d70aef345f1811f13cc990529f5114ccf3a92e (patch) | |
tree | 28aa7e492067703dd6b214a3d363292fd645ddca | |
parent | fecd9ca44bfc63e82f1cb2fde515b1e348678f7e (diff) |
refactor: move middlewares to corresponding packages
-rw-r--r-- | app.go | 94 | ||||
-rw-r--r-- | internal/acl/middleware.go | 25 | ||||
-rw-r--r-- | internal/acme/middleware.go | 20 | ||||
-rw-r--r-- | internal/auth/middleware.go | 18 | ||||
-rw-r--r-- | internal/customheaders/customheaders.go (renamed from internal/middleware/headers.go) | 2 | ||||
-rw-r--r-- | internal/customheaders/customheaders_test.go (renamed from internal/middleware/headers_test.go) | 13 | ||||
-rw-r--r-- | internal/customheaders/middleware.go | 14 | ||||
-rw-r--r-- | internal/routing/middleware.go | 41 |
8 files changed, 136 insertions, 91 deletions
@@ -21,19 +21,21 @@ import ( labmetrics "gitlab.com/gitlab-org/labkit/metrics" "gitlab.com/gitlab-org/labkit/monitoring" + "gitlab.com/gitlab-org/gitlab-pages/internal/acl" "gitlab.com/gitlab-org/gitlab-pages/internal/acme" "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" cfg "gitlab.com/gitlab-org/gitlab-pages/internal/config" "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" + "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/handlers" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" - "gitlab.com/gitlab-org/gitlab-pages/internal/middleware" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" + "gitlab.com/gitlab-org/gitlab-pages/internal/routing" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" "gitlab.com/gitlab-org/gitlab-pages/internal/source/gitlab" @@ -92,13 +94,6 @@ func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusC http.Redirect(w, r, u.String(), statusCode) } -func (a *theApp) getHostAndDomain(r *http.Request) (string, *domain.Domain, error) { - host := request.GetHostWithoutPort(r) - domain, err := a.domain(r.Context(), host) - - return host, domain, err -} - func (a *theApp) domain(ctx context.Context, host string) (*domain.Domain, error) { return a.source.GetDomain(ctx, host) } @@ -154,27 +149,6 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht return false } -// routingMiddleware will determine the host and domain for the request, for -// downstream middlewares to use -func (a *theApp) routingMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // if we could not retrieve a domain from domains source we break the - // middleware chain and simply respond with 502 after logging this - host, d, err := a.getHostAndDomain(r) - if err != nil && !errors.Is(err, domain.ErrDomainDoesNotExist) { - metrics.DomainsSourceFailures.Inc() - logging.LogRequest(r).WithError(err).Error("could not fetch domain information from a source") - - httperrors.Serve502(w) - return - } - - r = request.WithHostAndDomain(r, host, d) - - handler.ServeHTTP(w, r) - }) -} - // healthCheckMiddleware is serving the application status check func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, error) { healthCheck := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -196,39 +170,6 @@ func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, erro }), nil } -// customHeadersMiddleware will inject custom headers into the response -func (a *theApp) customHeadersMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - middleware.AddCustomHeaders(w, a.CustomHeaders) - - handler.ServeHTTP(w, r) - }) -} - -// acmeMiddleware will handle ACME challenges -func (a *theApp) acmeMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - domain := request.GetDomain(r) - - if a.AcmeMiddleware.ServeAcmeChallenges(w, r, domain) { - return - } - - handler.ServeHTTP(w, r) - }) -} - -// authMiddleware handles authentication requests -func (a *theApp) authMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if a.Auth.TryAuthenticate(w, r, a.source) { - return - } - - handler.ServeHTTP(w, r) - }) -} - // auxiliaryMiddleware will handle status updates, not-ready requests and other // not static-content responses func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { @@ -245,23 +186,6 @@ func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { }) } -// accessControlMiddleware will handle authorization -func (a *theApp) accessControlMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - domain := request.GetDomain(r) - - // Only for projects that have access control enabled - if domain.IsAccessControlEnabled(r) { - // accessControlMiddleware - if a.Auth.CheckAuthentication(w, r, domain) { - return - } - } - - handler.ServeHTTP(w, r) - }) -} - // serveFileOrNotFoundHandler will serve static content or // return a 404 Not Found response func (a *theApp) serveFileOrNotFoundHandler() http.Handler { @@ -324,10 +248,10 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { if !a.config.General.DisableCrossOriginRequests { handler = corsHandler.Handler(handler) } - handler = a.accessControlMiddleware(handler) + handler = acl.NewMiddleware(handler, a.Auth) handler = a.auxiliaryMiddleware(handler) - handler = a.authMiddleware(handler) - handler = a.acmeMiddleware(handler) + handler = auth.NewMiddleware(handler, a.Auth, a.source) + handler = acme.NewMiddleware(handler, a.AcmeMiddleware) handler, err := logging.AccessLogger(handler, a.config.Log.Format) if err != nil { return nil, err @@ -337,7 +261,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) handler = metricsMiddleware(handler) - handler = a.routingMiddleware(handler) + handler = routing.NewMiddleware(handler, a.source) // Health Check handler, err = a.healthCheckMiddleware(handler) @@ -346,7 +270,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { } // Custom response headers - handler = a.customHeadersMiddleware(handler) + handler = customheaders.NewMiddleware(handler, a.CustomHeaders) // Correlation ID injection middleware var correlationOpts []correlation.InboundHandlerOption @@ -521,7 +445,7 @@ func runApp(config *cfg.Config) { } if len(config.General.CustomHeaders) != 0 { - customHeaders, err := middleware.ParseHeaderString(config.General.CustomHeaders) + customHeaders, err := customheaders.ParseHeaderString(config.General.CustomHeaders) if err != nil { log.WithError(err).Fatal("Unable to parse header string") } diff --git a/internal/acl/middleware.go b/internal/acl/middleware.go new file mode 100644 index 00000000..119d0549 --- /dev/null +++ b/internal/acl/middleware.go @@ -0,0 +1,25 @@ +package acl + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/auth" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" +) + +// NewMiddleware returns middleware which handle authorization +func NewMiddleware(handler http.Handler, a *auth.Auth) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + domain := request.GetDomain(r) + + // Only for projects that have access control enabled + if domain.IsAccessControlEnabled(r) { + // accessControlMiddleware + if a.CheckAuthentication(w, r, domain) { + return + } + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/acme/middleware.go b/internal/acme/middleware.go new file mode 100644 index 00000000..fb6aa158 --- /dev/null +++ b/internal/acme/middleware.go @@ -0,0 +1,20 @@ +package acme + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/request" +) + +// NewMiddleware returns middleware which handle ACME challenges +func NewMiddleware(handler http.Handler, m *Middleware) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + domain := request.GetDomain(r) + + if m.ServeAcmeChallenges(w, r, domain) { + return + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 00000000..263e48b8 --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,18 @@ +package auth + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/source" +) + +// NewMiddleware returns middleware which handles authentication requests +func NewMiddleware(handler http.Handler, a *Auth, s source.Source) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.TryAuthenticate(w, r, s) { + return + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/middleware/headers.go b/internal/customheaders/customheaders.go index 837dbe3b..b54df585 100644 --- a/internal/middleware/headers.go +++ b/internal/customheaders/customheaders.go @@ -1,4 +1,4 @@ -package middleware +package customheaders import ( "errors" diff --git a/internal/middleware/headers_test.go b/internal/customheaders/customheaders_test.go index 1f3d98c6..a667f43a 100644 --- a/internal/middleware/headers_test.go +++ b/internal/customheaders/customheaders_test.go @@ -1,10 +1,12 @@ -package middleware +package customheaders_test import ( "net/http/httptest" "testing" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/customheaders" ) func TestParseHeaderString(t *testing.T) { @@ -80,7 +82,7 @@ func TestParseHeaderString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ParseHeaderString(tt.headerStrings) + got, err := customheaders.ParseHeaderString(tt.headerStrings) if tt.valid { require.NoError(t, err) require.Len(t, got, tt.expectedLen) @@ -131,12 +133,13 @@ func TestAddCustomHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - headers, err := ParseHeaderString(tt.headerStrings) + headers, err := customheaders.ParseHeaderString(tt.headerStrings) require.NoError(t, err) w := httptest.NewRecorder() - AddCustomHeaders(w, headers) + customheaders.AddCustomHeaders(w, headers) + rsp := w.Result() for k, v := range tt.wantHeaders { - require.Equal(t, v, w.HeaderMap.Get(k), "Expected header %+v, got %+v", v, w.HeaderMap.Get(k)) + require.Equal(t, v, rsp.Header.Get(k), "Expected header %+v, got %+v", v, rsp.Header.Get(k)) } }) } diff --git a/internal/customheaders/middleware.go b/internal/customheaders/middleware.go new file mode 100644 index 00000000..e0e11ad6 --- /dev/null +++ b/internal/customheaders/middleware.go @@ -0,0 +1,14 @@ +package customheaders + +import ( + "net/http" +) + +// NewMiddleware returns middleware which inject custom headers into the response +func NewMiddleware(handler http.Handler, headers http.Header) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + AddCustomHeaders(w, headers) + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/routing/middleware.go b/internal/routing/middleware.go new file mode 100644 index 00000000..5f065c61 --- /dev/null +++ b/internal/routing/middleware.go @@ -0,0 +1,41 @@ +package routing + +import ( + "errors" + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/domain" + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" + "gitlab.com/gitlab-org/gitlab-pages/internal/logging" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" + "gitlab.com/gitlab-org/gitlab-pages/internal/source" + "gitlab.com/gitlab-org/gitlab-pages/metrics" +) + +// NewMiddleware returns middleware which determine the host and domain for the request, for +// downstream middlewares to use +func NewMiddleware(handler http.Handler, s source.Source) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // if we could not retrieve a domain from domains source we break the + // middleware chain and simply respond with 502 after logging this + host, d, err := getHostAndDomain(r, s) + if err != nil && !errors.Is(err, domain.ErrDomainDoesNotExist) { + metrics.DomainsSourceFailures.Inc() + logging.LogRequest(r).WithError(err).Error("could not fetch domain information from a source") + + httperrors.Serve502(w) + return + } + + r = request.WithHostAndDomain(r, host, d) + + handler.ServeHTTP(w, r) + }) +} + +func getHostAndDomain(r *http.Request, s source.Source) (string, *domain.Domain, error) { + host := request.GetHostWithoutPort(r) + domain, err := s.GetDomain(r.Context(), host) + + return host, domain, err +} |