diff options
-rw-r--r-- | internal/httprange/http_reader.go | 1 | ||||
-rw-r--r-- | internal/httptransport/transport.go | 37 | ||||
-rw-r--r-- | internal/httptransport/transport_test.go | 59 | ||||
-rw-r--r-- | internal/source/gitlab/client/client.go | 4 |
4 files changed, 76 insertions, 25 deletions
diff --git a/internal/httprange/http_reader.go b/internal/httprange/http_reader.go index 5dc0f693..8e632212 100644 --- a/internal/httprange/http_reader.go +++ b/internal/httprange/http_reader.go @@ -62,6 +62,7 @@ var httpClient = &http.Client{ metrics.HTTPRangeTraceDuration, metrics.HTTPRangeRequestDuration, metrics.HTTPRangeRequestsTotal, + 15*time.Second, ), } diff --git a/internal/httptransport/transport.go b/internal/httptransport/transport.go index 7d388d81..fd5473c8 100644 --- a/internal/httptransport/transport.go +++ b/internal/httptransport/transport.go @@ -1,6 +1,7 @@ package httptransport import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -26,11 +27,12 @@ var ( ) type meteredRoundTripper struct { - next http.RoundTripper - name string - tracer *prometheus.HistogramVec - durations *prometheus.HistogramVec - counter *prometheus.CounterVec + next http.RoundTripper + name string + tracer *prometheus.HistogramVec + durations *prometheus.HistogramVec + counter *prometheus.CounterVec + ttfbTimeout time.Duration } func newInternalTransport() *http.Transport { @@ -46,20 +48,21 @@ func newInternalTransport() *http.Transport { // Set more timeouts https://gitlab.com/gitlab-org/gitlab-pages/-/issues/495 TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: 15 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + ExpectContinueTimeout: 15 * time.Second, } } // NewTransportWithMetrics will create a custom http.RoundTripper that can be used with an http.Client. // The RoundTripper will report metrics based on the collectors passed. func NewTransportWithMetrics(name string, tracerVec, durationsVec *prometheus. - HistogramVec, counterVec *prometheus.CounterVec) http.RoundTripper { + HistogramVec, counterVec *prometheus.CounterVec, ttfbTimeout time.Duration) http.RoundTripper { return &meteredRoundTripper{ - next: InternalTransport, - name: name, - tracer: tracerVec, - durations: durationsVec, - counter: counterVec, + next: InternalTransport, + name: name, + tracer: tracerVec, + durations: durationsVec, + counter: counterVec, + ttfbTimeout: ttfbTimeout, } } @@ -92,7 +95,15 @@ func loadPool() { func (mrt *meteredRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { start := time.Now() - r = r.WithContext(httptrace.WithClientTrace(r.Context(), mrt.newTracer(start))) + ctx := httptrace.WithClientTrace(r.Context(), mrt.newTracer(start)) + ctx, cancel := context.WithCancel(ctx) + + timer := time.AfterFunc(mrt.ttfbTimeout, func() { + cancel() + }) + defer timer.Stop() + + r = r.WithContext(ctx) resp, err := mrt.next.RoundTrip(r) if err != nil { diff --git a/internal/httptransport/transport_test.go b/internal/httptransport/transport_test.go index 5df0175a..869f21db 100644 --- a/internal/httptransport/transport_test.go +++ b/internal/httptransport/transport_test.go @@ -1,6 +1,8 @@ package httptransport import ( + "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -43,13 +45,7 @@ func Test_withRoundTripper(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - histVec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: t.Name(), - }, []string{"status_code"}) - - counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: t.Name(), - }, []string{"status_code"}) + histVec, counterVec := newTestMetrics(t) next := &mockRoundTripper{ res: &http.Response{ @@ -78,13 +74,54 @@ func Test_withRoundTripper(t *testing.T) { } } +func TestRoundTripTTFBTimeout(t *testing.T) { + histVec, counterVec := newTestMetrics(t) + + next := &mockRoundTripper{ + res: &http.Response{ + StatusCode: http.StatusOK, + }, + timeout: time.Millisecond, + err: nil, + } + + mtr := &meteredRoundTripper{next: next, durations: histVec, counter: counterVec, ttfbTimeout: time.Nanosecond} + req, err := http.NewRequest("GET", "https://gitlab.com", nil) + require.NoError(t, err) + + res, err := mtr.RoundTrip(req) + require.Nil(t, res) + require.True(t, errors.Is(err, context.Canceled), "context must have been canceled after ttfb timeout") +} + +func newTestMetrics(t *testing.T) (*prometheus.HistogramVec, *prometheus.CounterVec) { + t.Helper() + + histVec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: t.Name(), + }, []string{"status_code"}) + + counterVec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: t.Name(), + }, []string{"status_code"}) + + return histVec, counterVec +} + type mockRoundTripper struct { - res *http.Response - err error + res *http.Response + err error + timeout time.Duration } func (mrt *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { - return mrt.res, mrt.err + time.Sleep(mrt.timeout) + select { + case <-r.Context().Done(): + return nil, r.Context().Err() + default: + return mrt.res, mrt.err + } } func TestInternalTransportShouldHaveCustomConnectionPoolSettings(t *testing.T) { @@ -94,5 +131,5 @@ func TestInternalTransportShouldHaveCustomConnectionPoolSettings(t *testing.T) { require.EqualValues(t, 90*time.Second, InternalTransport.IdleConnTimeout) require.EqualValues(t, 10*time.Second, InternalTransport.TLSHandshakeTimeout) require.EqualValues(t, 15*time.Second, InternalTransport.ResponseHeaderTimeout) - require.EqualValues(t, 1*time.Second, InternalTransport.ExpectContinueTimeout) + require.EqualValues(t, 15*time.Second, InternalTransport.ExpectContinueTimeout) } diff --git a/internal/source/gitlab/client/client.go b/internal/source/gitlab/client/client.go index 0e8235c0..2b80e832 100644 --- a/internal/source/gitlab/client/client.go +++ b/internal/source/gitlab/client/client.go @@ -60,7 +60,9 @@ func NewClient(baseURL string, secretKey []byte, connectionTimeout, jwtTokenExpi "gitlab_internal_api", metrics.DomainsSourceAPITraceDuration, metrics.DomainsSourceAPICallDuration, - metrics.DomainsSourceAPIReqTotal), + metrics.DomainsSourceAPIReqTotal, + 15*time.Second, + ), }, jwtTokenExpiry: jwtTokenExpiry, }, nil |