diff options
Diffstat (limited to 'workhorse/internal/builds/register.go')
-rw-r--r-- | workhorse/internal/builds/register.go | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/workhorse/internal/builds/register.go b/workhorse/internal/builds/register.go index f28ad75e1d8..0a2fe47ed7e 100644 --- a/workhorse/internal/builds/register.go +++ b/workhorse/internal/builds/register.go @@ -1,8 +1,10 @@ package builds import ( + "bytes" "encoding/json" "errors" + "io" "net/http" "time" @@ -10,6 +12,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper/fail" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" ) @@ -63,11 +66,18 @@ func readRunnerBody(w http.ResponseWriter, r *http.Request) ([]byte, error) { registerHandlerOpenAtReading.Inc() defer registerHandlerOpenAtReading.Dec() - return helper.ReadRequestBody(w, r, maxRegisterBodySize) + return readRequestBody(w, r, maxRegisterBodySize) +} + +func readRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) { + limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize) + defer limitedBody.Close() + + return io.ReadAll(limitedBody) } func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) { - if !helper.IsApplicationJson(r) { + if !isApplicationJson(r) { return nil, errors.New("invalid content-type received") } @@ -80,6 +90,11 @@ func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) { return &runnerRequest, nil } +func isApplicationJson(r *http.Request) bool { + contentType := r.Header.Get("Content-Type") + return helper.IsContentType("application/json", contentType) +} + func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request) { registerHandlerOpenAtProxying.Inc() defer registerHandlerOpenAtProxying.Dec() @@ -105,11 +120,12 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati requestBody, err := readRunnerBody(w, r) if err != nil { registerHandlerBodyReadErrors.Inc() - helper.RequestEntityTooLarge(w, r, &largeBodyError{err}) + fail.Request(w, r, &largeBodyError{err}, + fail.WithStatus(http.StatusRequestEntityTooLarge)) return } - newRequest := helper.CloneRequestWithNewBody(r, requestBody) + newRequest := cloneRequestWithNewBody(r, requestBody) runnerRequest, err := readRunnerRequest(r, requestBody) if err != nil { @@ -161,3 +177,10 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati } }) } + +func cloneRequestWithNewBody(r *http.Request, body []byte) *http.Request { + newReq := r.Clone(r.Context()) + newReq.Body = io.NopCloser(bytes.NewReader(body)) + newReq.ContentLength = int64(len(body)) + return newReq +} |