diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-10-25 03:50:40 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-10-25 03:59:58 +0300 |
commit | abd1612d0abccbbb04188df15b008118d3d62115 (patch) | |
tree | a68586b36b8cae8f23270fc1bdbb74a69e8f7131 | |
parent | 1a1adbaf869ed651aae3671c507f726a64b1d7ca (diff) |
refactor: remove domain from requestmove-domain-out-of-req
-rw-r--r-- | app.go | 8 | ||||
-rw-r--r-- | internal/acme/middleware.go | 6 | ||||
-rw-r--r-- | internal/auth/middleware.go | 4 | ||||
-rw-r--r-- | internal/domain/logging.go | 23 | ||||
-rw-r--r-- | internal/domain/logging_test.go | 9 | ||||
-rw-r--r-- | internal/domain/request.go | 32 | ||||
-rw-r--r-- | internal/domain/request_test.go | 50 | ||||
-rw-r--r-- | internal/logging/logging.go | 23 | ||||
-rw-r--r-- | internal/logging/logging_test.go | 4 | ||||
-rw-r--r-- | internal/mocks/source.go | 8 | ||||
-rw-r--r-- | internal/request/request.go | 27 | ||||
-rw-r--r-- | internal/request/request_test.go | 44 | ||||
-rw-r--r-- | internal/routing/middleware.go | 2 |
13 files changed, 124 insertions, 116 deletions
@@ -174,8 +174,8 @@ func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, erro // not static-content responses func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := request.GetHost(r) - domain := request.GetDomain(r) + host := domain.GetHost(r) + domain := domain.FromRequest(r) https := request.IsHTTPS(r) if a.tryAuxiliaryHandlers(w, r, https, host, domain) { @@ -193,7 +193,7 @@ func (a *theApp) serveFileOrNotFoundHandler() http.Handler { start := time.Now() defer metrics.ServingTime.Observe(time.Since(start).Seconds()) - domain := request.GetDomain(r) + domain := domain.FromRequest(r) fileServed := domain.ServeFileHTTP(w, r) if !fileServed { @@ -252,7 +252,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = a.auxiliaryMiddleware(handler) handler = a.Auth.AuthenticationMiddleware(handler, a.source) handler = a.AcmeMiddleware.AcmeMiddleware(handler) - handler, err := logging.AccessLogger(handler, a.config.Log.Format) + handler, err := logging.BasicAccessLogger(handler, a.config.Log.Format, domain.LogFields) if err != nil { return nil, err } diff --git a/internal/acme/middleware.go b/internal/acme/middleware.go index fa9d696c..faa2a017 100644 --- a/internal/acme/middleware.go +++ b/internal/acme/middleware.go @@ -2,14 +2,14 @@ package acme import ( "net/http" - - "gitlab.com/gitlab-org/gitlab-pages/internal/request" + // TODO: break this dependency too + d "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) // AcmeMiddleware handles ACME challenges func (m *Middleware) AcmeMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - domain := request.GetDomain(r) + domain := d.FromRequest(r) if m.ServeAcmeChallenges(w, r, domain) { return diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index a9cb6c56..01559c10 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -3,7 +3,7 @@ package auth import ( "net/http" - "gitlab.com/gitlab-org/gitlab-pages/internal/request" + d "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) @@ -21,7 +21,7 @@ func (a *Auth) AuthenticationMiddleware(handler http.Handler, s source.Source) h // AuthorizationMiddleware handles authorization func (a *Auth) AuthorizationMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - domain := request.GetDomain(r) + domain := d.FromRequest(r) // Only for projects that have access control enabled if domain.IsAccessControlEnabled(r) { diff --git a/internal/domain/logging.go b/internal/domain/logging.go index ae2b44e3..2ecddf6c 100644 --- a/internal/domain/logging.go +++ b/internal/domain/logging.go @@ -6,19 +6,26 @@ import ( "gitlab.com/gitlab-org/labkit/log" ) -func (d *Domain) LogFields(r *http.Request) log.Fields { +func LogFields(r *http.Request) log.Fields { + logFields := log.Fields{ + "pages_https": r.URL.Scheme == "https", + "pages_host": GetHost(r), + } + + d := FromRequest(r) if d == nil { - return log.Fields{} + return logFields } lp, err := d.GetLookupPath(r) if err != nil { - return log.Fields{"error": err.Error()} + logFields["error"] = err.Error() + return logFields } - return log.Fields{ - "pages_project_serving_type": lp.ServingType, - "pages_project_prefix": lp.Prefix, - "pages_project_id": lp.ProjectID, - } + logFields["pages_project_serving_type"] = lp.ServingType + logFields["pages_project_prefix"] = lp.Prefix + logFields["pages_project_id"] = lp.ProjectID + + return logFields } diff --git a/internal/domain/logging_test.go b/internal/domain/logging_test.go index 2a1902c2..b736758c 100644 --- a/internal/domain/logging_test.go +++ b/internal/domain/logging_test.go @@ -40,17 +40,19 @@ func TestDomainLogFields(t *testing.T) { "nil_domain_returns_empty_fields": { domain: nil, host: "gitlab.io", - expectedFields: log.Fields{}, + expectedFields: log.Fields{"pages_https": false, "pages_host": "gitlab.io"}, }, "unresolved_domain_returns_error": { domain: New("githost.io", "", "", &resolver{err: ErrDomainDoesNotExist}), host: "gitlab.io", - expectedFields: log.Fields{"error": ErrDomainDoesNotExist.Error()}, + expectedFields: log.Fields{"error": ErrDomainDoesNotExist.Error(), "pages_https": false, "pages_host": "gitlab.io"}, }, "domain_with_fields": { domain: domainWithResolver, host: "gitlab.io", expectedFields: log.Fields{ + "pages_https": false, + "pages_host": "gitlab.io", "pages_project_id": uint64(100), "pages_project_prefix": "/prefix", "pages_project_serving_type": "file", @@ -62,7 +64,8 @@ func TestDomainLogFields(t *testing.T) { r, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - require.Equal(t, tt.expectedFields, tt.domain.LogFields(r)) + r = ReqWithHostAndDomain(r, tt.host, tt.domain) + require.Equal(t, tt.expectedFields, LogFields(r)) }) } } diff --git a/internal/domain/request.go b/internal/domain/request.go new file mode 100644 index 00000000..a10218ed --- /dev/null +++ b/internal/domain/request.go @@ -0,0 +1,32 @@ +package domain + +import ( + "context" + "net/http" +) + +type ctxKey string + +const ( + ctxHostKey ctxKey = "host" + ctxDomainKey ctxKey = "domain" +) + +// ReqWithHostAndDomain saves host name and domain in the request's context +func ReqWithHostAndDomain(r *http.Request, host string, domain *Domain) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxHostKey, host) + ctx = context.WithValue(ctx, ctxDomainKey, domain) + + return r.WithContext(ctx) +} + +// GetHost extracts the host from request's context +func GetHost(r *http.Request) string { + return r.Context().Value(ctxHostKey).(string) +} + +// FromRequest extracts the domain from request's context +func FromRequest(r *http.Request) *Domain { + return r.Context().Value(ctxDomainKey).(*Domain) +} diff --git a/internal/domain/request_test.go b/internal/domain/request_test.go new file mode 100644 index 00000000..c261d1f7 --- /dev/null +++ b/internal/domain/request_test.go @@ -0,0 +1,50 @@ +package domain + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPanics(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + require.Panics(t, func() { + GetHost(r) + }) + + require.Panics(t, func() { + FromRequest(r) + }) +} + +func TestWithHostAndDomain(t *testing.T) { + tests := []struct { + name string + host string + domain *Domain + }{ + { + name: "values", + host: "gitlab.com", + domain: &Domain{}, + }, + { + name: "no_host", + host: "", + domain: &Domain{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + r = ReqWithHostAndDomain(r, tt.host, tt.domain) + require.Exactly(t, tt.domain, FromRequest(r)) + require.Equal(t, tt.host, GetHost(r)) + }) + } +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index f485b20b..d2d758db 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -51,21 +51,6 @@ func getAccessLogger(format string) (*logrus.Logger, error) { return accessLogger, nil } -// getExtraLogFields is used to inject additional fields into the -// HTTP access logger middleware. -func getExtraLogFields(r *http.Request) log.Fields { - logFields := log.Fields{ - "pages_https": request.IsHTTPS(r), - "pages_host": request.GetHost(r), - } - - for name, value := range request.GetDomain(r).LogFields(r) { - logFields[name] = value - } - - return logFields -} - // BasicAccessLogger configures the GitLab pages basic HTTP access logger middleware func BasicAccessLogger(handler http.Handler, format string, extraFields log.ExtraFieldsGeneratorFunc) (http.Handler, error) { accessLogger, err := getAccessLogger(format) @@ -98,10 +83,10 @@ func enrichExtraFields(extraFields log.ExtraFieldsGeneratorFunc) log.ExtraFields } } -// AccessLogger configures the GitLab pages HTTP access logger middleware with extra log fields -func AccessLogger(handler http.Handler, format string) (http.Handler, error) { - return BasicAccessLogger(handler, format, getExtraLogFields) -} +//// AccessLogger configures the GitLab pages HTTP access logger middleware with extra log fields +//func AccessLogger(handler http.Handler, format string) (http.Handler, error) { +// return BasicAccessLogger(handler, format, getExtraLogFields) +//} // LogRequest will inject request host and path to the logged messages func LogRequest(r *http.Request) *logrus.Entry { diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index df8c3013..fe157e88 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -72,9 +72,9 @@ func TestGetExtraLogFields(t *testing.T) { require.NoError(t, err) req.URL.Scheme = tt.scheme - req = request.WithHostAndDomain(req, tt.host, domainWithResolver) + req = domain.ReqWithHostAndDomain(req, tt.host, domainWithResolver) - got := getExtraLogFields(req) + got := domain.LogFields(req) require.Equal(t, tt.expectedHTTPS, got["pages_https"]) require.Equal(t, tt.expectedHost, got["pages_host"]) require.Equal(t, tt.expectedProjectID, got["pages_project_id"]) diff --git a/internal/mocks/source.go b/internal/mocks/source.go index c6cc6216..fe8c3198 100644 --- a/internal/mocks/source.go +++ b/internal/mocks/source.go @@ -6,9 +6,11 @@ package mocks import ( context "context" + reflect "reflect" + gomock "github.com/golang/mock/gomock" + domain "gitlab.com/gitlab-org/gitlab-pages/internal/domain" - reflect "reflect" ) // MockSource is a mock of Source interface @@ -37,7 +39,7 @@ func (m *MockSource) EXPECT() *MockSourceMockRecorder { // GetDomain mocks base method func (m *MockSource) GetDomain(arg0 context.Context, arg1 string) (*domain.Domain, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDomain", arg0, arg1) + ret := m.ctrl.Call(m, "FromRequest", arg0, arg1) ret0, _ := ret[0].(*domain.Domain) ret1, _ := ret[1].(error) return ret0, ret1 @@ -46,5 +48,5 @@ func (m *MockSource) GetDomain(arg0 context.Context, arg1 string) (*domain.Domai // GetDomain indicates an expected call of GetDomain func (mr *MockSourceMockRecorder) GetDomain(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomain", reflect.TypeOf((*MockSource)(nil).GetDomain), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromRequest", reflect.TypeOf((*MockSource)(nil).GetDomain), arg0, arg1) } diff --git a/internal/request/request.go b/internal/request/request.go index 77cc4a76..f98b0819 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -1,19 +1,11 @@ package request import ( - "context" "net" "net/http" - - "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) -type ctxKey string - const ( - ctxHostKey ctxKey = "host" - ctxDomainKey ctxKey = "domain" - // SchemeHTTP name for the HTTP scheme SchemeHTTP = "http" // SchemeHTTPS name for the HTTPS scheme @@ -26,25 +18,6 @@ func IsHTTPS(r *http.Request) bool { return r.URL.Scheme == SchemeHTTPS } -// WithHostAndDomain saves host name and domain in the request's context -func WithHostAndDomain(r *http.Request, host string, domain *domain.Domain) *http.Request { - ctx := r.Context() - ctx = context.WithValue(ctx, ctxHostKey, host) - ctx = context.WithValue(ctx, ctxDomainKey, domain) - - return r.WithContext(ctx) -} - -// GetHost extracts the host from request's context -func GetHost(r *http.Request) string { - return r.Context().Value(ctxHostKey).(string) -} - -// GetDomain extracts the domain from request's context -func GetDomain(r *http.Request) *domain.Domain { - return r.Context().Value(ctxDomainKey).(*domain.Domain) -} - // GetHostWithoutPort returns a host without the port. The host(:port) comes // from a Host: header if it is provided, otherwise it is a server name. func GetHostWithoutPort(r *http.Request) string { diff --git a/internal/request/request_test.go b/internal/request/request_test.go index a9ffb223..0455fc55 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -6,8 +6,6 @@ import ( "testing" "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) func TestIsHTTPS(t *testing.T) { @@ -26,48 +24,6 @@ func TestIsHTTPS(t *testing.T) { }) } -func TestPanics(t *testing.T) { - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - - require.Panics(t, func() { - GetHost(r) - }) - - require.Panics(t, func() { - GetDomain(r) - }) -} - -func TestWithHostAndDomain(t *testing.T) { - tests := []struct { - name string - host string - domain *domain.Domain - }{ - { - name: "values", - host: "gitlab.com", - domain: &domain.Domain{}, - }, - { - name: "no_host", - host: "", - domain: &domain.Domain{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r, err := http.NewRequest("GET", "/", nil) - require.NoError(t, err) - - r = WithHostAndDomain(r, tt.host, tt.domain) - require.Exactly(t, tt.domain, GetDomain(r)) - require.Equal(t, tt.host, GetHost(r)) - }) - } -} - func TestGetHostWithoutPort(t *testing.T) { t.Run("when port component is provided", func(t *testing.T) { request := httptest.NewRequest("GET", "https://example.com:443", nil) diff --git a/internal/routing/middleware.go b/internal/routing/middleware.go index 5f065c61..de34ec21 100644 --- a/internal/routing/middleware.go +++ b/internal/routing/middleware.go @@ -27,7 +27,7 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler { return } - r = request.WithHostAndDomain(r, host, d) + r = domain.ReqWithHostAndDomain(r, host, d) handler.ServeHTTP(w, r) }) |