Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJaime Martinez <jmartinez@gitlab.com>2021-10-25 03:50:40 +0300
committerJaime Martinez <jmartinez@gitlab.com>2021-10-25 03:59:58 +0300
commitabd1612d0abccbbb04188df15b008118d3d62115 (patch)
treea68586b36b8cae8f23270fc1bdbb74a69e8f7131
parent1a1adbaf869ed651aae3671c507f726a64b1d7ca (diff)
refactor: remove domain from requestmove-domain-out-of-req
-rw-r--r--app.go8
-rw-r--r--internal/acme/middleware.go6
-rw-r--r--internal/auth/middleware.go4
-rw-r--r--internal/domain/logging.go23
-rw-r--r--internal/domain/logging_test.go9
-rw-r--r--internal/domain/request.go32
-rw-r--r--internal/domain/request_test.go50
-rw-r--r--internal/logging/logging.go23
-rw-r--r--internal/logging/logging_test.go4
-rw-r--r--internal/mocks/source.go8
-rw-r--r--internal/request/request.go27
-rw-r--r--internal/request/request_test.go44
-rw-r--r--internal/routing/middleware.go2
13 files changed, 124 insertions, 116 deletions
diff --git a/app.go b/app.go
index eba3c690..cddf8d49 100644
--- a/app.go
+++ b/app.go
@@ -174,8 +174,8 @@ func (a *theApp) healthCheckMiddleware(handler http.Handler) (http.Handler, erro
// not static-content responses
func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- host := request.GetHost(r)
- domain := request.GetDomain(r)
+ host := domain.GetHost(r)
+ domain := domain.FromRequest(r)
https := request.IsHTTPS(r)
if a.tryAuxiliaryHandlers(w, r, https, host, domain) {
@@ -193,7 +193,7 @@ func (a *theApp) serveFileOrNotFoundHandler() http.Handler {
start := time.Now()
defer metrics.ServingTime.Observe(time.Since(start).Seconds())
- domain := request.GetDomain(r)
+ domain := domain.FromRequest(r)
fileServed := domain.ServeFileHTTP(w, r)
if !fileServed {
@@ -252,7 +252,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) {
handler = a.auxiliaryMiddleware(handler)
handler = a.Auth.AuthenticationMiddleware(handler, a.source)
handler = a.AcmeMiddleware.AcmeMiddleware(handler)
- handler, err := logging.AccessLogger(handler, a.config.Log.Format)
+ handler, err := logging.BasicAccessLogger(handler, a.config.Log.Format, domain.LogFields)
if err != nil {
return nil, err
}
diff --git a/internal/acme/middleware.go b/internal/acme/middleware.go
index fa9d696c..faa2a017 100644
--- a/internal/acme/middleware.go
+++ b/internal/acme/middleware.go
@@ -2,14 +2,14 @@ package acme
import (
"net/http"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/request"
+ // TODO: break this dependency too
+ d "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)
// AcmeMiddleware handles ACME challenges
func (m *Middleware) AcmeMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- domain := request.GetDomain(r)
+ domain := d.FromRequest(r)
if m.ServeAcmeChallenges(w, r, domain) {
return
diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go
index a9cb6c56..01559c10 100644
--- a/internal/auth/middleware.go
+++ b/internal/auth/middleware.go
@@ -3,7 +3,7 @@ package auth
import (
"net/http"
- "gitlab.com/gitlab-org/gitlab-pages/internal/request"
+ d "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
"gitlab.com/gitlab-org/gitlab-pages/internal/source"
)
@@ -21,7 +21,7 @@ func (a *Auth) AuthenticationMiddleware(handler http.Handler, s source.Source) h
// AuthorizationMiddleware handles authorization
func (a *Auth) AuthorizationMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- domain := request.GetDomain(r)
+ domain := d.FromRequest(r)
// Only for projects that have access control enabled
if domain.IsAccessControlEnabled(r) {
diff --git a/internal/domain/logging.go b/internal/domain/logging.go
index ae2b44e3..2ecddf6c 100644
--- a/internal/domain/logging.go
+++ b/internal/domain/logging.go
@@ -6,19 +6,26 @@ import (
"gitlab.com/gitlab-org/labkit/log"
)
-func (d *Domain) LogFields(r *http.Request) log.Fields {
+func LogFields(r *http.Request) log.Fields {
+ logFields := log.Fields{
+ "pages_https": r.URL.Scheme == "https",
+ "pages_host": GetHost(r),
+ }
+
+ d := FromRequest(r)
if d == nil {
- return log.Fields{}
+ return logFields
}
lp, err := d.GetLookupPath(r)
if err != nil {
- return log.Fields{"error": err.Error()}
+ logFields["error"] = err.Error()
+ return logFields
}
- return log.Fields{
- "pages_project_serving_type": lp.ServingType,
- "pages_project_prefix": lp.Prefix,
- "pages_project_id": lp.ProjectID,
- }
+ logFields["pages_project_serving_type"] = lp.ServingType
+ logFields["pages_project_prefix"] = lp.Prefix
+ logFields["pages_project_id"] = lp.ProjectID
+
+ return logFields
}
diff --git a/internal/domain/logging_test.go b/internal/domain/logging_test.go
index 2a1902c2..b736758c 100644
--- a/internal/domain/logging_test.go
+++ b/internal/domain/logging_test.go
@@ -40,17 +40,19 @@ func TestDomainLogFields(t *testing.T) {
"nil_domain_returns_empty_fields": {
domain: nil,
host: "gitlab.io",
- expectedFields: log.Fields{},
+ expectedFields: log.Fields{"pages_https": false, "pages_host": "gitlab.io"},
},
"unresolved_domain_returns_error": {
domain: New("githost.io", "", "", &resolver{err: ErrDomainDoesNotExist}),
host: "gitlab.io",
- expectedFields: log.Fields{"error": ErrDomainDoesNotExist.Error()},
+ expectedFields: log.Fields{"error": ErrDomainDoesNotExist.Error(), "pages_https": false, "pages_host": "gitlab.io"},
},
"domain_with_fields": {
domain: domainWithResolver,
host: "gitlab.io",
expectedFields: log.Fields{
+ "pages_https": false,
+ "pages_host": "gitlab.io",
"pages_project_id": uint64(100),
"pages_project_prefix": "/prefix",
"pages_project_serving_type": "file",
@@ -62,7 +64,8 @@ func TestDomainLogFields(t *testing.T) {
r, err := http.NewRequest("GET", "/", nil)
require.NoError(t, err)
- require.Equal(t, tt.expectedFields, tt.domain.LogFields(r))
+ r = ReqWithHostAndDomain(r, tt.host, tt.domain)
+ require.Equal(t, tt.expectedFields, LogFields(r))
})
}
}
diff --git a/internal/domain/request.go b/internal/domain/request.go
new file mode 100644
index 00000000..a10218ed
--- /dev/null
+++ b/internal/domain/request.go
@@ -0,0 +1,32 @@
+package domain
+
+import (
+ "context"
+ "net/http"
+)
+
+type ctxKey string
+
+const (
+ ctxHostKey ctxKey = "host"
+ ctxDomainKey ctxKey = "domain"
+)
+
+// ReqWithHostAndDomain saves host name and domain in the request's context
+func ReqWithHostAndDomain(r *http.Request, host string, domain *Domain) *http.Request {
+ ctx := r.Context()
+ ctx = context.WithValue(ctx, ctxHostKey, host)
+ ctx = context.WithValue(ctx, ctxDomainKey, domain)
+
+ return r.WithContext(ctx)
+}
+
+// GetHost extracts the host from request's context
+func GetHost(r *http.Request) string {
+ return r.Context().Value(ctxHostKey).(string)
+}
+
+// FromRequest extracts the domain from request's context
+func FromRequest(r *http.Request) *Domain {
+ return r.Context().Value(ctxDomainKey).(*Domain)
+}
diff --git a/internal/domain/request_test.go b/internal/domain/request_test.go
new file mode 100644
index 00000000..c261d1f7
--- /dev/null
+++ b/internal/domain/request_test.go
@@ -0,0 +1,50 @@
+package domain
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestPanics(t *testing.T) {
+ r, err := http.NewRequest("GET", "/", nil)
+ require.NoError(t, err)
+
+ require.Panics(t, func() {
+ GetHost(r)
+ })
+
+ require.Panics(t, func() {
+ FromRequest(r)
+ })
+}
+
+func TestWithHostAndDomain(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ domain *Domain
+ }{
+ {
+ name: "values",
+ host: "gitlab.com",
+ domain: &Domain{},
+ },
+ {
+ name: "no_host",
+ host: "",
+ domain: &Domain{},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r, err := http.NewRequest("GET", "/", nil)
+ require.NoError(t, err)
+
+ r = ReqWithHostAndDomain(r, tt.host, tt.domain)
+ require.Exactly(t, tt.domain, FromRequest(r))
+ require.Equal(t, tt.host, GetHost(r))
+ })
+ }
+}
diff --git a/internal/logging/logging.go b/internal/logging/logging.go
index f485b20b..d2d758db 100644
--- a/internal/logging/logging.go
+++ b/internal/logging/logging.go
@@ -51,21 +51,6 @@ func getAccessLogger(format string) (*logrus.Logger, error) {
return accessLogger, nil
}
-// getExtraLogFields is used to inject additional fields into the
-// HTTP access logger middleware.
-func getExtraLogFields(r *http.Request) log.Fields {
- logFields := log.Fields{
- "pages_https": request.IsHTTPS(r),
- "pages_host": request.GetHost(r),
- }
-
- for name, value := range request.GetDomain(r).LogFields(r) {
- logFields[name] = value
- }
-
- return logFields
-}
-
// BasicAccessLogger configures the GitLab pages basic HTTP access logger middleware
func BasicAccessLogger(handler http.Handler, format string, extraFields log.ExtraFieldsGeneratorFunc) (http.Handler, error) {
accessLogger, err := getAccessLogger(format)
@@ -98,10 +83,10 @@ func enrichExtraFields(extraFields log.ExtraFieldsGeneratorFunc) log.ExtraFields
}
}
-// AccessLogger configures the GitLab pages HTTP access logger middleware with extra log fields
-func AccessLogger(handler http.Handler, format string) (http.Handler, error) {
- return BasicAccessLogger(handler, format, getExtraLogFields)
-}
+//// AccessLogger configures the GitLab pages HTTP access logger middleware with extra log fields
+//func AccessLogger(handler http.Handler, format string) (http.Handler, error) {
+// return BasicAccessLogger(handler, format, getExtraLogFields)
+//}
// LogRequest will inject request host and path to the logged messages
func LogRequest(r *http.Request) *logrus.Entry {
diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go
index df8c3013..fe157e88 100644
--- a/internal/logging/logging_test.go
+++ b/internal/logging/logging_test.go
@@ -72,9 +72,9 @@ func TestGetExtraLogFields(t *testing.T) {
require.NoError(t, err)
req.URL.Scheme = tt.scheme
- req = request.WithHostAndDomain(req, tt.host, domainWithResolver)
+ req = domain.ReqWithHostAndDomain(req, tt.host, domainWithResolver)
- got := getExtraLogFields(req)
+ got := domain.LogFields(req)
require.Equal(t, tt.expectedHTTPS, got["pages_https"])
require.Equal(t, tt.expectedHost, got["pages_host"])
require.Equal(t, tt.expectedProjectID, got["pages_project_id"])
diff --git a/internal/mocks/source.go b/internal/mocks/source.go
index c6cc6216..fe8c3198 100644
--- a/internal/mocks/source.go
+++ b/internal/mocks/source.go
@@ -6,9 +6,11 @@ package mocks
import (
context "context"
+ reflect "reflect"
+
gomock "github.com/golang/mock/gomock"
+
domain "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
- reflect "reflect"
)
// MockSource is a mock of Source interface
@@ -37,7 +39,7 @@ func (m *MockSource) EXPECT() *MockSourceMockRecorder {
// GetDomain mocks base method
func (m *MockSource) GetDomain(arg0 context.Context, arg1 string) (*domain.Domain, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "GetDomain", arg0, arg1)
+ ret := m.ctrl.Call(m, "FromRequest", arg0, arg1)
ret0, _ := ret[0].(*domain.Domain)
ret1, _ := ret[1].(error)
return ret0, ret1
@@ -46,5 +48,5 @@ func (m *MockSource) GetDomain(arg0 context.Context, arg1 string) (*domain.Domai
// GetDomain indicates an expected call of GetDomain
func (mr *MockSourceMockRecorder) GetDomain(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomain", reflect.TypeOf((*MockSource)(nil).GetDomain), arg0, arg1)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromRequest", reflect.TypeOf((*MockSource)(nil).GetDomain), arg0, arg1)
}
diff --git a/internal/request/request.go b/internal/request/request.go
index 77cc4a76..f98b0819 100644
--- a/internal/request/request.go
+++ b/internal/request/request.go
@@ -1,19 +1,11 @@
package request
import (
- "context"
"net"
"net/http"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)
-type ctxKey string
-
const (
- ctxHostKey ctxKey = "host"
- ctxDomainKey ctxKey = "domain"
-
// SchemeHTTP name for the HTTP scheme
SchemeHTTP = "http"
// SchemeHTTPS name for the HTTPS scheme
@@ -26,25 +18,6 @@ func IsHTTPS(r *http.Request) bool {
return r.URL.Scheme == SchemeHTTPS
}
-// WithHostAndDomain saves host name and domain in the request's context
-func WithHostAndDomain(r *http.Request, host string, domain *domain.Domain) *http.Request {
- ctx := r.Context()
- ctx = context.WithValue(ctx, ctxHostKey, host)
- ctx = context.WithValue(ctx, ctxDomainKey, domain)
-
- return r.WithContext(ctx)
-}
-
-// GetHost extracts the host from request's context
-func GetHost(r *http.Request) string {
- return r.Context().Value(ctxHostKey).(string)
-}
-
-// GetDomain extracts the domain from request's context
-func GetDomain(r *http.Request) *domain.Domain {
- return r.Context().Value(ctxDomainKey).(*domain.Domain)
-}
-
// GetHostWithoutPort returns a host without the port. The host(:port) comes
// from a Host: header if it is provided, otherwise it is a server name.
func GetHostWithoutPort(r *http.Request) string {
diff --git a/internal/request/request_test.go b/internal/request/request_test.go
index a9ffb223..0455fc55 100644
--- a/internal/request/request_test.go
+++ b/internal/request/request_test.go
@@ -6,8 +6,6 @@ import (
"testing"
"github.com/stretchr/testify/require"
-
- "gitlab.com/gitlab-org/gitlab-pages/internal/domain"
)
func TestIsHTTPS(t *testing.T) {
@@ -26,48 +24,6 @@ func TestIsHTTPS(t *testing.T) {
})
}
-func TestPanics(t *testing.T) {
- r, err := http.NewRequest("GET", "/", nil)
- require.NoError(t, err)
-
- require.Panics(t, func() {
- GetHost(r)
- })
-
- require.Panics(t, func() {
- GetDomain(r)
- })
-}
-
-func TestWithHostAndDomain(t *testing.T) {
- tests := []struct {
- name string
- host string
- domain *domain.Domain
- }{
- {
- name: "values",
- host: "gitlab.com",
- domain: &domain.Domain{},
- },
- {
- name: "no_host",
- host: "",
- domain: &domain.Domain{},
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- r, err := http.NewRequest("GET", "/", nil)
- require.NoError(t, err)
-
- r = WithHostAndDomain(r, tt.host, tt.domain)
- require.Exactly(t, tt.domain, GetDomain(r))
- require.Equal(t, tt.host, GetHost(r))
- })
- }
-}
-
func TestGetHostWithoutPort(t *testing.T) {
t.Run("when port component is provided", func(t *testing.T) {
request := httptest.NewRequest("GET", "https://example.com:443", nil)
diff --git a/internal/routing/middleware.go b/internal/routing/middleware.go
index 5f065c61..de34ec21 100644
--- a/internal/routing/middleware.go
+++ b/internal/routing/middleware.go
@@ -27,7 +27,7 @@ func NewMiddleware(handler http.Handler, s source.Source) http.Handler {
return
}
- r = request.WithHostAndDomain(r, host, d)
+ r = domain.ReqWithHostAndDomain(r, host, d)
handler.ServeHTTP(w, r)
})