diff options
Diffstat (limited to 'workhorse/internal/redis/redis.go')
-rw-r--r-- | workhorse/internal/redis/redis.go | 336 |
1 files changed, 138 insertions, 198 deletions
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index c79e1e56b3a..b528255d25b 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -1,24 +1,39 @@ package redis import ( + "context" + "errors" "fmt" "net" - "net/url" "time" - "github.com/FZambia/sentinel" - "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "gitlab.com/gitlab-org/labkit/log" + redis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( - pool *redis.Pool - sntnl *sentinel.Sentinel + rdb *redis.Client + // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 + errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") + + TotalConnections = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_total_connections", + Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", + }, + ) + + ErrorCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_redis_errors", + Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", + }, + []string{"type", "dst"}, + ) ) const ( @@ -36,241 +51,166 @@ const ( // If you _actually_ hit this timeout often, you should consider turning of // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute - // KeepAlivePeriod is to keep a TCP connection open for an extended period of - // time without being killed. This is used both in the pool, and in the - // worker-connection. - // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more - // information. - defaultKeepAlivePeriod = 5 * time.Minute ) -var ( - TotalConnections = promauto.NewCounter( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_total_connections", - Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", - }, - ) - - ErrorCounter = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_redis_errors", - Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", - }, - []string{"type", "dst"}, - ) -) +// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 +// it intercepts the error and tracks it via a Prometheus counter +func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var isSentinel bool + for _, sentinelAddr := range sentinels { + if sentinelAddr == addr { + isSentinel = true + break + } + } -func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { - if len(urls) == 0 { - return nil - } - var addrs []string - for _, url := range urls { - h := url.URL.String() - log.WithFields(log.Fields{ - "scheme": url.URL.Scheme, - "host": url.URL.Host, - }).Printf("redis: using sentinel") - addrs = append(addrs, h) - } - return &sentinel.Sentinel{ - Addrs: addrs, - MasterName: master, - Dial: func(addr string) (redis.Conn, error) { + dialTimeout := 5 * time.Second // go-redis default + destination := "redis" + if isSentinel { // This timeout is recommended for Sentinel-support according to the guidelines. // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel // For every address it should try to connect to the Sentinel, // using a short timeout (in the order of a few hundreds of milliseconds). - timeout := 500 * time.Millisecond - url := helper.URLMustParse(addr) - - var c redis.Conn - var err error - options := []redis.DialOption{ - redis.DialConnectTimeout(timeout), - redis.DialReadTimeout(timeout), - redis.DialWriteTimeout(timeout), - } + destination = "sentinel" + dialTimeout = 500 * time.Millisecond + } - if url.Scheme == "redis" || url.Scheme == "rediss" { - c, err = redis.DialURL(addr, options...) - } else { - c, err = redis.Dial("tcp", url.Host, options...) - } + netDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 5 * time.Minute, + } - if err != nil { - ErrorCounter.WithLabelValues("dial", "sentinel").Inc() - return nil, err + conn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + ErrorCounter.WithLabelValues("dial", destination).Inc() + } else { + if !isSentinel { + TotalConnections.Inc() } - return c, nil - }, + } + + return conn, err } } -var poolDialFunc func() (redis.Conn, error) -var workerDialFunc func() (redis.Conn, error) +// implements the redis.Hook interface for instrumentation +type sentinelInstrumentationHook struct{} -func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { - return []redis.DialOption{ - redis.DialReadTimeout(defaultReadTimeout), - redis.DialWriteTimeout(defaultWriteTimeout), +func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := next(ctx, network, addr) + if err != nil && err.Error() == errSentinelMasterAddr.Error() { + // check for non-dialer error + ErrorCounter.WithLabelValues("master", "sentinel").Inc() + } + return conn, err } } -func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { - var dopts []redis.DialOption - if setTimeouts { - dopts = timeoutDialOptions(cfg) +func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return next(ctx, cmd) } - if cfg == nil { - return dopts +} + +func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + return next(ctx, cmds) } - if cfg.Password != "" { - dopts = append(dopts, redis.DialPassword(cfg.Password)) +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) error { + if cfg == nil { + return nil } - if cfg.DB != nil { - dopts = append(dopts, redis.DialDatabase(*cfg.DB)) + + var err error + + if len(cfg.Sentinel) > 0 { + rdb = configureSentinel(cfg) + } else { + rdb, err = configureRedis(cfg) } - return dopts + + return err } -func keepAliveDialer(network, address string) (net.Conn, error) { - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err +func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { + if cfg.URL.Scheme == "tcp" { + cfg.URL.Scheme = "redis" } - tc, err := net.DialTCP(network, nil, addr) + + opt, err := redis.ParseURL(cfg.URL.String()) if err != nil { return nil, err } - if err := tc.SetKeepAlive(true); err != nil { - return nil, err - } - if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { - return nil, err - } - return tc, nil -} -type redisDialerFunc func() (redis.Conn, error) + opt.DB = getOrDefault(cfg.DB, 0) + opt.Password = cfg.Password -func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { - return func() (redis.Conn, error) { - address, err := sntnl.MasterAddr() - if err != nil { - ErrorCounter.WithLabelValues("master", "sentinel").Inc() - return nil, err - } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) - conn, err := redisDial("tcp", address, dopts...) - if err != nil { - return nil, err - } - if !sentinel.TestRole(conn, "master") { - conn.Close() - return nil, fmt.Errorf("%s is not redis master", address) - } - return conn, nil - } -} + opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) + opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout -func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { - return func() (redis.Conn, error) { - if url.Scheme == "unix" { - return redisDial(url.Scheme, url.Path, dopts...) - } + opt.Dialer = createDialer([]string{}) - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) + return redis.NewClient(opt), nil +} - // redis.DialURL only works with redis[s]:// URLs - if url.Scheme == "redis" || url.Scheme == "rediss" { - return redisURLDial(url, dopts...) - } +func configureSentinel(cfg *config.RedisConfig) *redis.Client { + sentinelPassword, sentinels := sentinelOptions(cfg) + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + SentinelPassword: sentinelPassword, + DB: getOrDefault(cfg.DB, 0), - return redisDial(url.Scheme, url.Host, dopts...) - } -} + PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), + MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), + ConnMaxIdleTime: defaultIdleTimeout, -func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "scheme": url.Scheme, - "address": url.Host, - }).Printf("redis: dialing") + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, - return redis.DialURL(url.String(), options...) -} + Dialer: createDialer(sentinels), + }) -func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "network": network, - "address": address, - }).Printf("redis: dialing") + client.AddHook(sentinelInstrumentationHook{}) - return redis.Dial(network, address, options...) + return client } -func countDialer(dialer redisDialerFunc) redisDialerFunc { - return func() (redis.Conn, error) { - c, err := dialer() - if err != nil { - ErrorCounter.WithLabelValues("dial", "redis").Inc() - } else { - TotalConnections.Inc() - } - return c, err - } -} +// sentinelOptions extracts the sentinel password and addresses in <host>:<port> format +// the order of priority for the passwords is: SentinelPassword -> first password-in-url +func sentinelOptions(cfg *config.RedisConfig) (string, []string) { + sentinels := make([]string, len(cfg.Sentinel)) + sentinelPassword := cfg.SentinelPassword -// DefaultDialFunc should always used. Only exception is for unit-tests. -func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { - dopts := dialOptionsBuilder(cfg, setReadTimeout) - if sntnl != nil { - return countDialer(sentinelDialer(dopts)) - } - return countDialer(defaultDialer(dopts, cfg.URL.URL)) -} + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) -// Configure redis-connection -func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { - if cfg == nil { - return - } - maxIdle := defaultMaxIdle - if cfg.MaxIdle != nil { - maxIdle = *cfg.MaxIdle - } - maxActive := defaultMaxActive - if cfg.MaxActive != nil { - maxActive = *cfg.MaxActive - } - sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) - workerDialFunc = dialFunc(cfg, false) - poolDialFunc = dialFunc(cfg, true) - pool = &redis.Pool{ - MaxIdle: maxIdle, // Keep at most X hot connections - MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited - IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed - Dial: poolDialFunc, - Wait: true, + if pw, exist := sentinelDetails.User.Password(); exist && len(sentinelPassword) == 0 { + // sets password using the first non-empty password + sentinelPassword = pw + } } -} -// Get a connection for the Redis-pool -func Get() redis.Conn { - if pool != nil { - return pool.Get() - } - return nil + return sentinelPassword, sentinels } -// GetString fetches the value of a key in Redis as a string -func GetString(key string) (string, error) { - conn := Get() - if conn == nil { - return "", fmt.Errorf("redis: could not get connection from pool") +func getOrDefault(ptr *int, val int) int { + if ptr != nil { + return *ptr } - defer conn.Close() - - return redis.String(conn.Do("GET", key)) + return val } |