Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/redis')
-rw-r--r--workhorse/internal/redis/keywatcher.go198
-rw-r--r--workhorse/internal/redis/keywatcher_test.go162
-rw-r--r--workhorse/internal/redis/redis.go295
-rw-r--r--workhorse/internal/redis/redis_test.go234
4 files changed, 889 insertions, 0 deletions
diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go
new file mode 100644
index 00000000000..96e33a64b85
--- /dev/null
+++ b/workhorse/internal/redis/keywatcher.go
@@ -0,0 +1,198 @@
+package redis
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gomodule/redigo/redis"
+ "github.com/jpillora/backoff"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ keyWatcher = make(map[string][]chan string)
+ keyWatcherMutex sync.Mutex
+ redisReconnectTimeout = backoff.Backoff{
+ //These are the defaults
+ Min: 100 * time.Millisecond,
+ Max: 60 * time.Second,
+ Factor: 2,
+ Jitter: true,
+ }
+ keyWatchers = promauto.NewGauge(
+ prometheus.GaugeOpts{
+ Name: "gitlab_workhorse_keywatcher_keywatchers",
+ Help: "The number of keys that is being watched by gitlab-workhorse",
+ },
+ )
+ totalMessages = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_keywatcher_total_messages",
+ Help: "How many messages gitlab-workhorse has received in total on pubsub.",
+ },
+ )
+)
+
+const (
+ keySubChannel = "workhorse:notifications"
+)
+
+// KeyChan holds a key and a channel
+type KeyChan struct {
+ Key string
+ Chan chan string
+}
+
+func processInner(conn redis.Conn) error {
+ defer conn.Close()
+ psc := redis.PubSubConn{Conn: conn}
+ if err := psc.Subscribe(keySubChannel); err != nil {
+ return err
+ }
+ defer psc.Unsubscribe(keySubChannel)
+
+ for {
+ switch v := psc.Receive().(type) {
+ case redis.Message:
+ totalMessages.Inc()
+ dataStr := string(v.Data)
+ msg := strings.SplitN(dataStr, "=", 2)
+ if len(msg) != 2 {
+ helper.LogError(nil, fmt.Errorf("keywatcher: invalid notification: %q", dataStr))
+ continue
+ }
+ key, value := msg[0], msg[1]
+ notifyChanWatchers(key, value)
+ case error:
+ helper.LogError(nil, fmt.Errorf("keywatcher: pubsub receive: %v", v))
+ // Intermittent error, return nil so that it doesn't wait before reconnect
+ return nil
+ }
+ }
+}
+
+func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) {
+ conn, err := dialer()
+ if err != nil {
+ return nil, err
+ }
+
+ // Make sure Redis is actually connected
+ conn.Do("PING")
+ if err := conn.Err(); err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ return conn, nil
+}
+
+// Process redis subscriptions
+//
+// NOTE: There Can Only Be One!
+func Process() {
+ log.Info("keywatcher: starting process loop")
+ for {
+ conn, err := dialPubSub(workerDialFunc)
+ if err != nil {
+ helper.LogError(nil, fmt.Errorf("keywatcher: %v", err))
+ time.Sleep(redisReconnectTimeout.Duration())
+ continue
+ }
+ redisReconnectTimeout.Reset()
+
+ if err = processInner(conn); err != nil {
+ helper.LogError(nil, fmt.Errorf("keywatcher: process loop: %v", err))
+ }
+ }
+}
+
+func notifyChanWatchers(key, value string) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ if chanList, ok := keyWatcher[key]; ok {
+ for _, c := range chanList {
+ c <- value
+ keyWatchers.Dec()
+ }
+ delete(keyWatcher, key)
+ }
+}
+
+func addKeyChan(kc *KeyChan) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ keyWatcher[kc.Key] = append(keyWatcher[kc.Key], kc.Chan)
+ keyWatchers.Inc()
+}
+
+func delKeyChan(kc *KeyChan) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ if chans, ok := keyWatcher[kc.Key]; ok {
+ for i, c := range chans {
+ if kc.Chan == c {
+ keyWatcher[kc.Key] = append(chans[:i], chans[i+1:]...)
+ keyWatchers.Dec()
+ break
+ }
+ }
+ if len(keyWatcher[kc.Key]) == 0 {
+ delete(keyWatcher, kc.Key)
+ }
+ }
+}
+
+// WatchKeyStatus is used to tell how WatchKey returned
+type WatchKeyStatus int
+
+const (
+ // WatchKeyStatusTimeout is returned when the watch timeout provided by the caller was exceeded
+ WatchKeyStatusTimeout WatchKeyStatus = iota
+ // WatchKeyStatusAlreadyChanged is returned when the value passed by the caller was never observed
+ WatchKeyStatusAlreadyChanged
+ // WatchKeyStatusSeenChange is returned when we have seen the value passed by the caller get changed
+ WatchKeyStatusSeenChange
+ // WatchKeyStatusNoChange is returned when the function had to return before observing a change.
+ // Also returned on errors.
+ WatchKeyStatusNoChange
+)
+
+// WatchKey waits for a key to be updated or expired
+func WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) {
+ kw := &KeyChan{
+ Key: key,
+ Chan: make(chan string, 1),
+ }
+
+ addKeyChan(kw)
+ defer delKeyChan(kw)
+
+ currentValue, err := GetString(key)
+ if err != nil {
+ return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err)
+ }
+ if currentValue != value {
+ return WatchKeyStatusAlreadyChanged, nil
+ }
+
+ select {
+ case currentValue := <-kw.Chan:
+ if currentValue == "" {
+ return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed")
+ }
+ if currentValue == value {
+ return WatchKeyStatusNoChange, nil
+ }
+ return WatchKeyStatusSeenChange, nil
+
+ case <-time.After(timeout):
+ return WatchKeyStatusTimeout, nil
+ }
+}
diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go
new file mode 100644
index 00000000000..f1ee77e2194
--- /dev/null
+++ b/workhorse/internal/redis/keywatcher_test.go
@@ -0,0 +1,162 @@
+package redis
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/rafaeljusto/redigomock"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ runnerKey = "runner:build_queue:10"
+)
+
+func createSubscriptionMessage(key, data string) []interface{} {
+ return []interface{}{
+ []byte("message"),
+ []byte(key),
+ []byte(data),
+ }
+}
+
+func createSubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("subscribe"),
+ []byte(key),
+ []byte("1"),
+ }
+}
+func createUnsubscribeMessage(key string) []interface{} {
+ return []interface{}{
+ []byte("unsubscribe"),
+ []byte(key),
+ []byte("1"),
+ }
+}
+
+func countWatchers(key string) int {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ return len(keyWatcher[key])
+}
+
+func deleteWatchers(key string) {
+ keyWatcherMutex.Lock()
+ defer keyWatcherMutex.Unlock()
+ delete(keyWatcher, key)
+}
+
+// Forces a run of the `Process` loop against a mock PubSubConn.
+func processMessages(numWatchers int, value string) {
+ psc := redigomock.NewConn()
+
+ // Setup the initial subscription message
+ psc.Command("SUBSCRIBE", keySubChannel).Expect(createSubscribeMessage(keySubChannel))
+ psc.Command("UNSUBSCRIBE", keySubChannel).Expect(createUnsubscribeMessage(keySubChannel))
+ psc.AddSubscriptionMessage(createSubscriptionMessage(keySubChannel, runnerKey+"="+value))
+
+ // Wait for all the `WatchKey` calls to be registered
+ for countWatchers(runnerKey) != numWatchers {
+ time.Sleep(time.Millisecond)
+ }
+
+ processInner(psc)
+}
+
+func TestWatchKeySeenChange(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change")
+ wg.Done()
+ }()
+
+ processMessages(1, "somethingelse")
+ wg.Wait()
+}
+
+func TestWatchKeyNoChange(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ wg := &sync.WaitGroup{}
+ wg.Add(1)
+
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusNoChange, val, "Expected notification without change to value")
+ wg.Done()
+ }()
+
+ processMessages(1, "something")
+ wg.Wait()
+}
+
+func TestWatchKeyTimeout(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("something")
+
+ val, err := WatchKey(runnerKey, "something", time.Millisecond)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusTimeout, val, "Expected value to not change")
+
+ // Clean up watchers since Process isn't doing that for us (not running)
+ deleteWatchers(runnerKey)
+}
+
+func TestWatchKeyAlreadyChanged(t *testing.T) {
+ conn, td := setupMockPool()
+ defer td()
+
+ conn.Command("GET", runnerKey).Expect("somethingelse")
+
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusAlreadyChanged, val, "Expected value to have already changed")
+
+ // Clean up watchers since Process isn't doing that for us (not running)
+ deleteWatchers(runnerKey)
+}
+
+func TestWatchKeyMassivelyParallel(t *testing.T) {
+ runTimes := 100 // 100 parallel watchers
+
+ conn, td := setupMockPool()
+ defer td()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(runTimes)
+
+ getCmd := conn.Command("GET", runnerKey)
+
+ for i := 0; i < runTimes; i++ {
+ getCmd = getCmd.Expect("something")
+ }
+
+ for i := 0; i < runTimes; i++ {
+ go func() {
+ val, err := WatchKey(runnerKey, "something", time.Second)
+ require.NoError(t, err, "Expected no error")
+ require.Equal(t, WatchKeyStatusSeenChange, val, "Expected value to change")
+ wg.Done()
+ }()
+ }
+
+ processMessages(runTimes, "somethingelse")
+ wg.Wait()
+}
diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go
new file mode 100644
index 00000000000..0029a2a9e2b
--- /dev/null
+++ b/workhorse/internal/redis/redis.go
@@ -0,0 +1,295 @@
+package redis
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "time"
+
+ "github.com/FZambia/sentinel"
+ "github.com/gomodule/redigo/redis"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ pool *redis.Pool
+ sntnl *sentinel.Sentinel
+)
+
+const (
+ // Max Idle Connections in the pool.
+ defaultMaxIdle = 1
+ // Max Active Connections in the pool.
+ defaultMaxActive = 1
+ // Timeout for Read operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultReadTimeout = 1 * time.Second
+ // Timeout for Write operations on the pool. 1 second is technically overkill,
+ // it's just for sanity.
+ defaultWriteTimeout = 1 * time.Second
+ // Timeout before killing Idle connections in the pool. 3 minutes seemed good.
+ // If you _actually_ hit this timeout often, you should consider turning of
+ // redis-support since it's not necessary at that point...
+ defaultIdleTimeout = 3 * time.Minute
+ // KeepAlivePeriod is to keep a TCP connection open for an extended period of
+ // time without being killed. This is used both in the pool, and in the
+ // worker-connection.
+ // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more
+ // information.
+ defaultKeepAlivePeriod = 5 * time.Minute
+)
+
+var (
+ totalConnections = promauto.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_total_connections",
+ Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process",
+ },
+ )
+
+ errorCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitlab_workhorse_redis_errors",
+ Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)",
+ },
+ []string{"type", "dst"},
+ )
+)
+
+func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel {
+ if len(urls) == 0 {
+ return nil
+ }
+ var addrs []string
+ for _, url := range urls {
+ h := url.URL.String()
+ log.WithFields(log.Fields{
+ "scheme": url.URL.Scheme,
+ "host": url.URL.Host,
+ }).Printf("redis: using sentinel")
+ addrs = append(addrs, h)
+ }
+ return &sentinel.Sentinel{
+ Addrs: addrs,
+ MasterName: master,
+ Dial: func(addr string) (redis.Conn, error) {
+ // This timeout is recommended for Sentinel-support according to the guidelines.
+ // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel
+ // For every address it should try to connect to the Sentinel,
+ // using a short timeout (in the order of a few hundreds of milliseconds).
+ timeout := 500 * time.Millisecond
+ url := helper.URLMustParse(addr)
+
+ var c redis.Conn
+ var err error
+ options := []redis.DialOption{
+ redis.DialConnectTimeout(timeout),
+ redis.DialReadTimeout(timeout),
+ redis.DialWriteTimeout(timeout),
+ }
+
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ c, err = redis.DialURL(addr, options...)
+ } else {
+ c, err = redis.Dial("tcp", url.Host, options...)
+ }
+
+ if err != nil {
+ errorCounter.WithLabelValues("dial", "sentinel").Inc()
+ return nil, err
+ }
+ return c, nil
+ },
+ }
+}
+
+var poolDialFunc func() (redis.Conn, error)
+var workerDialFunc func() (redis.Conn, error)
+
+func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption {
+ readTimeout := defaultReadTimeout
+ writeTimeout := defaultWriteTimeout
+
+ if cfg != nil {
+ if cfg.ReadTimeout != nil {
+ readTimeout = cfg.ReadTimeout.Duration
+ }
+
+ if cfg.WriteTimeout != nil {
+ writeTimeout = cfg.WriteTimeout.Duration
+ }
+ }
+ return []redis.DialOption{
+ redis.DialReadTimeout(readTimeout),
+ redis.DialWriteTimeout(writeTimeout),
+ }
+}
+
+func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption {
+ var dopts []redis.DialOption
+ if setTimeouts {
+ dopts = timeoutDialOptions(cfg)
+ }
+ if cfg == nil {
+ return dopts
+ }
+ if cfg.Password != "" {
+ dopts = append(dopts, redis.DialPassword(cfg.Password))
+ }
+ if cfg.DB != nil {
+ dopts = append(dopts, redis.DialDatabase(*cfg.DB))
+ }
+ return dopts
+}
+
+func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) {
+ return func(network, address string) (net.Conn, error) {
+ addr, err := net.ResolveTCPAddr(network, address)
+ if err != nil {
+ return nil, err
+ }
+ tc, err := net.DialTCP(network, nil, addr)
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlive(true); err != nil {
+ return nil, err
+ }
+ if err := tc.SetKeepAlivePeriod(timeout); err != nil {
+ return nil, err
+ }
+ return tc, nil
+ }
+}
+
+type redisDialerFunc func() (redis.Conn, error)
+
+func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ address, err := sntnl.MasterAddr()
+ if err != nil {
+ errorCounter.WithLabelValues("master", "sentinel").Inc()
+ return nil, err
+ }
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+ return redisDial("tcp", address, dopts...)
+ }
+}
+
+func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ if url.Scheme == "unix" {
+ return redisDial(url.Scheme, url.Path, dopts...)
+ }
+
+ dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod)))
+
+ // redis.DialURL only works with redis[s]:// URLs
+ if url.Scheme == "redis" || url.Scheme == "rediss" {
+ return redisURLDial(url, dopts...)
+ }
+
+ return redisDial(url.Scheme, url.Host, dopts...)
+ }
+}
+
+func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "scheme": url.Scheme,
+ "address": url.Host,
+ }).Printf("redis: dialing")
+
+ return redis.DialURL(url.String(), options...)
+}
+
+func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) {
+ log.WithFields(log.Fields{
+ "network": network,
+ "address": address,
+ }).Printf("redis: dialing")
+
+ return redis.Dial(network, address, options...)
+}
+
+func countDialer(dialer redisDialerFunc) redisDialerFunc {
+ return func() (redis.Conn, error) {
+ c, err := dialer()
+ if err != nil {
+ errorCounter.WithLabelValues("dial", "redis").Inc()
+ } else {
+ totalConnections.Inc()
+ }
+ return c, err
+ }
+}
+
+// DefaultDialFunc should always used. Only exception is for unit-tests.
+func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) {
+ keepAlivePeriod := defaultKeepAlivePeriod
+ if cfg.KeepAlivePeriod != nil {
+ keepAlivePeriod = cfg.KeepAlivePeriod.Duration
+ }
+ dopts := dialOptionsBuilder(cfg, setReadTimeout)
+ if sntnl != nil {
+ return countDialer(sentinelDialer(dopts, keepAlivePeriod))
+ }
+ return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL))
+}
+
+// Configure redis-connection
+func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) {
+ if cfg == nil {
+ return
+ }
+ maxIdle := defaultMaxIdle
+ if cfg.MaxIdle != nil {
+ maxIdle = *cfg.MaxIdle
+ }
+ maxActive := defaultMaxActive
+ if cfg.MaxActive != nil {
+ maxActive = *cfg.MaxActive
+ }
+ sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel)
+ workerDialFunc = dialFunc(cfg, false)
+ poolDialFunc = dialFunc(cfg, true)
+ pool = &redis.Pool{
+ MaxIdle: maxIdle, // Keep at most X hot connections
+ MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited
+ IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed
+ Dial: poolDialFunc,
+ Wait: true,
+ }
+ if sntnl != nil {
+ pool.TestOnBorrow = func(c redis.Conn, t time.Time) error {
+ if !sentinel.TestRole(c, "master") {
+ return errors.New("role check failed")
+ }
+ return nil
+ }
+ }
+}
+
+// Get a connection for the Redis-pool
+func Get() redis.Conn {
+ if pool != nil {
+ return pool.Get()
+ }
+ return nil
+}
+
+// GetString fetches the value of a key in Redis as a string
+func GetString(key string) (string, error) {
+ conn := Get()
+ if conn == nil {
+ return "", fmt.Errorf("redis: could not get connection from pool")
+ }
+ defer conn.Close()
+
+ return redis.String(conn.Do("GET", key))
+}
diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go
new file mode 100644
index 00000000000..f4b4120517d
--- /dev/null
+++ b/workhorse/internal/redis/redis_test.go
@@ -0,0 +1,234 @@
+package redis
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gomodule/redigo/redis"
+ "github.com/rafaeljusto/redigomock"
+ "github.com/stretchr/testify/require"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/config"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+func mockRedisServer(t *testing.T, connectReceived *bool) string {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+
+ require.Nil(t, err)
+
+ go func() {
+ defer ln.Close()
+ conn, err := ln.Accept()
+ require.Nil(t, err)
+ *connectReceived = true
+ conn.Write([]byte("OK\n"))
+ }()
+
+ return ln.Addr().String()
+}
+
+// Setup a MockPool for Redis
+//
+// Returns a teardown-function and the mock-connection
+func setupMockPool() (*redigomock.Conn, func()) {
+ conn := redigomock.NewConn()
+ cfg := &config.RedisConfig{URL: config.TomlURL{}}
+ Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) {
+ return func() (redis.Conn, error) {
+ return conn, nil
+ }
+ })
+ return conn, func() {
+ pool = nil
+ }
+}
+
+func TestDefaultDialFunc(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "tcp",
+ },
+ {
+ scheme: "redis",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := false
+ a := mockRedisServer(t, &connectReceived)
+
+ parsedURL := helper.URLMustParse(tc.scheme + "://" + a)
+ cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}}
+
+ dialer := DefaultDialFunc(cfg, true)
+ conn, err := dialer()
+
+ require.Nil(t, err)
+ conn.Receive()
+
+ require.True(t, connectReceived)
+ })
+ }
+}
+
+func TestConfigureNoConfig(t *testing.T) {
+ pool = nil
+ Configure(nil, nil)
+ require.Nil(t, pool, "Pool should be nil")
+}
+
+func TestConfigureMinimalConfig(t *testing.T) {
+ cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""}
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, 1, pool.MaxIdle)
+ require.Equal(t, 1, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestConfigureFullConfig(t *testing.T) {
+ i, a := 4, 10
+ r := config.TomlDuration{Duration: 3}
+ cfg := &config.RedisConfig{
+ URL: config.TomlURL{},
+ Password: "",
+ MaxIdle: &i,
+ MaxActive: &a,
+ ReadTimeout: &r,
+ }
+ Configure(cfg, DefaultDialFunc)
+
+ require.NotNil(t, pool, "Pool should not be nil")
+ require.Equal(t, i, pool.MaxIdle)
+ require.Equal(t, a, pool.MaxActive)
+ require.Equal(t, 3*time.Minute, pool.IdleTimeout)
+
+ pool = nil
+}
+
+func TestGetConnFail(t *testing.T) {
+ conn := Get()
+ require.Nil(t, conn, "Expected `conn` to be nil")
+}
+
+func TestGetConnPass(t *testing.T) {
+ _, teardown := setupMockPool()
+ defer teardown()
+ conn := Get()
+ require.NotNil(t, conn, "Expected `conn` to be non-nil")
+}
+
+func TestGetStringPass(t *testing.T) {
+ conn, teardown := setupMockPool()
+ defer teardown()
+ conn.Command("GET", "foobar").Expect("baz")
+ str, err := GetString("foobar")
+
+ require.NoError(t, err, "Expected `err` to be nil")
+ var value string
+ require.IsType(t, value, str, "Expected value to be a string")
+ require.Equal(t, "baz", str, "Expected it to be equal")
+}
+
+func TestGetStringFail(t *testing.T) {
+ _, err := GetString("foobar")
+ require.Error(t, err, "Expected error when not connected to redis")
+}
+
+func TestSentinelConnNoSentinel(t *testing.T) {
+ s := sentinelConn("", []config.TomlURL{})
+
+ require.Nil(t, s, "Sentinel without urls should return nil")
+}
+
+func TestSentinelConnDialURL(t *testing.T) {
+ testCases := []struct {
+ scheme string
+ }{
+ {
+ scheme: "tcp",
+ },
+ {
+ scheme: "redis",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.scheme, func(t *testing.T) {
+ connectReceived := false
+ a := mockRedisServer(t, &connectReceived)
+
+ addrs := []string{tc.scheme + "://" + a}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
+
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
+
+ conn, err := s.Dial(s.Addrs[0])
+
+ require.Nil(t, err)
+ conn.Receive()
+
+ require.True(t, connectReceived)
+ })
+ }
+}
+
+func TestSentinelConnTwoURLs(t *testing.T) {
+ addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"}
+ var sentinelUrls []config.TomlURL
+
+ for _, a := range addrs {
+ parsedURL := helper.URLMustParse(a)
+ sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL})
+ }
+
+ s := sentinelConn("foobar", sentinelUrls)
+ require.Equal(t, len(addrs), len(s.Addrs))
+
+ for i := range addrs {
+ require.Equal(t, addrs[i], s.Addrs[i])
+ }
+}
+
+func TestDialOptionsBuildersPassword(t *testing.T) {
+ dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false)
+ require.Equal(t, 1, len(dopts))
+}
+
+func TestDialOptionsBuildersSetTimeouts(t *testing.T) {
+ dopts := dialOptionsBuilder(nil, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) {
+ cfg := &config.RedisConfig{
+ ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
+ WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)},
+ }
+ dopts := dialOptionsBuilder(cfg, true)
+ require.Equal(t, 2, len(dopts))
+}
+
+func TestDialOptionsBuildersSelectDB(t *testing.T) {
+ db := 3
+ dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false)
+ require.Equal(t, 1, len(dopts))
+}