diff options
Diffstat (limited to 'workhorse/internal/senddata')
-rw-r--r-- | workhorse/internal/senddata/contentprocessor/contentprocessor.go | 126 | ||||
-rw-r--r-- | workhorse/internal/senddata/contentprocessor/contentprocessor_test.go | 293 | ||||
-rw-r--r-- | workhorse/internal/senddata/injecter.go | 35 | ||||
-rw-r--r-- | workhorse/internal/senddata/senddata.go | 105 | ||||
-rw-r--r-- | workhorse/internal/senddata/writer_test.go | 71 |
5 files changed, 630 insertions, 0 deletions
diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor.go b/workhorse/internal/senddata/contentprocessor/contentprocessor.go new file mode 100644 index 00000000000..a5cc0fee013 --- /dev/null +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor.go @@ -0,0 +1,126 @@ +package contentprocessor + +import ( + "bytes" + "io" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" +) + +type contentDisposition struct { + rw http.ResponseWriter + buf *bytes.Buffer + wroteHeader bool + flushed bool + active bool + removedResponseHeaders bool + status int + sentStatus bool +} + +// SetContentHeaders buffers the response if Gitlab-Workhorse-Detect-Content-Type +// header is found and set the proper content headers based on the current +// value of content type and disposition +func SetContentHeaders(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cd := &contentDisposition{ + rw: w, + buf: &bytes.Buffer{}, + status: http.StatusOK, + } + + defer cd.flush() + + h.ServeHTTP(cd, r) + }) +} + +func (cd *contentDisposition) Header() http.Header { + return cd.rw.Header() +} + +func (cd *contentDisposition) Write(data []byte) (int, error) { + // Normal write if we don't need to buffer + if cd.isUnbuffered() { + cd.WriteHeader(cd.status) + return cd.rw.Write(data) + } + + // Write the new data into the buffer + n, _ := cd.buf.Write(data) + + // If we have enough data to calculate the content headers then flush the Buffer + var err error + if cd.buf.Len() >= headers.MaxDetectSize { + err = cd.flushBuffer() + } + + return n, err +} + +func (cd *contentDisposition) flushBuffer() error { + if cd.isUnbuffered() { + return nil + } + + cd.flushed = true + + // If the buffer has any content then we calculate the content headers and + // write in the response + if cd.buf.Len() > 0 { + cd.writeContentHeaders() + cd.WriteHeader(cd.status) + _, err := io.Copy(cd.rw, cd.buf) + return err + } + + // If no content is present in the buffer we still need to send the headers + cd.WriteHeader(cd.status) + return nil +} + +func (cd *contentDisposition) writeContentHeaders() { + if cd.wroteHeader { + return + } + + cd.wroteHeader = true + contentType, contentDisposition := headers.SafeContentHeaders(cd.buf.Bytes(), cd.Header().Get(headers.ContentDispositionHeader)) + cd.Header().Set(headers.ContentTypeHeader, contentType) + cd.Header().Set(headers.ContentDispositionHeader, contentDisposition) +} + +func (cd *contentDisposition) WriteHeader(status int) { + if cd.sentStatus { + return + } + + cd.status = status + + if cd.isUnbuffered() { + cd.rw.WriteHeader(cd.status) + cd.sentStatus = true + } +} + +// If we find any response header, then we must calculate the content headers +// If we don't find any, the data is not buffered and it works as +// a usual ResponseWriter +func (cd *contentDisposition) isUnbuffered() bool { + if !cd.removedResponseHeaders { + if headers.IsDetectContentTypeHeaderPresent(cd.rw) { + cd.active = true + } + + cd.removedResponseHeaders = true + // We ensure to clear any response header from the response + headers.RemoveResponseHeaders(cd.rw) + } + + return cd.flushed || !cd.active +} + +func (cd *contentDisposition) flush() { + cd.flushBuffer() +} diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go new file mode 100644 index 00000000000..5e3a74f04f9 --- /dev/null +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor_test.go @@ -0,0 +1,293 @@ +package contentprocessor + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestFailSetContentTypeAndDisposition(t *testing.T) { + testCaseBody := "Hello world!" + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "", resp.Header.Get(headers.ContentDispositionHeader)) + require.Equal(t, "", resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSuccessSetContentTypeAndDispositionFeatureEnabled(t *testing.T) { + testCaseBody := "Hello world!" + + resp := makeRequest(t, nil, testCaseBody, "") + + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) + require.Equal(t, "text/plain; charset=utf-8", resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSetProperContentTypeAndDisposition(t *testing.T) { + testCases := []struct { + desc string + contentType string + contentDisposition string + body string + }{ + { + desc: "text type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "Hello world!", + }, + { + desc: "HTML type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "<html><body>Hello world!</body></html>", + }, + { + desc: "Javascript type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "<script>alert(\"foo\")</script>", + }, + { + desc: "Image type", + contentType: "image/png", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/image.png"), + }, + { + desc: "SVG type", + contentType: "image/svg+xml", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/image.svg"), + }, + { + desc: "Partial SVG type", + contentType: "image/svg+xml", + contentDisposition: "attachment", + body: "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" viewBox=\"0 0 330 82\"><title>SVG logo combined with the W3C logo, set horizontally</title><desc>The logo combines three entities displayed horizontall</desc><metadata>", + }, + { + desc: "Application type", + contentType: "application/pdf", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.pdf"), + }, + { + desc: "Application type pdf with inline disposition", + contentType: "application/pdf", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.pdf"), + }, + { + desc: "Application executable type", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.swf"), + }, + { + desc: "Video type", + contentType: "video/mp4", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/video.mp4"), + }, + { + desc: "Audio type", + contentType: "audio/mpeg", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/audio.mp3"), + }, + { + desc: "JSON type", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: "{ \"glossary\": { \"title\": \"example glossary\", \"GlossDiv\": { \"title\": \"S\" } } }", + }, + { + desc: "Forged file with png extension but SWF content", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/forgedfile.png"), + }, + { + desc: "BMPR file", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.bmpr"), + }, + { + desc: "STL file", + contentType: "application/octet-stream", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.stl"), + }, + { + desc: "RDoc file", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.rdoc"), + }, + { + desc: "IPYNB file", + contentType: "text/plain; charset=utf-8", + contentDisposition: "inline", + body: testhelper.LoadFile(t, "testdata/file.ipynb"), + }, + { + desc: "Sketch file", + contentType: "application/zip", + contentDisposition: "attachment", + body: testhelper.LoadFile(t, "testdata/file.sketch"), + }, + { + desc: "PDF file with non-ASCII characters in filename", + contentType: "application/pdf", + contentDisposition: `attachment; filename="file-ä.pdf"; filename*=UTF-8''file-%c3.pdf`, + body: testhelper.LoadFile(t, "testdata/file-ä.pdf"), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + resp := makeRequest(t, nil, tc.body, tc.contentDisposition) + + require.Equal(t, tc.contentType, resp.Header.Get(headers.ContentTypeHeader)) + require.Equal(t, tc.contentDisposition, resp.Header.Get(headers.ContentDispositionHeader)) + }) + } +} + +func TestFailOverrideContentType(t *testing.T) { + testCase := struct { + contentType string + body string + }{ + contentType: "text/plain; charset=utf-8", + body: "<html><body>Hello world!</body></html>", + } + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + w.Header().Set(headers.ContentTypeHeader, "text/html; charset=utf-8") + _, err := io.WriteString(w, testCase.body) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCase.body, "") + + require.Equal(t, testCase.contentType, resp.Header.Get(headers.ContentTypeHeader)) +} + +func TestSuccessOverrideContentDispositionFromInlineToAttachment(t *testing.T) { + testCaseBody := "Hello world!" + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "attachment") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestInlineContentDispositionForPdfFiles(t *testing.T) { + testCaseBody := testhelper.LoadFile(t, "testdata/file.pdf") + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "inline") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "inline", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestFailOverrideContentDispositionFromAttachmentToInline(t *testing.T) { + testCaseBody := testhelper.LoadFile(t, "testdata/image.svg") + + h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream or an inner layer of the ResponseWriter chain + w.Header().Set(headers.ContentDispositionHeader, "inline") + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + _, err := io.WriteString(w, testCaseBody) + require.NoError(t, err) + }) + + resp := makeRequest(t, h, testCaseBody, "") + + require.Equal(t, "attachment", resp.Header.Get(headers.ContentDispositionHeader)) +} + +func TestHeadersDelete(t *testing.T) { + for _, code := range []int{200, 400} { + recorder := httptest.NewRecorder() + rw := &contentDisposition{rw: recorder} + for _, name := range headers.ResponseHeaders { + rw.Header().Set(name, "foobar") + } + + rw.WriteHeader(code) + + for _, name := range headers.ResponseHeaders { + if header := recorder.Header().Get(name); header != "" { + t.Fatalf("HTTP %d response: expected header to be empty, found %q", code, name) + } + } + } +} + +func TestWriteHeadersCalledOnce(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &contentDisposition{rw: recorder} + rw.WriteHeader(400) + require.Equal(t, 400, rw.status) + require.Equal(t, true, rw.sentStatus) + + rw.WriteHeader(200) + require.Equal(t, 400, rw.status) +} + +func makeRequest(t *testing.T, handler http.HandlerFunc, body string, disposition string) *http.Response { + if handler == nil { + handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // We are pretending to be upstream + w.Header().Set(headers.GitlabWorkhorseDetectContentTypeHeader, "true") + w.Header().Set(headers.ContentDispositionHeader, disposition) + _, err := io.WriteString(w, body) + require.NoError(t, err) + }) + } + req, _ := http.NewRequest("GET", "/", nil) + + rw := httptest.NewRecorder() + SetContentHeaders(handler).ServeHTTP(rw, req) + + resp := rw.Result() + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equal(t, body, string(respBody)) + + return resp +} diff --git a/workhorse/internal/senddata/injecter.go b/workhorse/internal/senddata/injecter.go new file mode 100644 index 00000000000..d5739d2a053 --- /dev/null +++ b/workhorse/internal/senddata/injecter.go @@ -0,0 +1,35 @@ +package senddata + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "strings" +) + +type Injecter interface { + Match(string) bool + Inject(http.ResponseWriter, *http.Request, string) + Name() string +} + +type Prefix string + +func (p Prefix) Match(s string) bool { + return strings.HasPrefix(s, string(p)) +} + +func (p Prefix) Unpack(result interface{}, sendData string) error { + jsonBytes, err := base64.URLEncoding.DecodeString(strings.TrimPrefix(sendData, string(p))) + if err != nil { + return err + } + if err := json.Unmarshal([]byte(jsonBytes), result); err != nil { + return err + } + return nil +} + +func (p Prefix) Name() string { + return strings.TrimSuffix(string(p), ":") +} diff --git a/workhorse/internal/senddata/senddata.go b/workhorse/internal/senddata/senddata.go new file mode 100644 index 00000000000..c287d2574fa --- /dev/null +++ b/workhorse/internal/senddata/senddata.go @@ -0,0 +1,105 @@ +package senddata + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/senddata/contentprocessor" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + sendDataResponses = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_senddata_responses", + Help: "How many HTTP responses have been hijacked by a workhorse senddata injecter", + }, + []string{"injecter"}, + ) + sendDataResponseBytes = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_senddata_response_bytes", + Help: "How many bytes have been written by workhorse senddata response injecters", + }, + []string{"injecter"}, + ) +) + +type sendDataResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + req *http.Request + injecters []Injecter +} + +func SendData(h http.Handler, injecters ...Injecter) http.Handler { + return contentprocessor.SetContentHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := sendDataResponseWriter{ + rw: w, + req: r, + injecters: injecters, + } + defer s.flush() + h.ServeHTTP(&s, r) + })) +} + +func (s *sendDataResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *sendDataResponseWriter) Write(data []byte) (int, error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return len(data), nil + } + return s.rw.Write(data) +} + +func (s *sendDataResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + s.status = status + + if s.status == http.StatusOK && s.tryInject() { + return + } + + s.rw.WriteHeader(s.status) +} + +func (s *sendDataResponseWriter) tryInject() bool { + if s.hijacked { + return false + } + + header := s.Header().Get(headers.GitlabWorkhorseSendDataHeader) + if header == "" { + return false + } + + for _, injecter := range s.injecters { + if injecter.Match(header) { + s.hijacked = true + helper.DisableResponseBuffering(s.rw) + crw := helper.NewCountingResponseWriter(s.rw) + injecter.Inject(crw, s.req, header) + sendDataResponses.WithLabelValues(injecter.Name()).Inc() + sendDataResponseBytes.WithLabelValues(injecter.Name()).Add(float64(crw.Count())) + return true + } + } + + return false +} + +func (s *sendDataResponseWriter) flush() { + s.WriteHeader(http.StatusOK) +} diff --git a/workhorse/internal/senddata/writer_test.go b/workhorse/internal/senddata/writer_test.go new file mode 100644 index 00000000000..1262acd5472 --- /dev/null +++ b/workhorse/internal/senddata/writer_test.go @@ -0,0 +1,71 @@ +package senddata + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/headers" +) + +func TestWriter(t *testing.T) { + upstreamResponse := "hello world" + + testCases := []struct { + desc string + headerValue string + out string + }{ + { + desc: "inject", + headerValue: testInjecterName + ":" + testInjecterName, + out: testInjecterData, + }, + { + desc: "pass", + headerValue: "", + out: upstreamResponse, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + recorder := httptest.NewRecorder() + rw := &sendDataResponseWriter{rw: recorder, injecters: []Injecter{&testInjecter{}}} + + rw.Header().Set(headers.GitlabWorkhorseSendDataHeader, tc.headerValue) + + n, err := rw.Write([]byte(upstreamResponse)) + require.NoError(t, err) + require.Equal(t, len(upstreamResponse), n, "bytes written") + + recorder.Flush() + + body := recorder.Result().Body + data, err := ioutil.ReadAll(body) + require.NoError(t, err) + require.NoError(t, body.Close()) + + require.Equal(t, tc.out, string(data)) + }) + } +} + +const ( + testInjecterName = "test-injecter" + testInjecterData = "hello this is injected data" +) + +type testInjecter struct{} + +func (ti *testInjecter) Inject(w http.ResponseWriter, r *http.Request, sendData string) { + io.WriteString(w, testInjecterData) +} + +func (ti *testInjecter) Match(s string) bool { return strings.HasPrefix(s, testInjecterName+":") } +func (ti *testInjecter) Name() string { return testInjecterName } |