diff options
author | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-14 11:41:52 +0300 |
---|---|---|
committer | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-14 11:41:52 +0300 |
commit | 585826cb22ecea5998a2c2a4675735c94bdeedac (patch) | |
tree | 5b05f0b30d33cef48963609e8a18a4dff260eab3 /workhorse | |
parent | df221d036e5d0c6c0ee4d55b9c97f481ee05dee8 (diff) |
Add latest changes from gitlab-org/gitlab@16-6-stable-eev16.6.0-rc42
Diffstat (limited to 'workhorse')
-rw-r--r-- | workhorse/go.mod | 3 | ||||
-rw-r--r-- | workhorse/go.sum | 7 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis.go | 186 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis_test.go | 107 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher.go | 236 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher_test.go | 301 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher.go | 83 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher_test.go | 120 | ||||
-rw-r--r-- | workhorse/internal/redis/redis.go | 336 | ||||
-rw-r--r-- | workhorse/internal/redis/redis_test.go | 222 | ||||
-rw-r--r-- | workhorse/main.go | 41 |
11 files changed, 321 insertions, 1321 deletions
diff --git a/workhorse/go.mod b/workhorse/go.mod index 04f59a5a6f6..0773904ce21 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -5,7 +5,6 @@ go 1.19 require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 github.com/BurntSushi/toml v1.3.2 - github.com/FZambia/sentinel v1.1.1 github.com/alecthomas/chroma/v2 v2.9.1 github.com/aws/aws-sdk-go v1.45.20 github.com/disintegration/imaging v1.6.2 @@ -13,14 +12,12 @@ require ( github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/golang/protobuf v1.5.3 - github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/johannesboyne/gofakes3 v0.0.0-20230914150226-f005f5cc03aa github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.17.0 - github.com/rafaeljusto/redigomock/v3 v3.1.2 github.com/redis/go-redis/v9 v9.2.1 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 diff --git a/workhorse/go.sum b/workhorse/go.sum index 6cf33000fcf..d35e2948db7 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -85,8 +85,6 @@ github.com/DataDog/datadog-go v4.4.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/gostackparse v0.5.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= github.com/DataDog/sketches-go v1.0.0 h1:chm5KSXO7kO+ywGWJ0Zs6tdmWU8PBXSbywFVciL6BG4= github.com/DataDog/sketches-go v1.0.0/go.mod h1:O+XkJHWk9w4hDwY2ZUDU31ZC9sNYlYo8DiFsxjYeo1k= -github.com/FZambia/sentinel v1.1.1 h1:0ovTimlR7Ldm+wR15GgO+8C2dt7kkn+tm3PQS+Qk3Ek= -github.com/FZambia/sentinel v1.1.1/go.mod h1:ytL1Am/RLlAoAXG6Kj5LNuw/TRRQrv2rt2FT26vP5gI= github.com/HdrHistogram/hdrhistogram-go v1.1.1 h1:cJXY5VLMHgejurPjZH6Fo9rIwRGLefBGdiaENZALqrg= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= @@ -231,9 +229,6 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= -github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= -github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -392,8 +387,6 @@ github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwa github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= github.com/prometheus/prometheus v0.46.0 h1:9JSdXnsuT6YsbODEhSQMwxNkGwPExfmzqG73vCMk/Kw= github.com/prometheus/prometheus v0.46.0/go.mod h1:10L5IJE5CEsjee1FnOcVswYXlPIscDWWt3IJ2UDYrz4= -github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqzmDKyZbC99o= -github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg= github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go deleted file mode 100644 index 13a9d4cc34f..00000000000 --- a/workhorse/internal/goredis/goredis.go +++ /dev/null @@ -1,186 +0,0 @@ -package goredis - -import ( - "context" - "errors" - "fmt" - "net" - "time" - - redis "github.com/redis/go-redis/v9" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" - internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" -) - -var ( - rdb *redis.Client - // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 - errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") -) - -const ( - // Max Idle Connections in the pool. - defaultMaxIdle = 1 - // Max Active Connections in the pool. - defaultMaxActive = 1 - // Timeout for Read operations on the pool. 1 second is technically overkill, - // it's just for sanity. - defaultReadTimeout = 1 * time.Second - // Timeout for Write operations on the pool. 1 second is technically overkill, - // it's just for sanity. - defaultWriteTimeout = 1 * time.Second - // Timeout before killing Idle connections in the pool. 3 minutes seemed good. - // If you _actually_ hit this timeout often, you should consider turning of - // redis-support since it's not necessary at that point... - defaultIdleTimeout = 3 * time.Minute -) - -// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 -// it intercepts the error and tracks it via a Prometheus counter -func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - var isSentinel bool - for _, sentinelAddr := range sentinels { - if sentinelAddr == addr { - isSentinel = true - break - } - } - - dialTimeout := 5 * time.Second // go-redis default - destination := "redis" - if isSentinel { - // This timeout is recommended for Sentinel-support according to the guidelines. - // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel - // For every address it should try to connect to the Sentinel, - // using a short timeout (in the order of a few hundreds of milliseconds). - destination = "sentinel" - dialTimeout = 500 * time.Millisecond - } - - netDialer := &net.Dialer{ - Timeout: dialTimeout, - KeepAlive: 5 * time.Minute, - } - - conn, err := netDialer.DialContext(ctx, network, addr) - if err != nil { - internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc() - } else { - if !isSentinel { - internalredis.TotalConnections.Inc() - } - } - - return conn, err - } -} - -// implements the redis.Hook interface for instrumentation -type sentinelInstrumentationHook struct{} - -func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := next(ctx, network, addr) - if err != nil && err.Error() == errSentinelMasterAddr.Error() { - // check for non-dialer error - internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc() - } - return conn, err - } -} - -func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { - return func(ctx context.Context, cmd redis.Cmder) error { - return next(ctx, cmd) - } -} - -func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { - return func(ctx context.Context, cmds []redis.Cmder) error { - return next(ctx, cmds) - } -} - -func GetRedisClient() *redis.Client { - return rdb -} - -// Configure redis-connection -func Configure(cfg *config.RedisConfig) error { - if cfg == nil { - return nil - } - - var err error - - if len(cfg.Sentinel) > 0 { - rdb = configureSentinel(cfg) - } else { - rdb, err = configureRedis(cfg) - } - - return err -} - -func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { - if cfg.URL.Scheme == "tcp" { - cfg.URL.Scheme = "redis" - } - - opt, err := redis.ParseURL(cfg.URL.String()) - if err != nil { - return nil, err - } - - opt.DB = getOrDefault(cfg.DB, 0) - opt.Password = cfg.Password - - opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) - opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) - opt.ConnMaxIdleTime = defaultIdleTimeout - opt.ReadTimeout = defaultReadTimeout - opt.WriteTimeout = defaultWriteTimeout - - opt.Dialer = createDialer([]string{}) - - return redis.NewClient(opt), nil -} - -func configureSentinel(cfg *config.RedisConfig) *redis.Client { - sentinels := make([]string, len(cfg.Sentinel)) - for i := range cfg.Sentinel { - sentinelDetails := cfg.Sentinel[i] - sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) - } - - client := redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: cfg.SentinelMaster, - SentinelAddrs: sentinels, - Password: cfg.Password, - SentinelPassword: cfg.SentinelPassword, - DB: getOrDefault(cfg.DB, 0), - - PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), - MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), - ConnMaxIdleTime: defaultIdleTimeout, - - ReadTimeout: defaultReadTimeout, - WriteTimeout: defaultWriteTimeout, - - Dialer: createDialer(sentinels), - }) - - client.AddHook(sentinelInstrumentationHook{}) - - return client -} - -func getOrDefault(ptr *int, val int) int { - if ptr != nil { - return *ptr - } - return val -} diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go deleted file mode 100644 index 6b281229ea4..00000000000 --- a/workhorse/internal/goredis/goredis_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package goredis - -import ( - "context" - "net" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" -) - -func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { - ln, err := net.Listen("tcp", "127.0.0.1:0") - - require.Nil(t, err) - - go func() { - defer ln.Close() - conn, err := ln.Accept() - require.Nil(t, err) - connectReceived.Store(true) - conn.Write([]byte("OK\n")) - }() - - return ln.Addr().String() -} - -func TestConfigureNoConfig(t *testing.T) { - rdb = nil - Configure(nil) - require.Nil(t, rdb, "rdb client should be nil") -} - -func TestConfigureValidConfigX(t *testing.T) { - testCases := []struct { - scheme string - }{ - { - scheme: "redis", - }, - { - scheme: "tcp", - }, - } - - for _, tc := range testCases { - t.Run(tc.scheme, func(t *testing.T) { - connectReceived := atomic.Value{} - a := mockRedisServer(t, &connectReceived) - - parsedURL := helper.URLMustParse(tc.scheme + "://" + a) - cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} - - Configure(cfg) - - require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") - - // goredis initialise connections lazily - rdb.Ping(context.Background()) - require.True(t, connectReceived.Load().(bool)) - - rdb = nil - }) - } -} - -func TestConnectToSentinel(t *testing.T) { - testCases := []struct { - scheme string - }{ - { - scheme: "redis", - }, - { - scheme: "tcp", - }, - } - - for _, tc := range testCases { - t.Run(tc.scheme, func(t *testing.T) { - connectReceived := atomic.Value{} - a := mockRedisServer(t, &connectReceived) - - addrs := []string{tc.scheme + "://" + a} - var sentinelUrls []config.TomlURL - - for _, a := range addrs { - parsedURL := helper.URLMustParse(a) - sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) - } - - cfg := &config.RedisConfig{Sentinel: sentinelUrls} - Configure(cfg) - - require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") - - // goredis initialise connections lazily - rdb.Ping(context.Background()) - require.True(t, connectReceived.Load().(bool)) - - rdb = nil - }) - } -} diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go deleted file mode 100644 index 741bfb17652..00000000000 --- a/workhorse/internal/goredis/keywatcher.go +++ /dev/null @@ -1,236 +0,0 @@ -package goredis - -import ( - "context" - "errors" - "fmt" - "strings" - "sync" - "time" - - "github.com/jpillora/backoff" - "github.com/redis/go-redis/v9" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" - internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" -) - -type KeyWatcher struct { - mu sync.Mutex - subscribers map[string][]chan string - shutdown chan struct{} - reconnectBackoff backoff.Backoff - redisConn *redis.Client - conn *redis.PubSub -} - -func NewKeyWatcher() *KeyWatcher { - return &KeyWatcher{ - shutdown: make(chan struct{}), - reconnectBackoff: backoff.Backoff{ - Min: 100 * time.Millisecond, - Max: 60 * time.Second, - Factor: 2, - Jitter: true, - }, - } -} - -const channelPrefix = "workhorse:notifications:" - -func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) } - -func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { - kw.mu.Lock() - // We must share kw.conn with the goroutines that call SUBSCRIBE and - // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the - // connection. - kw.conn = pubsub - kw.mu.Unlock() - - defer func() { - kw.mu.Lock() - defer kw.mu.Unlock() - kw.conn.Close() - kw.conn = nil - - // Reset kw.subscribers because it is tied to Redis server side state of - // kw.conn and we just closed that connection. - for _, chans := range kw.subscribers { - for _, ch := range chans { - close(ch) - internalredis.KeyWatchers.Dec() - } - } - kw.subscribers = nil - }() - - for { - msg, err := kw.conn.Receive(ctx) - if err != nil { - log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() - return nil - } - - switch msg := msg.(type) { - case *redis.Subscription: - internalredis.RedisSubscriptions.Set(float64(msg.Count)) - case *redis.Pong: - // Ignore. - case *redis.Message: - internalredis.TotalMessages.Inc() - internalredis.ReceivedBytes.Add(float64(len(msg.Payload))) - if strings.HasPrefix(msg.Channel, channelPrefix) { - kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) - } - default: - log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() - return nil - } - } -} - -func (kw *KeyWatcher) Process(client *redis.Client) { - log.Info("keywatcher: starting process loop") - - ctx := context.Background() // lint:allow context.Background - kw.mu.Lock() - kw.redisConn = client - kw.mu.Unlock() - - for { - pubsub := client.Subscribe(ctx, []string{}...) - if err := pubsub.Ping(ctx); err != nil { - log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() - time.Sleep(kw.reconnectBackoff.Duration()) - continue - } - - kw.reconnectBackoff.Reset() - - if err := kw.receivePubSubStream(ctx, pubsub); err != nil { - log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() - } - } -} - -func (kw *KeyWatcher) Shutdown() { - log.Info("keywatcher: shutting down") - - kw.mu.Lock() - defer kw.mu.Unlock() - - select { - case <-kw.shutdown: - // already closed - default: - close(kw.shutdown) - } -} - -func (kw *KeyWatcher) notifySubscribers(key, value string) { - kw.mu.Lock() - defer kw.mu.Unlock() - - chanList, ok := kw.subscribers[key] - if !ok { - countAction("drop-message") - return - } - - countAction("deliver-message") - for _, c := range chanList { - select { - case c <- value: - default: - } - } -} - -func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { - kw.mu.Lock() - defer kw.mu.Unlock() - - if kw.conn == nil { - // This can happen because CI long polling is disabled in this Workhorse - // process. It can also be that we are waiting for the pubsub connection - // to be established. Either way it is OK to fail fast. - return errors.New("no redis connection") - } - - if len(kw.subscribers[key]) == 0 { - countAction("create-subscription") - if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { - return err - } - } - - if kw.subscribers == nil { - kw.subscribers = make(map[string][]chan string) - } - kw.subscribers[key] = append(kw.subscribers[key], notify) - internalredis.KeyWatchers.Inc() - - return nil -} - -func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { - kw.mu.Lock() - defer kw.mu.Unlock() - - chans, ok := kw.subscribers[key] - if !ok { - // This can happen if the pubsub connection dropped while we were - // waiting. - return - } - - for i, c := range chans { - if notify == c { - kw.subscribers[key] = append(chans[:i], chans[i+1:]...) - internalredis.KeyWatchers.Dec() - break - } - } - if len(kw.subscribers[key]) == 0 { - delete(kw.subscribers, key) - countAction("delete-subscription") - if kw.conn != nil { - kw.conn.Unsubscribe(ctx, channelPrefix+key) - } - } -} - -func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) { - notify := make(chan string, 1) - if err := kw.addSubscription(ctx, key, notify); err != nil { - return internalredis.WatchKeyStatusNoChange, err - } - defer kw.delSubscription(ctx, key, notify) - - currentValue, err := kw.redisConn.Get(ctx, key).Result() - if errors.Is(err, redis.Nil) { - currentValue = "" - } else if err != nil { - return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) - } - if currentValue != value { - return internalredis.WatchKeyStatusAlreadyChanged, nil - } - - select { - case <-kw.shutdown: - log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") - return internalredis.WatchKeyStatusNoChange, nil - case currentValue := <-notify: - if currentValue == "" { - return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") - } - if currentValue == value { - return internalredis.WatchKeyStatusNoChange, nil - } - return internalredis.WatchKeyStatusSeenChange, nil - case <-time.After(timeout): - return internalredis.WatchKeyStatusTimeout, nil - } -} diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go deleted file mode 100644 index b64262dc9c8..00000000000 --- a/workhorse/internal/goredis/keywatcher_test.go +++ /dev/null @@ -1,301 +0,0 @@ -package goredis - -import ( - "context" - "os" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" -) - -var ctx = context.Background() - -const ( - runnerKey = "runner:build_queue:10" -) - -func initRdb() { - buf, _ := os.ReadFile("../../config.toml") - cfg, _ := config.LoadConfig(string(buf)) - Configure(cfg.Redis) -} - -func (kw *KeyWatcher) countSubscribers(key string) int { - kw.mu.Lock() - defer kw.mu.Unlock() - return len(kw.subscribers[key]) -} - -// Forces a run of the `Process` loop against a mock PubSubConn. -func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { - kw.mu.Lock() - kw.redisConn = rdb - psc := kw.redisConn.Subscribe(ctx, []string{}...) - kw.mu.Unlock() - - errC := make(chan error) - go func() { errC <- kw.receivePubSubStream(ctx, psc) }() - - require.Eventually(t, func() bool { - kw.mu.Lock() - defer kw.mu.Unlock() - return kw.conn != nil - }, time.Second, time.Millisecond) - close(ready) - - require.Eventually(t, func() bool { - return kw.countSubscribers(runnerKey) == numWatchers - }, time.Second, time.Millisecond) - - // send message after listeners are ready - kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) - - // close subscription after all workers are done - wg.Wait() - kw.mu.Lock() - kw.conn.Close() - kw.mu.Unlock() - - require.NoError(t, <-errC) -} - -type keyChangeTestCase struct { - desc string - returnValue string - isKeyMissing bool - watchValue string - processedValue string - expectedStatus redis.WatchKeyStatus - timeout time.Duration -} - -func TestKeyChangesInstantReturn(t *testing.T) { - initRdb() - - testCases := []keyChangeTestCase{ - // WatchKeyStatusAlreadyChanged - { - desc: "sees change with key existing and changed", - returnValue: "somethingelse", - watchValue: "something", - expectedStatus: redis.WatchKeyStatusAlreadyChanged, - timeout: time.Second, - }, - { - desc: "sees change with key non-existing", - isKeyMissing: true, - watchValue: "something", - processedValue: "somethingelse", - expectedStatus: redis.WatchKeyStatusAlreadyChanged, - timeout: time.Second, - }, - // WatchKeyStatusTimeout - { - desc: "sees timeout with key existing and unchanged", - returnValue: "something", - watchValue: "something", - expectedStatus: redis.WatchKeyStatusTimeout, - timeout: time.Millisecond, - }, - { - desc: "sees timeout with key non-existing and unchanged", - isKeyMissing: true, - watchValue: "", - expectedStatus: redis.WatchKeyStatusTimeout, - timeout: time.Millisecond, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - - // setup - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) - } - - defer func() { - rdb.FlushDB(ctx) - }() - - kw := NewKeyWatcher() - defer kw.Shutdown() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) - - val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) - - require.NoError(t, err, "Expected no error") - require.Equal(t, tc.expectedStatus, val, "Expected value") - }) - } -} - -func TestKeyChangesWhenWatching(t *testing.T) { - initRdb() - - testCases := []keyChangeTestCase{ - // WatchKeyStatusSeenChange - { - desc: "sees change with key existing", - returnValue: "something", - watchValue: "something", - processedValue: "somethingelse", - expectedStatus: redis.WatchKeyStatusSeenChange, - }, - { - desc: "sees change with key non-existing, when watching empty value", - isKeyMissing: true, - watchValue: "", - processedValue: "something", - expectedStatus: redis.WatchKeyStatusSeenChange, - }, - // WatchKeyStatusNoChange - { - desc: "sees no change with key existing", - returnValue: "something", - watchValue: "something", - processedValue: "something", - expectedStatus: redis.WatchKeyStatusNoChange, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) - } - - kw := NewKeyWatcher() - defer kw.Shutdown() - defer func() { - rdb.FlushDB(ctx) - }() - - wg := &sync.WaitGroup{} - wg.Add(1) - ready := make(chan struct{}) - - go func() { - defer wg.Done() - <-ready - val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) - - require.NoError(t, err, "Expected no error") - require.Equal(t, tc.expectedStatus, val, "Expected value") - }() - - kw.processMessages(t, 1, tc.processedValue, ready, wg) - }) - } -} - -func TestKeyChangesParallel(t *testing.T) { - initRdb() - - testCases := []keyChangeTestCase{ - { - desc: "massively parallel, sees change with key existing", - returnValue: "something", - watchValue: "something", - processedValue: "somethingelse", - expectedStatus: redis.WatchKeyStatusSeenChange, - }, - { - desc: "massively parallel, sees change with key existing, watching missing keys", - isKeyMissing: true, - watchValue: "", - processedValue: "somethingelse", - expectedStatus: redis.WatchKeyStatusSeenChange, - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - runTimes := 100 - - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) - } - - defer func() { - rdb.FlushDB(ctx) - }() - - wg := &sync.WaitGroup{} - wg.Add(runTimes) - ready := make(chan struct{}) - - kw := NewKeyWatcher() - defer kw.Shutdown() - - for i := 0; i < runTimes; i++ { - go func() { - defer wg.Done() - <-ready - val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) - - require.NoError(t, err, "Expected no error") - require.Equal(t, tc.expectedStatus, val, "Expected value") - }() - } - - kw.processMessages(t, runTimes, tc.processedValue, ready, wg) - }) - } -} - -func TestShutdown(t *testing.T) { - initRdb() - - kw := NewKeyWatcher() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) - defer kw.Shutdown() - - rdb.Set(ctx, runnerKey, "something", 0) - - wg := &sync.WaitGroup{} - wg.Add(2) - - go func() { - defer wg.Done() - val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) - - require.NoError(t, err, "Expected no error") - require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") - }() - - go func() { - defer wg.Done() - require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond) - - kw.Shutdown() - }() - - wg.Wait() - - require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond) - - // Adding a key after the shutdown should result in an immediate response - var val redis.WatchKeyStatus - var err error - done := make(chan struct{}) - go func() { - val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) - close(done) - }() - - select { - case <-done: - require.NoError(t, err, "Expected no error") - require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout waiting for WatchKey") - } -} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index 8f1772a9195..ddb838121b7 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -8,10 +8,10 @@ import ( "sync" "time" - "github.com/gomodule/redigo/redis" "github.com/jpillora/backoff" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" ) @@ -21,7 +21,8 @@ type KeyWatcher struct { subscribers map[string][]chan string shutdown chan struct{} reconnectBackoff backoff.Backoff - conn *redis.PubSubConn + redisConn *redis.Client + conn *redis.PubSub } func NewKeyWatcher() *KeyWatcher { @@ -74,12 +75,12 @@ const channelPrefix = "workhorse:notifications:" func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) } -func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { kw.mu.Lock() // We must share kw.conn with the goroutines that call SUBSCRIBE and // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the // connection. - kw.conn = &redis.PubSubConn{Conn: conn} + kw.conn = pubsub kw.mu.Unlock() defer func() { @@ -100,51 +101,49 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { }() for { - switch v := kw.conn.Receive().(type) { - case redis.Message: + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *redis.Subscription: + RedisSubscriptions.Set(float64(msg.Count)) + case *redis.Pong: + // Ignore. + case *redis.Message: TotalMessages.Inc() - ReceivedBytes.Add(float64(len(v.Data))) - if strings.HasPrefix(v.Channel, channelPrefix) { - kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) + ReceivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) } - case redis.Subscription: - RedisSubscriptions.Set(float64(v.Count)) - case error: - log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() - // Intermittent error, return nil so that it doesn't wait before reconnect + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() return nil } } } -func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { - conn, err := dialer() - if err != nil { - return nil, err - } - - // Make sure Redis is actually connected - conn.Do("PING") - if err := conn.Err(); err != nil { - conn.Close() - return nil, err - } +func (kw *KeyWatcher) Process(client *redis.Client) { + log.Info("keywatcher: starting process loop") - return conn, nil -} + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() -func (kw *KeyWatcher) Process() { - log.Info("keywatcher: starting process loop") for { - conn, err := dialPubSub(workerDialFunc) - if err != nil { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() time.Sleep(kw.reconnectBackoff.Duration()) continue } + kw.reconnectBackoff.Reset() - if err = kw.receivePubSubStream(conn); err != nil { + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() } } @@ -183,7 +182,7 @@ func (kw *KeyWatcher) notifySubscribers(key, value string) { } } -func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { kw.mu.Lock() defer kw.mu.Unlock() @@ -196,7 +195,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { if len(kw.subscribers[key]) == 0 { countAction("create-subscription") - if err := kw.conn.Subscribe(channelPrefix + key); err != nil { + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { return err } } @@ -210,7 +209,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { return nil } -func (kw *KeyWatcher) delSubscription(key string, notify chan string) { +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { kw.mu.Lock() defer kw.mu.Unlock() @@ -232,7 +231,7 @@ func (kw *KeyWatcher) delSubscription(key string, notify chan string) { delete(kw.subscribers, key) countAction("delete-subscription") if kw.conn != nil { - kw.conn.Unsubscribe(channelPrefix + key) + kw.conn.Unsubscribe(ctx, channelPrefix+key) } } } @@ -252,15 +251,15 @@ const ( WatchKeyStatusNoChange ) -func (kw *KeyWatcher) WatchKey(_ context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { notify := make(chan string, 1) - if err := kw.addSubscription(key, notify); err != nil { + if err := kw.addSubscription(ctx, key, notify); err != nil { return WatchKeyStatusNoChange, err } - defer kw.delSubscription(key, notify) + defer kw.delSubscription(ctx, key, notify) - currentValue, err := GetString(key) - if errors.Is(err, redis.ErrNil) { + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { currentValue = "" } else if err != nil { return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index 3abc1bf1107..bca4ca43a64 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -2,13 +2,14 @@ package redis import ( "context" + "os" "sync" "testing" "time" - "github.com/gomodule/redigo/redis" - "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" ) var ctx = context.Background() @@ -17,27 +18,10 @@ const ( runnerKey = "runner:build_queue:10" ) -func createSubscriptionMessage(key, data string) []interface{} { - return []interface{}{ - []byte("message"), - []byte(key), - []byte(data), - } -} - -func createSubscribeMessage(key string) []interface{} { - return []interface{}{ - []byte("subscribe"), - []byte(key), - []byte("1"), - } -} -func createUnsubscribeMessage(key string) []interface{} { - return []interface{}{ - []byte("unsubscribe"), - []byte(key), - []byte("1"), - } +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) } func (kw *KeyWatcher) countSubscribers(key string) int { @@ -47,17 +31,14 @@ func (kw *KeyWatcher) countSubscribers(key string) int { } // Forces a run of the `Process` loop against a mock PubSubConn. -func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}) { - psc := redigomock.NewConn() - psc.ReceiveWait = true - - channel := channelPrefix + runnerKey - psc.Command("SUBSCRIBE", channel).Expect(createSubscribeMessage(channel)) - psc.Command("UNSUBSCRIBE", channel).Expect(createUnsubscribeMessage(channel)) - psc.AddSubscriptionMessage(createSubscriptionMessage(channel, value)) +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() errC := make(chan error) - go func() { errC <- kw.receivePubSubStream(psc) }() + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() require.Eventually(t, func() bool { kw.mu.Lock() @@ -69,7 +50,15 @@ func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value strin require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == numWatchers }, time.Second, time.Millisecond) - close(psc.ReceiveNow) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() require.NoError(t, <-errC) } @@ -85,6 +74,8 @@ type keyChangeTestCase struct { } func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ // WatchKeyStatusAlreadyChanged { @@ -121,18 +112,20 @@ func TestKeyChangesInstantReturn(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - conn, td := setupMockPool() - defer td() - if tc.isKeyMissing { - conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) - } else { - conn.Command("GET", runnerKey).Expect(tc.returnValue) + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } + defer func() { + rdb.FlushDB(ctx) + }() + kw := NewKeyWatcher() defer kw.Shutdown() - kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) @@ -143,6 +136,8 @@ func TestKeyChangesInstantReturn(t *testing.T) { } func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ // WatchKeyStatusSeenChange { @@ -171,17 +166,15 @@ func TestKeyChangesWhenWatching(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - conn, td := setupMockPool() - defer td() - - if tc.isKeyMissing { - conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) - } else { - conn.Command("GET", runnerKey).Expect(tc.returnValue) + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } kw := NewKeyWatcher() defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() wg := &sync.WaitGroup{} wg.Add(1) @@ -196,13 +189,14 @@ func TestKeyChangesWhenWatching(t *testing.T) { require.Equal(t, tc.expectedStatus, val, "Expected value") }() - kw.processMessages(t, 1, tc.processedValue, ready) - wg.Wait() + kw.processMessages(t, 1, tc.processedValue, ready, wg) }) } } func TestKeyChangesParallel(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ { desc: "massively parallel, sees change with key existing", @@ -224,19 +218,14 @@ func TestKeyChangesParallel(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { runTimes := 100 - conn, td := setupMockPool() - defer td() - - getCmd := conn.Command("GET", runnerKey) - - for i := 0; i < runTimes; i++ { - if tc.isKeyMissing { - getCmd = getCmd.ExpectError(redis.ErrNil) - } else { - getCmd = getCmd.Expect(tc.returnValue) - } + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } + defer func() { + rdb.FlushDB(ctx) + }() + wg := &sync.WaitGroup{} wg.Add(runTimes) ready := make(chan struct{}) @@ -255,21 +244,20 @@ func TestKeyChangesParallel(t *testing.T) { }() } - kw.processMessages(t, runTimes, tc.processedValue, ready) - wg.Wait() + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) }) } } func TestShutdown(t *testing.T) { - conn, td := setupMockPool() - defer td() + initRdb() kw := NewKeyWatcher() - kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) defer kw.Shutdown() - conn.Command("GET", runnerKey).Expect("something") + rdb.Set(ctx, runnerKey, "something", 0) wg := &sync.WaitGroup{} wg.Add(2) diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index c79e1e56b3a..b528255d25b 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -1,24 +1,39 @@ package redis import ( + "context" + "errors" "fmt" "net" - "net/url" "time" - "github.com/FZambia/sentinel" - "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "gitlab.com/gitlab-org/labkit/log" + redis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( - pool *redis.Pool - sntnl *sentinel.Sentinel + rdb *redis.Client + // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 + errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") + + TotalConnections = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_total_connections", + Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", + }, + ) + + ErrorCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_errors", + Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", + }, + []string{"type", "dst"}, + ) ) const ( @@ -36,241 +51,166 @@ const ( // If you _actually_ hit this timeout often, you should consider turning of // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute - // KeepAlivePeriod is to keep a TCP connection open for an extended period of - // time without being killed. This is used both in the pool, and in the - // worker-connection. - // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more - // information. - defaultKeepAlivePeriod = 5 * time.Minute ) -var ( - TotalConnections = promauto.NewCounter( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_total_connections", - Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", - }, - ) - - ErrorCounter = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_errors", - Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", - }, - []string{"type", "dst"}, - ) -) +// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 +// it intercepts the error and tracks it via a Prometheus counter +func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var isSentinel bool + for _, sentinelAddr := range sentinels { + if sentinelAddr == addr { + isSentinel = true + break + } + } -func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { - if len(urls) == 0 { - return nil - } - var addrs []string - for _, url := range urls { - h := url.URL.String() - log.WithFields(log.Fields{ - "scheme": url.URL.Scheme, - "host": url.URL.Host, - }).Printf("redis: using sentinel") - addrs = append(addrs, h) - } - return &sentinel.Sentinel{ - Addrs: addrs, - MasterName: master, - Dial: func(addr string) (redis.Conn, error) { + dialTimeout := 5 * time.Second // go-redis default + destination := "redis" + if isSentinel { // This timeout is recommended for Sentinel-support according to the guidelines. // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel // For every address it should try to connect to the Sentinel, // using a short timeout (in the order of a few hundreds of milliseconds). - timeout := 500 * time.Millisecond - url := helper.URLMustParse(addr) - - var c redis.Conn - var err error - options := []redis.DialOption{ - redis.DialConnectTimeout(timeout), - redis.DialReadTimeout(timeout), - redis.DialWriteTimeout(timeout), - } + destination = "sentinel" + dialTimeout = 500 * time.Millisecond + } - if url.Scheme == "redis" || url.Scheme == "rediss" { - c, err = redis.DialURL(addr, options...) - } else { - c, err = redis.Dial("tcp", url.Host, options...) - } + netDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 5 * time.Minute, + } - if err != nil { - ErrorCounter.WithLabelValues("dial", "sentinel").Inc() - return nil, err + conn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + ErrorCounter.WithLabelValues("dial", destination).Inc() + } else { + if !isSentinel { + TotalConnections.Inc() } - return c, nil - }, + } + + return conn, err } } -var poolDialFunc func() (redis.Conn, error) -var workerDialFunc func() (redis.Conn, error) +// implements the redis.Hook interface for instrumentation +type sentinelInstrumentationHook struct{} -func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { - return []redis.DialOption{ - redis.DialReadTimeout(defaultReadTimeout), - redis.DialWriteTimeout(defaultWriteTimeout), +func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := next(ctx, network, addr) + if err != nil && err.Error() == errSentinelMasterAddr.Error() { + // check for non-dialer error + ErrorCounter.WithLabelValues("master", "sentinel").Inc() + } + return conn, err } } -func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { - var dopts []redis.DialOption - if setTimeouts { - dopts = timeoutDialOptions(cfg) +func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return next(ctx, cmd) } - if cfg == nil { - return dopts +} + +func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + return next(ctx, cmds) } - if cfg.Password != "" { - dopts = append(dopts, redis.DialPassword(cfg.Password)) +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) error { + if cfg == nil { + return nil } - if cfg.DB != nil { - dopts = append(dopts, redis.DialDatabase(*cfg.DB)) + + var err error + + if len(cfg.Sentinel) > 0 { + rdb = configureSentinel(cfg) + } else { + rdb, err = configureRedis(cfg) } - return dopts + + return err } -func keepAliveDialer(network, address string) (net.Conn, error) { - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err +func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { + if cfg.URL.Scheme == "tcp" { + cfg.URL.Scheme = "redis" } - tc, err := net.DialTCP(network, nil, addr) + + opt, err := redis.ParseURL(cfg.URL.String()) if err != nil { return nil, err } - if err := tc.SetKeepAlive(true); err != nil { - return nil, err - } - if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { - return nil, err - } - return tc, nil -} -type redisDialerFunc func() (redis.Conn, error) + opt.DB = getOrDefault(cfg.DB, 0) + opt.Password = cfg.Password -func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { - return func() (redis.Conn, error) { - address, err := sntnl.MasterAddr() - if err != nil { - ErrorCounter.WithLabelValues("master", "sentinel").Inc() - return nil, err - } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) - conn, err := redisDial("tcp", address, dopts...) - if err != nil { - return nil, err - } - if !sentinel.TestRole(conn, "master") { - conn.Close() - return nil, fmt.Errorf("%s is not redis master", address) - } - return conn, nil - } -} + opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) + opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout -func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { - return func() (redis.Conn, error) { - if url.Scheme == "unix" { - return redisDial(url.Scheme, url.Path, dopts...) - } + opt.Dialer = createDialer([]string{}) - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) + return redis.NewClient(opt), nil +} - // redis.DialURL only works with redis[s]:// URLs - if url.Scheme == "redis" || url.Scheme == "rediss" { - return redisURLDial(url, dopts...) - } +func configureSentinel(cfg *config.RedisConfig) *redis.Client { + sentinelPassword, sentinels := sentinelOptions(cfg) + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + SentinelPassword: sentinelPassword, + DB: getOrDefault(cfg.DB, 0), - return redisDial(url.Scheme, url.Host, dopts...) - } -} + PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), + MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), + ConnMaxIdleTime: defaultIdleTimeout, -func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "scheme": url.Scheme, - "address": url.Host, - }).Printf("redis: dialing") + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, - return redis.DialURL(url.String(), options...) -} + Dialer: createDialer(sentinels), + }) -func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "network": network, - "address": address, - }).Printf("redis: dialing") + client.AddHook(sentinelInstrumentationHook{}) - return redis.Dial(network, address, options...) + return client } -func countDialer(dialer redisDialerFunc) redisDialerFunc { - return func() (redis.Conn, error) { - c, err := dialer() - if err != nil { - ErrorCounter.WithLabelValues("dial", "redis").Inc() - } else { - TotalConnections.Inc() - } - return c, err - } -} +// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format +// the order of priority for the passwords is: SentinelPassword -> first password-in-url +func sentinelOptions(cfg *config.RedisConfig) (string, []string) { + sentinels := make([]string, len(cfg.Sentinel)) + sentinelPassword := cfg.SentinelPassword -// DefaultDialFunc should always used. Only exception is for unit-tests. -func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { - dopts := dialOptionsBuilder(cfg, setReadTimeout) - if sntnl != nil { - return countDialer(sentinelDialer(dopts)) - } - return countDialer(defaultDialer(dopts, cfg.URL.URL)) -} + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) -// Configure redis-connection -func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { - if cfg == nil { - return - } - maxIdle := defaultMaxIdle - if cfg.MaxIdle != nil { - maxIdle = *cfg.MaxIdle - } - maxActive := defaultMaxActive - if cfg.MaxActive != nil { - maxActive = *cfg.MaxActive - } - sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) - workerDialFunc = dialFunc(cfg, false) - poolDialFunc = dialFunc(cfg, true) - pool = &redis.Pool{ - MaxIdle: maxIdle, // Keep at most X hot connections - MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited - IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed - Dial: poolDialFunc, - Wait: true, + if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 { + // sets password using the first non-empty password + sentinelPassword = pw + } } -} -// Get a connection for the Redis-pool -func Get() redis.Conn { - if pool != nil { - return pool.Get() - } - return nil + return sentinelPassword, sentinels } -// GetString fetches the value of a key in Redis as a string -func GetString(key string) (string, error) { - conn := Get() - if conn == nil { - return "", fmt.Errorf("redis: could not get connection from pool") +func getOrDefault(ptr *int, val int) int { + if ptr != nil { + return *ptr } - defer conn.Close() - - return redis.String(conn.Do("GET", key)) + return val } diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go index 64b3a842a54..6fd6ecbae11 100644 --- a/workhorse/internal/redis/redis_test.go +++ b/workhorse/internal/redis/redis_test.go @@ -1,19 +1,18 @@ package redis import ( + "context" "net" + "sync/atomic" "testing" - "time" - "github.com/gomodule/redigo/redis" - "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) -func mockRedisServer(t *testing.T, connectReceived *bool) string { +func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { ln, err := net.Listen("tcp", "127.0.0.1:0") require.Nil(t, err) @@ -22,146 +21,67 @@ func mockRedisServer(t *testing.T, connectReceived *bool) string { defer ln.Close() conn, err := ln.Accept() require.Nil(t, err) - *connectReceived = true + connectReceived.Store(true) conn.Write([]byte("OK\n")) }() return ln.Addr().String() } -// Setup a MockPool for Redis -// -// Returns a teardown-function and the mock-connection -func setupMockPool() (*redigomock.Conn, func()) { - conn := redigomock.NewConn() - cfg := &config.RedisConfig{URL: config.TomlURL{}} - Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { - return func() (redis.Conn, error) { - return conn, nil - } - }) - return conn, func() { - pool = nil - } +func TestConfigureNoConfig(t *testing.T) { + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") } -func TestDefaultDialFunc(t *testing.T) { +func TestConfigureValidConfigX(t *testing.T) { testCases := []struct { scheme string }{ { - scheme: "tcp", + scheme: "redis", }, { - scheme: "redis", + scheme: "tcp", }, } for _, tc := range testCases { t.Run(tc.scheme, func(t *testing.T) { - connectReceived := false + connectReceived := atomic.Value{} a := mockRedisServer(t, &connectReceived) parsedURL := helper.URLMustParse(tc.scheme + "://" + a) cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} - dialer := DefaultDialFunc(cfg, true) - conn, err := dialer() - - require.Nil(t, err) - conn.Receive() - - require.True(t, connectReceived) - }) - } -} - -func TestConfigureNoConfig(t *testing.T) { - pool = nil - Configure(nil, nil) - require.Nil(t, pool, "Pool should be nil") -} - -func TestConfigureMinimalConfig(t *testing.T) { - cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} - Configure(cfg, DefaultDialFunc) + Configure(cfg) - require.NotNil(t, pool, "Pool should not be nil") - require.Equal(t, 1, pool.MaxIdle) - require.Equal(t, 1, pool.MaxActive) - require.Equal(t, 3*time.Minute, pool.IdleTimeout) + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") - pool = nil -} + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) -func TestConfigureFullConfig(t *testing.T) { - i, a := 4, 10 - cfg := &config.RedisConfig{ - URL: config.TomlURL{}, - Password: "", - MaxIdle: &i, - MaxActive: &a, + rdb = nil + }) } - Configure(cfg, DefaultDialFunc) - - require.NotNil(t, pool, "Pool should not be nil") - require.Equal(t, i, pool.MaxIdle) - require.Equal(t, a, pool.MaxActive) - require.Equal(t, 3*time.Minute, pool.IdleTimeout) - - pool = nil -} - -func TestGetConnFail(t *testing.T) { - conn := Get() - require.Nil(t, conn, "Expected `conn` to be nil") -} - -func TestGetConnPass(t *testing.T) { - _, teardown := setupMockPool() - defer teardown() - conn := Get() - require.NotNil(t, conn, "Expected `conn` to be non-nil") } -func TestGetStringPass(t *testing.T) { - conn, teardown := setupMockPool() - defer teardown() - conn.Command("GET", "foobar").Expect("baz") - str, err := GetString("foobar") - - require.NoError(t, err, "Expected `err` to be nil") - var value string - require.IsType(t, value, str, "Expected value to be a string") - require.Equal(t, "baz", str, "Expected it to be equal") -} - -func TestGetStringFail(t *testing.T) { - _, err := GetString("foobar") - require.Error(t, err, "Expected error when not connected to redis") -} - -func TestSentinelConnNoSentinel(t *testing.T) { - s := sentinelConn("", []config.TomlURL{}) - - require.Nil(t, s, "Sentinel without urls should return nil") -} - -func TestSentinelConnDialURL(t *testing.T) { +func TestConnectToSentinel(t *testing.T) { testCases := []struct { scheme string }{ { - scheme: "tcp", + scheme: "redis", }, { - scheme: "redis", + scheme: "tcp", }, } for _, tc := range testCases { t.Run(tc.scheme, func(t *testing.T) { - connectReceived := false + connectReceived := atomic.Value{} a := mockRedisServer(t, &connectReceived) addrs := []string{tc.scheme + "://" + a} @@ -172,57 +92,71 @@ func TestSentinelConnDialURL(t *testing.T) { sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) } - s := sentinelConn("foobar", sentinelUrls) - require.Equal(t, len(addrs), len(s.Addrs)) - - for i := range addrs { - require.Equal(t, addrs[i], s.Addrs[i]) - } + cfg := &config.RedisConfig{Sentinel: sentinelUrls} + Configure(cfg) - conn, err := s.Dial(s.Addrs[0]) + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") - require.Nil(t, err) - conn.Receive() + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) - require.True(t, connectReceived) + rdb = nil }) } } -func TestSentinelConnTwoURLs(t *testing.T) { - addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} - var sentinelUrls []config.TomlURL - - for _, a := range addrs { - parsedURL := helper.URLMustParse(a) - sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) - } - - s := sentinelConn("foobar", sentinelUrls) - require.Equal(t, len(addrs), len(s.Addrs)) - - for i := range addrs { - require.Equal(t, addrs[i], s.Addrs[i]) +func TestSentinelOptions(t *testing.T) { + testCases := []struct { + description string + inputSentinelPassword string + inputSentinel []string + password string + sentinels []string + }{ + { + description: "no sentinel passwords", + inputSentinel: []string{"tcp://localhost:26480"}, + sentinels: []string{"localhost:26480"}, + }, + { + description: "specific sentinel password defined", + inputSentinel: []string{"tcp://localhost:26480"}, + inputSentinelPassword: "password1", + sentinels: []string{"localhost:26480"}, + password: "password1", + }, + { + description: "specific sentinel password defined in url", + inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, + sentinels: []string{"localhost:26480", "localhost:26481"}, + password: "password2", + }, + { + description: "passwords defined specifically and in url", + inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, + sentinels: []string{"localhost:26480", "localhost:26481"}, + inputSentinelPassword: "password1", + password: "password1", + }, } -} -func TestDialOptionsBuildersPassword(t *testing.T) { - dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) - require.Equal(t, 1, len(dopts)) -} + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + sentinelUrls := make([]config.TomlURL, len(tc.inputSentinel)) -func TestDialOptionsBuildersSetTimeouts(t *testing.T) { - dopts := dialOptionsBuilder(nil, true) - require.Equal(t, 2, len(dopts)) -} + for i, str := range tc.inputSentinel { + parsedURL := helper.URLMustParse(str) + sentinelUrls[i] = config.TomlURL{URL: *parsedURL} + } -func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { - dopts := dialOptionsBuilder(nil, true) - require.Equal(t, 2, len(dopts)) -} + outputPw, outputSentinels := sentinelOptions(&config.RedisConfig{ + Sentinel: sentinelUrls, + SentinelPassword: tc.inputSentinelPassword, + }) -func TestDialOptionsBuildersSelectDB(t *testing.T) { - db := 3 - dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) - require.Equal(t, 1, len(dopts)) + require.Equal(t, tc.password, outputPw) + require.Equal(t, tc.sentinels, outputSentinels) + }) + } } diff --git a/workhorse/main.go b/workhorse/main.go index 9ba213d47d3..3043ae50a22 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -17,10 +17,8 @@ import ( "gitlab.com/gitlab-org/labkit/monitoring" "gitlab.com/gitlab-org/labkit/tracing" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/secret" @@ -225,35 +223,19 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) - keyWatcher := redis.NewKeyWatcher() + log.Info("Using redis/go-redis") - var watchKeyFn builds.WatchKeyHandler - var goredisKeyWatcher *goredis.KeyWatcher - - if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { - log.Info("Using redis/go-redis") - - goredisKeyWatcher = goredis.NewKeyWatcher() - if err := goredis.Configure(cfg.Redis); err != nil { - log.WithError(err).Error("unable to configure redis client") - } - - if rdb := goredis.GetRedisClient(); rdb != nil { - go goredisKeyWatcher.Process(rdb) - } - - watchKeyFn = goredisKeyWatcher.WatchKey - } else { - log.Info("Using gomodule/redigo") - - if cfg.Redis != nil { - redis.Configure(cfg.Redis, redis.DefaultDialFunc) - go keyWatcher.Process() - } + redisKeyWatcher := redis.NewKeyWatcher() + if err := redis.Configure(cfg.Redis); err != nil { + log.WithError(err).Error("unable to configure redis client") + } - watchKeyFn = keyWatcher.WatchKey + if rdb := redis.GetRedisClient(); rdb != nil { + go redisKeyWatcher.Process(rdb) } + watchKeyFn := redisKeyWatcher.WatchKey + if err := cfg.RegisterGoCloudURLOpeners(); err != nil { return fmt.Errorf("register cloud credentials: %v", err) } @@ -300,11 +282,8 @@ func run(boot bootConfig, cfg config.Config) error { ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background defer cancel() - if goredisKeyWatcher != nil { - goredisKeyWatcher.Shutdown() - } + redisKeyWatcher.Shutdown() - keyWatcher.Shutdown() return srv.Shutdown(ctx) } } |