Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse')
-rw-r--r--workhorse/config_test.go2
-rw-r--r--workhorse/gitaly_integration_test.go54
-rw-r--r--workhorse/gitaly_test.go17
-rw-r--r--workhorse/go.mod4
-rw-r--r--workhorse/go.sum6
-rw-r--r--workhorse/internal/config/config.go15
-rw-r--r--workhorse/internal/dependencyproxy/dependencyproxy.go73
-rw-r--r--workhorse/internal/dependencyproxy/dependencyproxy_test.go153
-rw-r--r--workhorse/internal/gitaly/gitaly.go35
-rw-r--r--workhorse/internal/gitaly/gitaly_test.go9
-rw-r--r--workhorse/internal/gitaly/namespace.go8
-rw-r--r--workhorse/internal/goredis/goredis.go186
-rw-r--r--workhorse/internal/goredis/goredis_test.go107
-rw-r--r--workhorse/internal/goredis/keywatcher.go236
-rw-r--r--workhorse/internal/goredis/keywatcher_test.go301
-rw-r--r--workhorse/internal/redis/keywatcher.go24
-rw-r--r--workhorse/internal/redis/redis.go12
-rw-r--r--workhorse/main.go37
-rw-r--r--workhorse/main_test.go20
-rw-r--r--workhorse/sendfile_test.go9
-rw-r--r--workhorse/upload_test.go8
21 files changed, 1172 insertions, 144 deletions
diff --git a/workhorse/config_test.go b/workhorse/config_test.go
index a6a1bdd7187..64f0a24d148 100644
--- a/workhorse/config_test.go
+++ b/workhorse/config_test.go
@@ -34,6 +34,7 @@ trusted_cidrs_for_propagation = ["10.0.0.1/8"]
[redis]
password = "redis password"
+SentinelPassword = "sentinel password"
[object_storage]
provider = "test provider"
[image_resizer]
@@ -68,6 +69,7 @@ key = "/path/to/private/key"
// fields in each section; that should happen in the tests of the
// internal/config package.
require.Equal(t, "redis password", cfg.Redis.Password)
+ require.Equal(t, "sentinel password", cfg.Redis.SentinelPassword)
require.Equal(t, "test provider", cfg.ObjectStorageCredentials.Provider)
require.Equal(t, uint32(123), cfg.ImageResizerConfig.MaxScalerProcs, "image resizer max_scaler_procs")
require.Equal(t, []string{"127.0.0.1/8", "192.168.0.1/8"}, cfg.TrustedCIDRsForXForwardedFor)
diff --git a/workhorse/gitaly_integration_test.go b/workhorse/gitaly_integration_test.go
index a7ec0b63b9d..929b9263dfd 100644
--- a/workhorse/gitaly_integration_test.go
+++ b/workhorse/gitaly_integration_test.go
@@ -20,6 +20,8 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
"gitlab.com/gitlab-org/gitaly/v16/streamio"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly"
@@ -76,27 +78,24 @@ func realGitalyOkBody(t *testing.T, gitalyAddress string) *api.Response {
}
func ensureGitalyRepository(t *testing.T, apiResponse *api.Response) error {
- ctx, namespace, err := gitaly.NewNamespaceClient(
- context.Background(),
- apiResponse.GitalyServer,
- )
-
- if err != nil {
- return err
- }
- ctx, repository, err := gitaly.NewRepositoryClient(ctx, apiResponse.GitalyServer)
+ ctx, repository, err := gitaly.NewRepositoryClient(context.Background(), apiResponse.GitalyServer)
if err != nil {
return err
}
// Remove the repository if it already exists, for consistency
- rmNsReq := &gitalypb.RemoveNamespaceRequest{
- StorageName: apiResponse.Repository.StorageName,
- Name: apiResponse.Repository.RelativePath,
- }
- _, err = namespace.RemoveNamespace(ctx, rmNsReq)
- if err != nil {
- return err
+ if _, err := repository.RepositoryServiceClient.RemoveRepository(ctx, &gitalypb.RemoveRepositoryRequest{
+ Repository: &gitalypb.Repository{
+ StorageName: apiResponse.Repository.StorageName,
+ RelativePath: apiResponse.Repository.RelativePath,
+ },
+ }); err != nil {
+ status, ok := status.FromError(err)
+ if !ok || !(status.Code() == codes.NotFound && status.Message() == "repository does not exist") {
+ return fmt.Errorf("remove repository: %w", err)
+ }
+
+ // Repository didn't exist.
}
stream, err := repository.CreateRepositoryFromBundle(ctx)
@@ -139,13 +138,13 @@ func TestAllowedClone(t *testing.T) {
defer ws.Close()
// Do the git clone
- require.NoError(t, os.RemoveAll(scratchDir))
- cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir)
+ tmpDir := t.TempDir()
+ cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir)
runOrFail(t, cloneCmd)
// We may have cloned an 'empty' repository, 'git log' will fail in it
logCmd := exec.Command("git", "log", "-1", "--oneline")
- logCmd.Dir = checkoutDir
+ logCmd.Dir = tmpDir
runOrFail(t, logCmd)
})
}
@@ -167,13 +166,13 @@ func TestAllowedShallowClone(t *testing.T) {
defer ws.Close()
// Shallow git clone (depth 1)
- require.NoError(t, os.RemoveAll(scratchDir))
- cloneCmd := exec.Command("git", "clone", "--depth", "1", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir)
+ tmpDir := t.TempDir()
+ cloneCmd := exec.Command("git", "clone", "--depth", "1", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir)
runOrFail(t, cloneCmd)
// We may have cloned an 'empty' repository, 'git log' will fail in it
logCmd := exec.Command("git", "log", "-1", "--oneline")
- logCmd.Dir = checkoutDir
+ logCmd.Dir = tmpDir
runOrFail(t, logCmd)
})
}
@@ -194,9 +193,14 @@ func TestAllowedPush(t *testing.T) {
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
+ // Do the git clone
+ tmpDir := t.TempDir()
+ cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), tmpDir)
+ runOrFail(t, cloneCmd)
+
// Perform the git push
pushCmd := exec.Command("git", "push", fmt.Sprintf("%s/%s", ws.URL, testRepo), fmt.Sprintf("master:%s", newBranch()))
- pushCmd.Dir = checkoutDir
+ pushCmd.Dir = tmpDir
runOrFail(t, pushCmd)
})
}
@@ -249,7 +253,7 @@ func TestAllowedGetGitArchive(t *testing.T) {
apiResponse := realGitalyOkBody(t, gitalyAddress)
require.NoError(t, ensureGitalyRepository(t, apiResponse))
- archivePath := path.Join(scratchDir, "my/path")
+ archivePath := path.Join(t.TempDir(), "my/path")
archivePrefix := "repo-1"
msg := serializedProtoMessage("GetArchiveRequest", &gitalypb.GetArchiveRequest{
@@ -296,7 +300,7 @@ func TestAllowedGetGitArchiveOldPayload(t *testing.T) {
repo := &apiResponse.Repository
require.NoError(t, ensureGitalyRepository(t, apiResponse))
- archivePath := path.Join(scratchDir, "my/path")
+ archivePath := path.Join(t.TempDir(), "my/path")
archivePrefix := "repo-1"
jsonParams := fmt.Sprintf(
diff --git a/workhorse/gitaly_test.go b/workhorse/gitaly_test.go
index 6bbc67228c3..270c40cb4bc 100644
--- a/workhorse/gitaly_test.go
+++ b/workhorse/gitaly_test.go
@@ -31,9 +31,6 @@ import (
)
func TestFailedCloneNoGitaly(t *testing.T) {
- // Prepare clone directory
- require.NoError(t, os.RemoveAll(scratchDir))
-
authBody := &api.Response{
GL_ID: "user-123",
GL_USERNAME: "username",
@@ -48,7 +45,7 @@ func TestFailedCloneNoGitaly(t *testing.T) {
defer ws.Close()
// Do the git clone
- cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir)
+ cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), t.TempDir())
out, err := cloneCmd.CombinedOutput()
t.Log(string(out))
require.Error(t, err, "git clone should have failed")
@@ -632,7 +629,7 @@ func TestGetArchiveProxiedToGitalySuccessfully(t *testing.T) {
archivePath string
cacheDisabled bool
}{
- {archivePath: path.Join(scratchDir, "my/path"), cacheDisabled: false},
+ {archivePath: path.Join(t.TempDir(), "my/path"), cacheDisabled: false},
{archivePath: "/var/empty/my/path", cacheDisabled: true},
}
@@ -668,7 +665,7 @@ func TestGetArchiveProxiedToGitalyInterruptedStream(t *testing.T) {
archivePath := "my/path"
archivePrefix := "repo-1"
jsonParams := fmt.Sprintf(`{"GitalyServer":{"Address":"%s","Token":""},"GitalyRepository":{"storage_name":"%s","relative_path":"%s"},"ArchivePath":"%s","ArchivePrefix":"%s","CommitId":"%s"}`,
- gitalyAddress, repoStorage, repoRelativePath, path.Join(scratchDir, archivePath), archivePrefix, oid)
+ gitalyAddress, repoStorage, repoRelativePath, path.Join(t.TempDir(), archivePath), archivePrefix, oid)
resp, _, err := doSendDataRequest("/archive.tar.gz", "git-archive", jsonParams)
require.NoError(t, err)
@@ -850,7 +847,13 @@ type combinedServer struct {
}
func startGitalyServer(t *testing.T, finalMessageCode codes.Code) (*combinedServer, string) {
- socketPath := path.Join(scratchDir, fmt.Sprintf("gitaly-%d.sock", rand.Int()))
+ tmpFolder, err := os.MkdirTemp("", "gitaly")
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ os.RemoveAll(tmpFolder)
+ })
+
+ socketPath := path.Join(tmpFolder, fmt.Sprintf("gitaly-%d.sock", rand.Int()))
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
diff --git a/workhorse/go.mod b/workhorse/go.mod
index 17ae3ce12ec..18699787e6e 100644
--- a/workhorse/go.mod
+++ b/workhorse/go.mod
@@ -15,13 +15,13 @@ require (
github.com/golang/protobuf v1.5.3
github.com/gomodule/redigo v2.0.0+incompatible
github.com/gorilla/websocket v1.5.0
- github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/johannesboyne/gofakes3 v0.0.0-20230506070712-04da935ef877
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.2.0
github.com/prometheus/client_golang v1.16.0
github.com/rafaeljusto/redigomock/v3 v3.1.2
+ github.com/redis/go-redis/v9 v9.1.0
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a
github.com/sirupsen/logrus v1.9.3
github.com/smartystreets/goconvey v1.7.2
@@ -65,6 +65,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/client9/reopen v1.0.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
@@ -78,6 +79,7 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect
github.com/googleapis/gax-go/v2 v2.11.0 // indirect
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect
+ github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
github.com/hashicorp/yamux v0.1.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
diff --git a/workhorse/go.sum b/workhorse/go.sum
index f3ceee8b5e8..5163d055187 100644
--- a/workhorse/go.sum
+++ b/workhorse/go.sum
@@ -869,6 +869,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60=
github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk=
+github.com/bsm/ginkgo/v2 v2.9.5 h1:rtVBYPs3+TC5iLUVOis1B9tjLTup7Cj5IfzosKtvTJ0=
+github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
@@ -1068,6 +1070,8 @@ github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8l
github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY=
github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dgryski/go-sip13 v0.0.0-20200911182023-62edffca9245/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/digitalocean/godo v1.78.0/go.mod h1:GBmu8MkjZmNARE7IXRPmkbbnocNN8+uBm0xbEVw2LCs=
@@ -2047,6 +2051,8 @@ github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqz
github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM=
github.com/rakyll/embedmd v0.0.0-20171029212350-c8060a0752a2/go.mod h1:7jOTMgqac46PZcF54q6l2hkLEG8op93fZu61KmxWDV4=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
+github.com/redis/go-redis/v9 v9.1.0 h1:137FnGdk+EQdCbye1FW+qOEcY5S+SpY9T0NiuqvtfMY=
+github.com/redis/go-redis/v9 v9.1.0/go.mod h1:urWj3He21Dj5k4TK1y59xH8Uj6ATueP8AH1cY3lZl4c=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
diff --git a/workhorse/internal/config/config.go b/workhorse/internal/config/config.go
index 687986974a3..3b928d42fe1 100644
--- a/workhorse/internal/config/config.go
+++ b/workhorse/internal/config/config.go
@@ -83,13 +83,14 @@ type GoogleCredentials struct {
}
type RedisConfig struct {
- URL TomlURL
- Sentinel []TomlURL
- SentinelMaster string
- Password string
- DB *int
- MaxIdle *int
- MaxActive *int
+ URL TomlURL
+ Sentinel []TomlURL
+ SentinelMaster string
+ SentinelPassword string
+ Password string
+ DB *int
+ MaxIdle *int
+ MaxActive *int
}
type ImageResizerConfig struct {
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy.go b/workhorse/internal/dependencyproxy/dependencyproxy.go
index e170b001806..dbea3c29aec 100644
--- a/workhorse/internal/dependencyproxy/dependencyproxy.go
+++ b/workhorse/internal/dependencyproxy/dependencyproxy.go
@@ -23,8 +23,15 @@ type Injector struct {
}
type entryParams struct {
- Url string
- Header http.Header
+ Url string
+ Headers http.Header
+ UploadConfig uploadConfig
+}
+
+type uploadConfig struct {
+ Headers http.Header
+ Method string
+ Url string
}
type nullResponseWriter struct {
@@ -55,7 +62,13 @@ func (p *Injector) SetUploadHandler(uploadHandler http.Handler) {
}
func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData string) {
- dependencyResponse, err := p.fetchUrl(r.Context(), sendData)
+ params, err := p.unpackParams(sendData)
+ if err != nil {
+ fail.Request(w, r, err)
+ return
+ }
+
+ dependencyResponse, err := p.fetchUrl(r.Context(), params)
if err != nil {
fail.Request(w, r, err)
return
@@ -70,11 +83,10 @@ func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData strin
w.Header().Set("Content-Length", dependencyResponse.Header.Get("Content-Length"))
teeReader := io.TeeReader(dependencyResponse.Body, w)
- saveFileRequest, err := http.NewRequestWithContext(r.Context(), "POST", r.URL.String()+"/upload", teeReader)
+ saveFileRequest, err := p.newUploadRequest(r.Context(), params, r, teeReader)
if err != nil {
fail.Request(w, r, fmt.Errorf("dependency proxy: failed to create request: %w", err))
}
- saveFileRequest.Header = r.Header.Clone()
// forward headers from dependencyResponse to rails and client
for key, values := range dependencyResponse.Header {
@@ -100,17 +112,56 @@ func (p *Injector) Inject(w http.ResponseWriter, r *http.Request, sendData strin
}
}
-func (p *Injector) fetchUrl(ctx context.Context, sendData string) (*http.Response, error) {
+func (p *Injector) fetchUrl(ctx context.Context, params *entryParams) (*http.Response, error) {
+ r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %w", err)
+ }
+ r.Header = params.Headers
+
+ return httpClient.Do(r)
+}
+
+func (p *Injector) newUploadRequest(ctx context.Context, params *entryParams, originalRequest *http.Request, body io.Reader) (*http.Request, error) {
+ method := p.uploadMethodFrom(params)
+ uploadUrl := p.uploadUrlFrom(params, originalRequest)
+ request, err := http.NewRequestWithContext(ctx, method, uploadUrl, body)
+ if err != nil {
+ return nil, err
+ }
+
+ request.Header = originalRequest.Header.Clone()
+
+ for key, values := range params.UploadConfig.Headers {
+ request.Header.Del(key)
+ for _, value := range values {
+ request.Header.Add(key, value)
+ }
+ }
+
+ return request, nil
+}
+
+func (p *Injector) unpackParams(sendData string) (*entryParams, error) {
var params entryParams
if err := p.Unpack(&params, sendData); err != nil {
return nil, fmt.Errorf("dependency proxy: unpack sendData: %v", err)
}
- r, err := http.NewRequestWithContext(ctx, "GET", params.Url, nil)
- if err != nil {
- return nil, fmt.Errorf("dependency proxy: failed to fetch dependency: %v", err)
+ return &params, nil
+}
+
+func (p *Injector) uploadMethodFrom(params *entryParams) string {
+ if params.UploadConfig.Method != "" {
+ return params.UploadConfig.Method
}
- r.Header = params.Header
+ return http.MethodPost
+}
- return httpClient.Do(r)
+func (p *Injector) uploadUrlFrom(params *entryParams, originalRequest *http.Request) string {
+ if params.UploadConfig.Url != "" {
+ return params.UploadConfig.Url
+ }
+
+ return originalRequest.URL.String() + "/upload"
}
diff --git a/workhorse/internal/dependencyproxy/dependencyproxy_test.go b/workhorse/internal/dependencyproxy/dependencyproxy_test.go
index d893ddc500f..bee74ce0a9e 100644
--- a/workhorse/internal/dependencyproxy/dependencyproxy_test.go
+++ b/workhorse/internal/dependencyproxy/dependencyproxy_test.go
@@ -2,6 +2,7 @@ package dependencyproxy
import (
"encoding/base64"
+ "encoding/json"
"fmt"
"io"
"net/http"
@@ -149,6 +150,158 @@ func TestSuccessfullRequest(t *testing.T) {
require.Equal(t, dockerContentDigest, response.Header().Get("Docker-Content-Digest"))
}
+func TestValidUploadConfiguration(t *testing.T) {
+ content := []byte("content")
+ contentLength := strconv.Itoa(len(content))
+ contentType := "text/plain"
+ testHeader := "test-received-url"
+ originResourceServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set(testHeader, r.URL.Path)
+ w.Header().Set("Content-Length", contentLength)
+ w.Header().Set("Content-Type", contentType)
+ w.Write(content)
+ }))
+
+ testCases := []struct {
+ desc string
+ uploadConfig *uploadConfig
+ expectedConfig uploadConfig
+ }{
+ {
+ desc: "with the default values",
+ expectedConfig: uploadConfig{
+ Method: http.MethodPost,
+ Url: "/target/upload",
+ },
+ }, {
+ desc: "with overriden method",
+ uploadConfig: &uploadConfig{
+ Method: http.MethodPut,
+ },
+ expectedConfig: uploadConfig{
+ Method: http.MethodPut,
+ Url: "/target/upload",
+ },
+ }, {
+ desc: "with overriden url",
+ uploadConfig: &uploadConfig{
+ Url: "http://test.org/overriden/upload",
+ },
+ expectedConfig: uploadConfig{
+ Method: http.MethodPost,
+ Url: "http://test.org/overriden/upload",
+ },
+ }, {
+ desc: "with overriden headers",
+ uploadConfig: &uploadConfig{
+ Headers: map[string][]string{"Private-Token": {"123456789"}},
+ },
+ expectedConfig: uploadConfig{
+ Headers: map[string][]string{"Private-Token": {"123456789"}},
+ Method: http.MethodPost,
+ Url: "/target/upload",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ uploadHandler := &fakeUploadHandler{
+ handler: func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, tc.expectedConfig.Url, r.URL.String())
+ require.Equal(t, tc.expectedConfig.Method, r.Method)
+
+ if tc.expectedConfig.Headers != nil {
+ for k, v := range tc.expectedConfig.Headers {
+ require.Equal(t, v, r.Header[k])
+ }
+ }
+
+ w.WriteHeader(200)
+ },
+ }
+
+ injector := NewInjector()
+ injector.SetUploadHandler(uploadHandler)
+
+ sendData := map[string]interface{}{
+ "Token": "token",
+ "Url": originResourceServer.URL + `/remote/file`,
+ }
+
+ if tc.uploadConfig != nil {
+ sendData["UploadConfig"] = tc.uploadConfig
+ }
+
+ sendDataJsonString, err := json.Marshal(sendData)
+ require.NoError(t, err)
+
+ response := makeRequest(injector, string(sendDataJsonString))
+
+ //checking the response
+ require.Equal(t, 200, response.Code)
+ require.Equal(t, string(content), response.Body.String())
+ // checking remote file request
+ require.Equal(t, "/remote/file", response.Header().Get(testHeader))
+ })
+ }
+}
+
+func TestInvalidUploadConfiguration(t *testing.T) {
+ baseSendData := map[string]interface{}{
+ "Token": "token",
+ "Url": "http://remote.dev/remote/file",
+ }
+ testCases := []struct {
+ desc string
+ sendData map[string]interface{}
+ }{
+ {
+ desc: "with an invalid overriden method",
+ sendData: mergeMap(baseSendData, map[string]interface{}{
+ "UploadConfig": map[string]string{
+ "Method": "TEAPOT",
+ },
+ }),
+ }, {
+ desc: "with an invalid url",
+ sendData: mergeMap(baseSendData, map[string]interface{}{
+ "UploadConfig": map[string]string{
+ "Url": "invalid_url",
+ },
+ }),
+ }, {
+ desc: "with an invalid headers",
+ sendData: mergeMap(baseSendData, map[string]interface{}{
+ "UploadConfig": map[string]interface{}{
+ "Headers": map[string]string{
+ "Private-Token": "not_an_array",
+ },
+ },
+ }),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ sendDataJsonString, err := json.Marshal(tc.sendData)
+ require.NoError(t, err)
+
+ response := makeRequest(NewInjector(), string(sendDataJsonString))
+
+ require.Equal(t, 500, response.Code)
+ require.Equal(t, "Internal Server Error\n", response.Body.String())
+ })
+ }
+}
+
+func mergeMap(from map[string]interface{}, into map[string]interface{}) map[string]interface{} {
+ for k, v := range from {
+ into[k] = v
+ }
+ return into
+}
+
func TestIncorrectSendData(t *testing.T) {
response := makeRequest(NewInjector(), "")
diff --git a/workhorse/internal/gitaly/gitaly.go b/workhorse/internal/gitaly/gitaly.go
index d9dbbdbb605..e4fbad17017 100644
--- a/workhorse/internal/gitaly/gitaly.go
+++ b/workhorse/internal/gitaly/gitaly.go
@@ -7,7 +7,6 @@ import (
"github.com/golang/protobuf/jsonpb" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab/-/issues/324868
"github.com/golang/protobuf/proto" //lint:ignore SA1019 https://gitlab.com/gitlab-org/gitlab/-/issues/324868
- grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -114,16 +113,6 @@ func NewRepositoryClient(ctx context.Context, server api.GitalyServer) (context.
return withOutgoingMetadata(ctx, server), &RepositoryClient{grpcClient}, nil
}
-// NewNamespaceClient is only used by the Gitaly integration tests at present
-func NewNamespaceClient(ctx context.Context, server api.GitalyServer) (context.Context, *NamespaceClient, error) {
- conn, err := getOrCreateConnection(server)
- if err != nil {
- return nil, nil, err
- }
- grpcClient := gitalypb.NewNamespaceServiceClient(conn)
- return withOutgoingMetadata(ctx, server), &NamespaceClient{grpcClient}, nil
-}
-
func NewDiffClient(ctx context.Context, server api.GitalyServer) (context.Context, *DiffClient, error) {
conn, err := getOrCreateConnection(server)
if err != nil {
@@ -173,23 +162,19 @@ func CloseConnections() {
func newConnection(server api.GitalyServer) (*grpc.ClientConn, error) {
connOpts := append(gitalyclient.DefaultDialOpts,
grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(server.Token)),
- grpc.WithStreamInterceptor(
- grpc_middleware.ChainStreamClient(
- grpctracing.StreamClientTracingInterceptor(),
- grpc_prometheus.StreamClientInterceptor,
- grpccorrelation.StreamClientCorrelationInterceptor(
- grpccorrelation.WithClientName("gitlab-workhorse"),
- ),
+ grpc.WithChainStreamInterceptor(
+ grpctracing.StreamClientTracingInterceptor(),
+ grpc_prometheus.StreamClientInterceptor,
+ grpccorrelation.StreamClientCorrelationInterceptor(
+ grpccorrelation.WithClientName("gitlab-workhorse"),
),
),
- grpc.WithUnaryInterceptor(
- grpc_middleware.ChainUnaryClient(
- grpctracing.UnaryClientTracingInterceptor(),
- grpc_prometheus.UnaryClientInterceptor,
- grpccorrelation.UnaryClientCorrelationInterceptor(
- grpccorrelation.WithClientName("gitlab-workhorse"),
- ),
+ grpc.WithChainUnaryInterceptor(
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpc_prometheus.UnaryClientInterceptor,
+ grpccorrelation.UnaryClientCorrelationInterceptor(
+ grpccorrelation.WithClientName("gitlab-workhorse"),
),
),
diff --git a/workhorse/internal/gitaly/gitaly_test.go b/workhorse/internal/gitaly/gitaly_test.go
index 0ea5da20da3..04d3a0a79aa 100644
--- a/workhorse/internal/gitaly/gitaly_test.go
+++ b/workhorse/internal/gitaly/gitaly_test.go
@@ -46,15 +46,6 @@ func TestNewRepositoryClient(t *testing.T) {
testOutgoingMetadata(t, ctx)
}
-func TestNewNamespaceClient(t *testing.T) {
- ctx, _, err := NewNamespaceClient(
- context.Background(),
- serverFixture(),
- )
- require.NoError(t, err)
- testOutgoingMetadata(t, ctx)
-}
-
func TestNewDiffClient(t *testing.T) {
ctx, _, err := NewDiffClient(
context.Background(),
diff --git a/workhorse/internal/gitaly/namespace.go b/workhorse/internal/gitaly/namespace.go
deleted file mode 100644
index a9bc2d07a7e..00000000000
--- a/workhorse/internal/gitaly/namespace.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package gitaly
-
-import "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
-
-// NamespaceClient encapsulates NamespaceService calls
-type NamespaceClient struct {
- gitalypb.NamespaceServiceClient
-}
diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go
new file mode 100644
index 00000000000..13a9d4cc34f
--- /dev/null
+++ b/workhorse/internal/goredis/goredis.go
@@ -0,0 +1,186 @@
+package goredis
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ redis "github.com/redis/go-redis/v9"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+ internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+var (
+ rdb *redis.Client
+ // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626
+ errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable")
+)
+
+const (
+ // Max Idle Connections in the pool.
+ defaultMaxIdle = 1
+ // Max Active Connections in the pool.
+ defaultMaxActive = 1
+ // Timeout for Read operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultReadTimeout = 1 * time.Second
+ // Timeout for Write operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultWriteTimeout = 1 * time.Second
+ // Timeout before killing Idle connections in the pool. 3 minutes seemed good.
+ // If you _actually_ hit this timeout often, you should consider turning of
+ // redis-support since it's not necessary at that point...
+ defaultIdleTimeout = 3 * time.Minute
+)
+
+// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214
+// it intercepts the error and tracks it via a Prometheus counter
+func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var isSentinel bool
+ for _, sentinelAddr := range sentinels {
+ if sentinelAddr == addr {
+ isSentinel = true
+ break
+ }
+ }
+
+ dialTimeout := 5 * time.Second // go-redis default
+ destination := "redis"
+ if isSentinel {
+ // This timeout is recommended for Sentinel-support according to the guidelines.
+ // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
+ // For every address it should try to connect to the Sentinel,
+ // using a short timeout (in the order of a few hundreds of milliseconds).
+ destination = "sentinel"
+ dialTimeout = 500 * time.Millisecond
+ }
+
+ netDialer := &net.Dialer{
+ Timeout: dialTimeout,
+ KeepAlive: 5 * time.Minute,
+ }
+
+ conn, err := netDialer.DialContext(ctx, network, addr)
+ if err != nil {
+ internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc()
+ } else {
+ if !isSentinel {
+ internalredis.TotalConnections.Inc()
+ }
+ }
+
+ return conn, err
+ }
+}
+
+// implements the redis.Hook interface for instrumentation
+type sentinelInstrumentationHook struct{}
+
+func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ conn, err := next(ctx, network, addr)
+ if err != nil && err.Error() == errSentinelMasterAddr.Error() {
+ // check for non-dialer error
+ internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc()
+ }
+ return conn, err
+ }
+}
+
+func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
+ return func(ctx context.Context, cmd redis.Cmder) error {
+ return next(ctx, cmd)
+ }
+}
+
+func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
+ return func(ctx context.Context, cmds []redis.Cmder) error {
+ return next(ctx, cmds)
+ }
+}
+
+func GetRedisClient() *redis.Client {
+ return rdb
+}
+
+// Configure redis-connection
+func Configure(cfg *config.RedisConfig) error {
+ if cfg == nil {
+ return nil
+ }
+
+ var err error
+
+ if len(cfg.Sentinel) > 0 {
+ rdb = configureSentinel(cfg)
+ } else {
+ rdb, err = configureRedis(cfg)
+ }
+
+ return err
+}
+
+func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) {
+ if cfg.URL.Scheme == "tcp" {
+ cfg.URL.Scheme = "redis"
+ }
+
+ opt, err := redis.ParseURL(cfg.URL.String())
+ if err != nil {
+ return nil, err
+ }
+
+ opt.DB = getOrDefault(cfg.DB, 0)
+ opt.Password = cfg.Password
+
+ opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive)
+ opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle)
+ opt.ConnMaxIdleTime = defaultIdleTimeout
+ opt.ReadTimeout = defaultReadTimeout
+ opt.WriteTimeout = defaultWriteTimeout
+
+ opt.Dialer = createDialer([]string{})
+
+ return redis.NewClient(opt), nil
+}
+
+func configureSentinel(cfg *config.RedisConfig) *redis.Client {
+ sentinels := make([]string, len(cfg.Sentinel))
+ for i := range cfg.Sentinel {
+ sentinelDetails := cfg.Sentinel[i]
+ sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port())
+ }
+
+ client := redis.NewFailoverClient(&redis.FailoverOptions{
+ MasterName: cfg.SentinelMaster,
+ SentinelAddrs: sentinels,
+ Password: cfg.Password,
+ SentinelPassword: cfg.SentinelPassword,
+ DB: getOrDefault(cfg.DB, 0),
+
+ PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive),
+ MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle),
+ ConnMaxIdleTime: defaultIdleTimeout,
+
+ ReadTimeout: defaultReadTimeout,
+ WriteTimeout: defaultWriteTimeout,
+
+ Dialer: createDialer(sentinels),
+ })
+
+ client.AddHook(sentinelInstrumentationHook{})
+
+ return client
+}
+
+func getOrDefault(ptr *int, val int) int {
+ if ptr != nil {
+ return *ptr
+ }
+ return val
+}
diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go
new file mode 100644
index 00000000000..6b281229ea4
--- /dev/null
+++ b/workhorse/internal/goredis/goredis_test.go
@@ -0,0 +1,107 @@
+package goredis
+
+import (
+ "context"
+ "net"
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+)
+
+func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+
+ require.Nil(t, err)
+
+ go func() {
+ defer ln.Close()
+ conn, err := ln.Accept()
+ require.Nil(t, err)
+ connectReceived.Store(true)
+ conn.Write([]byte("OK\n"))
+ }()
+
+ return ln.Addr().String()
+}
+
+func TestConfigureNoConfig(t *testing.T) {
+ rdb = nil
+ Configure(nil)
+ require.Nil(t, rdb, "rdb client should be nil")
+}
+
+func TestConfigureValidConfigX(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "redis",
+ },
+ {
+ scheme: "tcp",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := atomic.Value{}
+ a := mockRedisServer(t, &connectReceived)
+
+ parsedURL := helper.URLMustParse(tc.scheme + "://" + a)
+ cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}}
+
+ Configure(cfg)
+
+ require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+
+ // goredis initialise connections lazily
+ rdb.Ping(context.Background())
+ require.True(t, connectReceived.Load().(bool))
+
+ rdb = nil
+ })
+ }
+}
+
+func TestConnectToSentinel(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "redis",
+ },
+ {
+ scheme: "tcp",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := atomic.Value{}
+ a := mockRedisServer(t, &connectReceived)
+
+ addrs := []string{tc.scheme + "://" + a}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ cfg := &config.RedisConfig{Sentinel: sentinelUrls}
+ Configure(cfg)
+
+ require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+
+ // goredis initialise connections lazily
+ rdb.Ping(context.Background())
+ require.True(t, connectReceived.Load().(bool))
+
+ rdb = nil
+ })
+ }
+}
diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go
new file mode 100644
index 00000000000..741bfb17652
--- /dev/null
+++ b/workhorse/internal/goredis/keywatcher.go
@@ -0,0 +1,236 @@
+package goredis
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jpillora/backoff"
+ "github.com/redis/go-redis/v9"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/log"
+ internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+type KeyWatcher struct {
+ mu sync.Mutex
+ subscribers map[string][]chan string
+ shutdown chan struct{}
+ reconnectBackoff backoff.Backoff
+ redisConn *redis.Client
+ conn *redis.PubSub
+}
+
+func NewKeyWatcher() *KeyWatcher {
+ return &KeyWatcher{
+ shutdown: make(chan struct{}),
+ reconnectBackoff: backoff.Backoff{
+ Min: 100 * time.Millisecond,
+ Max: 60 * time.Second,
+ Factor: 2,
+ Jitter: true,
+ },
+ }
+}
+
+const channelPrefix = "workhorse:notifications:"
+
+func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) }
+
+func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error {
+ kw.mu.Lock()
+ // We must share kw.conn with the goroutines that call SUBSCRIBE and
+ // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the
+ // connection.
+ kw.conn = pubsub
+ kw.mu.Unlock()
+
+ defer func() {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ kw.conn.Close()
+ kw.conn = nil
+
+ // Reset kw.subscribers because it is tied to Redis server side state of
+ // kw.conn and we just closed that connection.
+ for _, chans := range kw.subscribers {
+ for _, ch := range chans {
+ close(ch)
+ internalredis.KeyWatchers.Dec()
+ }
+ }
+ kw.subscribers = nil
+ }()
+
+ for {
+ msg, err := kw.conn.Receive(ctx)
+ if err != nil {
+ log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error()
+ return nil
+ }
+
+ switch msg := msg.(type) {
+ case *redis.Subscription:
+ internalredis.RedisSubscriptions.Set(float64(msg.Count))
+ case *redis.Pong:
+ // Ignore.
+ case *redis.Message:
+ internalredis.TotalMessages.Inc()
+ internalredis.ReceivedBytes.Add(float64(len(msg.Payload)))
+ if strings.HasPrefix(msg.Channel, channelPrefix) {
+ kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload))
+ }
+ default:
+ log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error()
+ return nil
+ }
+ }
+}
+
+func (kw *KeyWatcher) Process(client *redis.Client) {
+ log.Info("keywatcher: starting process loop")
+
+ ctx := context.Background() // lint:allow context.Background
+ kw.mu.Lock()
+ kw.redisConn = client
+ kw.mu.Unlock()
+
+ for {
+ pubsub := client.Subscribe(ctx, []string{}...)
+ if err := pubsub.Ping(ctx); err != nil {
+ log.WithError(fmt.Errorf("keywatcher: %v", err)).Error()
+ time.Sleep(kw.reconnectBackoff.Duration())
+ continue
+ }
+
+ kw.reconnectBackoff.Reset()
+
+ if err := kw.receivePubSubStream(ctx, pubsub); err != nil {
+ log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error()
+ }
+ }
+}
+
+func (kw *KeyWatcher) Shutdown() {
+ log.Info("keywatcher: shutting down")
+
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ select {
+ case <-kw.shutdown:
+ // already closed
+ default:
+ close(kw.shutdown)
+ }
+}
+
+func (kw *KeyWatcher) notifySubscribers(key, value string) {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ chanList, ok := kw.subscribers[key]
+ if !ok {
+ countAction("drop-message")
+ return
+ }
+
+ countAction("deliver-message")
+ for _, c := range chanList {
+ select {
+ case c <- value:
+ default:
+ }
+ }
+}
+
+func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ if kw.conn == nil {
+ // This can happen because CI long polling is disabled in this Workhorse
+ // process. It can also be that we are waiting for the pubsub connection
+ // to be established. Either way it is OK to fail fast.
+ return errors.New("no redis connection")
+ }
+
+ if len(kw.subscribers[key]) == 0 {
+ countAction("create-subscription")
+ if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil {
+ return err
+ }
+ }
+
+ if kw.subscribers == nil {
+ kw.subscribers = make(map[string][]chan string)
+ }
+ kw.subscribers[key] = append(kw.subscribers[key], notify)
+ internalredis.KeyWatchers.Inc()
+
+ return nil
+}
+
+func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ chans, ok := kw.subscribers[key]
+ if !ok {
+ // This can happen if the pubsub connection dropped while we were
+ // waiting.
+ return
+ }
+
+ for i, c := range chans {
+ if notify == c {
+ kw.subscribers[key] = append(chans[:i], chans[i+1:]...)
+ internalredis.KeyWatchers.Dec()
+ break
+ }
+ }
+ if len(kw.subscribers[key]) == 0 {
+ delete(kw.subscribers, key)
+ countAction("delete-subscription")
+ if kw.conn != nil {
+ kw.conn.Unsubscribe(ctx, channelPrefix+key)
+ }
+ }
+}
+
+func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) {
+ notify := make(chan string, 1)
+ if err := kw.addSubscription(ctx, key, notify); err != nil {
+ return internalredis.WatchKeyStatusNoChange, err
+ }
+ defer kw.delSubscription(ctx, key, notify)
+
+ currentValue, err := kw.redisConn.Get(ctx, key).Result()
+ if errors.Is(err, redis.Nil) {
+ currentValue = ""
+ } else if err != nil {
+ return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err)
+ }
+ if currentValue != value {
+ return internalredis.WatchKeyStatusAlreadyChanged, nil
+ }
+
+ select {
+ case <-kw.shutdown:
+ log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown")
+ return internalredis.WatchKeyStatusNoChange, nil
+ case currentValue := <-notify:
+ if currentValue == "" {
+ return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
+ }
+ if currentValue == value {
+ return internalredis.WatchKeyStatusNoChange, nil
+ }
+ return internalredis.WatchKeyStatusSeenChange, nil
+ case <-time.After(timeout):
+ return internalredis.WatchKeyStatusTimeout, nil
+ }
+}
diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go
new file mode 100644
index 00000000000..b64262dc9c8
--- /dev/null
+++ b/workhorse/internal/goredis/keywatcher_test.go
@@ -0,0 +1,301 @@
+package goredis
+
+import (
+ "context"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+var ctx = context.Background()
+
+const (
+ runnerKey = "runner:build_queue:10"
+)
+
+func initRdb() {
+ buf, _ := os.ReadFile("../../config.toml")
+ cfg, _ := config.LoadConfig(string(buf))
+ Configure(cfg.Redis)
+}
+
+func (kw *KeyWatcher) countSubscribers(key string) int {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ return len(kw.subscribers[key])
+}
+
+// Forces a run of the `Process` loop against a mock PubSubConn.
+func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) {
+ kw.mu.Lock()
+ kw.redisConn = rdb
+ psc := kw.redisConn.Subscribe(ctx, []string{}...)
+ kw.mu.Unlock()
+
+ errC := make(chan error)
+ go func() { errC <- kw.receivePubSubStream(ctx, psc) }()
+
+ require.Eventually(t, func() bool {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ return kw.conn != nil
+ }, time.Second, time.Millisecond)
+ close(ready)
+
+ require.Eventually(t, func() bool {
+ return kw.countSubscribers(runnerKey) == numWatchers
+ }, time.Second, time.Millisecond)
+
+ // send message after listeners are ready
+ kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value)
+
+ // close subscription after all workers are done
+ wg.Wait()
+ kw.mu.Lock()
+ kw.conn.Close()
+ kw.mu.Unlock()
+
+ require.NoError(t, <-errC)
+}
+
+type keyChangeTestCase struct {
+ desc string
+ returnValue string
+ isKeyMissing bool
+ watchValue string
+ processedValue string
+ expectedStatus redis.WatchKeyStatus
+ timeout time.Duration
+}
+
+func TestKeyChangesInstantReturn(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ // WatchKeyStatusAlreadyChanged
+ {
+ desc: "sees change with key existing and changed",
+ returnValue: "somethingelse",
+ watchValue: "something",
+ expectedStatus: redis.WatchKeyStatusAlreadyChanged,
+ timeout: time.Second,
+ },
+ {
+ desc: "sees change with key non-existing",
+ isKeyMissing: true,
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusAlreadyChanged,
+ timeout: time.Second,
+ },
+ // WatchKeyStatusTimeout
+ {
+ desc: "sees timeout with key existing and unchanged",
+ returnValue: "something",
+ watchValue: "something",
+ expectedStatus: redis.WatchKeyStatusTimeout,
+ timeout: time.Millisecond,
+ },
+ {
+ desc: "sees timeout with key non-existing and unchanged",
+ isKeyMissing: true,
+ watchValue: "",
+ expectedStatus: redis.WatchKeyStatusTimeout,
+ timeout: time.Millisecond,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+
+ // setup
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+ kw.redisConn = rdb
+ kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ })
+ }
+}
+
+func TestKeyChangesWhenWatching(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ // WatchKeyStatusSeenChange
+ {
+ desc: "sees change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ {
+ desc: "sees change with key non-existing, when watching empty value",
+ isKeyMissing: true,
+ watchValue: "",
+ processedValue: "something",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ // WatchKeyStatusNoChange
+ {
+ desc: "sees no change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "something",
+ expectedStatus: redis.WatchKeyStatusNoChange,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+ ready := make(chan struct{})
+
+ go func() {
+ defer wg.Done()
+ <-ready
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ }()
+
+ kw.processMessages(t, 1, tc.processedValue, ready, wg)
+ })
+ }
+}
+
+func TestKeyChangesParallel(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ {
+ desc: "massively parallel, sees change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ {
+ desc: "massively parallel, sees change with key existing, watching missing keys",
+ isKeyMissing: true,
+ watchValue: "",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ runTimes := 100
+
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(runTimes)
+ ready := make(chan struct{})
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+
+ for i := 0; i < runTimes; i++ {
+ go func() {
+ defer wg.Done()
+ <-ready
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ }()
+ }
+
+ kw.processMessages(t, runTimes, tc.processedValue, ready, wg)
+ })
+ }
+}
+
+func TestShutdown(t *testing.T) {
+ initRdb()
+
+ kw := NewKeyWatcher()
+ kw.redisConn = rdb
+ kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+ defer kw.Shutdown()
+
+ rdb.Set(ctx, runnerKey, "something", 0)
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change")
+ }()
+
+ go func() {
+ defer wg.Done()
+ require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond)
+
+ kw.Shutdown()
+ }()
+
+ wg.Wait()
+
+ require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond)
+
+ // Adding a key after the shutdown should result in an immediate response
+ var val redis.WatchKeyStatus
+ var err error
+ done := make(chan struct{})
+ go func() {
+ val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change")
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("timeout waiting for WatchKey")
+ }
+}
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go
index 2fd0753c3c9..8f1772a9195 100644
--- a/workhorse/internal/redis/keywatcher.go
+++ b/workhorse/internal/redis/keywatcher.go
@@ -37,32 +37,32 @@ func NewKeyWatcher() *KeyWatcher {
}
var (
- keyWatchers = promauto.NewGauge(
+ KeyWatchers = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "gitlab_workhorse_keywatcher_keywatchers",
Help: "The number of keys that is being watched by gitlab-workhorse",
},
)
- redisSubscriptions = promauto.NewGauge(
+ RedisSubscriptions = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "gitlab_workhorse_keywatcher_redis_subscriptions",
Help: "Current number of keywatcher Redis pubsub subscriptions",
},
)
- totalMessages = promauto.NewCounter(
+ TotalMessages = promauto.NewCounter(
prometheus.CounterOpts{
Name: "gitlab_workhorse_keywatcher_total_messages",
Help: "How many messages gitlab-workhorse has received in total on pubsub.",
},
)
- totalActions = promauto.NewCounterVec(
+ TotalActions = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gitlab_workhorse_keywatcher_actions_total",
Help: "Counts of various keywatcher actions",
},
[]string{"action"},
)
- receivedBytes = promauto.NewCounter(
+ ReceivedBytes = promauto.NewCounter(
prometheus.CounterOpts{
Name: "gitlab_workhorse_keywatcher_received_bytes_total",
Help: "How many bytes of messages gitlab-workhorse has received in total on pubsub.",
@@ -72,7 +72,7 @@ var (
const channelPrefix = "workhorse:notifications:"
-func countAction(action string) { totalActions.WithLabelValues(action).Add(1) }
+func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) }
func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error {
kw.mu.Lock()
@@ -93,7 +93,7 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error {
for _, chans := range kw.subscribers {
for _, ch := range chans {
close(ch)
- keyWatchers.Dec()
+ KeyWatchers.Dec()
}
}
kw.subscribers = nil
@@ -102,13 +102,13 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error {
for {
switch v := kw.conn.Receive().(type) {
case redis.Message:
- totalMessages.Inc()
- receivedBytes.Add(float64(len(v.Data)))
+ TotalMessages.Inc()
+ ReceivedBytes.Add(float64(len(v.Data)))
if strings.HasPrefix(v.Channel, channelPrefix) {
kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data))
}
case redis.Subscription:
- redisSubscriptions.Set(float64(v.Count))
+ RedisSubscriptions.Set(float64(v.Count))
case error:
log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error()
// Intermittent error, return nil so that it doesn't wait before reconnect
@@ -205,7 +205,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error {
kw.subscribers = make(map[string][]chan string)
}
kw.subscribers[key] = append(kw.subscribers[key], notify)
- keyWatchers.Inc()
+ KeyWatchers.Inc()
return nil
}
@@ -224,7 +224,7 @@ func (kw *KeyWatcher) delSubscription(key string, notify chan string) {
for i, c := range chans {
if notify == c {
kw.subscribers[key] = append(chans[:i], chans[i+1:]...)
- keyWatchers.Dec()
+ KeyWatchers.Dec()
break
}
}
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go
index 03118cfcef6..c79e1e56b3a 100644
--- a/workhorse/internal/redis/redis.go
+++ b/workhorse/internal/redis/redis.go
@@ -45,14 +45,14 @@ const (
)
var (
- totalConnections = promauto.NewCounter(
+ TotalConnections = promauto.NewCounter(
prometheus.CounterOpts{
Name: "gitlab_workhorse_redis_total_connections",
Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
},
)
- errorCounter = promauto.NewCounterVec(
+ ErrorCounter = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "gitlab_workhorse_redis_errors",
Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
@@ -100,7 +100,7 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
}
if err != nil {
- errorCounter.WithLabelValues("dial", "sentinel").Inc()
+ ErrorCounter.WithLabelValues("dial", "sentinel").Inc()
return nil, err
}
return c, nil
@@ -159,7 +159,7 @@ func sentinelDialer(dopts []redis.DialOption) redisDialerFunc {
return func() (redis.Conn, error) {
address, err := sntnl.MasterAddr()
if err != nil {
- errorCounter.WithLabelValues("master", "sentinel").Inc()
+ ErrorCounter.WithLabelValues("master", "sentinel").Inc()
return nil, err
}
dopts = append(dopts, redis.DialNetDial(keepAliveDialer))
@@ -214,9 +214,9 @@ func countDialer(dialer redisDialerFunc) redisDialerFunc {
return func() (redis.Conn, error) {
c, err := dialer()
if err != nil {
- errorCounter.WithLabelValues("dial", "redis").Inc()
+ ErrorCounter.WithLabelValues("dial", "redis").Inc()
} else {
- totalConnections.Inc()
+ TotalConnections.Inc()
}
return c, err
}
diff --git a/workhorse/main.go b/workhorse/main.go
index ca9b86de528..9ba213d47d3 100644
--- a/workhorse/main.go
+++ b/workhorse/main.go
@@ -17,8 +17,10 @@ import (
"gitlab.com/gitlab-org/labkit/monitoring"
"gitlab.com/gitlab-org/labkit/tracing"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/secret"
@@ -224,9 +226,32 @@ func run(boot bootConfig, cfg config.Config) error {
secret.SetPath(boot.secretPath)
keyWatcher := redis.NewKeyWatcher()
- if cfg.Redis != nil {
- redis.Configure(cfg.Redis, redis.DefaultDialFunc)
- go keyWatcher.Process()
+
+ var watchKeyFn builds.WatchKeyHandler
+ var goredisKeyWatcher *goredis.KeyWatcher
+
+ if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" {
+ log.Info("Using redis/go-redis")
+
+ goredisKeyWatcher = goredis.NewKeyWatcher()
+ if err := goredis.Configure(cfg.Redis); err != nil {
+ log.WithError(err).Error("unable to configure redis client")
+ }
+
+ if rdb := goredis.GetRedisClient(); rdb != nil {
+ go goredisKeyWatcher.Process(rdb)
+ }
+
+ watchKeyFn = goredisKeyWatcher.WatchKey
+ } else {
+ log.Info("Using gomodule/redigo")
+
+ if cfg.Redis != nil {
+ redis.Configure(cfg.Redis, redis.DefaultDialFunc)
+ go keyWatcher.Process()
+ }
+
+ watchKeyFn = keyWatcher.WatchKey
}
if err := cfg.RegisterGoCloudURLOpeners(); err != nil {
@@ -241,7 +266,7 @@ func run(boot bootConfig, cfg config.Config) error {
gitaly.InitializeSidechannelRegistry(accessLogger)
- up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, keyWatcher.WatchKey))
+ up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, watchKeyFn))
done := make(chan os.Signal, 1)
signal.Notify(done, syscall.SIGINT, syscall.SIGTERM)
@@ -275,6 +300,10 @@ func run(boot bootConfig, cfg config.Config) error {
ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background
defer cancel()
+ if goredisKeyWatcher != nil {
+ goredisKeyWatcher.Shutdown()
+ }
+
keyWatcher.Shutdown()
return srv.Shutdown(ctx)
}
diff --git a/workhorse/main_test.go b/workhorse/main_test.go
index 05834ab5d64..39eaa3ee30b 100644
--- a/workhorse/main_test.go
+++ b/workhorse/main_test.go
@@ -35,7 +35,6 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/upstream"
)
-const scratchDir = "testdata/scratch"
const testRepoRoot = "testdata/repo"
const testDocumentRoot = "testdata/public"
const testAltDocumentRoot = "testdata/alt-public"
@@ -45,9 +44,6 @@ var absDocumentRoot string
const testRepo = "group/test.git"
const testProject = "group/test"
-var checkoutDir = path.Join(scratchDir, "test")
-var cacheDir = path.Join(scratchDir, "cache")
-
func TestMain(m *testing.M) {
if _, err := os.Stat(path.Join(testRepoRoot, testRepo)); os.IsNotExist(err) {
log.WithError(err).Fatal("cannot find test repository. Please run 'make prepare-tests'")
@@ -64,9 +60,6 @@ func TestMain(m *testing.M) {
}
func TestDeniedClone(t *testing.T) {
- // Prepare clone directory
- require.NoError(t, os.RemoveAll(scratchDir))
-
// Prepare test server and backend
ts := testAuthServer(t, nil, nil, 403, "Access denied")
defer ts.Close()
@@ -74,7 +67,7 @@ func TestDeniedClone(t *testing.T) {
defer ws.Close()
// Do the git clone
- cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), checkoutDir)
+ cloneCmd := exec.Command("git", "clone", fmt.Sprintf("%s/%s", ws.URL, testRepo), t.TempDir())
out, err := cloneCmd.CombinedOutput()
t.Log(string(out))
require.Error(t, err, "git clone should have failed")
@@ -89,7 +82,7 @@ func TestDeniedPush(t *testing.T) {
// Perform the git push
pushCmd := exec.Command("git", "push", "-v", fmt.Sprintf("%s/%s", ws.URL, testRepo), fmt.Sprintf("master:%s", newBranch()))
- pushCmd.Dir = checkoutDir
+ pushCmd.Dir = t.TempDir()
out, err := pushCmd.CombinedOutput()
t.Log(string(out))
require.Error(t, err, "git push should have failed")
@@ -125,14 +118,12 @@ func TestRegularProjectsAPI(t *testing.T) {
func TestAllowedXSendfileDownload(t *testing.T) {
contentFilename := "my-content"
- prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, "foo/uploads/bar")
}
func TestDeniedXSendfileDownload(t *testing.T) {
contentFilename := "my-content"
- prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, "foo/uploads/bar")
}
@@ -721,11 +712,6 @@ func setupAltStaticFile(t *testing.T, fpath, content string) {
absDocumentRoot = testhelper.SetupStaticFileHelper(t, fpath, content, testAltDocumentRoot)
}
-func prepareDownloadDir(t *testing.T) {
- require.NoError(t, os.RemoveAll(scratchDir))
- require.NoError(t, os.MkdirAll(scratchDir, 0755))
-}
-
func newBranch() string {
return fmt.Sprintf("branch-%d", time.Now().UnixNano())
}
@@ -962,7 +948,7 @@ func TestDependencyProxyInjector(t *testing.T) {
w.Header().Set("Gitlab-Workhorse-Send-Data", `send-dependency:`+base64.URLEncoding.EncodeToString([]byte(params)))
case "/base/upload/authorize":
w.Header().Set("Content-Type", api.ResponseContentType)
- _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir)
+ _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, t.TempDir())
require.NoError(t, err)
case "/base/upload":
w.WriteHeader(tc.finalizeStatus)
diff --git a/workhorse/sendfile_test.go b/workhorse/sendfile_test.go
index f2b6da4eebd..ed8edf01533 100644
--- a/workhorse/sendfile_test.go
+++ b/workhorse/sendfile_test.go
@@ -20,7 +20,6 @@ func TestDeniedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
- prepareDownloadDir(t)
deniedXSendfileDownload(t, contentFilename, url)
}
@@ -28,14 +27,11 @@ func TestAllowedLfsDownload(t *testing.T) {
contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80"
url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename)
- prepareDownloadDir(t)
allowedXSendfileDownload(t, contentFilename, url)
}
func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
- contentPath := path.Join(cacheDir, contentFilename)
- prepareDownloadDir(t)
-
+ contentPath := path.Join(t.TempDir(), contentFilename)
// Prepare test server and backend
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{"method": r.Method, "url": r.URL}).Info("UPSTREAM")
@@ -51,7 +47,6 @@ func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath str
ws := startWorkhorseServer(ts.URL)
defer ws.Close()
- require.NoError(t, os.MkdirAll(cacheDir, 0755))
contentBytes := []byte("content")
require.NoError(t, os.WriteFile(contentPath, contentBytes, 0644))
@@ -68,8 +63,6 @@ func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath str
}
func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) {
- prepareDownloadDir(t)
-
// Prepare test server and backend
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{"method": r.Method, "url": r.URL}).Info("UPSTREAM")
diff --git a/workhorse/upload_test.go b/workhorse/upload_test.go
index 62a78dd9464..de7c5c6dd52 100644
--- a/workhorse/upload_test.go
+++ b/workhorse/upload_test.go
@@ -74,9 +74,9 @@ func uploadTestServer(t *testing.T, allowedHashFunctions []string, authorizeTest
var err error
if len(allowedHashFunctions) == 0 {
- _, err = fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir)
+ _, err = fmt.Fprintf(w, `{"TempPath":"%s"}`, t.TempDir())
} else {
- _, err = fmt.Fprintf(w, `{"TempPath":"%s", "UploadHashFunctions": ["%s"]}`, scratchDir, strings.Join(allowedHashFunctions, `","`))
+ _, err = fmt.Fprintf(w, `{"TempPath":"%s", "UploadHashFunctions": ["%s"]}`, t.TempDir(), strings.Join(allowedHashFunctions, `","`))
}
require.NoError(t, err)
@@ -386,7 +386,7 @@ func TestLfsUpload(t *testing.T) {
lfsApiResponse := fmt.Sprintf(
`{"TempPath":%q, "LfsOid":%q, "LfsSize": %d}`,
- scratchDir, oid, len(reqBody),
+ t.TempDir(), oid, len(reqBody),
)
ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
@@ -512,7 +512,7 @@ func packageUploadTestServer(t *testing.T, method string, resource string, reqBo
return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.Method, method)
apiResponse := fmt.Sprintf(
- `{"TempPath":%q, "Size": %d}`, scratchDir, len(reqBody),
+ `{"TempPath":%q, "Size": %d}`, t.TempDir(), len(reqBody),
)
switch r.RequestURI {
case resource + "/authorize":