diff options
author | GitLab Bot <gitlab-bot@gitlab.com> | 2023-02-02 15:07:33 +0300 |
---|---|---|
committer | GitLab Bot <gitlab-bot@gitlab.com> | 2023-02-02 15:07:33 +0300 |
commit | ae9f43a2c4bda0ee7dae59ea9a7d412068f6f7ff (patch) | |
tree | 0617cf9d21ee7b8cf0ba7c120781475050c6a7a6 /workhorse/internal/badgateway | |
parent | 4fbfae83afa1ea64ba4969bff2b459b4562944e4 (diff) |
Add latest changes from gitlab-org/gitlab@master
Diffstat (limited to 'workhorse/internal/badgateway')
-rw-r--r-- | workhorse/internal/badgateway/roundtripper.go | 10 | ||||
-rw-r--r-- | workhorse/internal/badgateway/roundtripper_test.go | 35 |
2 files changed, 43 insertions, 2 deletions
diff --git a/workhorse/internal/badgateway/roundtripper.go b/workhorse/internal/badgateway/roundtripper.go index cc982b092a7..ce4e9e6a177 100644 --- a/workhorse/internal/badgateway/roundtripper.go +++ b/workhorse/internal/badgateway/roundtripper.go @@ -2,6 +2,7 @@ package badgateway import ( "bytes" + "context" _ "embed" "encoding/base64" "fmt" @@ -47,9 +48,14 @@ func (t *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { fields := log.Fields{"duration_ms": int64(time.Since(start).Seconds() * 1000)} log.WithRequest(r).WithFields(fields).WithError(&sentryError{fmt.Errorf("badgateway: failed to receive response: %v", err)}).Error() + code := http.StatusBadGateway + if r.Context().Err() == context.Canceled { + code = 499 // Code used by NGINX when client disconnects + } + injectedResponse := &http.Response{ - StatusCode: http.StatusBadGateway, - Status: http.StatusText(http.StatusBadGateway), + StatusCode: code, + Status: http.StatusText(code), Request: r, ProtoMajor: r.ProtoMajor, diff --git a/workhorse/internal/badgateway/roundtripper_test.go b/workhorse/internal/badgateway/roundtripper_test.go index b59cb8d2c5b..ed2de452f80 100644 --- a/workhorse/internal/badgateway/roundtripper_test.go +++ b/workhorse/internal/badgateway/roundtripper_test.go @@ -1,9 +1,11 @@ package badgateway import ( + "context" "errors" "io" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/require" @@ -54,3 +56,36 @@ func TestErrorPage502(t *testing.T) { }) } } + +func TestClientDisconnect499(t *testing.T) { + serverSync := make(chan struct{}) + ts := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + serverSync <- struct{}{} + <-serverSync + })) + defer func() { + close(serverSync) + ts.Close() + }() + + clientResponse := make(chan *http.Response) + clientContext, clientCancel := context.WithCancel(context.Background()) + + go func() { + req, err := http.NewRequestWithContext(clientContext, "GET", ts.URL, nil) + require.NoError(t, err, "build request") + + rt := NewRoundTripper(false, http.DefaultTransport) + response, err := rt.RoundTrip(req) + require.NoError(t, err, "perform roundtrip") + require.NoError(t, response.Body.Close()) + + clientResponse <- response + }() + + <-serverSync + + clientCancel() + response := <-clientResponse + require.Equal(t, 499, response.StatusCode, "response status") +} |