Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/helper')
-rw-r--r--workhorse/internal/helper/context_reader.go40
-rw-r--r--workhorse/internal/helper/context_reader_test.go83
-rw-r--r--workhorse/internal/helper/countingresponsewriter.go56
-rw-r--r--workhorse/internal/helper/countingresponsewriter_test.go50
-rw-r--r--workhorse/internal/helper/helpers.go217
-rw-r--r--workhorse/internal/helper/helpers_test.go258
-rw-r--r--workhorse/internal/helper/raven.go58
-rw-r--r--workhorse/internal/helper/tempfile.go35
-rw-r--r--workhorse/internal/helper/writeafterreader.go144
-rw-r--r--workhorse/internal/helper/writeafterreader_test.go115
10 files changed, 1056 insertions, 0 deletions
diff --git a/workhorse/internal/helper/context_reader.go b/workhorse/internal/helper/context_reader.go
new file mode 100644
index 00000000000..a4764043147
--- /dev/null
+++ b/workhorse/internal/helper/context_reader.go
@@ -0,0 +1,40 @@
+package helper
+
+import (
+ "context"
+ "io"
+)
+
+type ContextReader struct {
+ ctx context.Context
+ underlyingReader io.Reader
+}
+
+func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader {
+ return &ContextReader{
+ ctx: ctx,
+ underlyingReader: underlyingReader,
+ }
+}
+
+func (r *ContextReader) Read(b []byte) (int, error) {
+ if r.canceled() {
+ return 0, r.err()
+ }
+
+ n, err := r.underlyingReader.Read(b)
+
+ if r.canceled() {
+ err = r.err()
+ }
+
+ return n, err
+}
+
+func (r *ContextReader) canceled() bool {
+ return r.err() != nil
+}
+
+func (r *ContextReader) err() error {
+ return r.ctx.Err()
+}
diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go
new file mode 100644
index 00000000000..257ec4e35f2
--- /dev/null
+++ b/workhorse/internal/helper/context_reader_test.go
@@ -0,0 +1,83 @@
+package helper
+
+import (
+ "context"
+ "io"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type fakeReader struct {
+ n int
+ err error
+}
+
+func (f *fakeReader) Read(b []byte) (int, error) {
+ return f.n, f.err
+}
+
+type fakeContextWithTimeout struct {
+ n int
+ threshold int
+}
+
+func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) {
+ return
+}
+
+func (*fakeContextWithTimeout) Done() <-chan struct{} {
+ return nil
+}
+
+func (*fakeContextWithTimeout) Value(key interface{}) interface{} {
+ return nil
+}
+
+func (f *fakeContextWithTimeout) Err() error {
+ f.n++
+ if f.n > f.threshold {
+ return context.DeadlineExceeded
+ }
+
+ return nil
+}
+
+func TestContextReaderRead(t *testing.T) {
+ underlyingReader := &fakeReader{n: 1, err: io.EOF}
+
+ for _, tc := range []struct {
+ desc string
+ ctx *fakeContextWithTimeout
+ expectedN int
+ expectedErr error
+ }{
+ {
+ desc: "Before and after read deadline checks are fine",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 2},
+ expectedN: underlyingReader.n,
+ expectedErr: underlyingReader.err,
+ },
+ {
+ desc: "Before read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 0},
+ expectedN: 0,
+ expectedErr: context.DeadlineExceeded,
+ },
+ {
+ desc: "After read deadline check fails",
+ ctx: &fakeContextWithTimeout{n: 0, threshold: 1},
+ expectedN: underlyingReader.n,
+ expectedErr: context.DeadlineExceeded,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ cr := NewContextReader(tc.ctx, underlyingReader)
+
+ n, err := cr.Read(nil)
+ require.Equal(t, tc.expectedN, n)
+ require.Equal(t, tc.expectedErr, err)
+ })
+ }
+}
diff --git a/workhorse/internal/helper/countingresponsewriter.go b/workhorse/internal/helper/countingresponsewriter.go
new file mode 100644
index 00000000000..a79d51d4c6a
--- /dev/null
+++ b/workhorse/internal/helper/countingresponsewriter.go
@@ -0,0 +1,56 @@
+package helper
+
+import (
+ "net/http"
+)
+
+type CountingResponseWriter interface {
+ http.ResponseWriter
+ Count() int64
+ Status() int
+}
+
+type countingResponseWriter struct {
+ rw http.ResponseWriter
+ status int
+ count int64
+}
+
+func NewCountingResponseWriter(rw http.ResponseWriter) CountingResponseWriter {
+ return &countingResponseWriter{rw: rw}
+}
+
+func (c *countingResponseWriter) Header() http.Header {
+ return c.rw.Header()
+}
+
+func (c *countingResponseWriter) Write(data []byte) (int, error) {
+ if c.status == 0 {
+ c.WriteHeader(http.StatusOK)
+ }
+
+ n, err := c.rw.Write(data)
+ c.count += int64(n)
+ return n, err
+}
+
+func (c *countingResponseWriter) WriteHeader(status int) {
+ if c.status != 0 {
+ return
+ }
+
+ c.status = status
+ c.rw.WriteHeader(status)
+}
+
+// Count returns the number of bytes written to the ResponseWriter. This
+// function is not thread-safe.
+func (c *countingResponseWriter) Count() int64 {
+ return c.count
+}
+
+// Status returns the first HTTP status value that was written to the
+// ResponseWriter. This function is not thread-safe.
+func (c *countingResponseWriter) Status() int {
+ return c.status
+}
diff --git a/workhorse/internal/helper/countingresponsewriter_test.go b/workhorse/internal/helper/countingresponsewriter_test.go
new file mode 100644
index 00000000000..f9f2f4ced5b
--- /dev/null
+++ b/workhorse/internal/helper/countingresponsewriter_test.go
@@ -0,0 +1,50 @@
+package helper
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "testing"
+ "testing/iotest"
+
+ "github.com/stretchr/testify/require"
+)
+
+type testResponseWriter struct {
+ data []byte
+}
+
+func (*testResponseWriter) WriteHeader(int) {}
+func (*testResponseWriter) Header() http.Header { return nil }
+
+func (trw *testResponseWriter) Write(p []byte) (int, error) {
+ trw.data = append(trw.data, p...)
+ return len(p), nil
+}
+
+func TestCountingResponseWriterStatus(t *testing.T) {
+ crw := NewCountingResponseWriter(&testResponseWriter{})
+ crw.WriteHeader(123)
+ crw.WriteHeader(456)
+ require.Equal(t, 123, crw.Status())
+}
+
+func TestCountingResponseWriterCount(t *testing.T) {
+ crw := NewCountingResponseWriter(&testResponseWriter{})
+ for _, n := range []int{1, 2, 4, 8, 16, 32} {
+ _, err := crw.Write(bytes.Repeat([]byte{'.'}, n))
+ require.NoError(t, err)
+ }
+ require.Equal(t, int64(63), crw.Count())
+}
+
+func TestCountingResponseWriterWrite(t *testing.T) {
+ trw := &testResponseWriter{}
+ crw := NewCountingResponseWriter(trw)
+
+ testData := []byte("test data")
+ _, err := io.Copy(crw, iotest.OneByteReader(bytes.NewReader(testData)))
+ require.NoError(t, err)
+
+ require.Equal(t, string(testData), string(trw.data))
+}
diff --git a/workhorse/internal/helper/helpers.go b/workhorse/internal/helper/helpers.go
new file mode 100644
index 00000000000..5f1e5fc51b3
--- /dev/null
+++ b/workhorse/internal/helper/helpers.go
@@ -0,0 +1,217 @@
+package helper
+
+import (
+ "bytes"
+ "errors"
+ "io/ioutil"
+ "mime"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "strings"
+ "syscall"
+
+ "github.com/sebest/xff"
+ "gitlab.com/gitlab-org/labkit/log"
+ "gitlab.com/gitlab-org/labkit/mask"
+)
+
+const NginxResponseBufferHeader = "X-Accel-Buffering"
+
+func LogError(r *http.Request, err error) {
+ LogErrorWithFields(r, err, nil)
+}
+
+func LogErrorWithFields(r *http.Request, err error, fields log.Fields) {
+ if err != nil {
+ captureRavenError(r, err, fields)
+ }
+
+ printError(r, err, fields)
+}
+
+func CaptureAndFail(w http.ResponseWriter, r *http.Request, err error, msg string, code int) {
+ http.Error(w, msg, code)
+ LogError(r, err)
+}
+
+func CaptureAndFailWithFields(w http.ResponseWriter, r *http.Request, err error, msg string, code int, fields log.Fields) {
+ http.Error(w, msg, code)
+ LogErrorWithFields(r, err, fields)
+}
+
+func Fail500(w http.ResponseWriter, r *http.Request, err error) {
+ CaptureAndFail(w, r, err, "Internal server error", http.StatusInternalServerError)
+}
+
+func Fail500WithFields(w http.ResponseWriter, r *http.Request, err error, fields log.Fields) {
+ CaptureAndFailWithFields(w, r, err, "Internal server error", http.StatusInternalServerError, fields)
+}
+
+func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, err error) {
+ CaptureAndFail(w, r, err, "Request Entity Too Large", http.StatusRequestEntityTooLarge)
+}
+
+func printError(r *http.Request, err error, fields log.Fields) {
+ if r != nil {
+ entry := log.WithContextFields(r.Context(), log.Fields{
+ "method": r.Method,
+ "uri": mask.URL(r.RequestURI),
+ })
+ entry.WithFields(fields).WithError(err).Error("error")
+ } else {
+ log.WithFields(fields).WithError(err).Error("unknown error")
+ }
+}
+
+func SetNoCacheHeaders(header http.Header) {
+ header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
+ header.Set("Pragma", "no-cache")
+ header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
+}
+
+func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
+ file, err = os.Open(path)
+ if err != nil {
+ return
+ }
+
+ defer func() {
+ if err != nil {
+ file.Close()
+ }
+ }()
+
+ fi, err = file.Stat()
+ if err != nil {
+ return
+ }
+
+ // The os.Open can also open directories
+ if fi.IsDir() {
+ err = &os.PathError{
+ Op: "open",
+ Path: path,
+ Err: errors.New("path is directory"),
+ }
+ return
+ }
+
+ return
+}
+
+func URLMustParse(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ log.WithError(err).WithField("url", s).Fatal("urlMustParse")
+ }
+ return u
+}
+
+func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
+ if r.ProtoAtLeast(1, 1) {
+ // Force client to disconnect if we render request error
+ w.Header().Set("Connection", "close")
+ }
+
+ http.Error(w, error, code)
+}
+
+func HeaderClone(h http.Header) http.Header {
+ h2 := make(http.Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+func CleanUpProcessGroup(cmd *exec.Cmd) {
+ if cmd == nil {
+ return
+ }
+
+ process := cmd.Process
+ if process != nil && process.Pid > 0 {
+ // Send SIGTERM to the process group of cmd
+ syscall.Kill(-process.Pid, syscall.SIGTERM)
+ }
+
+ // reap our child process
+ cmd.Wait()
+}
+
+func ExitStatus(err error) (int, bool) {
+ exitError, ok := err.(*exec.ExitError)
+ if !ok {
+ return 0, false
+ }
+
+ waitStatus, ok := exitError.Sys().(syscall.WaitStatus)
+ if !ok {
+ return 0, false
+ }
+
+ return waitStatus.ExitStatus(), true
+}
+
+func DisableResponseBuffering(w http.ResponseWriter) {
+ w.Header().Set(NginxResponseBufferHeader, "no")
+}
+
+func AllowResponseBuffering(w http.ResponseWriter) {
+ w.Header().Del(NginxResponseBufferHeader)
+}
+
+func FixRemoteAddr(r *http.Request) {
+ // Unix domain sockets have a remote addr of @. This will make the
+ // xff package lookup the X-Forwarded-For address if available.
+ if r.RemoteAddr == "@" {
+ r.RemoteAddr = "127.0.0.1:0"
+ }
+ r.RemoteAddr = xff.GetRemoteAddr(r)
+}
+
+func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
+ if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil {
+ var header string
+
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok {
+ header = strings.Join(prior, ", ") + ", " + clientIP
+ } else {
+ header = clientIP
+ }
+ newHeaders.Set("X-Forwarded-For", header)
+ }
+}
+
+func IsContentType(expected, actual string) bool {
+ parsed, _, err := mime.ParseMediaType(actual)
+ return err == nil && parsed == expected
+}
+
+func IsApplicationJson(r *http.Request) bool {
+ contentType := r.Header.Get("Content-Type")
+ return IsContentType("application/json", contentType)
+}
+
+func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) {
+ limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize)
+ defer limitedBody.Close()
+
+ return ioutil.ReadAll(limitedBody)
+}
+
+func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request {
+ newReq := *r
+ newReq.Body = ioutil.NopCloser(bytes.NewReader(body))
+ newReq.Header = HeaderClone(r.Header)
+ newReq.ContentLength = int64(len(body))
+ return &newReq
+}
diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go
new file mode 100644
index 00000000000..6a895aded03
--- /dev/null
+++ b/workhorse/internal/helper/helpers_test.go
@@ -0,0 +1,258 @@
+package helper
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFixRemoteAddr(t *testing.T) {
+ testCases := []struct {
+ initial string
+ forwarded string
+ expected string
+ }{
+ {initial: "@", forwarded: "", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
+ {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
+ {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
+ }
+
+ for _, tc := range testCases {
+ req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
+ require.NoError(t, err)
+
+ req.RemoteAddr = tc.initial
+
+ if tc.forwarded != "" {
+ req.Header.Add("X-Forwarded-For", tc.forwarded)
+ }
+
+ FixRemoteAddr(req)
+
+ require.Equal(t, tc.expected, req.RemoteAddr)
+ }
+}
+
+func TestSetForwardedForGeneratesHeader(t *testing.T) {
+ testCases := []struct {
+ remoteAddr string
+ previousForwardedFor []string
+ expected string
+ }{
+ {
+ "8.8.8.8:3000",
+ nil,
+ "8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"138.124.33.63, 151.146.211.237"},
+ "138.124.33.63, 151.146.211.237, 8.8.8.8",
+ },
+ {
+ "8.8.8.8:3000",
+ []string{"8.154.76.107", "115.206.118.179"},
+ "8.154.76.107, 115.206.118.179, 8.8.8.8",
+ },
+ }
+ for _, tc := range testCases {
+ headers := http.Header{}
+ originalRequest := http.Request{
+ RemoteAddr: tc.remoteAddr,
+ }
+
+ if tc.previousForwardedFor != nil {
+ originalRequest.Header = http.Header{
+ "X-Forwarded-For": tc.previousForwardedFor,
+ }
+ }
+
+ SetForwardedFor(&headers, &originalRequest)
+
+ result := headers.Get("X-Forwarded-For")
+ if result != tc.expected {
+ t.Fatalf("Expected %v, got %v", tc.expected, result)
+ }
+ }
+}
+
+func TestReadRequestBody(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ result, err := ReadRequestBody(rw, req, 1000)
+ require.NoError(t, err)
+ require.Equal(t, data, result)
+}
+
+func TestReadRequestBodyLimit(t *testing.T) {
+ data := []byte("123456")
+ rw := httptest.NewRecorder()
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
+
+ _, err := ReadRequestBody(rw, req, 2)
+ require.Error(t, err)
+}
+
+func TestCloneRequestWithBody(t *testing.T) {
+ input := []byte("test")
+ newInput := []byte("new body")
+ req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input))
+ newReq := CloneRequestWithNewBody(req, newInput)
+
+ require.NotEqual(t, req, newReq)
+ require.NotEqual(t, req.Body, newReq.Body)
+ require.NotEqual(t, len(newInput), newReq.ContentLength)
+
+ var buffer bytes.Buffer
+ io.Copy(&buffer, newReq.Body)
+ require.Equal(t, newInput, buffer.Bytes())
+}
+
+func TestApplicationJson(t *testing.T) {
+ req, _ := http.NewRequest("POST", "/test", nil)
+ req.Header.Set("Content-Type", "application/json")
+
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'")
+
+ req.Header.Set("Content-Type", "application/json; charset=utf-8")
+ require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'")
+
+ req.Header.Set("Content-Type", "text/plain")
+ require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'")
+}
+
+func TestFail500WorksWithNils(t *testing.T) {
+ body := bytes.NewBuffer(nil)
+ w := httptest.NewRecorder()
+ w.Body = body
+
+ Fail500(w, nil, nil)
+
+ require.Equal(t, http.StatusInternalServerError, w.Code)
+ require.Equal(t, "Internal server error\n", body.String())
+}
+
+func TestLogError(t *testing.T) {
+ tests := []struct {
+ name string
+ method string
+ uri string
+ err error
+ logMatchers []string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ logMatchers: []string{
+ `level=error msg="unknown error" error=crash`,
+ },
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ logMatchers: []string{
+ `level=error msg="unknown error" error="<nil>"`,
+ },
+ },
+ {
+ name: "basic_url",
+ method: "GET",
+ uri: "http://localhost:3000/",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/"`,
+ },
+ },
+ {
+ name: "secret_url",
+ method: "GET",
+ uri: "http://localhost:3000/path?certificate=123&sharedSecret=123&import_url=the_url&my_password_string=password",
+ err: fmt.Errorf("error"),
+ logMatchers: []string{
+ `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/path\?certificate=\[FILTERED\]&sharedSecret=\[FILTERED\]&import_url=\[FILTERED\]&my_password_string=\[FILTERED\]"`,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ var r *http.Request
+ if tt.uri != "" {
+ r = httptest.NewRequest(tt.method, tt.uri, nil)
+ }
+
+ LogError(r, tt.err)
+
+ logString := buf.String()
+ for _, v := range tt.logMatchers {
+ require.Regexp(t, v, logString)
+ }
+ })
+ }
+}
+
+func TestLogErrorWithFields(t *testing.T) {
+ tests := []struct {
+ name string
+ request *http.Request
+ err error
+ fields map[string]interface{}
+ logMatcher string
+ }{
+ {
+ name: "nil_request",
+ err: fmt.Errorf("crash"),
+ fields: map[string]interface{}{"extra_one": 123},
+ logMatcher: `level=error msg="unknown error" error=crash extra_one=123`,
+ },
+ {
+ name: "nil_request_nil_error",
+ err: nil,
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg="unknown error" error="<nil>" extra_one=123 extra_two=test`,
+ },
+ {
+ name: "basic_url",
+ request: httptest.NewRequest("GET", "http://localhost:3000/", nil),
+ err: fmt.Errorf("error"),
+ fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"},
+ logMatcher: `level=error msg=error correlation_id= error=error extra_one=123 extra_two=test method=GET uri="http://localhost:3000/`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ buf := &bytes.Buffer{}
+
+ oldOut := logrus.StandardLogger().Out
+ logrus.StandardLogger().Out = buf
+ defer func() {
+ logrus.StandardLogger().Out = oldOut
+ }()
+
+ LogErrorWithFields(tt.request, tt.err, tt.fields)
+
+ logString := buf.String()
+ require.Contains(t, logString, tt.logMatcher)
+ })
+ }
+}
diff --git a/workhorse/internal/helper/raven.go b/workhorse/internal/helper/raven.go
new file mode 100644
index 00000000000..ea1d0e1f6cc
--- /dev/null
+++ b/workhorse/internal/helper/raven.go
@@ -0,0 +1,58 @@
+package helper
+
+import (
+ "net/http"
+ "reflect"
+
+ raven "github.com/getsentry/raven-go"
+
+ //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package.
+ correlation "gitlab.com/gitlab-org/labkit/correlation/raven"
+
+ "gitlab.com/gitlab-org/labkit/log"
+)
+
+var ravenHeaderBlacklist = []string{
+ "Authorization",
+ "Private-Token",
+}
+
+func captureRavenError(r *http.Request, err error, fields log.Fields) {
+ client := raven.DefaultClient
+ extra := raven.Extra{}
+
+ for k, v := range fields {
+ extra[k] = v
+ }
+
+ interfaces := []raven.Interface{}
+ if r != nil {
+ CleanHeadersForRaven(r)
+ interfaces = append(interfaces, raven.NewHttp(r))
+
+ //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package.
+ extra = correlation.SetExtra(r.Context(), extra)
+ }
+
+ exception := &raven.Exception{
+ Stacktrace: raven.NewStacktrace(2, 3, nil),
+ Value: err.Error(),
+ Type: reflect.TypeOf(err).String(),
+ }
+ interfaces = append(interfaces, exception)
+
+ packet := raven.NewPacketWithExtra(err.Error(), extra, interfaces...)
+ client.Capture(packet, nil)
+}
+
+func CleanHeadersForRaven(r *http.Request) {
+ if r == nil {
+ return
+ }
+
+ for _, key := range ravenHeaderBlacklist {
+ if r.Header.Get(key) != "" {
+ r.Header.Set(key, "[redacted]")
+ }
+ }
+}
diff --git a/workhorse/internal/helper/tempfile.go b/workhorse/internal/helper/tempfile.go
new file mode 100644
index 00000000000..d8fc0d44698
--- /dev/null
+++ b/workhorse/internal/helper/tempfile.go
@@ -0,0 +1,35 @@
+package helper
+
+import (
+ "io"
+ "io/ioutil"
+ "os"
+)
+
+func ReadAllTempfile(r io.Reader) (tempfile *os.File, err error) {
+ tempfile, err = ioutil.TempFile("", "gitlab-workhorse-read-all-tempfile")
+ if err != nil {
+ return nil, err
+ }
+
+ defer func() {
+ // Avoid leaking an open file if the function returns with an error
+ if err != nil {
+ tempfile.Close()
+ }
+ }()
+
+ if err := os.Remove(tempfile.Name()); err != nil {
+ return nil, err
+ }
+
+ if _, err := io.Copy(tempfile, r); err != nil {
+ return nil, err
+ }
+
+ if _, err := tempfile.Seek(0, 0); err != nil {
+ return nil, err
+ }
+
+ return tempfile, nil
+}
diff --git a/workhorse/internal/helper/writeafterreader.go b/workhorse/internal/helper/writeafterreader.go
new file mode 100644
index 00000000000..d583ae4a9b8
--- /dev/null
+++ b/workhorse/internal/helper/writeafterreader.go
@@ -0,0 +1,144 @@
+package helper
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "sync"
+)
+
+type WriteFlusher interface {
+ io.Writer
+ Flush() error
+}
+
+// Couple r and w so that until r has been drained (before r.Read() has
+// returned some error), all writes to w are sent to a tempfile first.
+// The caller must call Flush() on the returned WriteFlusher to ensure
+// all data is propagated to w.
+func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) {
+ br := &busyReader{Reader: r}
+ return br, &coupledWriter{Writer: w, busyReader: br}
+}
+
+type busyReader struct {
+ io.Reader
+
+ error
+ errorMutex sync.RWMutex
+}
+
+func (r *busyReader) Read(p []byte) (int, error) {
+ if err := r.getError(); err != nil {
+ return 0, err
+ }
+
+ n, err := r.Reader.Read(p)
+ if err != nil {
+ if err != io.EOF {
+ err = fmt.Errorf("busyReader: %v", err)
+ }
+ r.setError(err)
+ }
+ return n, err
+}
+
+func (r *busyReader) IsBusy() bool {
+ return r.getError() == nil
+}
+
+func (r *busyReader) getError() error {
+ r.errorMutex.RLock()
+ defer r.errorMutex.RUnlock()
+ return r.error
+}
+
+func (r *busyReader) setError(err error) {
+ if err == nil {
+ panic("busyReader: attempt to reset error to nil")
+ }
+ r.errorMutex.Lock()
+ defer r.errorMutex.Unlock()
+ r.error = err
+}
+
+type coupledWriter struct {
+ io.Writer
+ *busyReader
+
+ tempfile *os.File
+ tempfileMutex sync.Mutex
+
+ writeError error
+}
+
+func (w *coupledWriter) Write(data []byte) (int, error) {
+ if w.writeError != nil {
+ return 0, w.writeError
+ }
+
+ if w.busyReader.IsBusy() {
+ n, err := w.tempfileWrite(data)
+ if err != nil {
+ w.writeError = fmt.Errorf("coupledWriter: %v", err)
+ }
+ return n, w.writeError
+ }
+
+ if err := w.Flush(); err != nil {
+ w.writeError = fmt.Errorf("coupledWriter: %v", err)
+ return 0, w.writeError
+ }
+
+ return w.Writer.Write(data)
+}
+
+func (w *coupledWriter) Flush() error {
+ w.tempfileMutex.Lock()
+ defer w.tempfileMutex.Unlock()
+
+ tempfile := w.tempfile
+ if tempfile == nil {
+ return nil
+ }
+
+ w.tempfile = nil
+ defer tempfile.Close()
+
+ if _, err := tempfile.Seek(0, 0); err != nil {
+ return err
+ }
+ if _, err := io.Copy(w.Writer, tempfile); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (w *coupledWriter) tempfileWrite(data []byte) (int, error) {
+ w.tempfileMutex.Lock()
+ defer w.tempfileMutex.Unlock()
+
+ if w.tempfile == nil {
+ tempfile, err := w.newTempfile()
+ if err != nil {
+ return 0, err
+ }
+ w.tempfile = tempfile
+ }
+
+ return w.tempfile.Write(data)
+}
+
+func (*coupledWriter) newTempfile() (tempfile *os.File, err error) {
+ tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter")
+ if err != nil {
+ return nil, err
+ }
+ if err := os.Remove(tempfile.Name()); err != nil {
+ tempfile.Close()
+ return nil, err
+ }
+
+ return tempfile, nil
+}
diff --git a/workhorse/internal/helper/writeafterreader_test.go b/workhorse/internal/helper/writeafterreader_test.go
new file mode 100644
index 00000000000..67cb3e6e542
--- /dev/null
+++ b/workhorse/internal/helper/writeafterreader_test.go
@@ -0,0 +1,115 @@
+package helper
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "testing"
+ "testing/iotest"
+)
+
+func TestBusyReader(t *testing.T) {
+ testData := "test data"
+ r := testReader(testData)
+ br, _ := NewWriteAfterReader(r, &bytes.Buffer{})
+
+ result, err := ioutil.ReadAll(br)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if string(result) != testData {
+ t.Fatalf("expected %q, got %q", testData, result)
+ }
+}
+
+func TestFirstWriteAfterReadDone(t *testing.T) {
+ writeRecorder := &bytes.Buffer{}
+ br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder)
+ if _, err := io.Copy(ioutil.Discard, br); err != nil {
+ t.Fatalf("copy from busyreader: %v", err)
+ }
+ testData := "test data"
+ if _, err := io.Copy(cw, testReader(testData)); err != nil {
+ t.Fatalf("copy test data: %v", err)
+ }
+ if err := cw.Flush(); err != nil {
+ t.Fatalf("flush error: %v", err)
+ }
+ if result := writeRecorder.String(); result != testData {
+ t.Fatalf("expected %q, got %q", testData, result)
+ }
+}
+
+func TestWriteDelay(t *testing.T) {
+ writeRecorder := &bytes.Buffer{}
+ w := &complainingWriter{Writer: writeRecorder}
+ br, cw := NewWriteAfterReader(&bytes.Buffer{}, w)
+
+ testData1 := "1 test"
+ if _, err := io.Copy(cw, testReader(testData1)); err != nil {
+ t.Fatalf("error on first copy: %v", err)
+ }
+
+ // Unblock the coupled writer by draining the reader
+ if _, err := io.Copy(ioutil.Discard, br); err != nil {
+ t.Fatalf("copy from busyreader: %v", err)
+ }
+ // Now it is no longer an error if 'w' receives a Write()
+ w.CheerUp()
+
+ testData2 := "2 experiment"
+ if _, err := io.Copy(cw, testReader(testData2)); err != nil {
+ t.Fatalf("error on second copy: %v", err)
+ }
+
+ if err := cw.Flush(); err != nil {
+ t.Fatalf("flush error: %v", err)
+ }
+
+ expected := testData1 + testData2
+ if result := writeRecorder.String(); result != expected {
+ t.Fatalf("total write: expected %q, got %q", expected, result)
+ }
+}
+
+func TestComplainingWriterSanity(t *testing.T) {
+ recorder := &bytes.Buffer{}
+ w := &complainingWriter{Writer: recorder}
+
+ testData := "test data"
+ if _, err := io.Copy(w, testReader(testData)); err == nil {
+ t.Error("error expected, none received")
+ }
+
+ w.CheerUp()
+ if _, err := io.Copy(w, testReader(testData)); err != nil {
+ t.Errorf("copy after CheerUp: %v", err)
+ }
+
+ if result := recorder.String(); result != testData {
+ t.Errorf("expected %q, got %q", testData, result)
+ }
+}
+
+func testReader(data string) io.Reader {
+ return iotest.OneByteReader(bytes.NewBuffer([]byte(data)))
+}
+
+type complainingWriter struct {
+ happy bool
+ io.Writer
+}
+
+func (comp *complainingWriter) Write(data []byte) (int, error) {
+ if comp.happy {
+ return comp.Writer.Write(data)
+ }
+
+ return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data)
+}
+
+func (comp *complainingWriter) CheerUp() {
+ comp.happy = true
+}