Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGitLab Bot <gitlab-bot@gitlab.com>2023-11-29 18:43:00 +0300
committerGitLab Bot <gitlab-bot@gitlab.com>2023-11-29 18:43:00 +0300
commit693d15dcb2f33c01a442784c13933da3d1b8d52e (patch)
treed00d6dca2b2a6d164d6d2d7c51d57acc32d92b54
parent94f0f0e4b9fa3f49bf6145100b206c36c0c4eef6 (diff)
Add latest changes from gitlab-org/gitlab@16-6-stable-ee
-rw-r--r--workhorse/go.mod3
-rw-r--r--workhorse/go.sum7
-rw-r--r--workhorse/internal/goredis/goredis.go200
-rw-r--r--workhorse/internal/goredis/goredis_test.go162
-rw-r--r--workhorse/internal/goredis/keywatcher.go236
-rw-r--r--workhorse/internal/goredis/keywatcher_test.go301
-rw-r--r--workhorse/internal/redis/keywatcher.go83
-rw-r--r--workhorse/internal/redis/keywatcher_test.go120
-rw-r--r--workhorse/internal/redis/redis.go336
-rw-r--r--workhorse/internal/redis/redis_test.go222
-rw-r--r--workhorse/main.go41
11 files changed, 1390 insertions, 321 deletions
diff --git a/workhorse/go.mod b/workhorse/go.mod
index 0773904ce21..04f59a5a6f6 100644
--- a/workhorse/go.mod
+++ b/workhorse/go.mod
@@ -5,6 +5,7 @@ go 1.19
require (
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0
github.com/BurntSushi/toml v1.3.2
+ github.com/FZambia/sentinel v1.1.1
github.com/alecthomas/chroma/v2 v2.9.1
github.com/aws/aws-sdk-go v1.45.20
github.com/disintegration/imaging v1.6.2
@@ -12,12 +13,14 @@ require (
github.com/golang-jwt/jwt/v5 v5.0.0
github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f
github.com/golang/protobuf v1.5.3
+ github.com/gomodule/redigo v2.0.0+incompatible
github.com/gorilla/websocket v1.5.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/johannesboyne/gofakes3 v0.0.0-20230914150226-f005f5cc03aa
github.com/jpillora/backoff v1.0.0
github.com/mitchellh/copystructure v1.2.0
github.com/prometheus/client_golang v1.17.0
+ github.com/rafaeljusto/redigomock/v3 v3.1.2
github.com/redis/go-redis/v9 v9.2.1
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a
github.com/sirupsen/logrus v1.9.3
diff --git a/workhorse/go.sum b/workhorse/go.sum
index d35e2948db7..6cf33000fcf 100644
--- a/workhorse/go.sum
+++ b/workhorse/go.sum
@@ -85,6 +85,8 @@ github.com/DataDog/datadog-go v4.4.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3
github.com/DataDog/gostackparse v0.5.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM=
github.com/DataDog/sketches-go v1.0.0 h1:chm5KSXO7kO+ywGWJ0Zs6tdmWU8PBXSbywFVciL6BG4=
github.com/DataDog/sketches-go v1.0.0/go.mod h1:O+XkJHWk9w4hDwY2ZUDU31ZC9sNYlYo8DiFsxjYeo1k=
+github.com/FZambia/sentinel v1.1.1 h1:0ovTimlR7Ldm+wR15GgO+8C2dt7kkn+tm3PQS+Qk3Ek=
+github.com/FZambia/sentinel v1.1.1/go.mod h1:ytL1Am/RLlAoAXG6Kj5LNuw/TRRQrv2rt2FT26vP5gI=
github.com/HdrHistogram/hdrhistogram-go v1.1.1 h1:cJXY5VLMHgejurPjZH6Fo9rIwRGLefBGdiaENZALqrg=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
@@ -229,6 +231,9 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE=
+github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0=
+github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@@ -387,6 +392,8 @@ github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwa
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
github.com/prometheus/prometheus v0.46.0 h1:9JSdXnsuT6YsbODEhSQMwxNkGwPExfmzqG73vCMk/Kw=
github.com/prometheus/prometheus v0.46.0/go.mod h1:10L5IJE5CEsjee1FnOcVswYXlPIscDWWt3IJ2UDYrz4=
+github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqzmDKyZbC99o=
+github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM=
github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg=
github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go
new file mode 100644
index 00000000000..5566e5a3434
--- /dev/null
+++ b/workhorse/internal/goredis/goredis.go
@@ -0,0 +1,200 @@
+package goredis
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ redis "github.com/redis/go-redis/v9"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+ internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+var (
+ rdb *redis.Client
+ // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626
+ errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable")
+)
+
+const (
+ // Max Idle Connections in the pool.
+ defaultMaxIdle = 1
+ // Max Active Connections in the pool.
+ defaultMaxActive = 1
+ // Timeout for Read operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultReadTimeout = 1 * time.Second
+ // Timeout for Write operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultWriteTimeout = 1 * time.Second
+ // Timeout before killing Idle connections in the pool. 3 minutes seemed good.
+ // If you _actually_ hit this timeout often, you should consider turning of
+ // redis-support since it's not necessary at that point...
+ defaultIdleTimeout = 3 * time.Minute
+)
+
+// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214
+// it intercepts the error and tracks it via a Prometheus counter
+func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var isSentinel bool
+ for _, sentinelAddr := range sentinels {
+ if sentinelAddr == addr {
+ isSentinel = true
+ break
+ }
+ }
+
+ dialTimeout := 5 * time.Second // go-redis default
+ destination := "redis"
+ if isSentinel {
+ // This timeout is recommended for Sentinel-support according to the guidelines.
+ // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
+ // For every address it should try to connect to the Sentinel,
+ // using a short timeout (in the order of a few hundreds of milliseconds).
+ destination = "sentinel"
+ dialTimeout = 500 * time.Millisecond
+ }
+
+ netDialer := &net.Dialer{
+ Timeout: dialTimeout,
+ KeepAlive: 5 * time.Minute,
+ }
+
+ conn, err := netDialer.DialContext(ctx, network, addr)
+ if err != nil {
+ internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc()
+ } else {
+ if !isSentinel {
+ internalredis.TotalConnections.Inc()
+ }
+ }
+
+ return conn, err
+ }
+}
+
+// implements the redis.Hook interface for instrumentation
+type sentinelInstrumentationHook struct{}
+
+func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook {
+ return func(ctx context.Context, network, addr string) (net.Conn, error) {
+ conn, err := next(ctx, network, addr)
+ if err != nil && err.Error() == errSentinelMasterAddr.Error() {
+ // check for non-dialer error
+ internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc()
+ }
+ return conn, err
+ }
+}
+
+func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
+ return func(ctx context.Context, cmd redis.Cmder) error {
+ return next(ctx, cmd)
+ }
+}
+
+func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
+ return func(ctx context.Context, cmds []redis.Cmder) error {
+ return next(ctx, cmds)
+ }
+}
+
+func GetRedisClient() *redis.Client {
+ return rdb
+}
+
+// Configure redis-connection
+func Configure(cfg *config.RedisConfig) error {
+ if cfg == nil {
+ return nil
+ }
+
+ var err error
+
+ if len(cfg.Sentinel) > 0 {
+ rdb = configureSentinel(cfg)
+ } else {
+ rdb, err = configureRedis(cfg)
+ }
+
+ return err
+}
+
+func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) {
+ if cfg.URL.Scheme == "tcp" {
+ cfg.URL.Scheme = "redis"
+ }
+
+ opt, err := redis.ParseURL(cfg.URL.String())
+ if err != nil {
+ return nil, err
+ }
+
+ opt.DB = getOrDefault(cfg.DB, 0)
+ opt.Password = cfg.Password
+
+ opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive)
+ opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle)
+ opt.ConnMaxIdleTime = defaultIdleTimeout
+ opt.ReadTimeout = defaultReadTimeout
+ opt.WriteTimeout = defaultWriteTimeout
+
+ opt.Dialer = createDialer([]string{})
+
+ return redis.NewClient(opt), nil
+}
+
+func configureSentinel(cfg *config.RedisConfig) *redis.Client {
+ sentinelPassword, sentinels := sentinelOptions(cfg)
+ client := redis.NewFailoverClient(&redis.FailoverOptions{
+ MasterName: cfg.SentinelMaster,
+ SentinelAddrs: sentinels,
+ Password: cfg.Password,
+ SentinelPassword: sentinelPassword,
+ DB: getOrDefault(cfg.DB, 0),
+
+ PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive),
+ MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle),
+ ConnMaxIdleTime: defaultIdleTimeout,
+
+ ReadTimeout: defaultReadTimeout,
+ WriteTimeout: defaultWriteTimeout,
+
+ Dialer: createDialer(sentinels),
+ })
+
+ client.AddHook(sentinelInstrumentationHook{})
+
+ return client
+}
+
+// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format
+// the order of priority for the passwords is: SentinelPassword -> first password-in-url
+func sentinelOptions(cfg *config.RedisConfig) (string, []string) {
+ sentinels := make([]string, len(cfg.Sentinel))
+ sentinelPassword := cfg.SentinelPassword
+
+ for i := range cfg.Sentinel {
+ sentinelDetails := cfg.Sentinel[i]
+ sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port())
+
+ if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 {
+ // sets password using the first non-empty password
+ sentinelPassword = pw
+ }
+ }
+
+ return sentinelPassword, sentinels
+}
+
+func getOrDefault(ptr *int, val int) int {
+ if ptr != nil {
+ return *ptr
+ }
+ return val
+}
diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go
new file mode 100644
index 00000000000..735b2076b0c
--- /dev/null
+++ b/workhorse/internal/goredis/goredis_test.go
@@ -0,0 +1,162 @@
+package goredis
+
+import (
+ "context"
+ "net"
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+)
+
+func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+
+ require.Nil(t, err)
+
+ go func() {
+ defer ln.Close()
+ conn, err := ln.Accept()
+ require.Nil(t, err)
+ connectReceived.Store(true)
+ conn.Write([]byte("OK\n"))
+ }()
+
+ return ln.Addr().String()
+}
+
+func TestConfigureNoConfig(t *testing.T) {
+ rdb = nil
+ Configure(nil)
+ require.Nil(t, rdb, "rdb client should be nil")
+}
+
+func TestConfigureValidConfigX(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "redis",
+ },
+ {
+ scheme: "tcp",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := atomic.Value{}
+ a := mockRedisServer(t, &connectReceived)
+
+ parsedURL := helper.URLMustParse(tc.scheme + "://" + a)
+ cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}}
+
+ Configure(cfg)
+
+ require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+
+ // goredis initialise connections lazily
+ rdb.Ping(context.Background())
+ require.True(t, connectReceived.Load().(bool))
+
+ rdb = nil
+ })
+ }
+}
+
+func TestConnectToSentinel(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "redis",
+ },
+ {
+ scheme: "tcp",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := atomic.Value{}
+ a := mockRedisServer(t, &connectReceived)
+
+ addrs := []string{tc.scheme + "://" + a}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ cfg := &config.RedisConfig{Sentinel: sentinelUrls}
+ Configure(cfg)
+
+ require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+
+ // goredis initialise connections lazily
+ rdb.Ping(context.Background())
+ require.True(t, connectReceived.Load().(bool))
+
+ rdb = nil
+ })
+ }
+}
+
+func TestSentinelOptions(t *testing.T) {
+ testCases := []struct {
+ description string
+ inputSentinelPassword string
+ inputSentinel []string
+ password string
+ sentinels []string
+ }{
+ {
+ description: "no sentinel passwords",
+ inputSentinel: []string{"tcp://localhost:26480"},
+ sentinels: []string{"localhost:26480"},
+ },
+ {
+ description: "specific sentinel password defined",
+ inputSentinel: []string{"tcp://localhost:26480"},
+ inputSentinelPassword: "password1",
+ sentinels: []string{"localhost:26480"},
+ password: "password1",
+ },
+ {
+ description: "specific sentinel password defined in url",
+ inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"},
+ sentinels: []string{"localhost:26480", "localhost:26481"},
+ password: "password2",
+ },
+ {
+ description: "passwords defined specifically and in url",
+ inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"},
+ sentinels: []string{"localhost:26480", "localhost:26481"},
+ inputSentinelPassword: "password1",
+ password: "password1",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.description, func(t *testing.T) {
+ sentinelUrls := make([]config.TomlURL, len(tc.inputSentinel))
+
+ for i, str := range tc.inputSentinel {
+ parsedURL := helper.URLMustParse(str)
+ sentinelUrls[i] = config.TomlURL{URL: *parsedURL}
+ }
+
+ outputPw, outputSentinels := sentinelOptions(&config.RedisConfig{
+ Sentinel: sentinelUrls,
+ SentinelPassword: tc.inputSentinelPassword,
+ })
+
+ require.Equal(t, tc.password, outputPw)
+ require.Equal(t, tc.sentinels, outputSentinels)
+ })
+ }
+}
diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go
new file mode 100644
index 00000000000..741bfb17652
--- /dev/null
+++ b/workhorse/internal/goredis/keywatcher.go
@@ -0,0 +1,236 @@
+package goredis
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/jpillora/backoff"
+ "github.com/redis/go-redis/v9"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/log"
+ internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+type KeyWatcher struct {
+ mu sync.Mutex
+ subscribers map[string][]chan string
+ shutdown chan struct{}
+ reconnectBackoff backoff.Backoff
+ redisConn *redis.Client
+ conn *redis.PubSub
+}
+
+func NewKeyWatcher() *KeyWatcher {
+ return &KeyWatcher{
+ shutdown: make(chan struct{}),
+ reconnectBackoff: backoff.Backoff{
+ Min: 100 * time.Millisecond,
+ Max: 60 * time.Second,
+ Factor: 2,
+ Jitter: true,
+ },
+ }
+}
+
+const channelPrefix = "workhorse:notifications:"
+
+func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) }
+
+func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error {
+ kw.mu.Lock()
+ // We must share kw.conn with the goroutines that call SUBSCRIBE and
+ // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the
+ // connection.
+ kw.conn = pubsub
+ kw.mu.Unlock()
+
+ defer func() {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ kw.conn.Close()
+ kw.conn = nil
+
+ // Reset kw.subscribers because it is tied to Redis server side state of
+ // kw.conn and we just closed that connection.
+ for _, chans := range kw.subscribers {
+ for _, ch := range chans {
+ close(ch)
+ internalredis.KeyWatchers.Dec()
+ }
+ }
+ kw.subscribers = nil
+ }()
+
+ for {
+ msg, err := kw.conn.Receive(ctx)
+ if err != nil {
+ log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error()
+ return nil
+ }
+
+ switch msg := msg.(type) {
+ case *redis.Subscription:
+ internalredis.RedisSubscriptions.Set(float64(msg.Count))
+ case *redis.Pong:
+ // Ignore.
+ case *redis.Message:
+ internalredis.TotalMessages.Inc()
+ internalredis.ReceivedBytes.Add(float64(len(msg.Payload)))
+ if strings.HasPrefix(msg.Channel, channelPrefix) {
+ kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload))
+ }
+ default:
+ log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error()
+ return nil
+ }
+ }
+}
+
+func (kw *KeyWatcher) Process(client *redis.Client) {
+ log.Info("keywatcher: starting process loop")
+
+ ctx := context.Background() // lint:allow context.Background
+ kw.mu.Lock()
+ kw.redisConn = client
+ kw.mu.Unlock()
+
+ for {
+ pubsub := client.Subscribe(ctx, []string{}...)
+ if err := pubsub.Ping(ctx); err != nil {
+ log.WithError(fmt.Errorf("keywatcher: %v", err)).Error()
+ time.Sleep(kw.reconnectBackoff.Duration())
+ continue
+ }
+
+ kw.reconnectBackoff.Reset()
+
+ if err := kw.receivePubSubStream(ctx, pubsub); err != nil {
+ log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error()
+ }
+ }
+}
+
+func (kw *KeyWatcher) Shutdown() {
+ log.Info("keywatcher: shutting down")
+
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ select {
+ case <-kw.shutdown:
+ // already closed
+ default:
+ close(kw.shutdown)
+ }
+}
+
+func (kw *KeyWatcher) notifySubscribers(key, value string) {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ chanList, ok := kw.subscribers[key]
+ if !ok {
+ countAction("drop-message")
+ return
+ }
+
+ countAction("deliver-message")
+ for _, c := range chanList {
+ select {
+ case c <- value:
+ default:
+ }
+ }
+}
+
+func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ if kw.conn == nil {
+ // This can happen because CI long polling is disabled in this Workhorse
+ // process. It can also be that we are waiting for the pubsub connection
+ // to be established. Either way it is OK to fail fast.
+ return errors.New("no redis connection")
+ }
+
+ if len(kw.subscribers[key]) == 0 {
+ countAction("create-subscription")
+ if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil {
+ return err
+ }
+ }
+
+ if kw.subscribers == nil {
+ kw.subscribers = make(map[string][]chan string)
+ }
+ kw.subscribers[key] = append(kw.subscribers[key], notify)
+ internalredis.KeyWatchers.Inc()
+
+ return nil
+}
+
+func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+
+ chans, ok := kw.subscribers[key]
+ if !ok {
+ // This can happen if the pubsub connection dropped while we were
+ // waiting.
+ return
+ }
+
+ for i, c := range chans {
+ if notify == c {
+ kw.subscribers[key] = append(chans[:i], chans[i+1:]...)
+ internalredis.KeyWatchers.Dec()
+ break
+ }
+ }
+ if len(kw.subscribers[key]) == 0 {
+ delete(kw.subscribers, key)
+ countAction("delete-subscription")
+ if kw.conn != nil {
+ kw.conn.Unsubscribe(ctx, channelPrefix+key)
+ }
+ }
+}
+
+func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) {
+ notify := make(chan string, 1)
+ if err := kw.addSubscription(ctx, key, notify); err != nil {
+ return internalredis.WatchKeyStatusNoChange, err
+ }
+ defer kw.delSubscription(ctx, key, notify)
+
+ currentValue, err := kw.redisConn.Get(ctx, key).Result()
+ if errors.Is(err, redis.Nil) {
+ currentValue = ""
+ } else if err != nil {
+ return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err)
+ }
+ if currentValue != value {
+ return internalredis.WatchKeyStatusAlreadyChanged, nil
+ }
+
+ select {
+ case <-kw.shutdown:
+ log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown")
+ return internalredis.WatchKeyStatusNoChange, nil
+ case currentValue := <-notify:
+ if currentValue == "" {
+ return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
+ }
+ if currentValue == value {
+ return internalredis.WatchKeyStatusNoChange, nil
+ }
+ return internalredis.WatchKeyStatusSeenChange, nil
+ case <-time.After(timeout):
+ return internalredis.WatchKeyStatusTimeout, nil
+ }
+}
diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go
new file mode 100644
index 00000000000..b64262dc9c8
--- /dev/null
+++ b/workhorse/internal/goredis/keywatcher_test.go
@@ -0,0 +1,301 @@
+package goredis
+
+import (
+ "context"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
+)
+
+var ctx = context.Background()
+
+const (
+ runnerKey = "runner:build_queue:10"
+)
+
+func initRdb() {
+ buf, _ := os.ReadFile("../../config.toml")
+ cfg, _ := config.LoadConfig(string(buf))
+ Configure(cfg.Redis)
+}
+
+func (kw *KeyWatcher) countSubscribers(key string) int {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ return len(kw.subscribers[key])
+}
+
+// Forces a run of the `Process` loop against a mock PubSubConn.
+func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) {
+ kw.mu.Lock()
+ kw.redisConn = rdb
+ psc := kw.redisConn.Subscribe(ctx, []string{}...)
+ kw.mu.Unlock()
+
+ errC := make(chan error)
+ go func() { errC <- kw.receivePubSubStream(ctx, psc) }()
+
+ require.Eventually(t, func() bool {
+ kw.mu.Lock()
+ defer kw.mu.Unlock()
+ return kw.conn != nil
+ }, time.Second, time.Millisecond)
+ close(ready)
+
+ require.Eventually(t, func() bool {
+ return kw.countSubscribers(runnerKey) == numWatchers
+ }, time.Second, time.Millisecond)
+
+ // send message after listeners are ready
+ kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value)
+
+ // close subscription after all workers are done
+ wg.Wait()
+ kw.mu.Lock()
+ kw.conn.Close()
+ kw.mu.Unlock()
+
+ require.NoError(t, <-errC)
+}
+
+type keyChangeTestCase struct {
+ desc string
+ returnValue string
+ isKeyMissing bool
+ watchValue string
+ processedValue string
+ expectedStatus redis.WatchKeyStatus
+ timeout time.Duration
+}
+
+func TestKeyChangesInstantReturn(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ // WatchKeyStatusAlreadyChanged
+ {
+ desc: "sees change with key existing and changed",
+ returnValue: "somethingelse",
+ watchValue: "something",
+ expectedStatus: redis.WatchKeyStatusAlreadyChanged,
+ timeout: time.Second,
+ },
+ {
+ desc: "sees change with key non-existing",
+ isKeyMissing: true,
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusAlreadyChanged,
+ timeout: time.Second,
+ },
+ // WatchKeyStatusTimeout
+ {
+ desc: "sees timeout with key existing and unchanged",
+ returnValue: "something",
+ watchValue: "something",
+ expectedStatus: redis.WatchKeyStatusTimeout,
+ timeout: time.Millisecond,
+ },
+ {
+ desc: "sees timeout with key non-existing and unchanged",
+ isKeyMissing: true,
+ watchValue: "",
+ expectedStatus: redis.WatchKeyStatusTimeout,
+ timeout: time.Millisecond,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+
+ // setup
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+ kw.redisConn = rdb
+ kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ })
+ }
+}
+
+func TestKeyChangesWhenWatching(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ // WatchKeyStatusSeenChange
+ {
+ desc: "sees change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ {
+ desc: "sees change with key non-existing, when watching empty value",
+ isKeyMissing: true,
+ watchValue: "",
+ processedValue: "something",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ // WatchKeyStatusNoChange
+ {
+ desc: "sees no change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "something",
+ expectedStatus: redis.WatchKeyStatusNoChange,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+ ready := make(chan struct{})
+
+ go func() {
+ defer wg.Done()
+ <-ready
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ }()
+
+ kw.processMessages(t, 1, tc.processedValue, ready, wg)
+ })
+ }
+}
+
+func TestKeyChangesParallel(t *testing.T) {
+ initRdb()
+
+ testCases := []keyChangeTestCase{
+ {
+ desc: "massively parallel, sees change with key existing",
+ returnValue: "something",
+ watchValue: "something",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ {
+ desc: "massively parallel, sees change with key existing, watching missing keys",
+ isKeyMissing: true,
+ watchValue: "",
+ processedValue: "somethingelse",
+ expectedStatus: redis.WatchKeyStatusSeenChange,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ runTimes := 100
+
+ if !tc.isKeyMissing {
+ rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ }
+
+ defer func() {
+ rdb.FlushDB(ctx)
+ }()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(runTimes)
+ ready := make(chan struct{})
+
+ kw := NewKeyWatcher()
+ defer kw.Shutdown()
+
+ for i := 0; i < runTimes; i++ {
+ go func() {
+ defer wg.Done()
+ <-ready
+ val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, tc.expectedStatus, val, "Expected value")
+ }()
+ }
+
+ kw.processMessages(t, runTimes, tc.processedValue, ready, wg)
+ })
+ }
+}
+
+func TestShutdown(t *testing.T) {
+ initRdb()
+
+ kw := NewKeyWatcher()
+ kw.redisConn = rdb
+ kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+ defer kw.Shutdown()
+
+ rdb.Set(ctx, runnerKey, "something", 0)
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second)
+
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change")
+ }()
+
+ go func() {
+ defer wg.Done()
+ require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond)
+
+ kw.Shutdown()
+ }()
+
+ wg.Wait()
+
+ require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond)
+
+ // Adding a key after the shutdown should result in an immediate response
+ var val redis.WatchKeyStatus
+ var err error
+ done := make(chan struct{})
+ go func() {
+ val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change")
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("timeout waiting for WatchKey")
+ }
+}
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go
index ddb838121b7..8f1772a9195 100644
--- a/workhorse/internal/redis/keywatcher.go
+++ b/workhorse/internal/redis/keywatcher.go
@@ -8,10 +8,10 @@ import (
"sync"
"time"
+ "github.com/gomodule/redigo/redis"
"github.com/jpillora/backoff"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
- "github.com/redis/go-redis/v9"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/log"
)
@@ -21,8 +21,7 @@ type KeyWatcher struct {
subscribers map[string][]chan string
shutdown chan struct{}
reconnectBackoff backoff.Backoff
- redisConn *redis.Client
- conn *redis.PubSub
+ conn *redis.PubSubConn
}
func NewKeyWatcher() *KeyWatcher {
@@ -75,12 +74,12 @@ const channelPrefix = "workhorse:notifications:"
func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) }
-func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error {
+func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error {
kw.mu.Lock()
// We must share kw.conn with the goroutines that call SUBSCRIBE and
// UNSUBSCRIBE because Redis pubsub subscriptions are tied to the
// connection.
- kw.conn = pubsub
+ kw.conn = &redis.PubSubConn{Conn: conn}
kw.mu.Unlock()
defer func() {
@@ -101,49 +100,51 @@ func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.Pub
}()
for {
- msg, err := kw.conn.Receive(ctx)
- if err != nil {
- log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error()
- return nil
- }
-
- switch msg := msg.(type) {
- case *redis.Subscription:
- RedisSubscriptions.Set(float64(msg.Count))
- case *redis.Pong:
- // Ignore.
- case *redis.Message:
+ switch v := kw.conn.Receive().(type) {
+ case redis.Message:
TotalMessages.Inc()
- ReceivedBytes.Add(float64(len(msg.Payload)))
- if strings.HasPrefix(msg.Channel, channelPrefix) {
- kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload))
+ ReceivedBytes.Add(float64(len(v.Data)))
+ if strings.HasPrefix(v.Channel, channelPrefix) {
+ kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data))
}
- default:
- log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error()
+ case redis.Subscription:
+ RedisSubscriptions.Set(float64(v.Count))
+ case error:
+ log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error()
+ // Intermittent error, return nil so that it doesn't wait before reconnect
return nil
}
}
}
-func (kw *KeyWatcher) Process(client *redis.Client) {
- log.Info("keywatcher: starting process loop")
+func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) {
+ conn, err := dialer()
+ if err != nil {
+ return nil, err
+ }
- ctx := context.Background() // lint:allow context.Background
- kw.mu.Lock()
- kw.redisConn = client
- kw.mu.Unlock()
+ // Make sure Redis is actually connected
+ conn.Do("PING")
+ if err := conn.Err(); err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ return conn, nil
+}
+func (kw *KeyWatcher) Process() {
+ log.Info("keywatcher: starting process loop")
for {
- pubsub := client.Subscribe(ctx, []string{}...)
- if err := pubsub.Ping(ctx); err != nil {
+ conn, err := dialPubSub(workerDialFunc)
+ if err != nil {
log.WithError(fmt.Errorf("keywatcher: %v", err)).Error()
time.Sleep(kw.reconnectBackoff.Duration())
continue
}
-
kw.reconnectBackoff.Reset()
- if err := kw.receivePubSubStream(ctx, pubsub); err != nil {
+ if err = kw.receivePubSubStream(conn); err != nil {
log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error()
}
}
@@ -182,7 +183,7 @@ func (kw *KeyWatcher) notifySubscribers(key, value string) {
}
}
-func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error {
+func (kw *KeyWatcher) addSubscription(key string, notify chan string) error {
kw.mu.Lock()
defer kw.mu.Unlock()
@@ -195,7 +196,7 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch
if len(kw.subscribers[key]) == 0 {
countAction("create-subscription")
- if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil {
+ if err := kw.conn.Subscribe(channelPrefix + key); err != nil {
return err
}
}
@@ -209,7 +210,7 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch
return nil
}
-func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) {
+func (kw *KeyWatcher) delSubscription(key string, notify chan string) {
kw.mu.Lock()
defer kw.mu.Unlock()
@@ -231,7 +232,7 @@ func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify ch
delete(kw.subscribers, key)
countAction("delete-subscription")
if kw.conn != nil {
- kw.conn.Unsubscribe(ctx, channelPrefix+key)
+ kw.conn.Unsubscribe(channelPrefix + key)
}
}
}
@@ -251,15 +252,15 @@ const (
WatchKeyStatusNoChange
)
-func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) {
+func (kw *KeyWatcher) WatchKey(_ context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) {
notify := make(chan string, 1)
- if err := kw.addSubscription(ctx, key, notify); err != nil {
+ if err := kw.addSubscription(key, notify); err != nil {
return WatchKeyStatusNoChange, err
}
- defer kw.delSubscription(ctx, key, notify)
+ defer kw.delSubscription(key, notify)
- currentValue, err := kw.redisConn.Get(ctx, key).Result()
- if errors.Is(err, redis.Nil) {
+ currentValue, err := GetString(key)
+ if errors.Is(err, redis.ErrNil) {
currentValue = ""
} else if err != nil {
return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err)
diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go
index bca4ca43a64..3abc1bf1107 100644
--- a/workhorse/internal/redis/keywatcher_test.go
+++ b/workhorse/internal/redis/keywatcher_test.go
@@ -2,14 +2,13 @@ package redis
import (
"context"
- "os"
"sync"
"testing"
"time"
+ "github.com/gomodule/redigo/redis"
+ "github.com/rafaeljusto/redigomock/v3"
"github.com/stretchr/testify/require"
-
- "gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
)
var ctx = context.Background()
@@ -18,10 +17,27 @@ const (
runnerKey = "runner:build_queue:10"
)
-func initRdb() {
- buf, _ := os.ReadFile("../../config.toml")
- cfg, _ := config.LoadConfig(string(buf))
- Configure(cfg.Redis)
+func createSubscriptionMessage(key, data string) []interface{} {
+ return []interface{}{
+ []byte("message"),
+ []byte(key),
+ []byte(data),
+ }
+}
+
+func createSubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("subscribe"),
+ []byte(key),
+ []byte("1"),
+ }
+}
+func createUnsubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("unsubscribe"),
+ []byte(key),
+ []byte("1"),
+ }
}
func (kw *KeyWatcher) countSubscribers(key string) int {
@@ -31,14 +47,17 @@ func (kw *KeyWatcher) countSubscribers(key string) int {
}
// Forces a run of the `Process` loop against a mock PubSubConn.
-func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) {
- kw.mu.Lock()
- kw.redisConn = rdb
- psc := kw.redisConn.Subscribe(ctx, []string{}...)
- kw.mu.Unlock()
+func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}) {
+ psc := redigomock.NewConn()
+ psc.ReceiveWait = true
+
+ channel := channelPrefix + runnerKey
+ psc.Command("SUBSCRIBE", channel).Expect(createSubscribeMessage(channel))
+ psc.Command("UNSUBSCRIBE", channel).Expect(createUnsubscribeMessage(channel))
+ psc.AddSubscriptionMessage(createSubscriptionMessage(channel, value))
errC := make(chan error)
- go func() { errC <- kw.receivePubSubStream(ctx, psc) }()
+ go func() { errC <- kw.receivePubSubStream(psc) }()
require.Eventually(t, func() bool {
kw.mu.Lock()
@@ -50,15 +69,7 @@ func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value strin
require.Eventually(t, func() bool {
return kw.countSubscribers(runnerKey) == numWatchers
}, time.Second, time.Millisecond)
-
- // send message after listeners are ready
- kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value)
-
- // close subscription after all workers are done
- wg.Wait()
- kw.mu.Lock()
- kw.conn.Close()
- kw.mu.Unlock()
+ close(psc.ReceiveNow)
require.NoError(t, <-errC)
}
@@ -74,8 +85,6 @@ type keyChangeTestCase struct {
}
func TestKeyChangesInstantReturn(t *testing.T) {
- initRdb()
-
testCases := []keyChangeTestCase{
// WatchKeyStatusAlreadyChanged
{
@@ -112,20 +121,18 @@ func TestKeyChangesInstantReturn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
- // setup
- if !tc.isKeyMissing {
- rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ if tc.isKeyMissing {
+ conn.Command("GET", runnerKey).ExpectError(redis.ErrNil)
+ } else {
+ conn.Command("GET", runnerKey).Expect(tc.returnValue)
}
- defer func() {
- rdb.FlushDB(ctx)
- }()
-
kw := NewKeyWatcher()
defer kw.Shutdown()
- kw.redisConn = rdb
- kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+ kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()}
val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout)
@@ -136,8 +143,6 @@ func TestKeyChangesInstantReturn(t *testing.T) {
}
func TestKeyChangesWhenWatching(t *testing.T) {
- initRdb()
-
testCases := []keyChangeTestCase{
// WatchKeyStatusSeenChange
{
@@ -166,15 +171,17 @@ func TestKeyChangesWhenWatching(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- if !tc.isKeyMissing {
- rdb.Set(ctx, runnerKey, tc.returnValue, 0)
+ conn, td := setupMockPool()
+ defer td()
+
+ if tc.isKeyMissing {
+ conn.Command("GET", runnerKey).ExpectError(redis.ErrNil)
+ } else {
+ conn.Command("GET", runnerKey).Expect(tc.returnValue)
}
kw := NewKeyWatcher()
defer kw.Shutdown()
- defer func() {
- rdb.FlushDB(ctx)
- }()
wg := &sync.WaitGroup{}
wg.Add(1)
@@ -189,14 +196,13 @@ func TestKeyChangesWhenWatching(t *testing.T) {
require.Equal(t, tc.expectedStatus, val, "Expected value")
}()
- kw.processMessages(t, 1, tc.processedValue, ready, wg)
+ kw.processMessages(t, 1, tc.processedValue, ready)
+ wg.Wait()
})
}
}
func TestKeyChangesParallel(t *testing.T) {
- initRdb()
-
testCases := []keyChangeTestCase{
{
desc: "massively parallel, sees change with key existing",
@@ -218,13 +224,18 @@ func TestKeyChangesParallel(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
runTimes := 100
- if !tc.isKeyMissing {
- rdb.Set(ctx, runnerKey, tc.returnValue, 0)
- }
+ conn, td := setupMockPool()
+ defer td()
- defer func() {
- rdb.FlushDB(ctx)
- }()
+ getCmd := conn.Command("GET", runnerKey)
+
+ for i := 0; i < runTimes; i++ {
+ if tc.isKeyMissing {
+ getCmd = getCmd.ExpectError(redis.ErrNil)
+ } else {
+ getCmd = getCmd.Expect(tc.returnValue)
+ }
+ }
wg := &sync.WaitGroup{}
wg.Add(runTimes)
@@ -244,20 +255,21 @@ func TestKeyChangesParallel(t *testing.T) {
}()
}
- kw.processMessages(t, runTimes, tc.processedValue, ready, wg)
+ kw.processMessages(t, runTimes, tc.processedValue, ready)
+ wg.Wait()
})
}
}
func TestShutdown(t *testing.T) {
- initRdb()
+ conn, td := setupMockPool()
+ defer td()
kw := NewKeyWatcher()
- kw.redisConn = rdb
- kw.conn = kw.redisConn.Subscribe(ctx, []string{}...)
+ kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()}
defer kw.Shutdown()
- rdb.Set(ctx, runnerKey, "something", 0)
+ conn.Command("GET", runnerKey).Expect("something")
wg := &sync.WaitGroup{}
wg.Add(2)
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go
index b528255d25b..c79e1e56b3a 100644
--- a/workhorse/internal/redis/redis.go
+++ b/workhorse/internal/redis/redis.go
@@ -1,39 +1,24 @@
package redis
import (
- "context"
- "errors"
"fmt"
"net"
+ "net/url"
"time"
+ "github.com/FZambia/sentinel"
+ "github.com/gomodule/redigo/redis"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
- redis "github.com/redis/go-redis/v9"
+ "gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
- _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
var (
- rdb *redis.Client
- // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626
- errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable")
-
- TotalConnections = promauto.NewCounter(
- prometheus.CounterOpts{
- Name: "gitlab_workhorse_redis_total_connections",
- Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
- },
- )
-
- ErrorCounter = promauto.NewCounterVec(
- prometheus.CounterOpts{
- Name: "gitlab_workhorse_redis_errors",
- Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
- },
- []string{"type", "dst"},
- )
+ pool *redis.Pool
+ sntnl *sentinel.Sentinel
)
const (
@@ -51,166 +36,241 @@ const (
// If you _actually_ hit this timeout often, you should consider turning of
// redis-support since it's not necessary at that point...
defaultIdleTimeout = 3 * time.Minute
+ // KeepAlivePeriod is to keep a TCP connection open for an extended period of
+ // time without being killed. This is used both in the pool, and in the
+ // worker-connection.
+ // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
+ // information.
+ defaultKeepAlivePeriod = 5 * time.Minute
)
-// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214
-// it intercepts the error and tracks it via a Prometheus counter
-func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) {
- return func(ctx context.Context, network, addr string) (net.Conn, error) {
- var isSentinel bool
- for _, sentinelAddr := range sentinels {
- if sentinelAddr == addr {
- isSentinel = true
- break
- }
- }
+var (
+ TotalConnections = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_total_connections",
+ Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
+ },
+ )
- dialTimeout := 5 * time.Second // go-redis default
- destination := "redis"
- if isSentinel {
+ ErrorCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_errors",
+ Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
+ },
+ []string{"type", "dst"},
+ )
+)
+
+func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
+ if len(urls) == 0 {
+ return nil
+ }
+ var addrs []string
+ for _, url := range urls {
+ h := url.URL.String()
+ log.WithFields(log.Fields{
+ "scheme": url.URL.Scheme,
+ "host": url.URL.Host,
+ }).Printf("redis: using sentinel")
+ addrs = append(addrs, h)
+ }
+ return &sentinel.Sentinel{
+ Addrs: addrs,
+ MasterName: master,
+ Dial: func(addr string) (redis.Conn, error) {
// This timeout is recommended for Sentinel-support according to the guidelines.
// https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
// For every address it should try to connect to the Sentinel,
// using a short timeout (in the order of a few hundreds of milliseconds).
- destination = "sentinel"
- dialTimeout = 500 * time.Millisecond
- }
-
- netDialer := &net.Dialer{
- Timeout: dialTimeout,
- KeepAlive: 5 * time.Minute,
- }
+ timeout := 500 * time.Millisecond
+ url := helper.URLMustParse(addr)
+
+ var c redis.Conn
+ var err error
+ options := []redis.DialOption{
+ redis.DialConnectTimeout(timeout),
+ redis.DialReadTimeout(timeout),
+ redis.DialWriteTimeout(timeout),
+ }
- conn, err := netDialer.DialContext(ctx, network, addr)
- if err != nil {
- ErrorCounter.WithLabelValues("dial", destination).Inc()
- } else {
- if !isSentinel {
- TotalConnections.Inc()
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ c, err = redis.DialURL(addr, options...)
+ } else {
+ c, err = redis.Dial("tcp", url.Host, options...)
}
- }
- return conn, err
+ if err != nil {
+ ErrorCounter.WithLabelValues("dial", "sentinel").Inc()
+ return nil, err
+ }
+ return c, nil
+ },
}
}
-// implements the redis.Hook interface for instrumentation
-type sentinelInstrumentationHook struct{}
-
-func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook {
- return func(ctx context.Context, network, addr string) (net.Conn, error) {
- conn, err := next(ctx, network, addr)
- if err != nil && err.Error() == errSentinelMasterAddr.Error() {
- // check for non-dialer error
- ErrorCounter.WithLabelValues("master", "sentinel").Inc()
- }
- return conn, err
- }
-}
+var poolDialFunc func() (redis.Conn, error)
+var workerDialFunc func() (redis.Conn, error)
-func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
- return func(ctx context.Context, cmd redis.Cmder) error {
- return next(ctx, cmd)
+func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
+ return []redis.DialOption{
+ redis.DialReadTimeout(defaultReadTimeout),
+ redis.DialWriteTimeout(defaultWriteTimeout),
}
}
-func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
- return func(ctx context.Context, cmds []redis.Cmder) error {
- return next(ctx, cmds)
+func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
+ var dopts []redis.DialOption
+ if setTimeouts {
+ dopts = timeoutDialOptions(cfg)
}
-}
-
-func GetRedisClient() *redis.Client {
- return rdb
-}
-
-// Configure redis-connection
-func Configure(cfg *config.RedisConfig) error {
if cfg == nil {
- return nil
+ return dopts
}
-
- var err error
-
- if len(cfg.Sentinel) > 0 {
- rdb = configureSentinel(cfg)
- } else {
- rdb, err = configureRedis(cfg)
+ if cfg.Password != "" {
+ dopts = append(dopts, redis.DialPassword(cfg.Password))
}
-
- return err
+ if cfg.DB != nil {
+ dopts = append(dopts, redis.DialDatabase(*cfg.DB))
+ }
+ return dopts
}
-func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) {
- if cfg.URL.Scheme == "tcp" {
- cfg.URL.Scheme = "redis"
+func keepAliveDialer(network, address string) (net.Conn, error) {
+ addr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
}
-
- opt, err := redis.ParseURL(cfg.URL.String())
+ tc, err := net.DialTCP(network, nil, addr)
if err != nil {
return nil, err
}
+ if err := tc.SetKeepAlive(true); err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil {
+ return nil, err
+ }
+ return tc, nil
+}
- opt.DB = getOrDefault(cfg.DB, 0)
- opt.Password = cfg.Password
-
- opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive)
- opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle)
- opt.ConnMaxIdleTime = defaultIdleTimeout
- opt.ReadTimeout = defaultReadTimeout
- opt.WriteTimeout = defaultWriteTimeout
-
- opt.Dialer = createDialer([]string{})
+type redisDialerFunc func() (redis.Conn, error)
- return redis.NewClient(opt), nil
+func sentinelDialer(dopts []redis.DialOption) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ address, err := sntnl.MasterAddr()
+ if err != nil {
+ ErrorCounter.WithLabelValues("master", "sentinel").Inc()
+ return nil, err
+ }
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer))
+ conn, err := redisDial("tcp", address, dopts...)
+ if err != nil {
+ return nil, err
+ }
+ if !sentinel.TestRole(conn, "master") {
+ conn.Close()
+ return nil, fmt.Errorf("%s is not redis master", address)
+ }
+ return conn, nil
+ }
}
-func configureSentinel(cfg *config.RedisConfig) *redis.Client {
- sentinelPassword, sentinels := sentinelOptions(cfg)
- client := redis.NewFailoverClient(&redis.FailoverOptions{
- MasterName: cfg.SentinelMaster,
- SentinelAddrs: sentinels,
- Password: cfg.Password,
- SentinelPassword: sentinelPassword,
- DB: getOrDefault(cfg.DB, 0),
+func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ if url.Scheme == "unix" {
+ return redisDial(url.Scheme, url.Path, dopts...)
+ }
- PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive),
- MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle),
- ConnMaxIdleTime: defaultIdleTimeout,
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer))
- ReadTimeout: defaultReadTimeout,
- WriteTimeout: defaultWriteTimeout,
+ // redis.DialURL only works with redis[s]:// URLs
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ return redisURLDial(url, dopts...)
+ }
- Dialer: createDialer(sentinels),
- })
+ return redisDial(url.Scheme, url.Host, dopts...)
+ }
+}
- client.AddHook(sentinelInstrumentationHook{})
+func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "scheme": url.Scheme,
+ "address": url.Host,
+ }).Printf("redis: dialing")
- return client
+ return redis.DialURL(url.String(), options...)
}
-// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format
-// the order of priority for the passwords is: SentinelPassword -> first password-in-url
-func sentinelOptions(cfg *config.RedisConfig) (string, []string) {
- sentinels := make([]string, len(cfg.Sentinel))
- sentinelPassword := cfg.SentinelPassword
+func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "network": network,
+ "address": address,
+ }).Printf("redis: dialing")
- for i := range cfg.Sentinel {
- sentinelDetails := cfg.Sentinel[i]
- sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port())
+ return redis.Dial(network, address, options...)
+}
- if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 {
- // sets password using the first non-empty password
- sentinelPassword = pw
+func countDialer(dialer redisDialerFunc) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ c, err := dialer()
+ if err != nil {
+ ErrorCounter.WithLabelValues("dial", "redis").Inc()
+ } else {
+ TotalConnections.Inc()
}
+ return c, err
+ }
+}
+
+// DefaultDialFunc should always used. Only exception is for unit-tests.
+func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
+ dopts := dialOptionsBuilder(cfg, setReadTimeout)
+ if sntnl != nil {
+ return countDialer(sentinelDialer(dopts))
+ }
+ return countDialer(defaultDialer(dopts, cfg.URL.URL))
+}
+
+// Configure redis-connection
+func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
+ if cfg == nil {
+ return
}
+ maxIdle := defaultMaxIdle
+ if cfg.MaxIdle != nil {
+ maxIdle = *cfg.MaxIdle
+ }
+ maxActive := defaultMaxActive
+ if cfg.MaxActive != nil {
+ maxActive = *cfg.MaxActive
+ }
+ sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
+ workerDialFunc = dialFunc(cfg, false)
+ poolDialFunc = dialFunc(cfg, true)
+ pool = &redis.Pool{
+ MaxIdle: maxIdle, // Keep at most X hot connections
+ MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
+ IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
+ Dial: poolDialFunc,
+ Wait: true,
+ }
+}
- return sentinelPassword, sentinels
+// Get a connection for the Redis-pool
+func Get() redis.Conn {
+ if pool != nil {
+ return pool.Get()
+ }
+ return nil
}
-func getOrDefault(ptr *int, val int) int {
- if ptr != nil {
- return *ptr
+// GetString fetches the value of a key in Redis as a string
+func GetString(key string) (string, error) {
+ conn := Get()
+ if conn == nil {
+ return "", fmt.Errorf("redis: could not get connection from pool")
}
- return val
+ defer conn.Close()
+
+ return redis.String(conn.Do("GET", key))
}
diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go
index 6fd6ecbae11..64b3a842a54 100644
--- a/workhorse/internal/redis/redis_test.go
+++ b/workhorse/internal/redis/redis_test.go
@@ -1,18 +1,19 @@
package redis
import (
- "context"
"net"
- "sync/atomic"
"testing"
+ "time"
+ "github.com/gomodule/redigo/redis"
+ "github.com/rafaeljusto/redigomock/v3"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
-func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string {
+func mockRedisServer(t *testing.T, connectReceived *bool) string {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.Nil(t, err)
@@ -21,67 +22,146 @@ func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string {
defer ln.Close()
conn, err := ln.Accept()
require.Nil(t, err)
- connectReceived.Store(true)
+ *connectReceived = true
conn.Write([]byte("OK\n"))
}()
return ln.Addr().String()
}
-func TestConfigureNoConfig(t *testing.T) {
- rdb = nil
- Configure(nil)
- require.Nil(t, rdb, "rdb client should be nil")
+// Setup a MockPool for Redis
+//
+// Returns a teardown-function and the mock-connection
+func setupMockPool() (*redigomock.Conn, func()) {
+ conn := redigomock.NewConn()
+ cfg := &config.RedisConfig{URL: config.TomlURL{}}
+ Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) {
+ return func() (redis.Conn, error) {
+ return conn, nil
+ }
+ })
+ return conn, func() {
+ pool = nil
+ }
}
-func TestConfigureValidConfigX(t *testing.T) {
+func TestDefaultDialFunc(t *testing.T) {
testCases := []struct {
scheme string
}{
{
- scheme: "redis",
+ scheme: "tcp",
},
{
- scheme: "tcp",
+ scheme: "redis",
},
}
for _, tc := range testCases {
t.Run(tc.scheme, func(t *testing.T) {
- connectReceived := atomic.Value{}
+ connectReceived := false
a := mockRedisServer(t, &connectReceived)
parsedURL := helper.URLMustParse(tc.scheme + "://" + a)
cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}}
- Configure(cfg)
+ dialer := DefaultDialFunc(cfg, true)
+ conn, err := dialer()
- require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+ require.Nil(t, err)
+ conn.Receive()
- // goredis initialise connections lazily
- rdb.Ping(context.Background())
- require.True(t, connectReceived.Load().(bool))
-
- rdb = nil
+ require.True(t, connectReceived)
})
}
}
-func TestConnectToSentinel(t *testing.T) {
+func TestConfigureNoConfig(t *testing.T) {
+ pool = nil
+ Configure(nil, nil)
+ require.Nil(t, pool, "Pool should be nil")
+}
+
+func TestConfigureMinimalConfig(t *testing.T) {
+ cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""}
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, 1, pool.MaxIdle)
+ require.Equal(t, 1, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestConfigureFullConfig(t *testing.T) {
+ i, a := 4, 10
+ cfg := &config.RedisConfig{
+ URL: config.TomlURL{},
+ Password: "",
+ MaxIdle: &i,
+ MaxActive: &a,
+ }
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, i, pool.MaxIdle)
+ require.Equal(t, a, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestGetConnFail(t *testing.T) {
+ conn := Get()
+ require.Nil(t, conn, "Expected `conn` to be nil")
+}
+
+func TestGetConnPass(t *testing.T) {
+ _, teardown := setupMockPool()
+ defer teardown()
+ conn := Get()
+ require.NotNil(t, conn, "Expected `conn` to be non-nil")
+}
+
+func TestGetStringPass(t *testing.T) {
+ conn, teardown := setupMockPool()
+ defer teardown()
+ conn.Command("GET", "foobar").Expect("baz")
+ str, err := GetString("foobar")
+
+ require.NoError(t, err, "Expected `err` to be nil")
+ var value string
+ require.IsType(t, value, str, "Expected value to be a string")
+ require.Equal(t, "baz", str, "Expected it to be equal")
+}
+
+func TestGetStringFail(t *testing.T) {
+ _, err := GetString("foobar")
+ require.Error(t, err, "Expected error when not connected to redis")
+}
+
+func TestSentinelConnNoSentinel(t *testing.T) {
+ s := sentinelConn("", []config.TomlURL{})
+
+ require.Nil(t, s, "Sentinel without urls should return nil")
+}
+
+func TestSentinelConnDialURL(t *testing.T) {
testCases := []struct {
scheme string
}{
{
- scheme: "redis",
+ scheme: "tcp",
},
{
- scheme: "tcp",
+ scheme: "redis",
},
}
for _, tc := range testCases {
t.Run(tc.scheme, func(t *testing.T) {
- connectReceived := atomic.Value{}
+ connectReceived := false
a := mockRedisServer(t, &connectReceived)
addrs := []string{tc.scheme + "://" + a}
@@ -92,71 +172,57 @@ func TestConnectToSentinel(t *testing.T) {
sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
}
- cfg := &config.RedisConfig{Sentinel: sentinelUrls}
- Configure(cfg)
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
+
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
- require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil")
+ conn, err := s.Dial(s.Addrs[0])
- // goredis initialise connections lazily
- rdb.Ping(context.Background())
- require.True(t, connectReceived.Load().(bool))
+ require.Nil(t, err)
+ conn.Receive()
- rdb = nil
+ require.True(t, connectReceived)
})
}
}
-func TestSentinelOptions(t *testing.T) {
- testCases := []struct {
- description string
- inputSentinelPassword string
- inputSentinel []string
- password string
- sentinels []string
- }{
- {
- description: "no sentinel passwords",
- inputSentinel: []string{"tcp://localhost:26480"},
- sentinels: []string{"localhost:26480"},
- },
- {
- description: "specific sentinel password defined",
- inputSentinel: []string{"tcp://localhost:26480"},
- inputSentinelPassword: "password1",
- sentinels: []string{"localhost:26480"},
- password: "password1",
- },
- {
- description: "specific sentinel password defined in url",
- inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"},
- sentinels: []string{"localhost:26480", "localhost:26481"},
- password: "password2",
- },
- {
- description: "passwords defined specifically and in url",
- inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"},
- sentinels: []string{"localhost:26480", "localhost:26481"},
- inputSentinelPassword: "password1",
- password: "password1",
- },
+func TestSentinelConnTwoURLs(t *testing.T) {
+ addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
}
- for _, tc := range testCases {
- t.Run(tc.description, func(t *testing.T) {
- sentinelUrls := make([]config.TomlURL, len(tc.inputSentinel))
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
- for i, str := range tc.inputSentinel {
- parsedURL := helper.URLMustParse(str)
- sentinelUrls[i] = config.TomlURL{URL: *parsedURL}
- }
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
+}
- outputPw, outputSentinels := sentinelOptions(&config.RedisConfig{
- Sentinel: sentinelUrls,
- SentinelPassword: tc.inputSentinelPassword,
- })
+func TestDialOptionsBuildersPassword(t *testing.T) {
+ dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false)
+ require.Equal(t, 1, len(dopts))
+}
- require.Equal(t, tc.password, outputPw)
- require.Equal(t, tc.sentinels, outputSentinels)
- })
- }
+func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
+ dopts := dialOptionsBuilder(nil, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
+ dopts := dialOptionsBuilder(nil, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSelectDB(t *testing.T) {
+ db := 3
+ dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false)
+ require.Equal(t, 1, len(dopts))
}
diff --git a/workhorse/main.go b/workhorse/main.go
index 3043ae50a22..9ba213d47d3 100644
--- a/workhorse/main.go
+++ b/workhorse/main.go
@@ -17,8 +17,10 @@ import (
"gitlab.com/gitlab-org/labkit/monitoring"
"gitlab.com/gitlab-org/labkit/tracing"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly"
+ "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/secret"
@@ -223,18 +225,34 @@ func run(boot bootConfig, cfg config.Config) error {
secret.SetPath(boot.secretPath)
- log.Info("Using redis/go-redis")
+ keyWatcher := redis.NewKeyWatcher()
- redisKeyWatcher := redis.NewKeyWatcher()
- if err := redis.Configure(cfg.Redis); err != nil {
- log.WithError(err).Error("unable to configure redis client")
- }
+ var watchKeyFn builds.WatchKeyHandler
+ var goredisKeyWatcher *goredis.KeyWatcher
- if rdb := redis.GetRedisClient(); rdb != nil {
- go redisKeyWatcher.Process(rdb)
- }
+ if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" {
+ log.Info("Using redis/go-redis")
+
+ goredisKeyWatcher = goredis.NewKeyWatcher()
+ if err := goredis.Configure(cfg.Redis); err != nil {
+ log.WithError(err).Error("unable to configure redis client")
+ }
+
+ if rdb := goredis.GetRedisClient(); rdb != nil {
+ go goredisKeyWatcher.Process(rdb)
+ }
- watchKeyFn := redisKeyWatcher.WatchKey
+ watchKeyFn = goredisKeyWatcher.WatchKey
+ } else {
+ log.Info("Using gomodule/redigo")
+
+ if cfg.Redis != nil {
+ redis.Configure(cfg.Redis, redis.DefaultDialFunc)
+ go keyWatcher.Process()
+ }
+
+ watchKeyFn = keyWatcher.WatchKey
+ }
if err := cfg.RegisterGoCloudURLOpeners(); err != nil {
return fmt.Errorf("register cloud credentials: %v", err)
@@ -282,8 +300,11 @@ func run(boot bootConfig, cfg config.Config) error {
ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background
defer cancel()
- redisKeyWatcher.Shutdown()
+ if goredisKeyWatcher != nil {
+ goredisKeyWatcher.Shutdown()
+ }
+ keyWatcher.Shutdown()
return srv.Shutdown(ctx)
}
}