Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfeistel <6742251-feistel@users.noreply.gitlab.com>2022-05-04 00:30:34 +0300
committerfeistel <6742251-feistel@users.noreply.gitlab.com>2022-05-05 13:57:09 +0300
commit5da5ac23e2757995661e3ac71c7a1011689e0ace (patch)
treeb94c3c93c6ea653cbccc003f80f35b9514593629
parent40528a5463f155f70d33dac3573199c0ba3e599d (diff)
Abstract artifact handling to a separate middleware
-rw-r--r--app.go10
-rw-r--r--internal/domain/request.go11
-rw-r--r--internal/domain/request_test.go7
-rw-r--r--internal/handlers/artifact.go13
-rw-r--r--internal/handlers/handlers.go5
-rw-r--r--internal/handlers/handlers_test.go6
-rw-r--r--internal/routing/middleware.go10
7 files changed, 30 insertions, 32 deletions
diff --git a/app.go b/app.go
index 31212460..2fa100fe 100644
--- a/app.go
+++ b/app.go
@@ -113,11 +113,7 @@ func (a *theApp) checkAuthAndServeNotFound(domain *domain.Domain, w http.Respons
domain.ServeNotFoundAuthFailed(w, r)
}
-func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, https bool, host string, domain *domain.Domain) bool {
- if a.Handlers.HandleArtifactRequest(host, w, r) {
- return true
- }
-
+func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, https bool, domain *domain.Domain) bool {
if _, err := domain.GetLookupPath(r); err != nil {
if errors.Is(err, gitlab.ErrDiskDisabled) {
errortracking.CaptureErrWithReqAndStackTrace(err, r)
@@ -142,11 +138,10 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht
// not static-content responses
func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- host := domain.GetHost(r)
domain := domain.FromRequest(r)
https := request.IsHTTPS(r)
- if a.tryAuxiliaryHandlers(w, r, https, host, domain) {
+ if a.tryAuxiliaryHandlers(w, r, https, domain) {
return
}
@@ -208,6 +203,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
}
handler = a.Auth.AuthorizationMiddleware(handler)
handler = a.auxiliaryMiddleware(handler)
+ handler = handlers.ArtifactMiddleware(handler, a.Handlers)
handler = a.Auth.AuthenticationMiddleware(handler, a.source)
handler = a.AcmeMiddleware.AcmeMiddleware(handler)
diff --git a/internal/domain/request.go b/internal/domain/request.go
index a10218ed..c6318922 100644
--- a/internal/domain/request.go
+++ b/internal/domain/request.go
@@ -8,24 +8,17 @@ import (
type ctxKey string
const (
- ctxHostKey ctxKey = "host"
ctxDomainKey ctxKey = "domain"
)
-// ReqWithHostAndDomain saves host name and domain in the request's context
-func ReqWithHostAndDomain(r *http.Request, host string, domain *Domain) *http.Request {
+// ReqWithDomain saves domain in the request's context
+func ReqWithDomain(r *http.Request, domain *Domain) *http.Request {
ctx := r.Context()
- ctx = context.WithValue(ctx, ctxHostKey, host)
ctx = context.WithValue(ctx, ctxDomainKey, domain)
return r.WithContext(ctx)
}
-// GetHost extracts the host from request's context
-func GetHost(r *http.Request) string {
- return r.Context().Value(ctxHostKey).(string)
-}
-
// FromRequest extracts the domain from request's context
func FromRequest(r *http.Request) *Domain {
return r.Context().Value(ctxDomainKey).(*Domain)
diff --git a/internal/domain/request_test.go b/internal/domain/request_test.go
index c261d1f7..3ed82baa 100644
--- a/internal/domain/request_test.go
+++ b/internal/domain/request_test.go
@@ -12,10 +12,6 @@ func TestPanics(t *testing.T) {
require.NoError(t, err)
require.Panics(t, func() {
- GetHost(r)
- })
-
- require.Panics(t, func() {
FromRequest(r)
})
}
@@ -42,9 +38,8 @@ func TestWithHostAndDomain(t *testing.T) {
r, err := http.NewRequest("GET", "/", nil)
require.NoError(t, err)
- r = ReqWithHostAndDomain(r, tt.host, tt.domain)
+ r = ReqWithDomain(r, tt.domain)
require.Exactly(t, tt.domain, FromRequest(r))
- require.Equal(t, tt.host, GetHost(r))
})
}
}
diff --git a/internal/handlers/artifact.go b/internal/handlers/artifact.go
new file mode 100644
index 00000000..994fde33
--- /dev/null
+++ b/internal/handlers/artifact.go
@@ -0,0 +1,13 @@
+package handlers
+
+import "net/http"
+
+func ArtifactMiddleware(handler http.Handler, h *Handlers) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if h.HandleArtifactRequest(w, r) {
+ return
+ }
+
+ handler.ServeHTTP(w, r)
+ })
+}
diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go
index 22d49e4b..c91491f0 100644
--- a/internal/handlers/handlers.go
+++ b/internal/handlers/handlers.go
@@ -5,6 +5,7 @@ import (
"gitlab.com/gitlab-org/gitlab-pages/internal"
"gitlab.com/gitlab-org/gitlab-pages/internal/logging"
+ "gitlab.com/gitlab-org/gitlab-pages/internal/request"
)
// Handlers take care of handling specific requests
@@ -52,7 +53,7 @@ func (a *Handlers) checkIfLoginRequiredOrInvalidToken(w http.ResponseWriter, r *
}
// HandleArtifactRequest handles all artifact related requests, will return true if request was handled here
-func (a *Handlers) HandleArtifactRequest(host string, w http.ResponseWriter, r *http.Request) bool {
+func (a *Handlers) HandleArtifactRequest(w http.ResponseWriter, r *http.Request) bool {
// In the event a host is prefixed with the artifact prefix an artifact
// value is created, and an attempt to proxy the request is made
@@ -62,6 +63,8 @@ func (a *Handlers) HandleArtifactRequest(host string, w http.ResponseWriter, r *
return true
}
+ host := request.GetHostWithoutPort(r)
+
// nolint: bodyclose // false positive
// a.checkIfLoginRequiredOrInvalidToken returns a response.Body, closing this body is responsibility
// of the TryMakeRequest implementation
diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go
index e6dc95db..aedc077c 100644
--- a/internal/handlers/handlers_test.go
+++ b/internal/handlers/handlers_test.go
@@ -35,7 +35,7 @@ func TestNotHandleArtifactRequestReturnsFalse(t *testing.T) {
require.NoError(t, err)
r := &http.Request{URL: reqURL}
- require.False(t, handlers.HandleArtifactRequest("host", result, r))
+ require.False(t, handlers.HandleArtifactRequest(result, r))
}
func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
@@ -58,7 +58,7 @@ func TestHandleArtifactRequestedReturnsTrue(t *testing.T) {
result := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/something", nil)
- require.True(t, handlers.HandleArtifactRequest("host", result, r))
+ require.True(t, handlers.HandleArtifactRequest(result, r))
}
func TestNotFoundWithTokenIsNotHandled(t *testing.T) {
@@ -186,5 +186,5 @@ func TestHandleArtifactRequestButGetTokenFails(t *testing.T) {
result := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/something", nil)
- require.True(t, handlers.HandleArtifactRequest("host", result, r))
+ require.True(t, handlers.HandleArtifactRequest(result, r))
}
diff --git a/internal/routing/middleware.go b/internal/routing/middleware.go
index de34ec21..712e22a3 100644
--- a/internal/routing/middleware.go
+++ b/internal/routing/middleware.go
@@ -18,7 +18,7 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// if we could not retrieve a domain from domains source we break the
// middleware chain and simply respond with 502 after logging this
- host, d, err := getHostAndDomain(r, s)
+ d, err := getDomain(r, s)
if err != nil && !errors.Is(err, domain.ErrDomainDoesNotExist) {
metrics.DomainsSourceFailures.Inc()
logging.LogRequest(r).WithError(err).Error("could not fetch domain information from a source")
@@ -27,15 +27,13 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler {
return
}
- r = domain.ReqWithHostAndDomain(r, host, d)
+ r = domain.ReqWithDomain(r, d)
handler.ServeHTTP(w, r)
})
}
-func getHostAndDomain(r *http.Request, s source.Source) (string, *domain.Domain, error) {
+func getDomain(r *http.Request, s source.Source) (*domain.Domain, error) {
host := request.GetHostWithoutPort(r)
- domain, err := s.GetDomain(r.Context(), host)
-
- return host, domain, err
+ return s.GetDomain(r.Context(), host)
}