diff options
author | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-29 18:43:00 +0300 |
---|---|---|
committer | GitLab Bot <gitlab-bot@gitlab.com> | 2023-11-29 18:43:00 +0300 |
commit | 693d15dcb2f33c01a442784c13933da3d1b8d52e (patch) | |
tree | d00d6dca2b2a6d164d6d2d7c51d57acc32d92b54 | |
parent | 94f0f0e4b9fa3f49bf6145100b206c36c0c4eef6 (diff) |
Add latest changes from gitlab-org/gitlab@16-6-stable-ee
-rw-r--r-- | workhorse/go.mod | 3 | ||||
-rw-r--r-- | workhorse/go.sum | 7 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis.go | 200 | ||||
-rw-r--r-- | workhorse/internal/goredis/goredis_test.go | 162 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher.go | 236 | ||||
-rw-r--r-- | workhorse/internal/goredis/keywatcher_test.go | 301 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher.go | 83 | ||||
-rw-r--r-- | workhorse/internal/redis/keywatcher_test.go | 120 | ||||
-rw-r--r-- | workhorse/internal/redis/redis.go | 336 | ||||
-rw-r--r-- | workhorse/internal/redis/redis_test.go | 222 | ||||
-rw-r--r-- | workhorse/main.go | 41 |
11 files changed, 1390 insertions, 321 deletions
diff --git a/workhorse/go.mod b/workhorse/go.mod index 0773904ce21..04f59a5a6f6 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -5,6 +5,7 @@ go 1.19 require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 github.com/BurntSushi/toml v1.3.2 + github.com/FZambia/sentinel v1.1.1 github.com/alecthomas/chroma/v2 v2.9.1 github.com/aws/aws-sdk-go v1.45.20 github.com/disintegration/imaging v1.6.2 @@ -12,12 +13,14 @@ require ( github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/golang/protobuf v1.5.3 + github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/johannesboyne/gofakes3 v0.0.0-20230914150226-f005f5cc03aa github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.17.0 + github.com/rafaeljusto/redigomock/v3 v3.1.2 github.com/redis/go-redis/v9 v9.2.1 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 diff --git a/workhorse/go.sum b/workhorse/go.sum index d35e2948db7..6cf33000fcf 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -85,6 +85,8 @@ github.com/DataDog/datadog-go v4.4.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/gostackparse v0.5.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= github.com/DataDog/sketches-go v1.0.0 h1:chm5KSXO7kO+ywGWJ0Zs6tdmWU8PBXSbywFVciL6BG4= github.com/DataDog/sketches-go v1.0.0/go.mod h1:O+XkJHWk9w4hDwY2ZUDU31ZC9sNYlYo8DiFsxjYeo1k= +github.com/FZambia/sentinel v1.1.1 h1:0ovTimlR7Ldm+wR15GgO+8C2dt7kkn+tm3PQS+Qk3Ek= +github.com/FZambia/sentinel v1.1.1/go.mod h1:ytL1Am/RLlAoAXG6Kj5LNuw/TRRQrv2rt2FT26vP5gI= github.com/HdrHistogram/hdrhistogram-go v1.1.1 h1:cJXY5VLMHgejurPjZH6Fo9rIwRGLefBGdiaENZALqrg= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= @@ -229,6 +231,9 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.1.1-0.20171103154506-982329095285/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -387,6 +392,8 @@ github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwa github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY= github.com/prometheus/prometheus v0.46.0 h1:9JSdXnsuT6YsbODEhSQMwxNkGwPExfmzqG73vCMk/Kw= github.com/prometheus/prometheus v0.46.0/go.mod h1:10L5IJE5CEsjee1FnOcVswYXlPIscDWWt3IJ2UDYrz4= +github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqzmDKyZbC99o= +github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/redis/go-redis/v9 v9.2.1 h1:WlYJg71ODF0dVspZZCpYmoF1+U1Jjk9Rwd7pq6QmlCg= github.com/redis/go-redis/v9 v9.2.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go new file mode 100644 index 00000000000..5566e5a3434 --- /dev/null +++ b/workhorse/internal/goredis/goredis.go @@ -0,0 +1,200 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + redis "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ( + rdb *redis.Client + // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 + errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") +) + +const ( + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... + defaultIdleTimeout = 3 * time.Minute +) + +// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 +// it intercepts the error and tracks it via a Prometheus counter +func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var isSentinel bool + for _, sentinelAddr := range sentinels { + if sentinelAddr == addr { + isSentinel = true + break + } + } + + dialTimeout := 5 * time.Second // go-redis default + destination := "redis" + if isSentinel { + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + destination = "sentinel" + dialTimeout = 500 * time.Millisecond + } + + netDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 5 * time.Minute, + } + + conn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc() + } else { + if !isSentinel { + internalredis.TotalConnections.Inc() + } + } + + return conn, err + } +} + +// implements the redis.Hook interface for instrumentation +type sentinelInstrumentationHook struct{} + +func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := next(ctx, network, addr) + if err != nil && err.Error() == errSentinelMasterAddr.Error() { + // check for non-dialer error + internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc() + } + return conn, err + } +} + +func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return next(ctx, cmd) + } +} + +func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + return next(ctx, cmds) + } +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) error { + if cfg == nil { + return nil + } + + var err error + + if len(cfg.Sentinel) > 0 { + rdb = configureSentinel(cfg) + } else { + rdb, err = configureRedis(cfg) + } + + return err +} + +func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { + if cfg.URL.Scheme == "tcp" { + cfg.URL.Scheme = "redis" + } + + opt, err := redis.ParseURL(cfg.URL.String()) + if err != nil { + return nil, err + } + + opt.DB = getOrDefault(cfg.DB, 0) + opt.Password = cfg.Password + + opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) + opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout + + opt.Dialer = createDialer([]string{}) + + return redis.NewClient(opt), nil +} + +func configureSentinel(cfg *config.RedisConfig) *redis.Client { + sentinelPassword, sentinels := sentinelOptions(cfg) + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + SentinelPassword: sentinelPassword, + DB: getOrDefault(cfg.DB, 0), + + PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), + MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), + ConnMaxIdleTime: defaultIdleTimeout, + + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + + Dialer: createDialer(sentinels), + }) + + client.AddHook(sentinelInstrumentationHook{}) + + return client +} + +// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format +// the order of priority for the passwords is: SentinelPassword -> first password-in-url +func sentinelOptions(cfg *config.RedisConfig) (string, []string) { + sentinels := make([]string, len(cfg.Sentinel)) + sentinelPassword := cfg.SentinelPassword + + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) + + if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 { + // sets password using the first non-empty password + sentinelPassword = pw + } + } + + return sentinelPassword, sentinels +} + +func getOrDefault(ptr *int, val int) int { + if ptr != nil { + return *ptr + } + return val +} diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go new file mode 100644 index 00000000000..735b2076b0c --- /dev/null +++ b/workhorse/internal/goredis/goredis_test.go @@ -0,0 +1,162 @@ +package goredis + +import ( + "context" + "net" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" +) + +func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + + require.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.Nil(t, err) + connectReceived.Store(true) + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + +func TestConfigureNoConfig(t *testing.T) { + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") +} + +func TestConfigureValidConfigX(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := atomic.Value{} + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) + + rdb = nil + }) + } +} + +func TestConnectToSentinel(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := atomic.Value{} + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + cfg := &config.RedisConfig{Sentinel: sentinelUrls} + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived.Load().(bool)) + + rdb = nil + }) + } +} + +func TestSentinelOptions(t *testing.T) { + testCases := []struct { + description string + inputSentinelPassword string + inputSentinel []string + password string + sentinels []string + }{ + { + description: "no sentinel passwords", + inputSentinel: []string{"tcp://localhost:26480"}, + sentinels: []string{"localhost:26480"}, + }, + { + description: "specific sentinel password defined", + inputSentinel: []string{"tcp://localhost:26480"}, + inputSentinelPassword: "password1", + sentinels: []string{"localhost:26480"}, + password: "password1", + }, + { + description: "specific sentinel password defined in url", + inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, + sentinels: []string{"localhost:26480", "localhost:26481"}, + password: "password2", + }, + { + description: "passwords defined specifically and in url", + inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, + sentinels: []string{"localhost:26480", "localhost:26481"}, + inputSentinelPassword: "password1", + password: "password1", + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + sentinelUrls := make([]config.TomlURL, len(tc.inputSentinel)) + + for i, str := range tc.inputSentinel { + parsedURL := helper.URLMustParse(str) + sentinelUrls[i] = config.TomlURL{URL: *parsedURL} + } + + outputPw, outputSentinels := sentinelOptions(&config.RedisConfig{ + Sentinel: sentinelUrls, + SentinelPassword: tc.inputSentinelPassword, + }) + + require.Equal(t, tc.password, outputPw) + require.Equal(t, tc.sentinels, outputSentinels) + }) + } +} diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go new file mode 100644 index 00000000000..741bfb17652 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher.go @@ -0,0 +1,236 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/jpillora/backoff" + "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +type KeyWatcher struct { + mu sync.Mutex + subscribers map[string][]chan string + shutdown chan struct{} + reconnectBackoff backoff.Backoff + redisConn *redis.Client + conn *redis.PubSub +} + +func NewKeyWatcher() *KeyWatcher { + return &KeyWatcher{ + shutdown: make(chan struct{}), + reconnectBackoff: backoff.Backoff{ + Min: 100 * time.Millisecond, + Max: 60 * time.Second, + Factor: 2, + Jitter: true, + }, + } +} + +const channelPrefix = "workhorse:notifications:" + +func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) } + +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { + kw.mu.Lock() + // We must share kw.conn with the goroutines that call SUBSCRIBE and + // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the + // connection. + kw.conn = pubsub + kw.mu.Unlock() + + defer func() { + kw.mu.Lock() + defer kw.mu.Unlock() + kw.conn.Close() + kw.conn = nil + + // Reset kw.subscribers because it is tied to Redis server side state of + // kw.conn and we just closed that connection. + for _, chans := range kw.subscribers { + for _, ch := range chans { + close(ch) + internalredis.KeyWatchers.Dec() + } + } + kw.subscribers = nil + }() + + for { + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *redis.Subscription: + internalredis.RedisSubscriptions.Set(float64(msg.Count)) + case *redis.Pong: + // Ignore. + case *redis.Message: + internalredis.TotalMessages.Inc() + internalredis.ReceivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + } + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + return nil + } + } +} + +func (kw *KeyWatcher) Process(client *redis.Client) { + log.Info("keywatcher: starting process loop") + + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() + + for { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { + log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() + time.Sleep(kw.reconnectBackoff.Duration()) + continue + } + + kw.reconnectBackoff.Reset() + + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() + } + } +} + +func (kw *KeyWatcher) Shutdown() { + log.Info("keywatcher: shutting down") + + kw.mu.Lock() + defer kw.mu.Unlock() + + select { + case <-kw.shutdown: + // already closed + default: + close(kw.shutdown) + } +} + +func (kw *KeyWatcher) notifySubscribers(key, value string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chanList, ok := kw.subscribers[key] + if !ok { + countAction("drop-message") + return + } + + countAction("deliver-message") + for _, c := range chanList { + select { + case c <- value: + default: + } + } +} + +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { + kw.mu.Lock() + defer kw.mu.Unlock() + + if kw.conn == nil { + // This can happen because CI long polling is disabled in this Workhorse + // process. It can also be that we are waiting for the pubsub connection + // to be established. Either way it is OK to fail fast. + return errors.New("no redis connection") + } + + if len(kw.subscribers[key]) == 0 { + countAction("create-subscription") + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + return err + } + } + + if kw.subscribers == nil { + kw.subscribers = make(map[string][]chan string) + } + kw.subscribers[key] = append(kw.subscribers[key], notify) + internalredis.KeyWatchers.Inc() + + return nil +} + +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chans, ok := kw.subscribers[key] + if !ok { + // This can happen if the pubsub connection dropped while we were + // waiting. + return + } + + for i, c := range chans { + if notify == c { + kw.subscribers[key] = append(chans[:i], chans[i+1:]...) + internalredis.KeyWatchers.Dec() + break + } + } + if len(kw.subscribers[key]) == 0 { + delete(kw.subscribers, key) + countAction("delete-subscription") + if kw.conn != nil { + kw.conn.Unsubscribe(ctx, channelPrefix+key) + } + } +} + +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) { + notify := make(chan string, 1) + if err := kw.addSubscription(ctx, key, notify); err != nil { + return internalredis.WatchKeyStatusNoChange, err + } + defer kw.delSubscription(ctx, key, notify) + + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + currentValue = "" + } else if err != nil { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) + } + if currentValue != value { + return internalredis.WatchKeyStatusAlreadyChanged, nil + } + + select { + case <-kw.shutdown: + log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") + return internalredis.WatchKeyStatusNoChange, nil + case currentValue := <-notify: + if currentValue == "" { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") + } + if currentValue == value { + return internalredis.WatchKeyStatusNoChange, nil + } + return internalredis.WatchKeyStatusSeenChange, nil + case <-time.After(timeout): + return internalredis.WatchKeyStatusTimeout, nil + } +} diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go new file mode 100644 index 00000000000..b64262dc9c8 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher_test.go @@ -0,0 +1,301 @@ +package goredis + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ctx = context.Background() + +const ( + runnerKey = "runner:build_queue:10" +) + +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) +} + +func (kw *KeyWatcher) countSubscribers(key string) int { + kw.mu.Lock() + defer kw.mu.Unlock() + return len(kw.subscribers[key]) +} + +// Forces a run of the `Process` loop against a mock PubSubConn. +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() + + errC := make(chan error) + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + + require.Eventually(t, func() bool { + kw.mu.Lock() + defer kw.mu.Unlock() + return kw.conn != nil + }, time.Second, time.Millisecond) + close(ready) + + require.Eventually(t, func() bool { + return kw.countSubscribers(runnerKey) == numWatchers + }, time.Second, time.Millisecond) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() + + require.NoError(t, <-errC) +} + +type keyChangeTestCase struct { + desc string + returnValue string + isKeyMissing bool + watchValue string + processedValue string + expectedStatus redis.WatchKeyStatus + timeout time.Duration +} + +func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusAlreadyChanged + { + desc: "sees change with key existing and changed", + returnValue: "somethingelse", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + { + desc: "sees change with key non-existing", + isKeyMissing: true, + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + // WatchKeyStatusTimeout + { + desc: "sees timeout with key existing and unchanged", + returnValue: "something", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + { + desc: "sees timeout with key non-existing and unchanged", + isKeyMissing: true, + watchValue: "", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + kw := NewKeyWatcher() + defer kw.Shutdown() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }) + } +} + +func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusSeenChange + { + desc: "sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "sees change with key non-existing, when watching empty value", + isKeyMissing: true, + watchValue: "", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + // WatchKeyStatusNoChange + { + desc: "sees no change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusNoChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + kw := NewKeyWatcher() + defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + ready := make(chan struct{}) + + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + + kw.processMessages(t, 1, tc.processedValue, ready, wg) + }) + } +} + +func TestKeyChangesParallel(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + { + desc: "massively parallel, sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "massively parallel, sees change with key existing, watching missing keys", + isKeyMissing: true, + watchValue: "", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + runTimes := 100 + + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(runTimes) + ready := make(chan struct{}) + + kw := NewKeyWatcher() + defer kw.Shutdown() + + for i := 0; i < runTimes; i++ { + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + } + + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + }) + } +} + +func TestShutdown(t *testing.T) { + initRdb() + + kw := NewKeyWatcher() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + defer kw.Shutdown() + + rdb.Set(ctx, runnerKey, "something", 0) + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + }() + + go func() { + defer wg.Done() + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond) + + kw.Shutdown() + }() + + wg.Wait() + + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond) + + // Adding a key after the shutdown should result in an immediate response + var val redis.WatchKeyStatus + var err error + done := make(chan struct{}) + go func() { + val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + close(done) + }() + + select { + case <-done: + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for WatchKey") + } +} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index ddb838121b7..8f1772a9195 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -8,10 +8,10 @@ import ( "sync" "time" + "github.com/gomodule/redigo/redis" "github.com/jpillora/backoff" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" ) @@ -21,8 +21,7 @@ type KeyWatcher struct { subscribers map[string][]chan string shutdown chan struct{} reconnectBackoff backoff.Backoff - redisConn *redis.Client - conn *redis.PubSub + conn *redis.PubSubConn } func NewKeyWatcher() *KeyWatcher { @@ -75,12 +74,12 @@ const channelPrefix = "workhorse:notifications:" func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) } -func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { +func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { kw.mu.Lock() // We must share kw.conn with the goroutines that call SUBSCRIBE and // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the // connection. - kw.conn = pubsub + kw.conn = &redis.PubSubConn{Conn: conn} kw.mu.Unlock() defer func() { @@ -101,49 +100,51 @@ func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.Pub }() for { - msg, err := kw.conn.Receive(ctx) - if err != nil { - log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() - return nil - } - - switch msg := msg.(type) { - case *redis.Subscription: - RedisSubscriptions.Set(float64(msg.Count)) - case *redis.Pong: - // Ignore. - case *redis.Message: + switch v := kw.conn.Receive().(type) { + case redis.Message: TotalMessages.Inc() - ReceivedBytes.Add(float64(len(msg.Payload))) - if strings.HasPrefix(msg.Channel, channelPrefix) { - kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + ReceivedBytes.Add(float64(len(v.Data))) + if strings.HasPrefix(v.Channel, channelPrefix) { + kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) } - default: - log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + case redis.Subscription: + RedisSubscriptions.Set(float64(v.Count)) + case error: + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() + // Intermittent error, return nil so that it doesn't wait before reconnect return nil } } } -func (kw *KeyWatcher) Process(client *redis.Client) { - log.Info("keywatcher: starting process loop") +func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { + conn, err := dialer() + if err != nil { + return nil, err + } - ctx := context.Background() // lint:allow context.Background - kw.mu.Lock() - kw.redisConn = client - kw.mu.Unlock() + // Make sure Redis is actually connected + conn.Do("PING") + if err := conn.Err(); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} +func (kw *KeyWatcher) Process() { + log.Info("keywatcher: starting process loop") for { - pubsub := client.Subscribe(ctx, []string{}...) - if err := pubsub.Ping(ctx); err != nil { + conn, err := dialPubSub(workerDialFunc) + if err != nil { log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() time.Sleep(kw.reconnectBackoff.Duration()) continue } - kw.reconnectBackoff.Reset() - if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + if err = kw.receivePubSubStream(conn); err != nil { log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() } } @@ -182,7 +183,7 @@ func (kw *KeyWatcher) notifySubscribers(key, value string) { } } -func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { +func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { kw.mu.Lock() defer kw.mu.Unlock() @@ -195,7 +196,7 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch if len(kw.subscribers[key]) == 0 { countAction("create-subscription") - if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + if err := kw.conn.Subscribe(channelPrefix + key); err != nil { return err } } @@ -209,7 +210,7 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch return nil } -func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { +func (kw *KeyWatcher) delSubscription(key string, notify chan string) { kw.mu.Lock() defer kw.mu.Unlock() @@ -231,7 +232,7 @@ func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify ch delete(kw.subscribers, key) countAction("delete-subscription") if kw.conn != nil { - kw.conn.Unsubscribe(ctx, channelPrefix+key) + kw.conn.Unsubscribe(channelPrefix + key) } } } @@ -251,15 +252,15 @@ const ( WatchKeyStatusNoChange ) -func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { +func (kw *KeyWatcher) WatchKey(_ context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { notify := make(chan string, 1) - if err := kw.addSubscription(ctx, key, notify); err != nil { + if err := kw.addSubscription(key, notify); err != nil { return WatchKeyStatusNoChange, err } - defer kw.delSubscription(ctx, key, notify) + defer kw.delSubscription(key, notify) - currentValue, err := kw.redisConn.Get(ctx, key).Result() - if errors.Is(err, redis.Nil) { + currentValue, err := GetString(key) + if errors.Is(err, redis.ErrNil) { currentValue = "" } else if err != nil { return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index bca4ca43a64..3abc1bf1107 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -2,14 +2,13 @@ package redis import ( "context" - "os" "sync" "testing" "time" + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" ) var ctx = context.Background() @@ -18,10 +17,27 @@ const ( runnerKey = "runner:build_queue:10" ) -func initRdb() { - buf, _ := os.ReadFile("../../config.toml") - cfg, _ := config.LoadConfig(string(buf)) - Configure(cfg.Redis) +func createSubscriptionMessage(key, data string) []interface{} { + return []interface{}{ + []byte("message"), + []byte(key), + []byte(data), + } +} + +func createSubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("subscribe"), + []byte(key), + []byte("1"), + } +} +func createUnsubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("unsubscribe"), + []byte(key), + []byte("1"), + } } func (kw *KeyWatcher) countSubscribers(key string) int { @@ -31,14 +47,17 @@ func (kw *KeyWatcher) countSubscribers(key string) int { } // Forces a run of the `Process` loop against a mock PubSubConn. -func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { - kw.mu.Lock() - kw.redisConn = rdb - psc := kw.redisConn.Subscribe(ctx, []string{}...) - kw.mu.Unlock() +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}) { + psc := redigomock.NewConn() + psc.ReceiveWait = true + + channel := channelPrefix + runnerKey + psc.Command("SUBSCRIBE", channel).Expect(createSubscribeMessage(channel)) + psc.Command("UNSUBSCRIBE", channel).Expect(createUnsubscribeMessage(channel)) + psc.AddSubscriptionMessage(createSubscriptionMessage(channel, value)) errC := make(chan error) - go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + go func() { errC <- kw.receivePubSubStream(psc) }() require.Eventually(t, func() bool { kw.mu.Lock() @@ -50,15 +69,7 @@ func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value strin require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == numWatchers }, time.Second, time.Millisecond) - - // send message after listeners are ready - kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) - - // close subscription after all workers are done - wg.Wait() - kw.mu.Lock() - kw.conn.Close() - kw.mu.Unlock() + close(psc.ReceiveNow) require.NoError(t, <-errC) } @@ -74,8 +85,6 @@ type keyChangeTestCase struct { } func TestKeyChangesInstantReturn(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ // WatchKeyStatusAlreadyChanged { @@ -112,20 +121,18 @@ func TestKeyChangesInstantReturn(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { + conn, td := setupMockPool() + defer td() - // setup - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) + if tc.isKeyMissing { + conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) + } else { + conn.Command("GET", runnerKey).Expect(tc.returnValue) } - defer func() { - rdb.FlushDB(ctx) - }() - kw := NewKeyWatcher() defer kw.Shutdown() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) @@ -136,8 +143,6 @@ func TestKeyChangesInstantReturn(t *testing.T) { } func TestKeyChangesWhenWatching(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ // WatchKeyStatusSeenChange { @@ -166,15 +171,17 @@ func TestKeyChangesWhenWatching(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) + conn, td := setupMockPool() + defer td() + + if tc.isKeyMissing { + conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) + } else { + conn.Command("GET", runnerKey).Expect(tc.returnValue) } kw := NewKeyWatcher() defer kw.Shutdown() - defer func() { - rdb.FlushDB(ctx) - }() wg := &sync.WaitGroup{} wg.Add(1) @@ -189,14 +196,13 @@ func TestKeyChangesWhenWatching(t *testing.T) { require.Equal(t, tc.expectedStatus, val, "Expected value") }() - kw.processMessages(t, 1, tc.processedValue, ready, wg) + kw.processMessages(t, 1, tc.processedValue, ready) + wg.Wait() }) } } func TestKeyChangesParallel(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ { desc: "massively parallel, sees change with key existing", @@ -218,13 +224,18 @@ func TestKeyChangesParallel(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { runTimes := 100 - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) - } + conn, td := setupMockPool() + defer td() - defer func() { - rdb.FlushDB(ctx) - }() + getCmd := conn.Command("GET", runnerKey) + + for i := 0; i < runTimes; i++ { + if tc.isKeyMissing { + getCmd = getCmd.ExpectError(redis.ErrNil) + } else { + getCmd = getCmd.Expect(tc.returnValue) + } + } wg := &sync.WaitGroup{} wg.Add(runTimes) @@ -244,20 +255,21 @@ func TestKeyChangesParallel(t *testing.T) { }() } - kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + kw.processMessages(t, runTimes, tc.processedValue, ready) + wg.Wait() }) } } func TestShutdown(t *testing.T) { - initRdb() + conn, td := setupMockPool() + defer td() kw := NewKeyWatcher() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} defer kw.Shutdown() - rdb.Set(ctx, runnerKey, "something", 0) + conn.Command("GET", runnerKey).Expect("something") wg := &sync.WaitGroup{} wg.Add(2) diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index b528255d25b..c79e1e56b3a 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -1,39 +1,24 @@ package redis import ( - "context" - "errors" "fmt" "net" + "net/url" "time" + "github.com/FZambia/sentinel" + "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - redis "github.com/redis/go-redis/v9" + "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( - rdb *redis.Client - // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 - errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") - - TotalConnections = promauto.NewCounter( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_total_connections", - Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", - }, - ) - - ErrorCounter = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_errors", - Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", - }, - []string{"type", "dst"}, - ) + pool *redis.Pool + sntnl *sentinel.Sentinel ) const ( @@ -51,166 +36,241 @@ const ( // If you _actually_ hit this timeout often, you should consider turning of // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute + // KeepAlivePeriod is to keep a TCP connection open for an extended period of + // time without being killed. This is used both in the pool, and in the + // worker-connection. + // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more + // information. + defaultKeepAlivePeriod = 5 * time.Minute ) -// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 -// it intercepts the error and tracks it via a Prometheus counter -func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - var isSentinel bool - for _, sentinelAddr := range sentinels { - if sentinelAddr == addr { - isSentinel = true - break - } - } +var ( + TotalConnections = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_total_connections", + Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", + }, + ) - dialTimeout := 5 * time.Second // go-redis default - destination := "redis" - if isSentinel { + ErrorCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_errors", + Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", + }, + []string{"type", "dst"}, + ) +) + +func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { + if len(urls) == 0 { + return nil + } + var addrs []string + for _, url := range urls { + h := url.URL.String() + log.WithFields(log.Fields{ + "scheme": url.URL.Scheme, + "host": url.URL.Host, + }).Printf("redis: using sentinel") + addrs = append(addrs, h) + } + return &sentinel.Sentinel{ + Addrs: addrs, + MasterName: master, + Dial: func(addr string) (redis.Conn, error) { // This timeout is recommended for Sentinel-support according to the guidelines. // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel // For every address it should try to connect to the Sentinel, // using a short timeout (in the order of a few hundreds of milliseconds). - destination = "sentinel" - dialTimeout = 500 * time.Millisecond - } - - netDialer := &net.Dialer{ - Timeout: dialTimeout, - KeepAlive: 5 * time.Minute, - } + timeout := 500 * time.Millisecond + url := helper.URLMustParse(addr) + + var c redis.Conn + var err error + options := []redis.DialOption{ + redis.DialConnectTimeout(timeout), + redis.DialReadTimeout(timeout), + redis.DialWriteTimeout(timeout), + } - conn, err := netDialer.DialContext(ctx, network, addr) - if err != nil { - ErrorCounter.WithLabelValues("dial", destination).Inc() - } else { - if !isSentinel { - TotalConnections.Inc() + if url.Scheme == "redis" || url.Scheme == "rediss" { + c, err = redis.DialURL(addr, options...) + } else { + c, err = redis.Dial("tcp", url.Host, options...) } - } - return conn, err + if err != nil { + ErrorCounter.WithLabelValues("dial", "sentinel").Inc() + return nil, err + } + return c, nil + }, } } -// implements the redis.Hook interface for instrumentation -type sentinelInstrumentationHook struct{} - -func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := next(ctx, network, addr) - if err != nil && err.Error() == errSentinelMasterAddr.Error() { - // check for non-dialer error - ErrorCounter.WithLabelValues("master", "sentinel").Inc() - } - return conn, err - } -} +var poolDialFunc func() (redis.Conn, error) +var workerDialFunc func() (redis.Conn, error) -func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { - return func(ctx context.Context, cmd redis.Cmder) error { - return next(ctx, cmd) +func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { + return []redis.DialOption{ + redis.DialReadTimeout(defaultReadTimeout), + redis.DialWriteTimeout(defaultWriteTimeout), } } -func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { - return func(ctx context.Context, cmds []redis.Cmder) error { - return next(ctx, cmds) +func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { + var dopts []redis.DialOption + if setTimeouts { + dopts = timeoutDialOptions(cfg) } -} - -func GetRedisClient() *redis.Client { - return rdb -} - -// Configure redis-connection -func Configure(cfg *config.RedisConfig) error { if cfg == nil { - return nil + return dopts } - - var err error - - if len(cfg.Sentinel) > 0 { - rdb = configureSentinel(cfg) - } else { - rdb, err = configureRedis(cfg) + if cfg.Password != "" { + dopts = append(dopts, redis.DialPassword(cfg.Password)) } - - return err + if cfg.DB != nil { + dopts = append(dopts, redis.DialDatabase(*cfg.DB)) + } + return dopts } -func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { - if cfg.URL.Scheme == "tcp" { - cfg.URL.Scheme = "redis" +func keepAliveDialer(network, address string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err } - - opt, err := redis.ParseURL(cfg.URL.String()) + tc, err := net.DialTCP(network, nil, addr) if err != nil { return nil, err } + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { + return nil, err + } + return tc, nil +} - opt.DB = getOrDefault(cfg.DB, 0) - opt.Password = cfg.Password - - opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) - opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) - opt.ConnMaxIdleTime = defaultIdleTimeout - opt.ReadTimeout = defaultReadTimeout - opt.WriteTimeout = defaultWriteTimeout - - opt.Dialer = createDialer([]string{}) +type redisDialerFunc func() (redis.Conn, error) - return redis.NewClient(opt), nil +func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { + return func() (redis.Conn, error) { + address, err := sntnl.MasterAddr() + if err != nil { + ErrorCounter.WithLabelValues("master", "sentinel").Inc() + return nil, err + } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) + conn, err := redisDial("tcp", address, dopts...) + if err != nil { + return nil, err + } + if !sentinel.TestRole(conn, "master") { + conn.Close() + return nil, fmt.Errorf("%s is not redis master", address) + } + return conn, nil + } } -func configureSentinel(cfg *config.RedisConfig) *redis.Client { - sentinelPassword, sentinels := sentinelOptions(cfg) - client := redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: cfg.SentinelMaster, - SentinelAddrs: sentinels, - Password: cfg.Password, - SentinelPassword: sentinelPassword, - DB: getOrDefault(cfg.DB, 0), +func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { + return func() (redis.Conn, error) { + if url.Scheme == "unix" { + return redisDial(url.Scheme, url.Path, dopts...) + } - PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), - MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), - ConnMaxIdleTime: defaultIdleTimeout, + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) - ReadTimeout: defaultReadTimeout, - WriteTimeout: defaultWriteTimeout, + // redis.DialURL only works with redis[s]:// URLs + if url.Scheme == "redis" || url.Scheme == "rediss" { + return redisURLDial(url, dopts...) + } - Dialer: createDialer(sentinels), - }) + return redisDial(url.Scheme, url.Host, dopts...) + } +} - client.AddHook(sentinelInstrumentationHook{}) +func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "scheme": url.Scheme, + "address": url.Host, + }).Printf("redis: dialing") - return client + return redis.DialURL(url.String(), options...) } -// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format -// the order of priority for the passwords is: SentinelPassword -> first password-in-url -func sentinelOptions(cfg *config.RedisConfig) (string, []string) { - sentinels := make([]string, len(cfg.Sentinel)) - sentinelPassword := cfg.SentinelPassword +func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "network": network, + "address": address, + }).Printf("redis: dialing") - for i := range cfg.Sentinel { - sentinelDetails := cfg.Sentinel[i] - sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) + return redis.Dial(network, address, options...) +} - if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 { - // sets password using the first non-empty password - sentinelPassword = pw +func countDialer(dialer redisDialerFunc) redisDialerFunc { + return func() (redis.Conn, error) { + c, err := dialer() + if err != nil { + ErrorCounter.WithLabelValues("dial", "redis").Inc() + } else { + TotalConnections.Inc() } + return c, err + } +} + +// DefaultDialFunc should always used. Only exception is for unit-tests. +func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { + dopts := dialOptionsBuilder(cfg, setReadTimeout) + if sntnl != nil { + return countDialer(sentinelDialer(dopts)) + } + return countDialer(defaultDialer(dopts, cfg.URL.URL)) +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { + if cfg == nil { + return } + maxIdle := defaultMaxIdle + if cfg.MaxIdle != nil { + maxIdle = *cfg.MaxIdle + } + maxActive := defaultMaxActive + if cfg.MaxActive != nil { + maxActive = *cfg.MaxActive + } + sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) + workerDialFunc = dialFunc(cfg, false) + poolDialFunc = dialFunc(cfg, true) + pool = &redis.Pool{ + MaxIdle: maxIdle, // Keep at most X hot connections + MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited + IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed + Dial: poolDialFunc, + Wait: true, + } +} - return sentinelPassword, sentinels +// Get a connection for the Redis-pool +func Get() redis.Conn { + if pool != nil { + return pool.Get() + } + return nil } -func getOrDefault(ptr *int, val int) int { - if ptr != nil { - return *ptr +// GetString fetches the value of a key in Redis as a string +func GetString(key string) (string, error) { + conn := Get() + if conn == nil { + return "", fmt.Errorf("redis: could not get connection from pool") } - return val + defer conn.Close() + + return redis.String(conn.Do("GET", key)) } diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go index 6fd6ecbae11..64b3a842a54 100644 --- a/workhorse/internal/redis/redis_test.go +++ b/workhorse/internal/redis/redis_test.go @@ -1,18 +1,19 @@ package redis import ( - "context" "net" - "sync/atomic" "testing" + "time" + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) -func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { +func mockRedisServer(t *testing.T, connectReceived *bool) string { ln, err := net.Listen("tcp", "127.0.0.1:0") require.Nil(t, err) @@ -21,67 +22,146 @@ func mockRedisServer(t *testing.T, connectReceived *atomic.Value) string { defer ln.Close() conn, err := ln.Accept() require.Nil(t, err) - connectReceived.Store(true) + *connectReceived = true conn.Write([]byte("OK\n")) }() return ln.Addr().String() } -func TestConfigureNoConfig(t *testing.T) { - rdb = nil - Configure(nil) - require.Nil(t, rdb, "rdb client should be nil") +// Setup a MockPool for Redis +// +// Returns a teardown-function and the mock-connection +func setupMockPool() (*redigomock.Conn, func()) { + conn := redigomock.NewConn() + cfg := &config.RedisConfig{URL: config.TomlURL{}} + Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { + return func() (redis.Conn, error) { + return conn, nil + } + }) + return conn, func() { + pool = nil + } } -func TestConfigureValidConfigX(t *testing.T) { +func TestDefaultDialFunc(t *testing.T) { testCases := []struct { scheme string }{ { - scheme: "redis", + scheme: "tcp", }, { - scheme: "tcp", + scheme: "redis", }, } for _, tc := range testCases { t.Run(tc.scheme, func(t *testing.T) { - connectReceived := atomic.Value{} + connectReceived := false a := mockRedisServer(t, &connectReceived) parsedURL := helper.URLMustParse(tc.scheme + "://" + a) cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} - Configure(cfg) + dialer := DefaultDialFunc(cfg, true) + conn, err := dialer() - require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + require.Nil(t, err) + conn.Receive() - // goredis initialise connections lazily - rdb.Ping(context.Background()) - require.True(t, connectReceived.Load().(bool)) - - rdb = nil + require.True(t, connectReceived) }) } } -func TestConnectToSentinel(t *testing.T) { +func TestConfigureNoConfig(t *testing.T) { + pool = nil + Configure(nil, nil) + require.Nil(t, pool, "Pool should be nil") +} + +func TestConfigureMinimalConfig(t *testing.T) { + cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, 1, pool.MaxIdle) + require.Equal(t, 1, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestConfigureFullConfig(t *testing.T) { + i, a := 4, 10 + cfg := &config.RedisConfig{ + URL: config.TomlURL{}, + Password: "", + MaxIdle: &i, + MaxActive: &a, + } + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, i, pool.MaxIdle) + require.Equal(t, a, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestGetConnFail(t *testing.T) { + conn := Get() + require.Nil(t, conn, "Expected `conn` to be nil") +} + +func TestGetConnPass(t *testing.T) { + _, teardown := setupMockPool() + defer teardown() + conn := Get() + require.NotNil(t, conn, "Expected `conn` to be non-nil") +} + +func TestGetStringPass(t *testing.T) { + conn, teardown := setupMockPool() + defer teardown() + conn.Command("GET", "foobar").Expect("baz") + str, err := GetString("foobar") + + require.NoError(t, err, "Expected `err` to be nil") + var value string + require.IsType(t, value, str, "Expected value to be a string") + require.Equal(t, "baz", str, "Expected it to be equal") +} + +func TestGetStringFail(t *testing.T) { + _, err := GetString("foobar") + require.Error(t, err, "Expected error when not connected to redis") +} + +func TestSentinelConnNoSentinel(t *testing.T) { + s := sentinelConn("", []config.TomlURL{}) + + require.Nil(t, s, "Sentinel without urls should return nil") +} + +func TestSentinelConnDialURL(t *testing.T) { testCases := []struct { scheme string }{ { - scheme: "redis", + scheme: "tcp", }, { - scheme: "tcp", + scheme: "redis", }, } for _, tc := range testCases { t.Run(tc.scheme, func(t *testing.T) { - connectReceived := atomic.Value{} + connectReceived := false a := mockRedisServer(t, &connectReceived) addrs := []string{tc.scheme + "://" + a} @@ -92,71 +172,57 @@ func TestConnectToSentinel(t *testing.T) { sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) } - cfg := &config.RedisConfig{Sentinel: sentinelUrls} - Configure(cfg) + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } - require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + conn, err := s.Dial(s.Addrs[0]) - // goredis initialise connections lazily - rdb.Ping(context.Background()) - require.True(t, connectReceived.Load().(bool)) + require.Nil(t, err) + conn.Receive() - rdb = nil + require.True(t, connectReceived) }) } } -func TestSentinelOptions(t *testing.T) { - testCases := []struct { - description string - inputSentinelPassword string - inputSentinel []string - password string - sentinels []string - }{ - { - description: "no sentinel passwords", - inputSentinel: []string{"tcp://localhost:26480"}, - sentinels: []string{"localhost:26480"}, - }, - { - description: "specific sentinel password defined", - inputSentinel: []string{"tcp://localhost:26480"}, - inputSentinelPassword: "password1", - sentinels: []string{"localhost:26480"}, - password: "password1", - }, - { - description: "specific sentinel password defined in url", - inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, - sentinels: []string{"localhost:26480", "localhost:26481"}, - password: "password2", - }, - { - description: "passwords defined specifically and in url", - inputSentinel: []string{"tcp://:password2@localhost:26480", "tcp://:password3@localhost:26481"}, - sentinels: []string{"localhost:26480", "localhost:26481"}, - inputSentinelPassword: "password1", - password: "password1", - }, +func TestSentinelConnTwoURLs(t *testing.T) { + addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) } - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - sentinelUrls := make([]config.TomlURL, len(tc.inputSentinel)) + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) - for i, str := range tc.inputSentinel { - parsedURL := helper.URLMustParse(str) - sentinelUrls[i] = config.TomlURL{URL: *parsedURL} - } + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } +} - outputPw, outputSentinels := sentinelOptions(&config.RedisConfig{ - Sentinel: sentinelUrls, - SentinelPassword: tc.inputSentinelPassword, - }) +func TestDialOptionsBuildersPassword(t *testing.T) { + dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) + require.Equal(t, 1, len(dopts)) +} - require.Equal(t, tc.password, outputPw) - require.Equal(t, tc.sentinels, outputSentinels) - }) - } +func TestDialOptionsBuildersSetTimeouts(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + require.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + require.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSelectDB(t *testing.T) { + db := 3 + dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) + require.Equal(t, 1, len(dopts)) } diff --git a/workhorse/main.go b/workhorse/main.go index 3043ae50a22..9ba213d47d3 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -17,8 +17,10 @@ import ( "gitlab.com/gitlab-org/labkit/monitoring" "gitlab.com/gitlab-org/labkit/tracing" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/secret" @@ -223,18 +225,34 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) - log.Info("Using redis/go-redis") + keyWatcher := redis.NewKeyWatcher() - redisKeyWatcher := redis.NewKeyWatcher() - if err := redis.Configure(cfg.Redis); err != nil { - log.WithError(err).Error("unable to configure redis client") - } + var watchKeyFn builds.WatchKeyHandler + var goredisKeyWatcher *goredis.KeyWatcher - if rdb := redis.GetRedisClient(); rdb != nil { - go redisKeyWatcher.Process(rdb) - } + if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { + log.Info("Using redis/go-redis") + + goredisKeyWatcher = goredis.NewKeyWatcher() + if err := goredis.Configure(cfg.Redis); err != nil { + log.WithError(err).Error("unable to configure redis client") + } + + if rdb := goredis.GetRedisClient(); rdb != nil { + go goredisKeyWatcher.Process(rdb) + } - watchKeyFn := redisKeyWatcher.WatchKey + watchKeyFn = goredisKeyWatcher.WatchKey + } else { + log.Info("Using gomodule/redigo") + + if cfg.Redis != nil { + redis.Configure(cfg.Redis, redis.DefaultDialFunc) + go keyWatcher.Process() + } + + watchKeyFn = keyWatcher.WatchKey + } if err := cfg.RegisterGoCloudURLOpeners(); err != nil { return fmt.Errorf("register cloud credentials: %v", err) @@ -282,8 +300,11 @@ func run(boot bootConfig, cfg config.Config) error { ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background defer cancel() - redisKeyWatcher.Shutdown() + if goredisKeyWatcher != nil { + goredisKeyWatcher.Shutdown() + } + keyWatcher.Shutdown() return srv.Shutdown(ctx) } } |