diff options
author | Andrew Newdigate <andrew@gitlab.com> | 2019-07-24 01:12:22 +0300 |
---|---|---|
committer | Andrew Newdigate <andrew@gitlab.com> | 2019-08-22 13:50:39 +0300 |
commit | c78ef2c684675b7b0685a78958860558149fae25 (patch) | |
tree | 9e7a31d5ecaa3ea6dfc1ca60bd42552df9be6621 /internal/request | |
parent | 654c183c4b06c7deb3b947c7f557fe4a48f2e218 (diff) |
Refactor to use pluggable http.Handler middlewaresan-use-middleware-handlers
Diffstat (limited to 'internal/request')
-rw-r--r-- | internal/request/request.go | 25 | ||||
-rw-r--r-- | internal/request/request_test.go | 52 |
2 files changed, 72 insertions, 5 deletions
diff --git a/internal/request/request.go b/internal/request/request.go index dad6af3d..730eb527 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -3,12 +3,16 @@ package request import ( "context" "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) type ctxKey string const ( - ctxHTTPSKey ctxKey = "https" + ctxHTTPSKey ctxKey = "https" + ctxHostKey ctxKey = "host" + ctxDomainKey ctxKey = "domain" ) // WithHTTPSFlag saves https flag in request's context @@ -22,3 +26,22 @@ func WithHTTPSFlag(r *http.Request, https bool) *http.Request { func IsHTTPS(r *http.Request) bool { return r.Context().Value(ctxHTTPSKey).(bool) } + +// WithHostAndDomain saves host name and domain in the request's context +func WithHostAndDomain(r *http.Request, host string, domain *domain.D) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, ctxHostKey, host) + ctx = context.WithValue(ctx, ctxDomainKey, domain) + + return r.WithContext(ctx) +} + +// GetHost extracts the host from request's context +func GetHost(r *http.Request) string { + return r.Context().Value(ctxHostKey).(string) +} + +// GetDomain extracts the domain from request's context +func GetDomain(r *http.Request) *domain.D { + return r.Context().Value(ctxDomainKey).(*domain.D) +} diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 1f47ee2e..97e40ee4 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -6,19 +6,63 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) func TestWithHTTPSFlag(t *testing.T) { r, err := http.NewRequest("GET", "/", nil) require.NoError(t, err) - assert.Panics(t, func() { - IsHTTPS(r) - }) - httpsRequest := WithHTTPSFlag(r, true) assert.True(t, IsHTTPS(httpsRequest)) httpRequest := WithHTTPSFlag(r, false) assert.False(t, IsHTTPS(httpRequest)) } + +func TestPanics(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + assert.Panics(t, func() { + IsHTTPS(r) + }) + + assert.Panics(t, func() { + GetHost(r) + }) + + assert.Panics(t, func() { + GetDomain(r) + }) +} + +func TestWithHostAndDomain(t *testing.T) { + tests := []struct { + name string + host string + domain *domain.D + }{ + { + name: "values", + host: "gitlab.com", + domain: &domain.D{}, + }, + { + name: "no_host", + host: "", + domain: &domain.D{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := http.NewRequest("GET", "/", nil) + require.NoError(t, err) + + r = WithHostAndDomain(r, tt.host, tt.domain) + assert.Exactly(t, tt.domain, GetDomain(r)) + assert.Equal(t, tt.host, GetHost(r)) + }) + } +} |