diff options
author | GitLab Bot <gitlab-bot@gitlab.com> | 2023-09-20 14:18:08 +0300 |
---|---|---|
committer | GitLab Bot <gitlab-bot@gitlab.com> | 2023-09-20 14:18:08 +0300 |
commit | 5afcbe03ead9ada87621888a31a62652b10a7e4f (patch) | |
tree | 9918b67a0d0f0bafa6542e839a8be37adf73102d /workhorse | |
parent | c97c0201564848c1f53226fe19d71fdcc472f7d0 (diff) |
Add latest changes from gitlab-org/gitlab@16-4-stable-eev16.4.0-rc42
Diffstat (limited to 'workhorse')
-rw-r--r-- | workhorse/config_test.go | 2 | ||||
-rw-r--r-- | workhorse/gitaly_integration_test.go | 54 | ||||
-rw-r--r-- | workhorse/gitaly_test.go | 17 | ||||
-rw-r--r-- | workhorse/go.mod | 4 | ||||
-rw-r--r-- | workhorse/go.sum | 6 | ||||
-rw-r--r-- | workhorse/internal/config/config.go | 15 | ||||
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy.go | 73 | ||||
-rw-r--r-- | workhorse/internal/dependencyproxy/dependencyproxy_test.go | 153 | ||||
-rw-r--r-- | workhorse/internal/gitaly/gitaly.go | 35 | ||||
-rw-r--r-- | workhorse/internal/gitaly/gitaly_test.go | 9 | ||||
-rw-r--r-- | workhorse/internal/gitaly/namespace.go | 8 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis.go | 186 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis_test.go | 107 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher.go | 236 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher_test.go | 301 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher.go | 24 | ||||
-rw-r--r-- | workhorse/internal/redis/redis.go | 12 | ||||
-rw-r--r-- | workhorse/main.go | 37 | ||||
-rw-r--r-- | workhorse/main_test.go | 20 | ||||
-rw-r--r-- | workhorse/sendfile_test.go | 9 | ||||
-rw-r--r-- | workhorse/upload_test.go | 8 |
21 files changed, 1172 insertions, 144 deletions
diff --git a/workhorse/config_test.go b/workhorse/config_test.go index a6a1bdd7187..64f0a24d148 100644 --- a/workhorse/config_test.go +++ b/workhorse/config_test.go @@ -34,6 +34,7 @@ trusted_cidrs_for_propagation = ["10.0.0.1/8"] [redis] password = "redis password" +SentinelPassword = "sentinel password" [object_storage] provider = "test provider" [image_resizer] @@ -68,6 +69,7 @@ key = "/path/to/private/key" // fields in each section; that should happen in the tests of the // internal/config package. require.Equal(t, "redis password", cfg.Redis.Password) + require.Equal(t, "sentinel password", cfg.Redis.SentinelPassword) require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider) require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs") require.Equal(t, []string{"127.0.0.1/8", "192.168.0.1/8"}, cfg.TrustedCIDRsForXForwardedFor) diff --git a/workhorse/gitaly_integration_test.go b/workhorse/gitaly_integration_test.go index a7ec0b63b9d..929b9263dfd 100644 --- a/workhorse/gitaly_integration_test.go +++ b/workhorse/gitaly_integration_test.go @@ -20,6 +20,8 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" "gitlab.com/gitlab-org/gitaly/v16/streamio" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" @@ -76,27 +78,24 @@ func realGitalyOkBody(t *testing.T, gitalyAddress string) *api.Response { } func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error { - ctx, namespace, err := gitaly.NewNamespaceClient( - context.Background(), - apiResponse.GitalyServer, - ) - - if err != nil { - return err - } - ctx, repository, err := gitaly.NewRepositoryClient(ctx, apiResponse.GitalyServer) + ctx, repository, err := gitaly.NewRepositoryClient(context.Background(), apiResponse.GitalyServer) if err != nil { return err } // Remove the repository if it already exists, for consistency - rmNsReq := &gitalypb.RemoveNamespaceRequest{ - StorageName: apiResponse.Repository.StorageName, - Name: apiResponse.Repository.RelativePath, - } - _, err = namespace.RemoveNamespace(ctx, rmNsReq) - if err != nil { - return err + if _, err := repository.RepositoryServiceClient.RemoveRepository(ctx, &gitalypb.RemoveRepositoryRequest{ + Repository: &gitalypb.Repository{ + StorageName: apiResponse.Repository.StorageName, + RelativePath: apiResponse.Repository.RelativePath, + }, + }); err != nil { + status, ok := status.FromError(err) + if !ok || !(status.Code() == codes.NotFound && status.Message() == "repository does not exist") { + return fmt.Errorf("remove repository: %w", err) + } + + // Repository didn't exist. } stream, err := repository.CreateRepositoryFromBundle(ctx) @@ -139,13 +138,13 @@ func TestAllowedClone(t *testing.T) { defer ws.Close() // Do the git clone - require.NoError(t, os.RemoveAll(scratchDir)) - cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir) + tmpDir := t.TempDir() + cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir) runOrFail(t, cloneCmd) // We may have cloned an 'empty' repository, 'git log' will fail in it logCmd := exec.Command("git", "log", "-1", "--oneline") - logCmd.Dir = checkoutDir + logCmd.Dir = tmpDir runOrFail(t, logCmd) }) } @@ -167,13 +166,13 @@ func TestAllowedShallowClone(t *testing.T) { defer ws.Close() // Shallow git clone (depth 1) - require.NoError(t, os.RemoveAll(scratchDir)) - cloneCmd := exec.Command("git", "clone", "--depth", "1", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir) + tmpDir := t.TempDir() + cloneCmd := exec.Command("git", "clone", "--depth", "1", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir) runOrFail(t, cloneCmd) // We may have cloned an 'empty' repository, 'git log' will fail in it logCmd := exec.Command("git", "log", "-1", "--oneline") - logCmd.Dir = checkoutDir + logCmd.Dir = tmpDir runOrFail(t, logCmd) }) } @@ -194,9 +193,14 @@ func TestAllowedPush(t *testing.T) { ws := startWorkhorseServer(ts.URL) defer ws.Close() + // Do the git clone + tmpDir := t.TempDir() + cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir) + runOrFail(t, cloneCmd) + // Perform the git push pushCmd := exec.Command("git", "push", fmt.Sprintf("%s/%s", ws.URL, testRepo), fmt.Sprintf("master:%s", newBranch())) - pushCmd.Dir = checkoutDir + pushCmd.Dir = tmpDir runOrFail(t, pushCmd) }) } @@ -249,7 +253,7 @@ func TestAllowedGetGitArchive(t *testing.T) { apiResponse := realGitalyOkBody(t, gitalyAddress) require.NoError(t, ensureGitalyRepository(t, apiResponse)) - archivePath := path.Join(scratchDir, "my/path") + archivePath := path.Join(t.TempDir(), "my/path") archivePrefix := "repo-1" msg := serializedProtoMessage("GetArchiveRequest", &gitalypb.GetArchiveRequest{ @@ -296,7 +300,7 @@ func TestAllowedGetGitArchiveOldPayload(t *testing.T) { repo := &apiResponse.Repository require.NoError(t, ensureGitalyRepository(t, apiResponse)) - archivePath := path.Join(scratchDir, "my/path") + archivePath := path.Join(t.TempDir(), "my/path") archivePrefix := "repo-1" jsonParams := fmt.Sprintf( diff --git a/workhorse/gitaly_test.go b/workhorse/gitaly_test.go index 6bbc67228c3..270c40cb4bc 100644 --- a/workhorse/gitaly_test.go +++ b/workhorse/gitaly_test.go @@ -31,9 +31,6 @@ import ( ) func TestFailedCloneNoGitaly(t *testing.T) { - // Prepare clone directory - require.NoError(t, os.RemoveAll(scratchDir)) - authBody := &api.Response{ GL_ID: "user-123", GL_USERNAME: "username", @@ -48,7 +45,7 @@ func TestFailedCloneNoGitaly(t *testing.T) { defer ws.Close() // Do the git clone - cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir) + cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), t.TempDir()) out, err := cloneCmd.CombinedOutput() t.Log(string(out)) require.Error(t, err, "git clone should have failed") @@ -632,7 +629,7 @@ func TestGetArchiveProxiedToGitalySuccessfully(t *testing.T) { archivePath string cacheDisabled bool }{ - {archivePath: path.Join(scratchDir, "my/path"), cacheDisabled: false}, + {archivePath: path.Join(t.TempDir(), "my/path"), cacheDisabled: false}, {archivePath: "/var/empty/my/path", cacheDisabled: true}, } @@ -668,7 +665,7 @@ func TestGetArchiveProxiedToGitalyInterruptedStream(t *testing.T) { archivePath := "my/path" archivePrefix := "repo-1" jsonParams := fmt.Sprintf(`{"GitalyServer":{"Address":"%s","Token":""},"GitalyRepository":{"storage_name":"%s","relative_path":"%s"},"ArchivePath":"%s","ArchivePrefix":"%s","CommitId":"%s"}`, - gitalyAddress, repoStorage, repoRelativePath, path.Join(scratchDir, archivePath), archivePrefix, oid) + gitalyAddress, repoStorage, repoRelativePath, path.Join(t.TempDir(), archivePath), archivePrefix, oid) resp, _, err := doSendDataRequest("/archive.tar.gz", "git-archive", jsonParams) require.NoError(t, err) @@ -850,7 +847,13 @@ type combinedServer struct { } func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*combinedServer, string) { - socketPath := path.Join(scratchDir, fmt.Sprintf("gitaly-%d.sock", rand.Int())) + tmpFolder, err := os.MkdirTemp("", "gitaly") + require.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(tmpFolder) + }) + + socketPath := path.Join(tmpFolder, fmt.Sprintf("gitaly-%d.sock", rand.Int())) if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { t.Fatal(err) } diff --git a/workhorse/go.mod b/workhorse/go.mod index 17ae3ce12ec..18699787e6e 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -15,13 +15,13 @@ require ( github.com/golang/protobuf v1.5.3 github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/websocket v1.5.0 - github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/johannesboyne/gofakes3 v0.0.0-20230506070712-04da935ef877 github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.16.0 github.com/rafaeljusto/redigomock/v3 v3.1.2 + github.com/redis/go-redis/v9 v9.1.0 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 github.com/smartystreets/goconvey v1.7.2 @@ -65,6 +65,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/client9/reopen v1.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -78,6 +79,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect + github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index f3ceee8b5e8..5163d055187 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -869,6 +869,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= +github.com/bsm/ginkgo/v2 v2.9.5 h1:rtVBYPs3+TC5iLUVOis1B9tjLTup7Cj5IfzosKtvTJ0= +github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= @@ -1068,6 +1070,8 @@ github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8l github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dgryski/go-sip13 v0.0.0-20200911182023-62edffca9245/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/digitalocean/godo v1.78.0/go.mod h1:GBmu8MkjZmNARE7IXRPmkbbnocNN8+uBm0xbEVw2LCs= @@ -2047,6 +2051,8 @@ github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqz github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/rakyll/embedmd v0.0.0-20171029212350-c8060a0752a2/go.mod h1:7jOTMgqac46PZcF54q6l2hkLEG8op93fZu61KmxWDV4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.1.0 h1:137FnGdk+EQdCbye1FW+qOEcY5S+SpY9T0NiuqvtfMY= +github.com/redis/go-redis/v9 v9.1.0/go.mod h1:urWj3He21Dj5k4TK1y59xH8Uj6ATueP8AH1cY3lZl4c= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= diff --git a/workhorse/internal/config/config.go b/workhorse/internal/config/config.go index 687986974a3..3b928d42fe1 100644 --- a/workhorse/internal/config/config.go +++ b/workhorse/internal/config/config.go @@ -83,13 +83,14 @@ type GoogleCredentials struct { } type RedisConfig struct { - URL TomlURL - Sentinel []TomlURL - SentinelMaster string - Password string - DB *int - MaxIdle *int - MaxActive *int + URL TomlURL + Sentinel []TomlURL + SentinelMaster string + SentinelPassword string + Password string + DB *int + MaxIdle *int + MaxActive *int } type ImageResizerConfig struct { diff --git a/workhorse/internal/dependencyproxy/dependencyproxy.go b/workhorse/internal/dependencyproxy/dependencyproxy.go index e170b001806..dbea3c29aec 100644 --- a/workhorse/internal/dependencyproxy/dependencyproxy.go +++ b/workhorse/internal/dependencyproxy/dependencyproxy.go @@ -23,8 +23,15 @@ type Injector struct { } type entryParams struct { - Url string - Header http.Header + Url string + Headers http.Header + UploadConfig uploadConfig +} + +type uploadConfig struct { + Headers http.Header + Method string + Url string } type nullResponseWriter struct { @@ -55,7 +62,13 @@ func (p *Injector) SetUploadHandler(uploadHandler http.Handler) { } func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData string) { - dependencyResponse, err := p.fetchUrl(r.Context(), sendData) + params, err := p.unpackParams(sendData) + if err != nil { + fail.Request(w, r, err) + return + } + + dependencyResponse, err := p.fetchUrl(r.Context(), params) if err != nil { fail.Request(w, r, err) return @@ -70,11 +83,10 @@ func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData strin w.Header().Set("Content-Length", dependencyResponse.Header.Get("Content-Length")) teeReader := io.TeeReader(dependencyResponse.Body, w) - saveFileRequest, err := http.NewRequestWithContext(r.Context(), "POST", r.URL.String()+"/upload", teeReader) + saveFileRequest, err := p.newUploadRequest(r.Context(), params, r, teeReader) if err != nil { fail.Request(w, r, fmt.Errorf("dependency proxy: failed to create request: %w", err)) } - saveFileRequest.Header = r.Header.Clone() // forward headers from dependencyResponse to rails and client for key, values := range dependencyResponse.Header { @@ -100,17 +112,56 @@ func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData strin } } -func (p *Injector) fetchUrl(ctx context.Context, sendData string) (*http.Response, error) { +func (p *Injector) fetchUrl(ctx context.Context, params *entryParams) (*http.Response, error) { + r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil) + if err != nil { + return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %w", err) + } + r.Header = params.Headers + + return httpClient.Do(r) +} + +func (p *Injector) newUploadRequest(ctx context.Context, params *entryParams, originalRequest *http.Request, body io.Reader) (*http.Request, error) { + method := p.uploadMethodFrom(params) + uploadUrl := p.uploadUrlFrom(params, originalRequest) + request, err := http.NewRequestWithContext(ctx, method, uploadUrl, body) + if err != nil { + return nil, err + } + + request.Header = originalRequest.Header.Clone() + + for key, values := range params.UploadConfig.Headers { + request.Header.Del(key) + for _, value := range values { + request.Header.Add(key, value) + } + } + + return request, nil +} + +func (p *Injector) unpackParams(sendData string) (*entryParams, error) { var params entryParams if err := p.Unpack(¶ms, sendData); err != nil { return nil, fmt.Errorf("dependency proxy: unpack sendData: %v", err) } - r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil) - if err != nil { - return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %v", err) + return ¶ms, nil +} + +func (p *Injector) uploadMethodFrom(params *entryParams) string { + if params.UploadConfig.Method != "" { + return params.UploadConfig.Method } - r.Header = params.Header + return http.MethodPost +} - return httpClient.Do(r) +func (p *Injector) uploadUrlFrom(params *entryParams, originalRequest *http.Request) string { + if params.UploadConfig.Url != "" { + return params.UploadConfig.Url + } + + return originalRequest.URL.String() + "/upload" } diff --git a/workhorse/internal/dependencyproxy/dependencyproxy_test.go b/workhorse/internal/dependencyproxy/dependencyproxy_test.go index d893ddc500f..bee74ce0a9e 100644 --- a/workhorse/internal/dependencyproxy/dependencyproxy_test.go +++ b/workhorse/internal/dependencyproxy/dependencyproxy_test.go @@ -2,6 +2,7 @@ package dependencyproxy import ( "encoding/base64" + "encoding/json" "fmt" "io" "net/http" @@ -149,6 +150,158 @@ func TestSuccessfullRequest(t *testing.T) { require.Equal(t, dockerContentDigest, response.Header().Get("Docker-Content-Digest")) } +func TestValidUploadConfiguration(t *testing.T) { + content := []byte("content") + contentLength := strconv.Itoa(len(content)) + contentType := "text/plain" + testHeader := "test-received-url" + originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(testHeader, r.URL.Path) + w.Header().Set("Content-Length", contentLength) + w.Header().Set("Content-Type", contentType) + w.Write(content) + })) + + testCases := []struct { + desc string + uploadConfig *uploadConfig + expectedConfig uploadConfig + }{ + { + desc: "with the default values", + expectedConfig: uploadConfig{ + Method: http.MethodPost, + Url: "/target/upload", + }, + }, { + desc: "with overriden method", + uploadConfig: &uploadConfig{ + Method: http.MethodPut, + }, + expectedConfig: uploadConfig{ + Method: http.MethodPut, + Url: "/target/upload", + }, + }, { + desc: "with overriden url", + uploadConfig: &uploadConfig{ + Url: "http://test.org/overriden/upload", + }, + expectedConfig: uploadConfig{ + Method: http.MethodPost, + Url: "http://test.org/overriden/upload", + }, + }, { + desc: "with overriden headers", + uploadConfig: &uploadConfig{ + Headers: map[string][]string{"Private-Token": {"123456789"}}, + }, + expectedConfig: uploadConfig{ + Headers: map[string][]string{"Private-Token": {"123456789"}}, + Method: http.MethodPost, + Url: "/target/upload", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + uploadHandler := &fakeUploadHandler{ + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, tc.expectedConfig.Url, r.URL.String()) + require.Equal(t, tc.expectedConfig.Method, r.Method) + + if tc.expectedConfig.Headers != nil { + for k, v := range tc.expectedConfig.Headers { + require.Equal(t, v, r.Header[k]) + } + } + + w.WriteHeader(200) + }, + } + + injector := NewInjector() + injector.SetUploadHandler(uploadHandler) + + sendData := map[string]interface{}{ + "Token": "token", + "Url": originResourceServer.URL + `/remote/file`, + } + + if tc.uploadConfig != nil { + sendData["UploadConfig"] = tc.uploadConfig + } + + sendDataJsonString, err := json.Marshal(sendData) + require.NoError(t, err) + + response := makeRequest(injector, string(sendDataJsonString)) + + //checking the response + require.Equal(t, 200, response.Code) + require.Equal(t, string(content), response.Body.String()) + // checking remote file request + require.Equal(t, "/remote/file", response.Header().Get(testHeader)) + }) + } +} + +func TestInvalidUploadConfiguration(t *testing.T) { + baseSendData := map[string]interface{}{ + "Token": "token", + "Url": "http://remote.dev/remote/file", + } + testCases := []struct { + desc string + sendData map[string]interface{} + }{ + { + desc: "with an invalid overriden method", + sendData: mergeMap(baseSendData, map[string]interface{}{ + "UploadConfig": map[string]string{ + "Method": "TEAPOT", + }, + }), + }, { + desc: "with an invalid url", + sendData: mergeMap(baseSendData, map[string]interface{}{ + "UploadConfig": map[string]string{ + "Url": "invalid_url", + }, + }), + }, { + desc: "with an invalid headers", + sendData: mergeMap(baseSendData, map[string]interface{}{ + "UploadConfig": map[string]interface{}{ + "Headers": map[string]string{ + "Private-Token": "not_an_array", + }, + }, + }), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + sendDataJsonString, err := json.Marshal(tc.sendData) + require.NoError(t, err) + + response := makeRequest(NewInjector(), string(sendDataJsonString)) + + require.Equal(t, 500, response.Code) + require.Equal(t, "Internal Server Error\n", response.Body.String()) + }) + } +} + +func mergeMap(from map[string]interface{}, into map[string]interface{}) map[string]interface{} { + for k, v := range from { + into[k] = v + } + return into +} + func TestIncorrectSendData(t *testing.T) { response := makeRequest(NewInjector(), "") diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go index d9dbbdbb605..e4fbad17017 100644 --- a/workhorse/internal/gitaly/gitaly.go +++ b/workhorse/internal/gitaly/gitaly.go @@ -7,7 +7,6 @@ import ( "github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab/-/issues/324868 "github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab/-/issues/324868 - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -114,16 +113,6 @@ func NewRepositoryClient(ctx context.Context, server api.GitalyServer) (context. return withOutgoingMetadata(ctx, server), &RepositoryClient{grpcClient}, nil } -// NewNamespaceClient is only used by the Gitaly integration tests at present -func NewNamespaceClient(ctx context.Context, server api.GitalyServer) (context.Context, *NamespaceClient, error) { - conn, err := getOrCreateConnection(server) - if err != nil { - return nil, nil, err - } - grpcClient := gitalypb.NewNamespaceServiceClient(conn) - return withOutgoingMetadata(ctx, server), &NamespaceClient{grpcClient}, nil -} - func NewDiffClient(ctx context.Context, server api.GitalyServer) (context.Context, *DiffClient, error) { conn, err := getOrCreateConnection(server) if err != nil { @@ -173,23 +162,19 @@ func CloseConnections() { func newConnection(server api.GitalyServer) (*grpc.ClientConn, error) { connOpts := append(gitalyclient.DefaultDialOpts, grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpctracing.StreamClientTracingInterceptor(), - grpc_prometheus.StreamClientInterceptor, - grpccorrelation.StreamClientCorrelationInterceptor( - grpccorrelation.WithClientName("gitlab-workhorse"), - ), + grpc.WithChainStreamInterceptor( + grpctracing.StreamClientTracingInterceptor(), + grpc_prometheus.StreamClientInterceptor, + grpccorrelation.StreamClientCorrelationInterceptor( + grpccorrelation.WithClientName("gitlab-workhorse"), ), ), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpctracing.UnaryClientTracingInterceptor(), - grpc_prometheus.UnaryClientInterceptor, - grpccorrelation.UnaryClientCorrelationInterceptor( - grpccorrelation.WithClientName("gitlab-workhorse"), - ), + grpc.WithChainUnaryInterceptor( + grpctracing.UnaryClientTracingInterceptor(), + grpc_prometheus.UnaryClientInterceptor, + grpccorrelation.UnaryClientCorrelationInterceptor( + grpccorrelation.WithClientName("gitlab-workhorse"), ), ), diff --git a/workhorse/internal/gitaly/gitaly_test.go b/workhorse/internal/gitaly/gitaly_test.go index 0ea5da20da3..04d3a0a79aa 100644 --- a/workhorse/internal/gitaly/gitaly_test.go +++ b/workhorse/internal/gitaly/gitaly_test.go @@ -46,15 +46,6 @@ func TestNewRepositoryClient(t *testing.T) { testOutgoingMetadata(t, ctx) } -func TestNewNamespaceClient(t *testing.T) { - ctx, _, err := NewNamespaceClient( - context.Background(), - serverFixture(), - ) - require.NoError(t, err) - testOutgoingMetadata(t, ctx) -} - func TestNewDiffClient(t *testing.T) { ctx, _, err := NewDiffClient( context.Background(), diff --git a/workhorse/internal/gitaly/namespace.go b/workhorse/internal/gitaly/namespace.go deleted file mode 100644 index a9bc2d07a7e..00000000000 --- a/workhorse/internal/gitaly/namespace.go +++ /dev/null @@ -1,8 +0,0 @@ -package gitaly - -import "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" - -// NamespaceClient encapsulates NamespaceService calls -type NamespaceClient struct { - gitalypb.NamespaceServiceClient -} diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go new file mode 100644 index 00000000000..13a9d4cc34f --- /dev/null +++ b/workhorse/internal/goredis/goredis.go @@ -0,0 +1,186 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + redis "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ( + rdb *redis.Client + // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 + errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") +) + +const ( + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... + defaultIdleTimeout = 3 * time.Minute +) + +// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 +// it intercepts the error and tracks it via a Prometheus counter +func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var isSentinel bool + for _, sentinelAddr := range sentinels { + if sentinelAddr == addr { + isSentinel = true + break + } + } + + dialTimeout := 5 * time.Second // go-redis default + destination := "redis" + if isSentinel { + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + destination = "sentinel" + dialTimeout = 500 * time.Millisecond + } + + netDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 5 * time.Minute, + } + + conn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc() + } else { + if !isSentinel { + internalredis.TotalConnections.Inc() + } + } + + return conn, err + } +} + +// implements the redis.Hook interface for instrumentation +type sentinelInstrumentationHook struct{} + +func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := next(ctx, network, addr) + if err != nil && err.Error() == errSentinelMasterAddr.Error() { + // check for non-dialer error + internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc() + } + return conn, err + } +} + +func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return next(ctx, cmd) + } +} + +func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + return next(ctx, cmds) + } +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) error { + if cfg == nil { + return nil + } + + var err error + + if len(cfg.Sentinel) > 0 { + rdb = configureSentinel(cfg) + } else { + rdb, err = configureRedis(cfg) + } + + return err +} + +func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { + if cfg.URL.Scheme == "tcp" { + cfg.URL.Scheme = "redis" + } + + opt, err := redis.ParseURL(cfg.URL.String()) + if err != nil { + return nil, err + } + + opt.DB = getOrDefault(cfg.DB, 0) + opt.Password = cfg.Password + + opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) + opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout + + opt.Dialer = createDialer([]string{}) + + return redis.NewClient(opt), nil +} + +func configureSentinel(cfg *config.RedisConfig) *redis.Client { + sentinels := make([]string, len(cfg.Sentinel)) + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) + } + + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + SentinelPassword: cfg.SentinelPassword, + DB: getOrDefault(cfg.DB, 0), + + PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), + MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), + ConnMaxIdleTime: defaultIdleTimeout, + + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + + Dialer: createDialer(sentinels), + }) + + client.AddHook(sentinelInstrumentationHook{}) + + return client +} + +func getOrDefault(ptr *int, val int) int { + if ptr != nil { + return *ptr + } + return val +} diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go new file mode 100644 index 00000000000..6b281229ea4 --- /dev/null +++ b/workhorse/internal/goredis/goredis_test.go @@ -0,0 +1,107 @@ +package goredis + +import ( + "context" + "net" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" +) + +func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + + require.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.Nil(t, err) + connectReceived.Store(true) + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + +func TestConfigureNoConfig(t *testing.T) { + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") +} + +func TestConfigureValidConfigX(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := atomic.Value{} + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) + + rdb = nil + }) + } +} + +func TestConnectToSentinel(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := atomic.Value{} + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + cfg := &config.RedisConfig{Sentinel: sentinelUrls} + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) + + rdb = nil + }) + } +} diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go new file mode 100644 index 00000000000..741bfb17652 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher.go @@ -0,0 +1,236 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/jpillora/backoff" + "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +type KeyWatcher struct { + mu sync.Mutex + subscribers map[string][]chan string + shutdown chan struct{} + reconnectBackoff backoff.Backoff + redisConn *redis.Client + conn *redis.PubSub +} + +func NewKeyWatcher() *KeyWatcher { + return &KeyWatcher{ + shutdown: make(chan struct{}), + reconnectBackoff: backoff.Backoff{ + Min: 100 * time.Millisecond, + Max: 60 * time.Second, + Factor: 2, + Jitter: true, + }, + } +} + +const channelPrefix = "workhorse:notifications:" + +func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) } + +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { + kw.mu.Lock() + // We must share kw.conn with the goroutines that call SUBSCRIBE and + // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the + // connection. + kw.conn = pubsub + kw.mu.Unlock() + + defer func() { + kw.mu.Lock() + defer kw.mu.Unlock() + kw.conn.Close() + kw.conn = nil + + // Reset kw.subscribers because it is tied to Redis server side state of + // kw.conn and we just closed that connection. + for _, chans := range kw.subscribers { + for _, ch := range chans { + close(ch) + internalredis.KeyWatchers.Dec() + } + } + kw.subscribers = nil + }() + + for { + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *redis.Subscription: + internalredis.RedisSubscriptions.Set(float64(msg.Count)) + case *redis.Pong: + // Ignore. + case *redis.Message: + internalredis.TotalMessages.Inc() + internalredis.ReceivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + } + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + return nil + } + } +} + +func (kw *KeyWatcher) Process(client *redis.Client) { + log.Info("keywatcher: starting process loop") + + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() + + for { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { + log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() + time.Sleep(kw.reconnectBackoff.Duration()) + continue + } + + kw.reconnectBackoff.Reset() + + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() + } + } +} + +func (kw *KeyWatcher) Shutdown() { + log.Info("keywatcher: shutting down") + + kw.mu.Lock() + defer kw.mu.Unlock() + + select { + case <-kw.shutdown: + // already closed + default: + close(kw.shutdown) + } +} + +func (kw *KeyWatcher) notifySubscribers(key, value string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chanList, ok := kw.subscribers[key] + if !ok { + countAction("drop-message") + return + } + + countAction("deliver-message") + for _, c := range chanList { + select { + case c <- value: + default: + } + } +} + +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { + kw.mu.Lock() + defer kw.mu.Unlock() + + if kw.conn == nil { + // This can happen because CI long polling is disabled in this Workhorse + // process. It can also be that we are waiting for the pubsub connection + // to be established. Either way it is OK to fail fast. + return errors.New("no redis connection") + } + + if len(kw.subscribers[key]) == 0 { + countAction("create-subscription") + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + return err + } + } + + if kw.subscribers == nil { + kw.subscribers = make(map[string][]chan string) + } + kw.subscribers[key] = append(kw.subscribers[key], notify) + internalredis.KeyWatchers.Inc() + + return nil +} + +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chans, ok := kw.subscribers[key] + if !ok { + // This can happen if the pubsub connection dropped while we were + // waiting. + return + } + + for i, c := range chans { + if notify == c { + kw.subscribers[key] = append(chans[:i], chans[i+1:]...) + internalredis.KeyWatchers.Dec() + break + } + } + if len(kw.subscribers[key]) == 0 { + delete(kw.subscribers, key) + countAction("delete-subscription") + if kw.conn != nil { + kw.conn.Unsubscribe(ctx, channelPrefix+key) + } + } +} + +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) { + notify := make(chan string, 1) + if err := kw.addSubscription(ctx, key, notify); err != nil { + return internalredis.WatchKeyStatusNoChange, err + } + defer kw.delSubscription(ctx, key, notify) + + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + currentValue = "" + } else if err != nil { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) + } + if currentValue != value { + return internalredis.WatchKeyStatusAlreadyChanged, nil + } + + select { + case <-kw.shutdown: + log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") + return internalredis.WatchKeyStatusNoChange, nil + case currentValue := <-notify: + if currentValue == "" { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") + } + if currentValue == value { + return internalredis.WatchKeyStatusNoChange, nil + } + return internalredis.WatchKeyStatusSeenChange, nil + case <-time.After(timeout): + return internalredis.WatchKeyStatusTimeout, nil + } +} diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go new file mode 100644 index 00000000000..b64262dc9c8 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher_test.go @@ -0,0 +1,301 @@ +package goredis + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ctx = context.Background() + +const ( + runnerKey = "runner:build_queue:10" +) + +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) +} + +func (kw *KeyWatcher) countSubscribers(key string) int { + kw.mu.Lock() + defer kw.mu.Unlock() + return len(kw.subscribers[key]) +} + +// Forces a run of the `Process` loop against a mock PubSubConn. +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() + + errC := make(chan error) + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + + require.Eventually(t, func() bool { + kw.mu.Lock() + defer kw.mu.Unlock() + return kw.conn != nil + }, time.Second, time.Millisecond) + close(ready) + + require.Eventually(t, func() bool { + return kw.countSubscribers(runnerKey) == numWatchers + }, time.Second, time.Millisecond) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() + + require.NoError(t, <-errC) +} + +type keyChangeTestCase struct { + desc string + returnValue string + isKeyMissing bool + watchValue string + processedValue string + expectedStatus redis.WatchKeyStatus + timeout time.Duration +} + +func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusAlreadyChanged + { + desc: "sees change with key existing and changed", + returnValue: "somethingelse", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + { + desc: "sees change with key non-existing", + isKeyMissing: true, + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + // WatchKeyStatusTimeout + { + desc: "sees timeout with key existing and unchanged", + returnValue: "something", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + { + desc: "sees timeout with key non-existing and unchanged", + isKeyMissing: true, + watchValue: "", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + kw := NewKeyWatcher() + defer kw.Shutdown() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }) + } +} + +func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusSeenChange + { + desc: "sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "sees change with key non-existing, when watching empty value", + isKeyMissing: true, + watchValue: "", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + // WatchKeyStatusNoChange + { + desc: "sees no change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusNoChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + kw := NewKeyWatcher() + defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + ready := make(chan struct{}) + + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + + kw.processMessages(t, 1, tc.processedValue, ready, wg) + }) + } +} + +func TestKeyChangesParallel(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + { + desc: "massively parallel, sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "massively parallel, sees change with key existing, watching missing keys", + isKeyMissing: true, + watchValue: "", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + runTimes := 100 + + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(runTimes) + ready := make(chan struct{}) + + kw := NewKeyWatcher() + defer kw.Shutdown() + + for i := 0; i < runTimes; i++ { + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + } + + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + }) + } +} + +func TestShutdown(t *testing.T) { + initRdb() + + kw := NewKeyWatcher() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + defer kw.Shutdown() + + rdb.Set(ctx, runnerKey, "something", 0) + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + }() + + go func() { + defer wg.Done() + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond) + + kw.Shutdown() + }() + + wg.Wait() + + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond) + + // Adding a key after the shutdown should result in an immediate response + var val redis.WatchKeyStatus + var err error + done := make(chan struct{}) + go func() { + val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + close(done) + }() + + select { + case <-done: + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for WatchKey") + } +} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index 2fd0753c3c9..8f1772a9195 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -37,32 +37,32 @@ func NewKeyWatcher() *KeyWatcher { } var ( - keyWatchers = promauto.NewGauge( + KeyWatchers = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_keywatchers", Help: "The number of keys that is being watched by gitlab-workhorse", }, ) - redisSubscriptions = promauto.NewGauge( + RedisSubscriptions = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_redis_subscriptions", Help: "Current number of keywatcher Redis pubsub subscriptions", }, ) - totalMessages = promauto.NewCounter( + TotalMessages = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_total_messages", Help: "How many messages gitlab-workhorse has received in total on pubsub.", }, ) - totalActions = promauto.NewCounterVec( + TotalActions = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_actions_total", Help: "Counts of various keywatcher actions", }, []string{"action"}, ) - receivedBytes = promauto.NewCounter( + ReceivedBytes = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_received_bytes_total", Help: "How many bytes of messages gitlab-workhorse has received in total on pubsub.", @@ -72,7 +72,7 @@ var ( const channelPrefix = "workhorse:notifications:" -func countAction(action string) { totalActions.WithLabelValues(action).Add(1) } +func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) } func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { kw.mu.Lock() @@ -93,7 +93,7 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { for _, chans := range kw.subscribers { for _, ch := range chans { close(ch) - keyWatchers.Dec() + KeyWatchers.Dec() } } kw.subscribers = nil @@ -102,13 +102,13 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { for { switch v := kw.conn.Receive().(type) { case redis.Message: - totalMessages.Inc() - receivedBytes.Add(float64(len(v.Data))) + TotalMessages.Inc() + ReceivedBytes.Add(float64(len(v.Data))) if strings.HasPrefix(v.Channel, channelPrefix) { kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) } case redis.Subscription: - redisSubscriptions.Set(float64(v.Count)) + RedisSubscriptions.Set(float64(v.Count)) case error: log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() // Intermittent error, return nil so that it doesn't wait before reconnect @@ -205,7 +205,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { kw.subscribers = make(map[string][]chan string) } kw.subscribers[key] = append(kw.subscribers[key], notify) - keyWatchers.Inc() + KeyWatchers.Inc() return nil } @@ -224,7 +224,7 @@ func (kw *KeyWatcher) delSubscription(key string, notify chan string) { for i, c := range chans { if notify == c { kw.subscribers[key] = append(chans[:i], chans[i+1:]...) - keyWatchers.Dec() + KeyWatchers.Dec() break } } diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index 03118cfcef6..c79e1e56b3a 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -45,14 +45,14 @@ const ( ) var ( - totalConnections = promauto.NewCounter( + TotalConnections = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_total_connections", Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", }, ) - errorCounter = promauto.NewCounterVec( + ErrorCounter = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_errors", Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", @@ -100,7 +100,7 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { } if err != nil { - errorCounter.WithLabelValues("dial", "sentinel").Inc() + ErrorCounter.WithLabelValues("dial", "sentinel").Inc() return nil, err } return c, nil @@ -159,7 +159,7 @@ func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { return func() (redis.Conn, error) { address, err := sntnl.MasterAddr() if err != nil { - errorCounter.WithLabelValues("master", "sentinel").Inc() + ErrorCounter.WithLabelValues("master", "sentinel").Inc() return nil, err } dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) @@ -214,9 +214,9 @@ func countDialer(dialer redisDialerFunc) redisDialerFunc { return func() (redis.Conn, error) { c, err := dialer() if err != nil { - errorCounter.WithLabelValues("dial", "redis").Inc() + ErrorCounter.WithLabelValues("dial", "redis").Inc() } else { - totalConnections.Inc() + TotalConnections.Inc() } return c, err } diff --git a/workhorse/main.go b/workhorse/main.go index ca9b86de528..9ba213d47d3 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -17,8 +17,10 @@ import ( "gitlab.com/gitlab-org/labkit/monitoring" "gitlab.com/gitlab-org/labkit/tracing" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/secret" @@ -224,9 +226,32 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) keyWatcher := redis.NewKeyWatcher() - if cfg.Redis != nil { - redis.Configure(cfg.Redis, redis.DefaultDialFunc) - go keyWatcher.Process() + + var watchKeyFn builds.WatchKeyHandler + var goredisKeyWatcher *goredis.KeyWatcher + + if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { + log.Info("Using redis/go-redis") + + goredisKeyWatcher = goredis.NewKeyWatcher() + if err := goredis.Configure(cfg.Redis); err != nil { + log.WithError(err).Error("unable to configure redis client") + } + + if rdb := goredis.GetRedisClient(); rdb != nil { + go goredisKeyWatcher.Process(rdb) + } + + watchKeyFn = goredisKeyWatcher.WatchKey + } else { + log.Info("Using gomodule/redigo") + + if cfg.Redis != nil { + redis.Configure(cfg.Redis, redis.DefaultDialFunc) + go keyWatcher.Process() + } + + watchKeyFn = keyWatcher.WatchKey } if err := cfg.RegisterGoCloudURLOpeners(); err != nil { @@ -241,7 +266,7 @@ func run(boot bootConfig, cfg config.Config) error { gitaly.InitializeSidechannelRegistry(accessLogger) - up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, keyWatcher.WatchKey)) + up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, watchKeyFn)) done := make(chan os.Signal, 1) signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) @@ -275,6 +300,10 @@ func run(boot bootConfig, cfg config.Config) error { ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background defer cancel() + if goredisKeyWatcher != nil { + goredisKeyWatcher.Shutdown() + } + keyWatcher.Shutdown() return srv.Shutdown(ctx) } diff --git a/workhorse/main_test.go b/workhorse/main_test.go index 05834ab5d64..39eaa3ee30b 100644 --- a/workhorse/main_test.go +++ b/workhorse/main_test.go @@ -35,7 +35,6 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/upstream" ) -const scratchDir = "testdata/scratch" const testRepoRoot = "testdata/repo" const testDocumentRoot = "testdata/public" const testAltDocumentRoot = "testdata/alt-public" @@ -45,9 +44,6 @@ var absDocumentRoot string const testRepo = "group/test.git" const testProject = "group/test" -var checkoutDir = path.Join(scratchDir, "test") -var cacheDir = path.Join(scratchDir, "cache") - func TestMain(m *testing.M) { if _, err := os.Stat(path.Join(testRepoRoot, testRepo)); os.IsNotExist(err) { log.WithError(err).Fatal("cannot find test repository. Please run 'make prepare-tests'") @@ -64,9 +60,6 @@ func TestMain(m *testing.M) { } func TestDeniedClone(t *testing.T) { - // Prepare clone directory - require.NoError(t, os.RemoveAll(scratchDir)) - // Prepare test server and backend ts := testAuthServer(t, nil, nil, 403, "Access denied") defer ts.Close() @@ -74,7 +67,7 @@ func TestDeniedClone(t *testing.T) { defer ws.Close() // Do the git clone - cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir) + cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), t.TempDir()) out, err := cloneCmd.CombinedOutput() t.Log(string(out)) require.Error(t, err, "git clone should have failed") @@ -89,7 +82,7 @@ func TestDeniedPush(t *testing.T) { // Perform the git push pushCmd := exec.Command("git", "push", "-v", fmt.Sprintf("%s/%s", ws.URL, testRepo), fmt.Sprintf("master:%s", newBranch())) - pushCmd.Dir = checkoutDir + pushCmd.Dir = t.TempDir() out, err := pushCmd.CombinedOutput() t.Log(string(out)) require.Error(t, err, "git push should have failed") @@ -125,14 +118,12 @@ func TestRegularProjectsAPI(t *testing.T) { func TestAllowedXSendfileDownload(t *testing.T) { contentFilename := "my-content" - prepareDownloadDir(t) allowedXSendfileDownload(t, contentFilename, "foo/uploads/bar") } func TestDeniedXSendfileDownload(t *testing.T) { contentFilename := "my-content" - prepareDownloadDir(t) deniedXSendfileDownload(t, contentFilename, "foo/uploads/bar") } @@ -721,11 +712,6 @@ func setupAltStaticFile(t *testing.T, fpath, content string) { absDocumentRoot = testhelper.SetupStaticFileHelper(t, fpath, content, testAltDocumentRoot) } -func prepareDownloadDir(t *testing.T) { - require.NoError(t, os.RemoveAll(scratchDir)) - require.NoError(t, os.MkdirAll(scratchDir, 0755)) -} - func newBranch() string { return fmt.Sprintf("branch-%d", time.Now().UnixNano()) } @@ -962,7 +948,7 @@ func TestDependencyProxyInjector(t *testing.T) { w.Header().Set("Gitlab-Workhorse-Send-Data", `send-dependency:`+base64.URLEncoding.EncodeToString([]byte(params))) case "/base/upload/authorize": w.Header().Set("Content-Type", api.ResponseContentType) - _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir) + _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, t.TempDir()) require.NoError(t, err) case "/base/upload": w.WriteHeader(tc.finalizeStatus) diff --git a/workhorse/sendfile_test.go b/workhorse/sendfile_test.go index f2b6da4eebd..ed8edf01533 100644 --- a/workhorse/sendfile_test.go +++ b/workhorse/sendfile_test.go @@ -20,7 +20,6 @@ func TestDeniedLfsDownload(t *testing.T) { contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) - prepareDownloadDir(t) deniedXSendfileDownload(t, contentFilename, url) } @@ -28,14 +27,11 @@ func TestAllowedLfsDownload(t *testing.T) { contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) - prepareDownloadDir(t) allowedXSendfileDownload(t, contentFilename, url) } func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { - contentPath := path.Join(cacheDir, contentFilename) - prepareDownloadDir(t) - + contentPath := path.Join(t.TempDir(), contentFilename) // Prepare test server and backend ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.WithFields(log.Fields{"method": r.Method, "url": r.URL}).Info("UPSTREAM") @@ -51,7 +47,6 @@ func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath str ws := startWorkhorseServer(ts.URL) defer ws.Close() - require.NoError(t, os.MkdirAll(cacheDir, 0755)) contentBytes := []byte("content") require.NoError(t, os.WriteFile(contentPath, contentBytes, 0644)) @@ -68,8 +63,6 @@ func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath str } func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { - prepareDownloadDir(t) - // Prepare test server and backend ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.WithFields(log.Fields{"method": r.Method, "url": r.URL}).Info("UPSTREAM") diff --git a/workhorse/upload_test.go b/workhorse/upload_test.go index 62a78dd9464..de7c5c6dd52 100644 --- a/workhorse/upload_test.go +++ b/workhorse/upload_test.go @@ -74,9 +74,9 @@ func uploadTestServer(t *testing.T, allowedHashFunctions []string, authorizeTest var err error if len(allowedHashFunctions) == 0 { - _, err = fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir) + _, err = fmt.Fprintf(w, `{"TempPath":"%s"}`, t.TempDir()) } else { - _, err = fmt.Fprintf(w, `{"TempPath":"%s", "UploadHashFunctions": ["%s"]}`, scratchDir, strings.Join(allowedHashFunctions, `","`)) + _, err = fmt.Fprintf(w, `{"TempPath":"%s", "UploadHashFunctions": ["%s"]}`, t.TempDir(), strings.Join(allowedHashFunctions, `","`)) } require.NoError(t, err) @@ -386,7 +386,7 @@ func TestLfsUpload(t *testing.T) { lfsApiResponse := fmt.Sprintf( `{"TempPath":%q, "LfsOid":%q, "LfsSize": %d}`, - scratchDir, oid, len(reqBody), + t.TempDir(), oid, len(reqBody), ) ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { @@ -512,7 +512,7 @@ func packageUploadTestServer(t *testing.T, method string, resource string, reqBo return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { require.Equal(t, r.Method, method) apiResponse := fmt.Sprintf( - `{"TempPath":%q, "Size": %d}`, scratchDir, len(reqBody), + `{"TempPath":%q, "Size": %d}`, t.TempDir(), len(reqBody), ) switch r.RequestURI { case resource + "/authorize": |