diff options
author | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-15 18:29:35 +0300 |
---|---|---|
committer | Vladimir Shushlin <v.shushlin@gmail.com> | 2022-02-17 13:22:44 +0300 |
commit | 231fd81805f478b1180320702a51c884b5bab79b (patch) | |
tree | 48f1588617f782a940df028288ce703d996240d4 | |
parent | ae8fbc5bf6725e6fa5b5a9dec6e4ac5016ec3c6c (diff) |
WIPreject-tls
-rw-r--r-- | app.go | 41 | ||||
-rw-r--r-- | internal/config/config.go | 20 | ||||
-rw-r--r-- | internal/config/flags.go | 2 | ||||
-rw-r--r-- | internal/feature/feature.go | 6 | ||||
-rw-r--r-- | internal/ratelimiter/middleware.go | 5 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 5 | ||||
-rw-r--r-- | internal/ratelimiter/tls.go | 62 | ||||
-rw-r--r-- | metrics/metrics.go | 28 | ||||
-rw-r--r-- | test/acceptance/artifacts_test.go | 5 | ||||
-rw-r--r-- | test/acceptance/auth_test.go | 2 | ||||
-rw-r--r-- | test/acceptance/helpers_test.go | 65 | ||||
-rw-r--r-- | test/acceptance/ratelimiter_test.go | 27 | ||||
-rw-r--r-- | test/acceptance/serving_test.go | 8 | ||||
-rw-r--r-- | test/acceptance/tls_test.go | 8 | ||||
-rw-r--r-- | test/acceptance/unknown_http_method_test.go | 2 |
15 files changed, 218 insertions, 68 deletions
@@ -13,6 +13,9 @@ import ( "syscall" "time" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + ghandlers "github.com/gorilla/handlers" "github.com/hashicorp/go-multierror" "github.com/rs/cors" @@ -51,6 +54,7 @@ var ( type theApp struct { config *cfg.Config source source.Source + tlsConfig *cryptotls.Config Artifact *artifact.Artifact Auth *auth.Auth Handlers *handlers.Handlers @@ -62,7 +66,8 @@ func (a *theApp) isReady() bool { return true } -func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { +func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { + log.Info("GetCertificate called") if ch.ServerName == "" { return nil, nil } @@ -75,6 +80,31 @@ func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate return nil, nil } +func (a *theApp) getTLSConfig() (*cryptotls.Config, error) { + if a.tlsConfig != nil { + return a.tlsConfig, nil + } + TLSRateLimiter := ratelimiter.New( + "tls", + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainTLSCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainTLSCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainTLSBlockedCount), + ratelimiter.WithLimitPerSecond(a.config.RateLimit.DomainTLSLimitPerSecond), + ratelimiter.WithBurstSize(a.config.RateLimit.DomainTLSBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + getCertificate := TLSRateLimiter.GetCertificateMiddleware(a.GetCertificate) + + tlsConfig, err := tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, getCertificate, + a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) + + a.tlsConfig = tlsConfig + + return a.tlsConfig, err +} + func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) { u := *r.URL u.Scheme = request.SchemeHTTPS @@ -306,7 +336,7 @@ func (a *theApp) Run() { // Listen for HTTPS for _, addr := range a.config.ListenHTTPSStrings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -334,7 +364,7 @@ func (a *theApp) Run() { // Listen for HTTPS PROXYv2 requests for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -478,11 +508,6 @@ func fatal(err error, message string) { log.WithError(err).Fatal(message) } -func (a *theApp) TLSConfig() (*cryptotls.Config, error) { - return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS, - a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) -} - // handlePanicMiddleware logs and captures the recover() information from any panic func handlePanicMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/config/config.go b/internal/config/config.go index 3bb7b126..bcd0cfc7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,10 +58,12 @@ type General struct { // RateLimit config struct type RateLimit struct { - SourceIPLimitPerSecond float64 - SourceIPBurst int - DomainLimitPerSecond float64 - DomainBurst int + SourceIPLimitPerSecond float64 + SourceIPBurst int + DomainLimitPerSecond float64 + DomainBurst int + DomainTLSLimitPerSecond float64 + DomainTLSBurst int } // ArtifactsServer groups settings related to configuring Artifacts @@ -179,10 +181,12 @@ func loadConfig() (*Config, error) { ShowVersion: *showVersion, }, RateLimit: RateLimit{ - SourceIPLimitPerSecond: *rateLimitSourceIP, - SourceIPBurst: *rateLimitSourceIPBurst, - DomainLimitPerSecond: *rateLimitDomain, - DomainBurst: *rateLimitDomainBurst, + SourceIPLimitPerSecond: *rateLimitSourceIP, + SourceIPBurst: *rateLimitSourceIPBurst, + DomainLimitPerSecond: *rateLimitDomain, + DomainBurst: *rateLimitDomainBurst, + DomainTLSLimitPerSecond: *rateLimitDomainTLS, + DomainTLSBurst: *rateLimitDomainTLSBurst, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index 93228827..bd40f362 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -19,6 +19,8 @@ var ( rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled") rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second") + rateLimitDomainTLS = flag.Float64("rate-limit-domain-tls", 0.0, "Rate limit per domain in number new TLS connections per second, 0 means is disabled") + rateLimitDomainTLSBurst = flag.Int("rate-limit-domain-tls-burst", 100, "Rate limit per domain maximum burst of TLS connections allowed per second") artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 81eef9a0..3c3f1ee3 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -19,6 +19,12 @@ var EnforceDomainRateLimits = Feature{ EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS", } +// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655 +var EnforceDomainTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS", +} + // RedirectsPlaceholders enables support for placeholders in redirects file // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620 var RedirectsPlaceholders = Feature{ diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index af7b0881..2faaac08 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -8,7 +8,6 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) @@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { rl.logRateLimitedRequest(r) if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc() + rl.blockedCount.WithLabelValues(strconv.FormatBool(rl.enforce)).Inc() } if rl.enforce { @@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) { "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), - "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(), + "rate_limiter_enabled": rl.enforce, "rate_limiter_limit_per_second": rl.limitPerSecond, "rate_limiter_burst_size": rl.burstSize, }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index feeb8cb4..24fc05fe 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -138,6 +138,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter { // requestAllowed checks if request is within the rate-limit func (rl *RateLimiter) requestAllowed(r *http.Request) bool { rateLimitedKey := rl.keyFunc(r) + + return rl.allowed(rateLimitedKey) +} + +func (rl *RateLimiter) allowed(rateLimitedKey string) bool { limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go new file mode 100644 index 00000000..e8d6dc72 --- /dev/null +++ b/internal/ratelimiter/tls.go @@ -0,0 +1,62 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "net" + "strconv" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/log" + + tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" +) + +var TLSRateLimitedError = errors.New("TLS connection is being rate-limited") + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { + log.Info("GetCertificateMiddleware set") + return func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.WithFields(logrus.Fields{ + "server_name": hi.ServerName, + }).Info("GetCertificateMiddleware called") + + return getCertificate(hi) + + if rl.allowed(hi.ServerName) { + return getCertificate(hi) + } + + rl.logRateLimitedTLS(hi) + + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues(strconv.FormatBool(rl.enforce)).Inc() + } + + if !rl.enforce { + return getCertificate(hi) + } + + return nil, TLSRateLimitedError + } +} + +func (rl *RateLimiter) logRateLimitedTLS(hi *tls.ClientHelloInfo) { + log.WithFields(logrus.Fields{ + "rate_limiter_name": rl.name, + "source_ip": getRemoteAddrFromHelloInfo(hi), + "req_host": hi.ServerName, + "rate_limiter_limit_per_second": rl.limitPerSecond, + "rate_limiter_burst_size": rl.burstSize, + }).Info("TLS connection rate-limited") +} + +func getRemoteAddrFromHelloInfo(hi *tls.ClientHelloInfo) string { + remoteAddr := hi.Conn.RemoteAddr().String() + remoteAddr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return remoteAddr +} diff --git a/metrics/metrics.go b/metrics/metrics.go index e0e4ab29..4954e6d1 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -248,6 +248,34 @@ var ( }, []string{"enforced"}, ) + + // RateLimitDomainTLSCacheRequests is the number of cache hits/misses + RateLimitDomainTLSCacheRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_cache_requests", + Help: "The number of source_ip cache hits/misses in the rate limiter", + }, + []string{"op", "cache"}, + ) + + // RateLimitDomainTLSCachedEntries is the number of entries in the cache + RateLimitDomainTLSCachedEntries = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_cached_entries", + Help: "The number of entries in the cache", + }, + []string{"op"}, + ) + + // RateLimitDomainTLSBlockedCount is the number of TLS connections that have been blocked by the + // domain TLS rate limiter + RateLimitDomainTLSBlockedCount = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_blocked_count", + Help: "The number of requests addresses that have been blocked by the domain TLS rate limiter", + }, + []string{"enforced"}, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index f087581c..653e31cf 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -14,9 +14,6 @@ import ( ) func TestArtifactProxyRequest(t *testing.T) { - transport := (TestHTTPSClient.Transport).(*http.Transport).Clone() - transport.ResponseHeaderTimeout = 5 * time.Second - content := "<!DOCTYPE html><html><head><title>Title of the document</title></head><body></body></html>" contentLength := int64(len(content)) testServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -150,8 +147,6 @@ func TestArtifactProxyRequest(t *testing.T) { } func TestPrivateArtifactProxyRequest(t *testing.T) { - setupTransport(t) - testServer := NewGitlabUnstartedServerStub(t, &stubOpts{}) keyFile, certFile := CreateHTTPSFixtureFiles(t) diff --git a/test/acceptance/auth_test.go b/test/acceptance/auth_test.go index 18b73161..c786ccf4 100644 --- a/test/acceptance/auth_test.go +++ b/test/acceptance/auth_test.go @@ -421,8 +421,6 @@ func TestAccessControlProject404DoesNotRedirect(t *testing.T) { type runPagesFunc func(t *testing.T, listeners []ListenSpec, sslCertFile string) func testAccessControl(t *testing.T, runPages runPagesFunc) { - setupTransport(t) - _, certFile := CreateHTTPSFixtureFiles(t) tests := map[string]struct { diff --git a/test/acceptance/helpers_test.go b/test/acceptance/helpers_test.go index 62a5e344..dad8b36e 100644 --- a/test/acceptance/helpers_test.go +++ b/test/acceptance/helpers_test.go @@ -34,14 +34,6 @@ import ( // The HTTPS certificate isn't signed by anyone. This http client is set up // so it can talk to servers using it. var ( - // The HTTPS certificate isn't signed by anyone. This http client is set up - // so it can talk to servers using it. - TestHTTPSClient = &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: TestCertPool}, - }, - } - // Use HTTP with a very short timeout to repeatedly check for the server to be // up. Again, ignore HTTP QuickTimeoutHTTPSClient = &http.Client{ @@ -151,15 +143,40 @@ func supportedListeners() []ListenSpec { return listeners } -func (l ListenSpec) URL(suffix string) string { - scheme := request.SchemeHTTP +func (l ListenSpec) Scheme() string { if l.Type == request.SchemeHTTPS || l.Type == "https-proxyv2" { - scheme = request.SchemeHTTPS + return request.SchemeHTTPS } + return request.SchemeHTTP +} + +func (l ListenSpec) URL(host string, suffix string) string { suffix = strings.TrimPrefix(suffix, "/") - return fmt.Sprintf("%s://%s/%s", scheme, l.JoinHostPort(), suffix) + if host == "" { + host = l.Host + } + + return fmt.Sprintf("%s://%s/%s", l.Scheme(), net.JoinHostPort(host, l.Port), suffix) +} + +func (l ListenSpec) Client() *http.Client { + if l.Type == "https-proxyv2" { + return TestProxyv2Client + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: TestCertPool}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + + return d.DialContext(ctx, network, l.JoinHostPort()) + }, + ResponseHeaderTimeout: 5 * time.Second, + }, + } } // Returns only once this spec points at a working TCP server @@ -172,7 +189,7 @@ func (l ListenSpec) WaitUntilRequestSucceeds(done chan struct{}) error { default: } - req, err := http.NewRequest("GET", l.URL("/"), nil) + req, err := http.NewRequest("GET", l.URL("", "/"), nil) if err != nil { return err } @@ -365,7 +382,7 @@ func GetPageFromListener(t *testing.T, spec ListenSpec, host, urlsuffix string) func GetPageFromListenerWithHeaders(t *testing.T, spec ListenSpec, host, urlSuffix string, header http.Header) (*http.Response, error) { t.Helper() - url := spec.URL(urlSuffix) + url := spec.URL(host, urlSuffix) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -380,11 +397,7 @@ func GetPageFromListenerWithHeaders(t *testing.T, spec ListenSpec, host, urlSuff func DoPagesRequest(t *testing.T, spec ListenSpec, req *http.Request) (*http.Response, error) { t.Logf("curl -X %s -H'Host: %s' %s", req.Method, req.Host, req.URL) - if spec.Type == "https-proxyv2" { - return TestProxyv2Client.Do(req) - } - - return TestHTTPSClient.Do(req) + return spec.Client().Do(req) } func GetRedirectPage(t *testing.T, spec ListenSpec, host, urlsuffix string) (*http.Response, error) { @@ -410,7 +423,7 @@ func GetRedirectPageWithCookie(t *testing.T, spec ListenSpec, host, urlsuffix st } func GetRedirectPageWithHeaders(t *testing.T, spec ListenSpec, host, urlsuffix string, header http.Header) (*http.Response, error) { - url := spec.URL(urlsuffix) + url := spec.URL(host, urlsuffix) req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -423,7 +436,7 @@ func GetRedirectPageWithHeaders(t *testing.T, spec ListenSpec, host, urlsuffix s return TestProxyv2Client.Transport.RoundTrip(req) } - return TestHTTPSClient.Transport.RoundTrip(req) + return spec.Client().Transport.RoundTrip(req) } func ClientWithConfig(tlsConfig *tls.Config) (*http.Client, func()) { @@ -601,13 +614,3 @@ func copyFile(dest, src string) error { _, err = io.Copy(destFile, srcFile) return err } - -func setupTransport(t *testing.T) { - t.Helper() - - transport := (TestHTTPSClient.Transport).(*http.Transport) - defer func(t time.Duration) { - transport.ResponseHeaderTimeout = t - }(transport.ResponseHeaderTimeout) - transport.ResponseHeaderTimeout = 5 * time.Second -} diff --git a/test/acceptance/ratelimiter_test.go b/test/acceptance/ratelimiter_test.go index 02ba54f5..c4c87a10 100644 --- a/test/acceptance/ratelimiter_test.go +++ b/test/acceptance/ratelimiter_test.go @@ -77,7 +77,7 @@ func TestIPRateLimits(t *testing.T) { } } -func TestDomainateLimits(t *testing.T) { +func TestDomainRateLimits(t *testing.T) { testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainRateLimits.EnvVariable, true) for name, tc := range ratelimitedListeners { @@ -112,6 +112,31 @@ func TestDomainateLimits(t *testing.T) { } } +func TestDomainTLSRateLimits(t *testing.T) { + testhelpers.StubFeatureFlagValue(t, feature.EnforceDomainTLSRateLimits.EnvVariable, true) + + rateLimit := 5 + + RunPagesProcess(t, + withListeners([]ListenSpec{httpsListener}), + withExtraArgument("rate-limit-domain-tls", fmt.Sprint(rateLimit)), + withExtraArgument("rate-limit-domain-tls-burst", fmt.Sprint(rateLimit)), + ) + + for i := 0; i < 10; i++ { + rsp, err := GetPageFromListener(t, httpsListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + require.NoError(t, rsp.Body.Close()) + + if i >= rateLimit { + require.Equal(t, http.StatusTooManyRequests, rsp.StatusCode, "group.gitlab-example.com request: %d failed", i) + //assertLogFound(t, logBuf, []string{"request hit rate limit", "\"source_ip\":\"" + tc.clientIP + "\""}) + } else { + require.Equal(t, http.StatusOK, rsp.StatusCode, "request: %d failed", i) + } + } +} + func assertLogFound(t *testing.T, logBuf *LogCaptureBuffer, expectedLogs []string) { t.Helper() diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 8b01f5b2..563ade8f 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -170,7 +170,7 @@ func TestCORSWhenDisabled(t *testing.T) { for _, spec := range supportedListeners() { for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { - rsp := doCrossOriginRequest(t, spec, method, method, spec.URL("project/")) + rsp := doCrossOriginRequest(t, spec, method, method, spec.URL("", "project/")) defer rsp.Body.Close() require.Equal(t, http.StatusOK, rsp.StatusCode) @@ -218,7 +218,7 @@ func TestCORSAllowsMethod(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, spec := range supportedListeners() { - rsp := doCrossOriginRequest(t, spec, tt.method, tt.method, spec.URL("project/")) + rsp := doCrossOriginRequest(t, spec, tt.method, tt.method, spec.URL("", "project/")) defer rsp.Body.Close() require.Equal(t, tt.expectedStatus, rsp.StatusCode) @@ -566,12 +566,10 @@ func TestSlowRequests(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), opts.delay/2) defer cancel() - url := httpListener.URL("/index.html") + url := httpListener.URL("group.gitlab-example.com", "/index.html") req, err := http.NewRequestWithContext(ctx, "GET", url, nil) require.NoError(t, err) - req.Host = "group.gitlab-example.com" - _, err = DoPagesRequest(t, httpListener, req) require.Error(t, err, "cancelling the context should trigger this error") diff --git a/test/acceptance/tls_test.go b/test/acceptance/tls_test.go index 3b4c3a5c..af8f10a2 100644 --- a/test/acceptance/tls_test.go +++ b/test/acceptance/tls_test.go @@ -25,7 +25,7 @@ func TestAcceptsSupportedCiphers(t *testing.T) { client, cleanup := ClientWithConfig(tlsConfig) defer cleanup() - rsp, err := client.Get(httpsListener.URL("/")) + rsp, err := client.Get(httpsListener.URL("", "/")) require.NoError(t, err) t.Cleanup(func() { @@ -51,7 +51,7 @@ func TestRejectsUnsupportedCiphers(t *testing.T) { client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) defer cleanup() - rsp, err := client.Get(httpsListener.URL("/")) + rsp, err := client.Get(httpsListener.URL("", "/")) require.Nil(t, rsp) require.Error(t, err) } @@ -65,7 +65,7 @@ func TestEnableInsecureCiphers(t *testing.T) { client, cleanup := ClientWithConfig(tlsConfigWithInsecureCiphersOnly()) defer cleanup() - rsp, err := client.Get(httpsListener.URL("/")) + rsp, err := client.Get(httpsListener.URL("", "/")) require.NoError(t, err) t.Cleanup(func() { rsp.Body.Close() @@ -107,7 +107,7 @@ func TestTLSVersions(t *testing.T) { client, cleanup := ClientWithConfig(tlsConfig) defer cleanup() - rsp, err := client.Get(httpsListener.URL("/")) + rsp, err := client.Get(httpsListener.URL("", "/")) if tc.expectError { require.Error(t, err) diff --git a/test/acceptance/unknown_http_method_test.go b/test/acceptance/unknown_http_method_test.go index dfe9c82f..01dfed01 100644 --- a/test/acceptance/unknown_http_method_test.go +++ b/test/acceptance/unknown_http_method_test.go @@ -12,7 +12,7 @@ func TestUnknownHTTPMethod(t *testing.T) { withListeners([]ListenSpec{httpListener}), ) - req, err := http.NewRequest("UNKNOWN", httpListener.URL(""), nil) + req, err := http.NewRequest("UNKNOWN", httpListener.URL("", ""), nil) require.NoError(t, err) req.Host = "" |