diff options
author | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-22 15:10:30 +0300 |
---|---|---|
committer | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-22 15:10:30 +0300 |
commit | 49203bfa3c7eb607a7561ae7da9b5c52aa49fd77 (patch) | |
tree | f33cd54ec9a45d69a3e58fe93735070d3b718913 /workhorse | |
parent | 3c9a2dd62025043448c9ea9a6df86422874ee4be (diff) |
Add latest changes from gitlab-org/gitlab@master
Diffstat (limited to 'workhorse')
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy.go | 16 | ||||
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy_test.go | 36 | ||||
-rw-r--r-- | workhorse/internal/sendurl/sendurl.go | 60 | ||||
-rw-r--r-- | workhorse/internal/sendurl/sendurl_test.go | 66 | ||||
-rw-r--r-- | workhorse/internal/transport/transport.go | 15 |
5 files changed, 166 insertions, 27 deletions
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy.go b/workhorse/internal/dependencyproxy/dependencyproxy.go index dbea3c29aec..76def761668 100644 --- a/workhorse/internal/dependencyproxy/dependencyproxy.go +++ b/workhorse/internal/dependencyproxy/dependencyproxy.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "net/http" + "os" + "time" "gitlab.com/gitlab-org/labkit/log" @@ -13,8 +15,12 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/transport" ) +const dialTimeout = 10 * time.Second +const responseHeaderTimeout = 10 * time.Second + +var httpTransport = transport.NewRestrictedTransport(transport.WithDialTimeout(dialTimeout), transport.WithResponseHeaderTimeout(responseHeaderTimeout)) var httpClient = &http.Client{ - Transport: transport.NewRestrictedTransport(), + Transport: httpTransport, } type Injector struct { @@ -70,7 +76,13 @@ func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData strin dependencyResponse, err := p.fetchUrl(r.Context(), params) if err != nil { - fail.Request(w, r, err) + status := http.StatusBadGateway + + if os.IsTimeout(err) { + status = http.StatusGatewayTimeout + } + + fail.Request(w, r, err, fail.WithStatus(status)) return } defer dependencyResponse.Body.Close() diff --git a/workhorse/internal/dependencyproxy/dependencyproxy_test.go b/workhorse/internal/dependencyproxy/dependencyproxy_test.go index 18d08ef162c..8a64517d578 100644 --- a/workhorse/internal/dependencyproxy/dependencyproxy_test.go +++ b/workhorse/internal/dependencyproxy/dependencyproxy_test.go @@ -10,11 +10,13 @@ import ( "strconv" "strings" "testing" + "time" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" "gitlab.com/gitlab-org/gitlab/workhorse/internal/testhelper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/transport" "gitlab.com/gitlab-org/gitlab/workhorse/internal/upload" ) @@ -124,6 +126,7 @@ func TestSuccessfullRequest(t *testing.T) { w.Header().Set("Overridden-Header", overriddenHeader) w.Write(content) })) + defer originResourceServer.Close() uploadHandler := &fakeUploadHandler{ handler: func(w http.ResponseWriter, r *http.Request) { @@ -161,6 +164,7 @@ func TestValidUploadConfiguration(t *testing.T) { w.Header().Set("Content-Type", contentType) w.Write(content) })) + defer originResourceServer.Close() testCases := []struct { desc string @@ -281,6 +285,34 @@ func TestInvalidUploadConfiguration(t *testing.T) { } } +func TestTimeoutConfiguration(t *testing.T) { + originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(20 * time.Millisecond) + })) + defer originResourceServer.Close() + + injector := NewInjector() + + var oldHttpClient = httpClient + httpClient = &http.Client{ + Transport: transport.NewRestrictedTransport(transport.WithResponseHeaderTimeout(10 * time.Millisecond)), + } + + t.Cleanup(func() { + httpClient = oldHttpClient + }) + + sendData := map[string]string{ + "Url": originResourceServer.URL + "/file", + } + + sendDataJsonString, err := json.Marshal(sendData) + require.NoError(t, err) + + response := makeRequest(injector, string(sendDataJsonString)) + require.Equal(t, http.StatusGatewayTimeout, response.Result().StatusCode) +} + func mergeMap(from map[string]interface{}, into map[string]interface{}) map[string]interface{} { for k, v := range from { into[k] = v @@ -298,8 +330,8 @@ func TestIncorrectSendData(t *testing.T) { func TestIncorrectSendDataUrl(t *testing.T) { response := makeRequest(NewInjector(), `{"Token": "token", "Url": "url"}`) - require.Equal(t, 500, response.Code) - require.Equal(t, "Internal Server Error\n", response.Body.String()) + require.Equal(t, http.StatusBadGateway, response.Code) + require.Equal(t, "Bad Gateway\n", response.Body.String()) } func TestFailedOriginServer(t *testing.T) { diff --git a/workhorse/internal/sendurl/sendurl.go b/workhorse/internal/sendurl/sendurl.go index 9cc1cd352b7..e011f57c6bc 100644 --- a/workhorse/internal/sendurl/sendurl.go +++ b/workhorse/internal/sendurl/sendurl.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "os" "strings" "github.com/prometheus/client_golang/prometheus" @@ -11,6 +12,7 @@ import ( "gitlab.com/gitlab-org/labkit/mask" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper/fail" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" "gitlab.com/gitlab-org/gitlab/workhorse/internal/senddata" @@ -20,11 +22,15 @@ import ( type entry struct{ senddata.Prefix } type entryParams struct { - URL string - AllowRedirects bool - Body string - Header http.Header - Method string + URL string + AllowRedirects bool + DialTimeout config.TomlDuration + ResponseHeaderTimeout config.TomlDuration + ErrorResponseStatus int + TimeoutResponseStatus int + Body string + Header http.Header + Method string } var SendURL = &entry{"send-url:"} @@ -48,10 +54,8 @@ var preserveHeaderKeys = map[string]bool{ "Pragma": true, // Support for HTTP 1.0 proxies } -var httpTransport = transport.NewRestrictedTransport() - -var httpClient = &http.Client{ - Transport: httpTransport, +var httpClientNoRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse } var ( @@ -126,14 +130,19 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) // execute new request var resp *http.Response - if params.AllowRedirects { - resp, err = httpClient.Do(newReq) - } else { - resp, err = httpTransport.RoundTrip(newReq) - } + resp, err = newClient(params).Do(newReq) + if err != nil { + status := http.StatusInternalServerError + + if params.TimeoutResponseStatus != 0 && os.IsTimeout(err) { + status = params.TimeoutResponseStatus + } else if params.ErrorResponseStatus != 0 { + status = params.ErrorResponseStatus + } + sendURLRequestsRequestFailed.Inc() - fail.Request(w, r, fmt.Errorf("SendURL: Do request: %v", err)) + fail.Request(w, r, fmt.Errorf("SendURL: Do request: %v", err), fail.WithStatus(status)) return } @@ -160,3 +169,24 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) sendURLRequestsSucceeded.Inc() } + +func newClient(params entryParams) *http.Client { + var options []transport.Option + + if params.DialTimeout.Duration != 0 { + options = append(options, transport.WithDialTimeout(params.DialTimeout.Duration)) + } + if params.ResponseHeaderTimeout.Duration != 0 { + options = append(options, transport.WithResponseHeaderTimeout(params.ResponseHeaderTimeout.Duration)) + } + + client := &http.Client{ + Transport: transport.NewRestrictedTransport(options...), + } + + if !params.AllowRedirects { + client.CheckRedirect = httpClientNoRedirect + } + + return client +} diff --git a/workhorse/internal/sendurl/sendurl_test.go b/workhorse/internal/sendurl/sendurl_test.go index ea0c6a43af6..4bebe43a649 100644 --- a/workhorse/internal/sendurl/sendurl_test.go +++ b/workhorse/internal/sendurl/sendurl_test.go @@ -3,12 +3,10 @@ package sendurl import ( "encoding/base64" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" "os" - "strconv" "testing" "time" @@ -20,13 +18,26 @@ import ( const testData = `123456789012345678901234567890` const testDataEtag = `W/"myetag"` -func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, allowRedirects bool) *httptest.ResponseRecorder { +type option struct { + Key string + Value interface{} +} + +func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, allowRedirects bool, options ...option) *httptest.ResponseRecorder { requestHandler := func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "GET", r.Method) - url := r.URL.String() + "/file" - jsonParams := fmt.Sprintf(`{"URL":%q,"AllowRedirects":%s}`, - url, strconv.FormatBool(allowRedirects)) + sendData := map[string]interface{}{ + "URL": r.URL.String() + "/file", + "AllowRedirects": allowRedirects, + } + + for _, o := range options { + sendData[o.Key] = o.Value + } + + jsonParams, err := json.Marshal(sendData) + require.NoError(t, err) data := base64.URLEncoding.EncodeToString([]byte(jsonParams)) // The server returns a Content-Disposition @@ -60,6 +71,9 @@ func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, a require.Equal(t, "GET", r.Method) http.Redirect(w, r, r.URL.String()+"/download", http.StatusTemporaryRedirect) } + timeoutFile := func(w http.ResponseWriter, r *http.Request) { + time.Sleep(20 * time.Millisecond) + } mux := http.NewServeMux() mux.HandleFunc("/get/request", requestHandler) @@ -68,6 +82,8 @@ func testEntryServer(t *testing.T, requestURL string, httpHeaders http.Header, a mux.HandleFunc("/get/redirect/file", redirectFile) mux.HandleFunc("/get/redirect/file/download", serveFile) mux.HandleFunc("/get/file-not-existing", requestHandler) + mux.HandleFunc("/get/timeout", requestHandler) + mux.HandleFunc("/get/timeout/file", timeoutFile) server := httptest.NewServer(mux) defer server.Close() @@ -199,13 +215,19 @@ func TestDownloadingNonExistingRemoteFileWithSendURL(t *testing.T) { func TestPostRequest(t *testing.T) { body := "any string" - header := map[string][]string{"Authorization": []string{"Bearer token"}} + header := map[string][]string{"Authorization": {"Bearer token"}} postRequestHandler := func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "POST", r.Method) url := r.URL.String() + "/external/url" - jsonParams, err := json.Marshal(entryParams{URL: url, Body: body, Header: header, Method: "POST"}) + sendData := map[string]interface{}{ + "URL": url, + "Body": body, + "Header": header, + "Method": "POST", + } + jsonParams, err := json.Marshal(sendData) require.NoError(t, err) data := base64.URLEncoding.EncodeToString([]byte(jsonParams)) @@ -242,3 +264,31 @@ func TestPostRequest(t *testing.T) { require.NoError(t, err) require.Equal(t, testData, string(result)) } + +func TestTimeout(t *testing.T) { + response := testEntryServer(t, "/get/timeout", nil, false, option{Key: "ResponseHeaderTimeout", Value: "10ms"}) + require.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestTimeoutWithCustomStatusCode(t *testing.T) { + response := testEntryServer(t, "/get/timeout", nil, false, option{Key: "ResponseHeaderTimeout", Value: "10ms"}, option{Key: "TimeoutResponseStatus", Value: http.StatusTeapot}) + require.Equal(t, http.StatusTeapot, response.Code) +} + +func TestErrorWithCustomStatusCode(t *testing.T) { + sendData := map[string]interface{}{ + "URL": "url", + "ErrorResponseStatus": http.StatusTeapot, + } + + jsonParams, err := json.Marshal(sendData) + require.NoError(t, err) + data := base64.URLEncoding.EncodeToString([]byte(jsonParams)) + + response := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/target", nil) + + SendURL.Inject(response, request, data) + + require.Equal(t, http.StatusTeapot, response.Code) +} diff --git a/workhorse/internal/transport/transport.go b/workhorse/internal/transport/transport.go index f19d332a28a..87fa73985a9 100644 --- a/workhorse/internal/transport/transport.go +++ b/workhorse/internal/transport/transport.go @@ -1,6 +1,7 @@ package transport import ( + "net" "net/http" "time" @@ -41,6 +42,20 @@ func WithDisabledCompression() Option { } } +func WithDialTimeout(timeout time.Duration) Option { + return func(t *http.Transport) { + t.DialContext = (&net.Dialer{ + Timeout: timeout, + }).DialContext + } +} + +func WithResponseHeaderTimeout(timeout time.Duration) Option { + return func(t *http.Transport) { + t.ResponseHeaderTimeout = timeout + } +} + func newRestrictedTransport(options ...Option) http.RoundTripper { t := http.DefaultTransport.(*http.Transport).Clone() |