Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/main_test.go')
-rw-r--r--workhorse/main_test.go860
1 files changed, 860 insertions, 0 deletions
diff --git a/workhorse/main_test.go b/workhorse/main_test.go
new file mode 100644
index 00000000000..16fa8ff10b7
--- /dev/null
+++ b/workhorse/main_test.go
@@ -0,0 +1,860 @@
+package main
+
+import (
+ "bytes"
+ "compress/gzip"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "image/png"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "os/exec"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream"
+)
+
+const scratchDir = "testdata/scratch"
+const testRepoRoot = "testdata/data"
+const testDocumentRoot = "testdata/public"
+const testAltDocumentRoot = "testdata/alt-public"
+
+var absDocumentRoot string
+
+const testRepo = "group/test.git"
+const testProject = "group/test"
+
+var checkoutDir = path.Join(scratchDir, "test")
+var cacheDir = path.Join(scratchDir, "cache")
+
+func TestMain(m *testing.M) {
+ if _, err := os.Stat(path.Join(testRepoRoot, testRepo)); os.IsNotExist(err) {
+ log.WithError(err).Fatal("cannot find test repository. Please run 'make prepare-tests'")
+ }
+
+ if err := testhelper.BuildExecutables(); err != nil {
+ log.WithError(err).Fatal()
+ }
+
+ defer gitaly.CloseConnections()
+
+ os.Exit(m.Run())
+}
+
+func TestDeniedClone(t *testing.T) {
+ // Prepare clone directory
+ require.NoError(t, os.RemoveAll(scratchDir))
+
+ // Prepare test server and backend
+ ts := testAuthServer(t, nil, nil, 403, "Access denied")
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ // Do the git clone
+ cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir)
+ out, err := cloneCmd.CombinedOutput()
+ t.Log(string(out))
+ require.Error(t, err, "git clone should have failed")
+}
+
+func TestDeniedPush(t *testing.T) {
+ // Prepare the test server and backend
+ ts := testAuthServer(t, nil, nil, 403, "Access denied")
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ // Perform the git push
+ pushCmd := exec.Command("git", "push", "-v", fmt.Sprintf("%s/%s", ws.URL, testRepo), fmt.Sprintf("master:%s", newBranch()))
+ pushCmd.Dir = checkoutDir
+ out, err := pushCmd.CombinedOutput()
+ t.Log(string(out))
+ require.Error(t, err, "git push should have failed")
+}
+
+func TestRegularProjectsAPI(t *testing.T) {
+ apiResponse := "API RESPONSE"
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
+ _, err := w.Write([]byte(apiResponse))
+ require.NoError(t, err)
+ })
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/api/v3/projects/123/repository/not/special",
+ "/api/v3/projects/foo%2Fbar/repository/not/special",
+ "/api/v3/projects/123/not/special",
+ "/api/v3/projects/foo%2Fbar/not/special",
+ "/api/v3/projects/foo%2Fbar%2Fbaz/repository/not/special",
+ "/api/v3/projects/foo%2Fbar%2Fbaz%2Fqux/repository/not/special",
+ } {
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, apiResponse, body, "GET %q: response body", resource)
+ requireNginxResponseBuffering(t, "", resp, "GET %q: nginx response buffering", resource)
+ }
+}
+
+func TestAllowedXSendfileDownload(t *testing.T) {
+ contentFilename := "my-content"
+ prepareDownloadDir(t)
+
+ allowedXSendfileDownload(t, contentFilename, "foo/uploads/bar")
+}
+
+func TestDeniedXSendfileDownload(t *testing.T) {
+ contentFilename := "my-content"
+ prepareDownloadDir(t)
+
+ deniedXSendfileDownload(t, contentFilename, "foo/uploads/bar")
+}
+
+func TestAllowedStaticFile(t *testing.T) {
+ content := "PUBLIC"
+ require.NoError(t, setupStaticFile("static file.txt", content))
+
+ proxied := false
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ proxied = true
+ w.WriteHeader(404)
+ })
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/static%20file.txt",
+ "/static file.txt",
+ } {
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, content, body, "GET %q: response body", resource)
+ requireNginxResponseBuffering(t, "no", resp, "GET %q: nginx response buffering", resource)
+ require.False(t, proxied, "GET %q: should not have made it to backend", resource)
+ }
+}
+
+func TestStaticFileRelativeURL(t *testing.T) {
+ content := "PUBLIC"
+ require.NoError(t, setupStaticFile("static.txt", content), "create public/static.txt")
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(http.NotFound))
+ defer ts.Close()
+ backendURLString := ts.URL + "/my-relative-url"
+ log.Info(backendURLString)
+ ws := startWorkhorseServer(backendURLString)
+ defer ws.Close()
+
+ resource := "/my-relative-url/static.txt"
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, content, body, "GET %q: response body", resource)
+}
+
+func TestAllowedPublicUploadsFile(t *testing.T) {
+ content := "PRIVATE but allowed"
+ require.NoError(t, setupStaticFile("uploads/static file.txt", content), "create public/uploads/static file.txt")
+
+ proxied := false
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ proxied = true
+ w.Header().Add("X-Sendfile", absDocumentRoot+r.URL.Path)
+ w.WriteHeader(200)
+ })
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/uploads/static%20file.txt",
+ "/uploads/static file.txt",
+ } {
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, content, body, "GET %q: response body", resource)
+ require.True(t, proxied, "GET %q: never made it to backend", resource)
+ }
+}
+
+func TestDeniedPublicUploadsFile(t *testing.T) {
+ content := "PRIVATE"
+ require.NoError(t, setupStaticFile("uploads/static.txt", content), "create public/uploads/static.txt")
+
+ proxied := false
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
+ proxied = true
+ w.WriteHeader(404)
+ })
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/uploads/static.txt",
+ "/uploads%2Fstatic.txt",
+ } {
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 404, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, "", body, "GET %q: response body", resource)
+ require.True(t, proxied, "GET %q: never made it to backend", resource)
+ }
+}
+
+func TestStaticErrorPage(t *testing.T) {
+ errorPageBody := `<html>
+<body>
+This is a static error page for code 499
+</body>
+</html>
+`
+ require.NoError(t, setupStaticFile("499.html", errorPageBody))
+ ts := testhelper.TestServerWithHandler(nil, func(w http.ResponseWriter, _ *http.Request) {
+ upstreamError := "499"
+ // This is the point of the test: the size of the upstream response body
+ // should be overridden.
+ require.NotEqual(t, len(upstreamError), len(errorPageBody))
+ w.WriteHeader(499)
+ _, err := w.Write([]byte(upstreamError))
+ require.NoError(t, err)
+ })
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ resourcePath := "/error-499"
+ resp, body := httpGet(t, ws.URL+resourcePath, nil)
+
+ require.Equal(t, 499, resp.StatusCode, "GET %q: status code", resourcePath)
+ require.Equal(t, string(errorPageBody), body, "GET %q: response body", resourcePath)
+}
+
+func TestGzipAssets(t *testing.T) {
+ path := "/assets/static.txt"
+ content := "asset"
+ require.NoError(t, setupStaticFile(path, content))
+
+ buf := &bytes.Buffer{}
+ gzipWriter := gzip.NewWriter(buf)
+ _, err := gzipWriter.Write([]byte(content))
+ require.NoError(t, err)
+ require.NoError(t, gzipWriter.Close())
+ contentGzip := buf.String()
+ require.NoError(t, setupStaticFile(path+".gz", contentGzip))
+
+ proxied := false
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ proxied = true
+ w.WriteHeader(404)
+ })
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ testCases := []struct {
+ content string
+ path string
+ acceptEncoding string
+ contentEncoding string
+ }{
+ {content: content, path: path},
+ {content: contentGzip, path: path, acceptEncoding: "gzip", contentEncoding: "gzip"},
+ {content: contentGzip, path: path, acceptEncoding: "gzip, compress, br", contentEncoding: "gzip"},
+ {content: contentGzip, path: path, acceptEncoding: "br;q=1.0, gzip;q=0.8, *;q=0.1", contentEncoding: "gzip"},
+ }
+
+ for _, tc := range testCases {
+ desc := fmt.Sprintf("accept-encoding: %q", tc.acceptEncoding)
+ req, err := http.NewRequest("GET", ws.URL+tc.path, nil)
+ require.NoError(t, err, desc)
+ req.Header.Set("Accept-Encoding", tc.acceptEncoding)
+
+ resp, err := http.DefaultTransport.RoundTrip(req)
+ require.NoError(t, err, desc)
+ defer resp.Body.Close()
+ b, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err, desc)
+
+ require.Equal(t, 200, resp.StatusCode, "%s: status code", desc)
+ require.Equal(t, tc.content, string(b), "%s: response body", desc)
+ require.Equal(t, tc.contentEncoding, resp.Header.Get("Content-Encoding"), "%s: response body", desc)
+ require.False(t, proxied, "%s: should not have made it to backend", desc)
+ }
+}
+
+func TestAltDocumentAssets(t *testing.T) {
+ path := "/assets/static.txt"
+ content := "asset"
+ require.NoError(t, setupAltStaticFile(path, content))
+
+ buf := &bytes.Buffer{}
+ gzipWriter := gzip.NewWriter(buf)
+ _, err := gzipWriter.Write([]byte(content))
+ require.NoError(t, err)
+ require.NoError(t, gzipWriter.Close())
+ contentGzip := buf.String()
+ require.NoError(t, setupAltStaticFile(path+".gz", contentGzip))
+
+ proxied := false
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ proxied = true
+ w.WriteHeader(404)
+ })
+ defer ts.Close()
+
+ upstreamConfig := newUpstreamConfig(ts.URL)
+ upstreamConfig.AltDocumentRoot = testAltDocumentRoot
+
+ ws := startWorkhorseServerWithConfig(upstreamConfig)
+ defer ws.Close()
+
+ testCases := []struct {
+ desc string
+ path string
+ content string
+ acceptEncoding string
+ contentEncoding string
+ }{
+ {desc: "plaintext asset", path: path, content: content},
+ {desc: "gzip asset available", path: path, content: contentGzip, acceptEncoding: "gzip", contentEncoding: "gzip"},
+ {desc: "non-existent file", path: "/assets/non-existent"},
+ }
+
+ for _, tc := range testCases {
+ req, err := http.NewRequest("GET", ws.URL+tc.path, nil)
+ require.NoError(t, err)
+
+ if tc.acceptEncoding != "" {
+ req.Header.Set("Accept-Encoding", tc.acceptEncoding)
+ }
+
+ resp, err := http.DefaultTransport.RoundTrip(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ b, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ if tc.content != "" {
+ require.Equal(t, 200, resp.StatusCode, "%s: status code", tc.desc)
+ require.Equal(t, tc.content, string(b), "%s: response body", tc.desc)
+ require.False(t, proxied, "%s: should not have made it to backend", tc.desc)
+
+ if tc.contentEncoding != "" {
+ require.Equal(t, tc.contentEncoding, resp.Header.Get("Content-Encoding"))
+ }
+ } else {
+ require.Equal(t, 404, resp.StatusCode, "%s: status code", tc.desc)
+ }
+ }
+}
+
+var sendDataHeader = "Gitlab-Workhorse-Send-Data"
+
+func sendDataResponder(command string, literalJSON string) *httptest.Server {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ data := base64.URLEncoding.EncodeToString([]byte(literalJSON))
+ w.Header().Set(sendDataHeader, fmt.Sprintf("%s:%s", command, data))
+
+ // This should never be returned
+ if _, err := fmt.Fprintf(w, "gibberish"); err != nil {
+ panic(err)
+ }
+ }
+
+ return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), handler)
+}
+
+func doSendDataRequest(path string, command, literalJSON string) (*http.Response, []byte, error) {
+ ts := sendDataResponder(command, literalJSON)
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ resp, err := http.Get(ws.URL + path)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer resp.Body.Close()
+
+ bodyData, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return resp, nil, err
+ }
+
+ headerValue := resp.Header.Get(sendDataHeader)
+ if headerValue != "" {
+ return resp, bodyData, fmt.Errorf("%s header should not be present, but has value %q", sendDataHeader, headerValue)
+ }
+
+ return resp, bodyData, nil
+}
+
+func TestArtifactsGetSingleFile(t *testing.T) {
+ // We manually created this zip file in the gitlab-workhorse Git repository
+ archivePath := `testdata/artifacts-archive.zip`
+ fileName := "myfile"
+ fileContents := "MY FILE"
+ resourcePath := `/namespace/project/builds/123/artifacts/file/` + fileName
+ encodedFilename := base64.StdEncoding.EncodeToString([]byte(fileName))
+ jsonParams := fmt.Sprintf(`{"Archive":"%s","Entry":"%s"}`, archivePath, encodedFilename)
+
+ resp, body, err := doSendDataRequest(resourcePath, "artifacts-entry", jsonParams)
+ require.NoError(t, err)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resourcePath)
+ require.Equal(t, fileContents, string(body), "GET %q: response body", resourcePath)
+ requireNginxResponseBuffering(t, "no", resp, "GET %q: nginx response buffering", resourcePath)
+}
+
+func TestImageResizing(t *testing.T) {
+ imageLocation := `testdata/image.png`
+ requestedWidth := 40
+ imageFormat := "image/png"
+ jsonParams := fmt.Sprintf(`{"Location":"%s","Width":%d, "ContentType":"%s"}`, imageLocation, requestedWidth, imageFormat)
+ resourcePath := "/uploads/-/system/user/avatar/123/avatar.png?width=40"
+
+ resp, body, err := doSendDataRequest(resourcePath, "send-scaled-img", jsonParams)
+ require.NoError(t, err, "send resize request")
+ require.Equal(t, 200, resp.StatusCode, "GET %q: body: %s", resourcePath, body)
+
+ img, err := png.Decode(bytes.NewReader(body))
+ require.NoError(t, err, "decode resized image")
+
+ bounds := img.Bounds()
+ require.Equal(t, requestedWidth, bounds.Size().X, "wrong width after resizing")
+}
+
+func TestSendURLForArtifacts(t *testing.T) {
+ expectedBody := strings.Repeat("CONTENT!", 1024)
+
+ regularHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(len(expectedBody)))
+ w.Write([]byte(expectedBody))
+ })
+
+ chunkedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Transfer-Encoding", "chunked")
+ w.Write([]byte(expectedBody))
+ })
+
+ rawHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ hj, ok := w.(http.Hijacker)
+ require.Equal(t, true, ok)
+
+ conn, buf, err := hj.Hijack()
+ require.NoError(t, err)
+ defer conn.Close()
+ defer buf.Flush()
+
+ fmt.Fprint(buf, "HTTP/1.1 200 OK\r\nContent-Type: application/zip\r\n\r\n")
+ fmt.Fprint(buf, expectedBody)
+ })
+
+ for _, tc := range []struct {
+ name string
+ handler http.Handler
+ transferEncoding []string
+ contentLength int
+ }{
+ {"No content-length, chunked TE", chunkedHandler, []string{"chunked"}, -1}, // Case 3 in https://tools.ietf.org/html/rfc7230#section-3.3.2
+ {"Known content-length, identity TE", regularHandler, nil, len(expectedBody)}, // Case 5 in https://tools.ietf.org/html/rfc7230#section-3.3.2
+ {"No content-length, identity TE", rawHandler, []string{"chunked"}, -1}, // Case 7 in https://tools.ietf.org/html/rfc7230#section-3.3.2
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ server := httptest.NewServer(tc.handler)
+ defer server.Close()
+
+ jsonParams := fmt.Sprintf(`{"URL":%q}`, server.URL)
+
+ resourcePath := `/namespace/project/builds/123/artifacts/file/download`
+ resp, body, err := doSendDataRequest(resourcePath, "send-url", jsonParams)
+ require.NoError(t, err)
+
+ require.Equal(t, http.StatusOK, resp.StatusCode, "GET %q: status code", resourcePath)
+ require.Equal(t, int64(tc.contentLength), resp.ContentLength, "GET %q: Content-Length", resourcePath)
+ require.Equal(t, tc.transferEncoding, resp.TransferEncoding, "GET %q: Transfer-Encoding", resourcePath)
+ require.Equal(t, expectedBody, string(body), "GET %q: response body", resourcePath)
+ requireNginxResponseBuffering(t, "no", resp, "GET %q: nginx response buffering", resourcePath)
+ })
+ }
+}
+
+func TestApiContentTypeBlock(t *testing.T) {
+ wrongResponse := `{"hello":"world"}`
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", api.ResponseContentType)
+ _, err := w.Write([]byte(wrongResponse))
+ require.NoError(t, err, "write upstream response")
+ })
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ resourcePath := "/something"
+ resp, body := httpGet(t, ws.URL+resourcePath, nil)
+
+ require.Equal(t, 500, resp.StatusCode, "GET %q: status code", resourcePath)
+ require.NotContains(t, wrongResponse, body, "GET %q: response body", resourcePath)
+}
+
+func TestAPIFalsePositivesAreProxied(t *testing.T) {
+ goodResponse := []byte(`<html></html>`)
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get(secret.RequestHeader) != "" && r.Method != "GET" {
+ w.WriteHeader(500)
+ w.Write([]byte("non-GET request went through PreAuthorize handler"))
+ } else {
+ w.Header().Set("Content-Type", "text/html")
+ _, err := w.Write(goodResponse)
+ require.NoError(t, err)
+ }
+ })
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ // Each of these cases is a specially-handled path in Workhorse that may
+ // actually be a request to be sent to gitlab-rails.
+ for _, tc := range []struct {
+ method string
+ path string
+ }{
+ {"GET", "/nested/group/project/blob/master/foo.git/info/refs"},
+ {"POST", "/nested/group/project/blob/master/foo.git/git-upload-pack"},
+ {"POST", "/nested/group/project/blob/master/foo.git/git-receive-pack"},
+ {"PUT", "/nested/group/project/blob/master/foo.git/gitlab-lfs/objects/0000000000000000000000000000000000000000000000000000000000000000/0"},
+ {"GET", "/nested/group/project/blob/master/environments/1/terminal.ws"},
+ } {
+ t.Run(tc.method+"_"+tc.path, func(t *testing.T) {
+ req, err := http.NewRequest(tc.method, ws.URL+tc.path, nil)
+ require.NoError(t, err, "Constructing %s %q", tc.method, tc.path)
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err, "%s %q", tc.method, tc.path)
+ defer resp.Body.Close()
+
+ respBody, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err, "%s %q: reading body", tc.method, tc.path)
+
+ require.Equal(t, 200, resp.StatusCode, "%s %q: status code", tc.method, tc.path)
+ testhelper.RequireResponseHeader(t, resp, "Content-Type", "text/html")
+ require.Equal(t, string(goodResponse), string(respBody), "%s %q: response body", tc.method, tc.path)
+ })
+ }
+}
+
+func TestCorrelationIdHeader(t *testing.T) {
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Request-Id", "12345678")
+ w.WriteHeader(200)
+ })
+ defer ts.Close()
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/api/v3/projects/123/repository/not/special",
+ } {
+ resp, _ := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ requestIds := resp.Header["X-Request-Id"]
+ require.Equal(t, 1, len(requestIds), "GET %q: One X-Request-Id present", resource)
+ }
+}
+
+func TestPropagateCorrelationIdHeader(t *testing.T) {
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Request-Id", r.Header.Get("X-Request-Id"))
+ w.WriteHeader(200)
+ })
+ defer ts.Close()
+
+ testCases := []struct {
+ desc string
+ propagateCorrelationID bool
+ }{
+ {
+ desc: "propagateCorrelatedId is true",
+ propagateCorrelationID: true,
+ },
+ {
+ desc: "propagateCorrelatedId is false",
+ propagateCorrelationID: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ upstreamConfig := newUpstreamConfig(ts.URL)
+ upstreamConfig.PropagateCorrelationID = tc.propagateCorrelationID
+
+ ws := startWorkhorseServerWithConfig(upstreamConfig)
+ defer ws.Close()
+
+ resource := "/api/v3/projects/123/repository/not/special"
+ propagatedRequestId := "Propagated-RequestId-12345678"
+ resp, _ := httpGet(t, ws.URL+resource, map[string]string{"X-Request-Id": propagatedRequestId})
+ requestIds := resp.Header["X-Request-Id"]
+
+ require.Equal(t, 200, resp.StatusCode, "GET %q: status code", resource)
+ require.Equal(t, 1, len(requestIds), "GET %q: One X-Request-Id present", resource)
+
+ if tc.propagateCorrelationID {
+ require.Contains(t, requestIds, propagatedRequestId, "GET %q: Has X-Request-Id %s present", resource, propagatedRequestId)
+ } else {
+ require.NotContains(t, requestIds, propagatedRequestId, "GET %q: X-Request-Id not propagated")
+ }
+ })
+ }
+}
+
+func setupStaticFile(fpath, content string) error {
+ return setupStaticFileHelper(fpath, content, testDocumentRoot)
+}
+
+func setupAltStaticFile(fpath, content string) error {
+ return setupStaticFileHelper(fpath, content, testAltDocumentRoot)
+}
+
+func setupStaticFileHelper(fpath, content, directory string) error {
+ cwd, err := os.Getwd()
+ if err != nil {
+ return err
+ }
+ absDocumentRoot = path.Join(cwd, directory)
+ if err := os.MkdirAll(path.Join(absDocumentRoot, path.Dir(fpath)), 0755); err != nil {
+ return err
+ }
+ staticFile := path.Join(absDocumentRoot, fpath)
+ return ioutil.WriteFile(staticFile, []byte(content), 0666)
+}
+
+func prepareDownloadDir(t *testing.T) {
+ require.NoError(t, os.RemoveAll(scratchDir))
+ require.NoError(t, os.MkdirAll(scratchDir, 0755))
+}
+
+func newBranch() string {
+ return fmt.Sprintf("branch-%d", time.Now().UnixNano())
+}
+
+func testAuthServer(t *testing.T, url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server {
+ return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) {
+ require.NotEmpty(t, r.Header.Get("X-Request-Id"))
+
+ w.Header().Set("Content-Type", api.ResponseContentType)
+
+ logEntry := log.WithFields(log.Fields{
+ "method": r.Method,
+ "url": r.URL,
+ })
+ logEntryWithCode := logEntry.WithField("code", code)
+
+ if params != nil {
+ currentParams := r.URL.Query()
+ for key := range params {
+ if currentParams.Get(key) != params.Get(key) {
+ logEntry.Info("UPSTREAM", "DENY", "invalid auth server params")
+ w.WriteHeader(http.StatusForbidden)
+ return
+ }
+ }
+ }
+
+ // Write pure string
+ if data, ok := body.(string); ok {
+ logEntryWithCode.Info("UPSTREAM")
+
+ w.WriteHeader(code)
+ fmt.Fprint(w, data)
+ return
+ }
+
+ // Write json string
+ data, err := json.Marshal(body)
+ if err != nil {
+ logEntry.WithError(err).Error("UPSTREAM")
+
+ w.WriteHeader(503)
+ fmt.Fprint(w, err)
+ return
+ }
+
+ logEntryWithCode.Info("UPSTREAM")
+
+ w.WriteHeader(code)
+ w.Write(data)
+ })
+}
+
+func newUpstreamConfig(authBackend string) *config.Config {
+ return &config.Config{
+ Version: "123",
+ DocumentRoot: testDocumentRoot,
+ Backend: helper.URLMustParse(authBackend),
+ ImageResizerConfig: config.DefaultImageResizerConfig,
+ }
+}
+
+func startWorkhorseServer(authBackend string) *httptest.Server {
+ return startWorkhorseServerWithConfig(newUpstreamConfig(authBackend))
+}
+
+func startWorkhorseServerWithConfig(cfg *config.Config) *httptest.Server {
+ testhelper.ConfigureSecret()
+ u := upstream.NewUpstream(*cfg, logrus.StandardLogger())
+
+ return httptest.NewServer(u)
+}
+
+func runOrFail(t *testing.T, cmd *exec.Cmd) {
+ out, err := cmd.CombinedOutput()
+ t.Logf("%s", out)
+ require.NoError(t, err)
+}
+
+func gitOkBody(t *testing.T) *api.Response {
+ return &api.Response{
+ GL_ID: "user-123",
+ GL_USERNAME: "username",
+ Repository: gitalypb.Repository{
+ StorageName: "default",
+ RelativePath: "foo/bar.git",
+ },
+ }
+}
+
+func httpGet(t *testing.T, url string, headers map[string]string) (*http.Response, string) {
+ req, err := http.NewRequest("GET", url, nil)
+ require.NoError(t, err)
+
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ b, err := ioutil.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ return resp, string(b)
+}
+
+func httpPost(t *testing.T, url string, headers map[string]string, reqBody io.Reader) *http.Response {
+ req, err := http.NewRequest("POST", url, reqBody)
+ require.NoError(t, err)
+
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ require.NoError(t, err)
+
+ return resp
+}
+
+func requireNginxResponseBuffering(t *testing.T, expected string, resp *http.Response, msgAndArgs ...interface{}) {
+ actual := resp.Header.Get(helper.NginxResponseBufferHeader)
+ require.Equal(t, expected, actual, msgAndArgs...)
+}
+
+// TestHealthChecksNoStaticHTML verifies that health endpoints pass errors through and don't return the static html error pages
+func TestHealthChecksNoStaticHTML(t *testing.T) {
+ apiResponse := "API RESPONSE"
+ errorPageBody := `<html>
+<body>
+This is a static error page for code 503
+</body>
+</html>
+`
+ require.NoError(t, setupStaticFile("503.html", errorPageBody))
+
+ ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("X-Gitlab-Custom-Error", "1")
+ w.WriteHeader(503)
+ _, err := w.Write([]byte(apiResponse))
+ require.NoError(t, err)
+ })
+ defer ts.Close()
+
+ ws := startWorkhorseServer(ts.URL)
+ defer ws.Close()
+
+ for _, resource := range []string{
+ "/-/health",
+ "/-/readiness",
+ "/-/liveness",
+ } {
+ t.Run(resource, func(t *testing.T) {
+ resp, body := httpGet(t, ws.URL+resource, nil)
+
+ require.Equal(t, 503, resp.StatusCode, "status code")
+ require.Equal(t, apiResponse, body, "response body")
+ requireNginxResponseBuffering(t, "", resp, "nginx response buffering")
+ })
+ }
+}
+
+// TestHealthChecksUnreachable verifies that health endpoints return the correct content-type when the upstream is down
+func TestHealthChecksUnreachable(t *testing.T) {
+ ws := startWorkhorseServer("http://127.0.0.1:99999") // This url should point to nothing for the test to be accurate (equivalent to upstream being down)
+ defer ws.Close()
+
+ testCases := []struct {
+ path string
+ content string
+ responseType string
+ }{
+ {path: "/-/health", content: "Bad Gateway\n", responseType: "text/plain; charset=utf-8"},
+ {path: "/-/readiness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
+ {path: "/-/liveness", content: "{\"error\":\"Bad Gateway\",\"status\":502}\n", responseType: "application/json; charset=utf-8"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.path, func(t *testing.T) {
+ resp, body := httpGet(t, ws.URL+tc.path, nil)
+
+ require.Equal(t, 502, resp.StatusCode, "status code")
+ require.Equal(t, tc.responseType, resp.Header.Get("Content-Type"), "content-type")
+ require.Equal(t, tc.content, body, "response body")
+ requireNginxResponseBuffering(t, "", resp, "nginx response buffering")
+ })
+ }
+}