diff options
-rw-r--r-- | app.go | 10 | ||||
-rw-r--r-- | internal/domain/request.go | 11 | ||||
-rw-r--r-- | internal/domain/request_test.go | 7 | ||||
-rw-r--r-- | internal/handlers/artifact.go | 13 | ||||
-rw-r--r-- | internal/handlers/handlers.go | 5 | ||||
-rw-r--r-- | internal/handlers/handlers_test.go | 6 | ||||
-rw-r--r-- | internal/routing/middleware.go | 10 |
7 files changed, 30 insertions, 32 deletions
@@ -113,11 +113,7 @@ func (a *theApp) checkAuthAndServeNotFound(domain *domain.Domain, w http.Respons domain.ServeNotFoundAuthFailed(w, r) } -func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, https bool, host string, domain *domain.Domain) bool { - if a.Handlers.HandleArtifactRequest(host, w, r) { - return true - } - +func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, https bool, domain *domain.Domain) bool { if _, err := domain.GetLookupPath(r); err != nil { if errors.Is(err, gitlab.ErrDiskDisabled) { errortracking.CaptureErrWithReqAndStackTrace(err, r) @@ -142,11 +138,10 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht // not static-content responses func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := domain.GetHost(r) domain := domain.FromRequest(r) https := request.IsHTTPS(r) - if a.tryAuxiliaryHandlers(w, r, https, host, domain) { + if a.tryAuxiliaryHandlers(w, r, https, domain) { return } @@ -208,6 +203,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { } handler = a.Auth.AuthorizationMiddleware(handler) handler = a.auxiliaryMiddleware(handler) + handler = handlers.ArtifactMiddleware(handler, a.Handlers) handler = a.Auth.AuthenticationMiddleware(handler, a.source) handler = a.AcmeMiddleware.AcmeMiddleware(handler) diff --git a/internal/domain/request.go b/internal/domain/request.go index a10218ed..c6318922 100644 --- a/internal/domain/request.go +++ b/internal/domain/request.go @@ -8,24 +8,17 @@ import ( type ctxKey string const ( - ctxHostKey ctxKey = "host" ctxDomainKey ctxKey = "domain" ) -// ReqWithHostAndDomain saves host name and domain in the request's context -func ReqWithHostAndDomain(r *http.Request, host string, domain *Domain) *http.Request { +// ReqWithDomain saves domain in the request's context +func ReqWithDomain(r *http.Request, domain *Domain) *http.Request { ctx := r.Context() - ctx = context.WithValue(ctx, ctxHostKey, host) ctx = context.WithValue(ctx, ctxDomainKey, domain) return r.WithContext(ctx) } -// GetHost extracts the host from request's context -func GetHost(r *http.Request) string { - return r.Context().Value(ctxHostKey).(string) -} - // FromRequest extracts the domain from request's context func FromRequest(r *http.Request) *Domain { return r.Context().Value(ctxDomainKey).(*Domain) diff --git a/internal/domain/request_test.go b/internal/domain/request_test.go index c261d1f7..3ed82baa 100644 --- a/internal/domain/request_test.go +++ b/internal/domain/request_test.go @@ -12,10 +12,6 @@ func TestPanics(t *testing.T) { require.NoError(t, err) require.Panics(t, func() { - GetHost(r) - }) - - require.Panics(t, func() { FromRequest(r) }) } @@ -42,9 +38,8 @@ func TestWithHostAndDomain(t *testing.T) { r, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - r = ReqWithHostAndDomain(r, tt.host, tt.domain) + r = ReqWithDomain(r, tt.domain) require.Exactly(t, tt.domain, FromRequest(r)) - require.Equal(t, tt.host, GetHost(r)) }) } } diff --git a/internal/handlers/artifact.go b/internal/handlers/artifact.go new file mode 100644 index 00000000..994fde33 --- /dev/null +++ b/internal/handlers/artifact.go @@ -0,0 +1,13 @@ +package handlers + +import "net/http" + +func ArtifactMiddleware(handler http.Handler, h *Handlers) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if h.HandleArtifactRequest(w, r) { + return + } + + handler.ServeHTTP(w, r) + }) +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 22d49e4b..c91491f0 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -5,6 +5,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal" "gitlab.com/gitlab-org/gitlab-pages/internal/logging" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) // Handlers take care of handling specific requests @@ -52,7 +53,7 @@ func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r * } // HandleArtifactRequest handles all artifact related requests, will return true if request was handled here -func (a *Handlers) HandleArtifactRequest(host string, w http.ResponseWriter, r *http.Request) bool { +func (a *Handlers) HandleArtifactRequest(w http.ResponseWriter, r *http.Request) bool { // In the event a host is prefixed with the artifact prefix an artifact // value is created, and an attempt to proxy the request is made @@ -62,6 +63,8 @@ func (a *Handlers) HandleArtifactRequest(host string, w http.ResponseWriter, r * return true } + host := request.GetHostWithoutPort(r) + // nolint: bodyclose // false positive // a.checkIfLoginRequiredOrInvalidToken returns a response.Body, closing this body is responsibility // of the TryMakeRequest implementation diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go index e6dc95db..aedc077c 100644 --- a/internal/handlers/handlers_test.go +++ b/internal/handlers/handlers_test.go @@ -35,7 +35,7 @@ func TestNotHandleArtifactRequestReturnsFalse(t *testing.T) { require.NoError(t, err) r := &http.Request{URL: reqURL} - require.False(t, handlers.HandleArtifactRequest("host", result, r)) + require.False(t, handlers.HandleArtifactRequest(result, r)) } func TestHandleArtifactRequestedReturnsTrue(t *testing.T) { @@ -58,7 +58,7 @@ func TestHandleArtifactRequestedReturnsTrue(t *testing.T) { result := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/something", nil) - require.True(t, handlers.HandleArtifactRequest("host", result, r)) + require.True(t, handlers.HandleArtifactRequest(result, r)) } func TestNotFoundWithTokenIsNotHandled(t *testing.T) { @@ -186,5 +186,5 @@ func TestHandleArtifactRequestButGetTokenFails(t *testing.T) { result := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/something", nil) - require.True(t, handlers.HandleArtifactRequest("host", result, r)) + require.True(t, handlers.HandleArtifactRequest(result, r)) } diff --git a/internal/routing/middleware.go b/internal/routing/middleware.go index de34ec21..712e22a3 100644 --- a/internal/routing/middleware.go +++ b/internal/routing/middleware.go @@ -18,7 +18,7 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // if we could not retrieve a domain from domains source we break the // middleware chain and simply respond with 502 after logging this - host, d, err := getHostAndDomain(r, s) + d, err := getDomain(r, s) if err != nil && !errors.Is(err, domain.ErrDomainDoesNotExist) { metrics.DomainsSourceFailures.Inc() logging.LogRequest(r).WithError(err).Error("could not fetch domain information from a source") @@ -27,15 +27,13 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler { return } - r = domain.ReqWithHostAndDomain(r, host, d) + r = domain.ReqWithDomain(r, d) handler.ServeHTTP(w, r) }) } -func getHostAndDomain(r *http.Request, s source.Source) (string, *domain.Domain, error) { +func getDomain(r *http.Request, s source.Source) (*domain.Domain, error) { host := request.GetHostWithoutPort(r) - domain, err := s.GetDomain(r.Context(), host) - - return host, domain, err + return s.GetDomain(r.Context(), host) } |