diff options
Diffstat (limited to 'workhorse/internal/helper')
-rw-r--r-- | workhorse/internal/helper/context_reader.go | 40 | ||||
-rw-r--r-- | workhorse/internal/helper/context_reader_test.go | 83 | ||||
-rw-r--r-- | workhorse/internal/helper/countingresponsewriter.go | 56 | ||||
-rw-r--r-- | workhorse/internal/helper/countingresponsewriter_test.go | 50 | ||||
-rw-r--r-- | workhorse/internal/helper/helpers.go | 217 | ||||
-rw-r--r-- | workhorse/internal/helper/helpers_test.go | 258 | ||||
-rw-r--r-- | workhorse/internal/helper/raven.go | 58 | ||||
-rw-r--r-- | workhorse/internal/helper/tempfile.go | 35 | ||||
-rw-r--r-- | workhorse/internal/helper/writeafterreader.go | 144 | ||||
-rw-r--r-- | workhorse/internal/helper/writeafterreader_test.go | 115 |
10 files changed, 1056 insertions, 0 deletions
diff --git a/workhorse/internal/helper/context_reader.go b/workhorse/internal/helper/context_reader.go new file mode 100644 index 00000000000..a4764043147 --- /dev/null +++ b/workhorse/internal/helper/context_reader.go @@ -0,0 +1,40 @@ +package helper + +import ( + "context" + "io" +) + +type ContextReader struct { + ctx context.Context + underlyingReader io.Reader +} + +func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader { + return &ContextReader{ + ctx: ctx, + underlyingReader: underlyingReader, + } +} + +func (r *ContextReader) Read(b []byte) (int, error) { + if r.canceled() { + return 0, r.err() + } + + n, err := r.underlyingReader.Read(b) + + if r.canceled() { + err = r.err() + } + + return n, err +} + +func (r *ContextReader) canceled() bool { + return r.err() != nil +} + +func (r *ContextReader) err() error { + return r.ctx.Err() +} diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go new file mode 100644 index 00000000000..257ec4e35f2 --- /dev/null +++ b/workhorse/internal/helper/context_reader_test.go @@ -0,0 +1,83 @@ +package helper + +import ( + "context" + "io" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type fakeReader struct { + n int + err error +} + +func (f *fakeReader) Read(b []byte) (int, error) { + return f.n, f.err +} + +type fakeContextWithTimeout struct { + n int + threshold int +} + +func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*fakeContextWithTimeout) Done() <-chan struct{} { + return nil +} + +func (*fakeContextWithTimeout) Value(key interface{}) interface{} { + return nil +} + +func (f *fakeContextWithTimeout) Err() error { + f.n++ + if f.n > f.threshold { + return context.DeadlineExceeded + } + + return nil +} + +func TestContextReaderRead(t *testing.T) { + underlyingReader := &fakeReader{n: 1, err: io.EOF} + + for _, tc := range []struct { + desc string + ctx *fakeContextWithTimeout + expectedN int + expectedErr error + }{ + { + desc: "Before and after read deadline checks are fine", + ctx: &fakeContextWithTimeout{n: 0, threshold: 2}, + expectedN: underlyingReader.n, + expectedErr: underlyingReader.err, + }, + { + desc: "Before read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 0}, + expectedN: 0, + expectedErr: context.DeadlineExceeded, + }, + { + desc: "After read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 1}, + expectedN: underlyingReader.n, + expectedErr: context.DeadlineExceeded, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + cr := NewContextReader(tc.ctx, underlyingReader) + + n, err := cr.Read(nil) + require.Equal(t, tc.expectedN, n) + require.Equal(t, tc.expectedErr, err) + }) + } +} diff --git a/workhorse/internal/helper/countingresponsewriter.go b/workhorse/internal/helper/countingresponsewriter.go new file mode 100644 index 00000000000..a79d51d4c6a --- /dev/null +++ b/workhorse/internal/helper/countingresponsewriter.go @@ -0,0 +1,56 @@ +package helper + +import ( + "net/http" +) + +type CountingResponseWriter interface { + http.ResponseWriter + Count() int64 + Status() int +} + +type countingResponseWriter struct { + rw http.ResponseWriter + status int + count int64 +} + +func NewCountingResponseWriter(rw http.ResponseWriter) CountingResponseWriter { + return &countingResponseWriter{rw: rw} +} + +func (c *countingResponseWriter) Header() http.Header { + return c.rw.Header() +} + +func (c *countingResponseWriter) Write(data []byte) (int, error) { + if c.status == 0 { + c.WriteHeader(http.StatusOK) + } + + n, err := c.rw.Write(data) + c.count += int64(n) + return n, err +} + +func (c *countingResponseWriter) WriteHeader(status int) { + if c.status != 0 { + return + } + + c.status = status + c.rw.WriteHeader(status) +} + +// Count returns the number of bytes written to the ResponseWriter. This +// function is not thread-safe. +func (c *countingResponseWriter) Count() int64 { + return c.count +} + +// Status returns the first HTTP status value that was written to the +// ResponseWriter. This function is not thread-safe. +func (c *countingResponseWriter) Status() int { + return c.status +} diff --git a/workhorse/internal/helper/countingresponsewriter_test.go b/workhorse/internal/helper/countingresponsewriter_test.go new file mode 100644 index 00000000000..f9f2f4ced5b --- /dev/null +++ b/workhorse/internal/helper/countingresponsewriter_test.go @@ -0,0 +1,50 @@ +package helper + +import ( + "bytes" + "io" + "net/http" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/require" +) + +type testResponseWriter struct { + data []byte +} + +func (*testResponseWriter) WriteHeader(int) {} +func (*testResponseWriter) Header() http.Header { return nil } + +func (trw *testResponseWriter) Write(p []byte) (int, error) { + trw.data = append(trw.data, p...) + return len(p), nil +} + +func TestCountingResponseWriterStatus(t *testing.T) { + crw := NewCountingResponseWriter(&testResponseWriter{}) + crw.WriteHeader(123) + crw.WriteHeader(456) + require.Equal(t, 123, crw.Status()) +} + +func TestCountingResponseWriterCount(t *testing.T) { + crw := NewCountingResponseWriter(&testResponseWriter{}) + for _, n := range []int{1, 2, 4, 8, 16, 32} { + _, err := crw.Write(bytes.Repeat([]byte{'.'}, n)) + require.NoError(t, err) + } + require.Equal(t, int64(63), crw.Count()) +} + +func TestCountingResponseWriterWrite(t *testing.T) { + trw := &testResponseWriter{} + crw := NewCountingResponseWriter(trw) + + testData := []byte("test data") + _, err := io.Copy(crw, iotest.OneByteReader(bytes.NewReader(testData))) + require.NoError(t, err) + + require.Equal(t, string(testData), string(trw.data)) +} diff --git a/workhorse/internal/helper/helpers.go b/workhorse/internal/helper/helpers.go new file mode 100644 index 00000000000..5f1e5fc51b3 --- /dev/null +++ b/workhorse/internal/helper/helpers.go @@ -0,0 +1,217 @@ +package helper + +import ( + "bytes" + "errors" + "io/ioutil" + "mime" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "strings" + "syscall" + + "github.com/sebest/xff" + "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/labkit/mask" +) + +const NginxResponseBufferHeader = "X-Accel-Buffering" + +func LogError(r *http.Request, err error) { + LogErrorWithFields(r, err, nil) +} + +func LogErrorWithFields(r *http.Request, err error, fields log.Fields) { + if err != nil { + captureRavenError(r, err, fields) + } + + printError(r, err, fields) +} + +func CaptureAndFail(w http.ResponseWriter, r *http.Request, err error, msg string, code int) { + http.Error(w, msg, code) + LogError(r, err) +} + +func CaptureAndFailWithFields(w http.ResponseWriter, r *http.Request, err error, msg string, code int, fields log.Fields) { + http.Error(w, msg, code) + LogErrorWithFields(r, err, fields) +} + +func Fail500(w http.ResponseWriter, r *http.Request, err error) { + CaptureAndFail(w, r, err, "Internal server error", http.StatusInternalServerError) +} + +func Fail500WithFields(w http.ResponseWriter, r *http.Request, err error, fields log.Fields) { + CaptureAndFailWithFields(w, r, err, "Internal server error", http.StatusInternalServerError, fields) +} + +func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, err error) { + CaptureAndFail(w, r, err, "Request Entity Too Large", http.StatusRequestEntityTooLarge) +} + +func printError(r *http.Request, err error, fields log.Fields) { + if r != nil { + entry := log.WithContextFields(r.Context(), log.Fields{ + "method": r.Method, + "uri": mask.URL(r.RequestURI), + }) + entry.WithFields(fields).WithError(err).Error("error") + } else { + log.WithFields(fields).WithError(err).Error("unknown error") + } +} + +func SetNoCacheHeaders(header http.Header) { + header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + header.Set("Pragma", "no-cache") + header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") +} + +func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) { + file, err = os.Open(path) + if err != nil { + return + } + + defer func() { + if err != nil { + file.Close() + } + }() + + fi, err = file.Stat() + if err != nil { + return + } + + // The os.Open can also open directories + if fi.IsDir() { + err = &os.PathError{ + Op: "open", + Path: path, + Err: errors.New("path is directory"), + } + return + } + + return +} + +func URLMustParse(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + log.WithError(err).WithField("url", s).Fatal("urlMustParse") + } + return u +} + +func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) { + if r.ProtoAtLeast(1, 1) { + // Force client to disconnect if we render request error + w.Header().Set("Connection", "close") + } + + http.Error(w, error, code) +} + +func HeaderClone(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +func CleanUpProcessGroup(cmd *exec.Cmd) { + if cmd == nil { + return + } + + process := cmd.Process + if process != nil && process.Pid > 0 { + // Send SIGTERM to the process group of cmd + syscall.Kill(-process.Pid, syscall.SIGTERM) + } + + // reap our child process + cmd.Wait() +} + +func ExitStatus(err error) (int, bool) { + exitError, ok := err.(*exec.ExitError) + if !ok { + return 0, false + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + return 0, false + } + + return waitStatus.ExitStatus(), true +} + +func DisableResponseBuffering(w http.ResponseWriter) { + w.Header().Set(NginxResponseBufferHeader, "no") +} + +func AllowResponseBuffering(w http.ResponseWriter) { + w.Header().Del(NginxResponseBufferHeader) +} + +func FixRemoteAddr(r *http.Request) { + // Unix domain sockets have a remote addr of @. This will make the + // xff package lookup the X-Forwarded-For address if available. + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:0" + } + r.RemoteAddr = xff.GetRemoteAddr(r) +} + +func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { + if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil { + var header string + + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok { + header = strings.Join(prior, ", ") + ", " + clientIP + } else { + header = clientIP + } + newHeaders.Set("X-Forwarded-For", header) + } +} + +func IsContentType(expected, actual string) bool { + parsed, _, err := mime.ParseMediaType(actual) + return err == nil && parsed == expected +} + +func IsApplicationJson(r *http.Request) bool { + contentType := r.Header.Get("Content-Type") + return IsContentType("application/json", contentType) +} + +func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) { + limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize) + defer limitedBody.Close() + + return ioutil.ReadAll(limitedBody) +} + +func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request { + newReq := *r + newReq.Body = ioutil.NopCloser(bytes.NewReader(body)) + newReq.Header = HeaderClone(r.Header) + newReq.ContentLength = int64(len(body)) + return &newReq +} diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go new file mode 100644 index 00000000000..6a895aded03 --- /dev/null +++ b/workhorse/internal/helper/helpers_test.go @@ -0,0 +1,258 @@ +package helper + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestFixRemoteAddr(t *testing.T) { + testCases := []struct { + initial string + forwarded string + expected string + }{ + {initial: "@", forwarded: "", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"}, + {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"}, + {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + } + + for _, tc := range testCases { + req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil) + require.NoError(t, err) + + req.RemoteAddr = tc.initial + + if tc.forwarded != "" { + req.Header.Add("X-Forwarded-For", tc.forwarded) + } + + FixRemoteAddr(req) + + require.Equal(t, tc.expected, req.RemoteAddr) + } +} + +func TestSetForwardedForGeneratesHeader(t *testing.T) { + testCases := []struct { + remoteAddr string + previousForwardedFor []string + expected string + }{ + { + "8.8.8.8:3000", + nil, + "8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"138.124.33.63, 151.146.211.237"}, + "138.124.33.63, 151.146.211.237, 8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"8.154.76.107", "115.206.118.179"}, + "8.154.76.107, 115.206.118.179, 8.8.8.8", + }, + } + for _, tc := range testCases { + headers := http.Header{} + originalRequest := http.Request{ + RemoteAddr: tc.remoteAddr, + } + + if tc.previousForwardedFor != nil { + originalRequest.Header = http.Header{ + "X-Forwarded-For": tc.previousForwardedFor, + } + } + + SetForwardedFor(&headers, &originalRequest) + + result := headers.Get("X-Forwarded-For") + if result != tc.expected { + t.Fatalf("Expected %v, got %v", tc.expected, result) + } + } +} + +func TestReadRequestBody(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + result, err := ReadRequestBody(rw, req, 1000) + require.NoError(t, err) + require.Equal(t, data, result) +} + +func TestReadRequestBodyLimit(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + _, err := ReadRequestBody(rw, req, 2) + require.Error(t, err) +} + +func TestCloneRequestWithBody(t *testing.T) { + input := []byte("test") + newInput := []byte("new body") + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input)) + newReq := CloneRequestWithNewBody(req, newInput) + + require.NotEqual(t, req, newReq) + require.NotEqual(t, req.Body, newReq.Body) + require.NotEqual(t, len(newInput), newReq.ContentLength) + + var buffer bytes.Buffer + io.Copy(&buffer, newReq.Body) + require.Equal(t, newInput, buffer.Bytes()) +} + +func TestApplicationJson(t *testing.T) { + req, _ := http.NewRequest("POST", "/test", nil) + req.Header.Set("Content-Type", "application/json") + + require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'") + + req.Header.Set("Content-Type", "application/json; charset=utf-8") + require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'") + + req.Header.Set("Content-Type", "text/plain") + require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'") +} + +func TestFail500WorksWithNils(t *testing.T) { + body := bytes.NewBuffer(nil) + w := httptest.NewRecorder() + w.Body = body + + Fail500(w, nil, nil) + + require.Equal(t, http.StatusInternalServerError, w.Code) + require.Equal(t, "Internal server error\n", body.String()) +} + +func TestLogError(t *testing.T) { + tests := []struct { + name string + method string + uri string + err error + logMatchers []string + }{ + { + name: "nil_request", + err: fmt.Errorf("crash"), + logMatchers: []string{ + `level=error msg="unknown error" error=crash`, + }, + }, + { + name: "nil_request_nil_error", + err: nil, + logMatchers: []string{ + `level=error msg="unknown error" error="<nil>"`, + }, + }, + { + name: "basic_url", + method: "GET", + uri: "http://localhost:3000/", + err: fmt.Errorf("error"), + logMatchers: []string{ + `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/"`, + }, + }, + { + name: "secret_url", + method: "GET", + uri: "http://localhost:3000/path?certificate=123&sharedSecret=123&import_url=the_url&my_password_string=password", + err: fmt.Errorf("error"), + logMatchers: []string{ + `level=error msg=error correlation_id= error=error method=GET uri="http://localhost:3000/path\?certificate=\[FILTERED\]&sharedSecret=\[FILTERED\]&import_url=\[FILTERED\]&my_password_string=\[FILTERED\]"`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + + oldOut := logrus.StandardLogger().Out + logrus.StandardLogger().Out = buf + defer func() { + logrus.StandardLogger().Out = oldOut + }() + + var r *http.Request + if tt.uri != "" { + r = httptest.NewRequest(tt.method, tt.uri, nil) + } + + LogError(r, tt.err) + + logString := buf.String() + for _, v := range tt.logMatchers { + require.Regexp(t, v, logString) + } + }) + } +} + +func TestLogErrorWithFields(t *testing.T) { + tests := []struct { + name string + request *http.Request + err error + fields map[string]interface{} + logMatcher string + }{ + { + name: "nil_request", + err: fmt.Errorf("crash"), + fields: map[string]interface{}{"extra_one": 123}, + logMatcher: `level=error msg="unknown error" error=crash extra_one=123`, + }, + { + name: "nil_request_nil_error", + err: nil, + fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"}, + logMatcher: `level=error msg="unknown error" error="<nil>" extra_one=123 extra_two=test`, + }, + { + name: "basic_url", + request: httptest.NewRequest("GET", "http://localhost:3000/", nil), + err: fmt.Errorf("error"), + fields: map[string]interface{}{"extra_one": 123, "extra_two": "test"}, + logMatcher: `level=error msg=error correlation_id= error=error extra_one=123 extra_two=test method=GET uri="http://localhost:3000/`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + + oldOut := logrus.StandardLogger().Out + logrus.StandardLogger().Out = buf + defer func() { + logrus.StandardLogger().Out = oldOut + }() + + LogErrorWithFields(tt.request, tt.err, tt.fields) + + logString := buf.String() + require.Contains(t, logString, tt.logMatcher) + }) + } +} diff --git a/workhorse/internal/helper/raven.go b/workhorse/internal/helper/raven.go new file mode 100644 index 00000000000..ea1d0e1f6cc --- /dev/null +++ b/workhorse/internal/helper/raven.go @@ -0,0 +1,58 @@ +package helper + +import ( + "net/http" + "reflect" + + raven "github.com/getsentry/raven-go" + + //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package. + correlation "gitlab.com/gitlab-org/labkit/correlation/raven" + + "gitlab.com/gitlab-org/labkit/log" +) + +var ravenHeaderBlacklist = []string{ + "Authorization", + "Private-Token", +} + +func captureRavenError(r *http.Request, err error, fields log.Fields) { + client := raven.DefaultClient + extra := raven.Extra{} + + for k, v := range fields { + extra[k] = v + } + + interfaces := []raven.Interface{} + if r != nil { + CleanHeadersForRaven(r) + interfaces = append(interfaces, raven.NewHttp(r)) + + //lint:ignore SA1019 this was recently deprecated. Update workhorse to use labkit errortracking package. + extra = correlation.SetExtra(r.Context(), extra) + } + + exception := &raven.Exception{ + Stacktrace: raven.NewStacktrace(2, 3, nil), + Value: err.Error(), + Type: reflect.TypeOf(err).String(), + } + interfaces = append(interfaces, exception) + + packet := raven.NewPacketWithExtra(err.Error(), extra, interfaces...) + client.Capture(packet, nil) +} + +func CleanHeadersForRaven(r *http.Request) { + if r == nil { + return + } + + for _, key := range ravenHeaderBlacklist { + if r.Header.Get(key) != "" { + r.Header.Set(key, "[redacted]") + } + } +} diff --git a/workhorse/internal/helper/tempfile.go b/workhorse/internal/helper/tempfile.go new file mode 100644 index 00000000000..d8fc0d44698 --- /dev/null +++ b/workhorse/internal/helper/tempfile.go @@ -0,0 +1,35 @@ +package helper + +import ( + "io" + "io/ioutil" + "os" +) + +func ReadAllTempfile(r io.Reader) (tempfile *os.File, err error) { + tempfile, err = ioutil.TempFile("", "gitlab-workhorse-read-all-tempfile") + if err != nil { + return nil, err + } + + defer func() { + // Avoid leaking an open file if the function returns with an error + if err != nil { + tempfile.Close() + } + }() + + if err := os.Remove(tempfile.Name()); err != nil { + return nil, err + } + + if _, err := io.Copy(tempfile, r); err != nil { + return nil, err + } + + if _, err := tempfile.Seek(0, 0); err != nil { + return nil, err + } + + return tempfile, nil +} diff --git a/workhorse/internal/helper/writeafterreader.go b/workhorse/internal/helper/writeafterreader.go new file mode 100644 index 00000000000..d583ae4a9b8 --- /dev/null +++ b/workhorse/internal/helper/writeafterreader.go @@ -0,0 +1,144 @@ +package helper + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "sync" +) + +type WriteFlusher interface { + io.Writer + Flush() error +} + +// Couple r and w so that until r has been drained (before r.Read() has +// returned some error), all writes to w are sent to a tempfile first. +// The caller must call Flush() on the returned WriteFlusher to ensure +// all data is propagated to w. +func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) { + br := &busyReader{Reader: r} + return br, &coupledWriter{Writer: w, busyReader: br} +} + +type busyReader struct { + io.Reader + + error + errorMutex sync.RWMutex +} + +func (r *busyReader) Read(p []byte) (int, error) { + if err := r.getError(); err != nil { + return 0, err + } + + n, err := r.Reader.Read(p) + if err != nil { + if err != io.EOF { + err = fmt.Errorf("busyReader: %v", err) + } + r.setError(err) + } + return n, err +} + +func (r *busyReader) IsBusy() bool { + return r.getError() == nil +} + +func (r *busyReader) getError() error { + r.errorMutex.RLock() + defer r.errorMutex.RUnlock() + return r.error +} + +func (r *busyReader) setError(err error) { + if err == nil { + panic("busyReader: attempt to reset error to nil") + } + r.errorMutex.Lock() + defer r.errorMutex.Unlock() + r.error = err +} + +type coupledWriter struct { + io.Writer + *busyReader + + tempfile *os.File + tempfileMutex sync.Mutex + + writeError error +} + +func (w *coupledWriter) Write(data []byte) (int, error) { + if w.writeError != nil { + return 0, w.writeError + } + + if w.busyReader.IsBusy() { + n, err := w.tempfileWrite(data) + if err != nil { + w.writeError = fmt.Errorf("coupledWriter: %v", err) + } + return n, w.writeError + } + + if err := w.Flush(); err != nil { + w.writeError = fmt.Errorf("coupledWriter: %v", err) + return 0, w.writeError + } + + return w.Writer.Write(data) +} + +func (w *coupledWriter) Flush() error { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + tempfile := w.tempfile + if tempfile == nil { + return nil + } + + w.tempfile = nil + defer tempfile.Close() + + if _, err := tempfile.Seek(0, 0); err != nil { + return err + } + if _, err := io.Copy(w.Writer, tempfile); err != nil { + return err + } + return nil +} + +func (w *coupledWriter) tempfileWrite(data []byte) (int, error) { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + if w.tempfile == nil { + tempfile, err := w.newTempfile() + if err != nil { + return 0, err + } + w.tempfile = tempfile + } + + return w.tempfile.Write(data) +} + +func (*coupledWriter) newTempfile() (tempfile *os.File, err error) { + tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter") + if err != nil { + return nil, err + } + if err := os.Remove(tempfile.Name()); err != nil { + tempfile.Close() + return nil, err + } + + return tempfile, nil +} diff --git a/workhorse/internal/helper/writeafterreader_test.go b/workhorse/internal/helper/writeafterreader_test.go new file mode 100644 index 00000000000..67cb3e6e542 --- /dev/null +++ b/workhorse/internal/helper/writeafterreader_test.go @@ -0,0 +1,115 @@ +package helper + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "testing" + "testing/iotest" +) + +func TestBusyReader(t *testing.T) { + testData := "test data" + r := testReader(testData) + br, _ := NewWriteAfterReader(r, &bytes.Buffer{}) + + result, err := ioutil.ReadAll(br) + if err != nil { + t.Fatal(err) + } + + if string(result) != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestFirstWriteAfterReadDone(t *testing.T) { + writeRecorder := &bytes.Buffer{} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder) + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + testData := "test data" + if _, err := io.Copy(cw, testReader(testData)); err != nil { + t.Fatalf("copy test data: %v", err) + } + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + if result := writeRecorder.String(); result != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestWriteDelay(t *testing.T) { + writeRecorder := &bytes.Buffer{} + w := &complainingWriter{Writer: writeRecorder} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, w) + + testData1 := "1 test" + if _, err := io.Copy(cw, testReader(testData1)); err != nil { + t.Fatalf("error on first copy: %v", err) + } + + // Unblock the coupled writer by draining the reader + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + // Now it is no longer an error if 'w' receives a Write() + w.CheerUp() + + testData2 := "2 experiment" + if _, err := io.Copy(cw, testReader(testData2)); err != nil { + t.Fatalf("error on second copy: %v", err) + } + + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + + expected := testData1 + testData2 + if result := writeRecorder.String(); result != expected { + t.Fatalf("total write: expected %q, got %q", expected, result) + } +} + +func TestComplainingWriterSanity(t *testing.T) { + recorder := &bytes.Buffer{} + w := &complainingWriter{Writer: recorder} + + testData := "test data" + if _, err := io.Copy(w, testReader(testData)); err == nil { + t.Error("error expected, none received") + } + + w.CheerUp() + if _, err := io.Copy(w, testReader(testData)); err != nil { + t.Errorf("copy after CheerUp: %v", err) + } + + if result := recorder.String(); result != testData { + t.Errorf("expected %q, got %q", testData, result) + } +} + +func testReader(data string) io.Reader { + return iotest.OneByteReader(bytes.NewBuffer([]byte(data))) +} + +type complainingWriter struct { + happy bool + io.Writer +} + +func (comp *complainingWriter) Write(data []byte) (int, error) { + if comp.happy { + return comp.Writer.Write(data) + } + + return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data) +} + +func (comp *complainingWriter) CheerUp() { + comp.happy = true +} |