diff options
-rw-r--r-- | internal/metadata/featureflag/cache.go | 121 | ||||
-rw-r--r-- | internal/metadata/featureflag/cache_test.go | 181 | ||||
-rw-r--r-- | internal/metadata/featureflag/featureflag.go | 8 | ||||
-rw-r--r-- | internal/metadata/featureflag/featureflag_test.go | 31 |
4 files changed, 341 insertions, 0 deletions
diff --git a/internal/metadata/featureflag/cache.go b/internal/metadata/featureflag/cache.go new file mode 100644 index 000000000..49a7d461d --- /dev/null +++ b/internal/metadata/featureflag/cache.go @@ -0,0 +1,121 @@ +package featureflag + +import ( + "context" + "sync" + "time" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/v15/internal/helper/tick" +) + +// Cache is an abstraction over feature flag management. +type Cache interface { + // Get returns true/false as a first return parameter is the feature flag is + // enabled/disabled. The second return parameter is true only if the feature + // flag was evaluated and false if it is not know (not in the Cache). + Get(context.Context, string) (bool, bool) +} + +// NoopCache is a dummy implementation of the FlagCache that is used as a stub. +type NoopCache struct{} + +// Get always returns the same result (false, false) to the caller. +func (NoopCache) Get(context.Context, string) (bool, bool) { + return false, false +} + +var ( + // flagCache is a storage of the feature flags evaluated somewhere and + // used to prevent potential performance troubles with triggering + // feature flag fetching too much often from the Provider. + flagCache Cache = NoopCache{} + flagCacheMtx sync.Mutex +) + +// SetCache sets a Cache for the feature flags. +func SetCache(new Cache) { + flagCacheMtx.Lock() + defer flagCacheMtx.Unlock() + flagCache = new +} + +// GetCache returns current Cache of the feature flags. +func GetCache() Cache { + flagCacheMtx.Lock() + defer flagCacheMtx.Unlock() + return flagCache +} + +// Provider is an abstraction that is able to return a set of feature flags +// in their current state. +type Provider interface { + // GetAll returns all known feature flags and their state. + GetAll(ctx context.Context) (map[string]bool, error) +} + +// RefreshableCache is a periodically refreshable cache for storing feature flags. +// To start auto-refresh the RefreshLoop method needs to be called. +type RefreshableCache struct { + mtx sync.RWMutex + logger logrus.FieldLogger + flags map[string]bool + provider Provider +} + +// NewRefreshableCache returns a new instance of the RefreshableCache that is already initialized. +func NewRefreshableCache( + ctx context.Context, + logger logrus.FieldLogger, + provider Provider, +) *RefreshableCache { + c := &RefreshableCache{logger: logger, provider: provider} + c.refresh(ctx) + return c +} + +// RefreshLoop is a blocking call that returns once passed in context is cancelled. +// It continuously reloads cache data. If data retrieval from the Provider returns +// an error the cache data remains the same without any changes. +func (rc *RefreshableCache) RefreshLoop(ctx context.Context, ticker tick.Ticker) { + ticker.Reset() + for { + select { + case <-ticker.C(): + rc.refresh(ctx) + ticker.Reset() + case <-ctx.Done(): + return + } + } +} + +// Get returns current cached value for the feature flag and true as a second return parameter. +// If flag is not found the both return values are false. +func (rc *RefreshableCache) Get(_ context.Context, name string) (bool, bool) { + rc.mtx.RLock() + defer rc.mtx.RUnlock() + val, found := rc.flags[name] + return val, found +} + +func (rc *RefreshableCache) refresh(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + newFlags, err := rc.provider.GetAll(ctx) + if err != nil { + rc.logger.Errorf("failure on fetching the state of the feature flags: %v", err) + // In case of an issue with flags retrieval, proceed without updating cached data. + return + } + + flagsClone := make(map[string]bool, len(newFlags)) + for k, v := range newFlags { + flagsClone[k] = v + } + + rc.mtx.Lock() + defer rc.mtx.Unlock() + rc.flags = flagsClone +} diff --git a/internal/metadata/featureflag/cache_test.go b/internal/metadata/featureflag/cache_test.go new file mode 100644 index 000000000..fa4ad9060 --- /dev/null +++ b/internal/metadata/featureflag/cache_test.go @@ -0,0 +1,181 @@ +//go:build !gitaly_test_sha256 + +package featureflag + +import ( + "context" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v15/internal/helper/tick" +) + +func TestGetCache(t *testing.T) { + defaultCache := GetCache() + require.IsType(t, NoopCache{}, defaultCache) +} + +func TestSetCache(t *testing.T) { + oldCache := GetCache() + t.Cleanup(func() { + SetCache(oldCache) + }) + SetCache(NewRefreshableCache(createContext(), logrus.New(), stubProvider{})) + require.NotEqual(t, oldCache, GetCache()) +} + +func TestRefreshableCache_Get(t *testing.T) { + t.Parallel() + ctx := createContext() + const flagName = "f1" + for _, tc := range []struct { + desc string + provider Provider + expValue bool + expFound bool + }{ + { + desc: "not found", + provider: stubProvider{}, + expValue: false, + expFound: false, + }, + { + desc: "found, disabled", + provider: stubProvider{getAll: func(context.Context) (map[string]bool, error) { + return map[string]bool{flagName: false}, nil + }}, + expValue: false, + expFound: true, + }, + { + desc: "found, enabled", + provider: stubProvider{getAll: func(context.Context) (map[string]bool, error) { + return map[string]bool{flagName: true}, nil + }}, + expValue: true, + expFound: true, + }, + { + desc: "error from provider on initialisation", + provider: stubProvider{getAll: func(context.Context) (map[string]bool, error) { + return nil, assert.AnError + }}, + expValue: false, + expFound: false, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + rcache := NewRefreshableCache(ctx, logrus.New(), tc.provider) + res, ok := rcache.Get(ctx, flagName) + require.Equal(t, tc.expFound, ok) + require.Equal(t, tc.expValue, res) + }) + } + + t.Run("error from provider on refresh", func(t *testing.T) { + var called int + provider := stubProvider{getAll: func(context.Context) (map[string]bool, error) { + t.Helper() + called++ + switch called { + case 1: + return map[string]bool{flagName: true}, nil + case 2: + return nil, assert.AnError + case 3: + require.FailNow(t, "unexpected GetAll call on the Provider") + } + return nil, nil + }} + rcache := NewRefreshableCache(ctx, logrus.New(), provider) + res1, _ := rcache.Get(ctx, flagName) + require.True(t, res1) + rcache.refresh(ctx) // we don't start background refresh that is why it is done manually + res2, _ := rcache.Get(ctx, flagName) + require.Equal(t, res1, res2, "the old value shouldn't change") + }) +} + +func TestRefreshableCache_RefreshLoop(t *testing.T) { + t.Parallel() + const flagName = "f1" + var called int + provider := stubProvider{getAll: func(context.Context) (map[string]bool, error) { + t.Helper() + called++ + switch called { + case 1: // initialisation fetch + return map[string]bool{flagName: false}, nil + case 2: // first refresh + return nil, assert.AnError + case 3: + return map[string]bool{flagName: true}, nil + case 4: + require.FailNow(t, "unexpected GetAll call on the Provider") + } + return nil, nil + }} + + refreshDone := make(chan any) + ticker := tick.NewManualTicker() + ticker.ResetFunc = func() { + refreshDone <- struct{}{} + } + + ctx := createContext() + refreshLoopCtx, cancel := context.WithCancel(ctx) + defer cancel() + + rcache := NewRefreshableCache(ctx, logrus.New(), provider) + // No refresh done, value after initialization. + res, ok := rcache.Get(ctx, flagName) + require.True(t, ok) + require.False(t, res) + + refreshLoopDone := make(chan any) + go func() { + defer close(refreshLoopDone) + rcache.RefreshLoop(refreshLoopCtx, ticker) + }() + // Consumption from the channel unblocks ticker.Reset() done before loop starts. + <-refreshDone + + // Refresh done, but because Provider returned an error, no data changes. + ticker.Tick() + <-refreshDone + res, ok = rcache.Get(ctx, flagName) + require.True(t, ok) + require.False(t, res) + + // Refresh done, Provider returned a new data. + ticker.Tick() + <-refreshDone + res, ok = rcache.Get(ctx, flagName) + require.True(t, ok) + require.True(t, res) + + cancel() + require.Eventually(t, func() bool { + select { + case <-refreshLoopDone: + return true + default: + return false + } + }, time.Second, time.Millisecond*10) +} + +type stubProvider struct { + getAll func(ctx context.Context) (map[string]bool, error) +} + +func (sp stubProvider) GetAll(ctx context.Context) (map[string]bool, error) { + if sp.getAll == nil { + return nil, nil + } + return sp.getAll(ctx) +} diff --git a/internal/metadata/featureflag/featureflag.go b/internal/metadata/featureflag/featureflag.go index 24481d108..a945a78e3 100644 --- a/internal/metadata/featureflag/featureflag.go +++ b/internal/metadata/featureflag/featureflag.go @@ -136,6 +136,10 @@ func (ff FeatureFlag) IsEnabled(ctx context.Context) bool { } } + if val, found := ff.fromCache(ctx); found { + return val + } + return ff.OnByDefault } @@ -173,3 +177,7 @@ func (ff FeatureFlag) valueFromContext(ctx context.Context) (string, bool) { return val[0], true } + +func (ff FeatureFlag) fromCache(ctx context.Context) (bool, bool) { + return GetCache().Get(ctx, ff.Name) +} diff --git a/internal/metadata/featureflag/featureflag_test.go b/internal/metadata/featureflag/featureflag_test.go index b2776d94a..715098b91 100644 --- a/internal/metadata/featureflag/featureflag_test.go +++ b/internal/metadata/featureflag/featureflag_test.go @@ -3,6 +3,7 @@ package featureflag import ( + "context" "fmt" "testing" @@ -79,6 +80,7 @@ func TestFeatureFlag_enabled(t *testing.T) { shouldPanic bool enabled bool onByDefault bool + cache Cache }{ { desc: "empty name", @@ -168,8 +170,30 @@ func TestFeatureFlag_enabled(t *testing.T) { enabled: true, onByDefault: true, }, + { + desc: "flag missing in metadata but it is available in cache", + flag: "flag", + headers: map[string]string{}, + shouldPanic: false, + enabled: true, + onByDefault: false, + cache: staticCache{"flag": true}, + }, + { + desc: "flag missing in metadata and in cache", + flag: "flag", + headers: map[string]string{}, + shouldPanic: false, + enabled: true, + onByDefault: true, + cache: staticCache{}, + }, } { t.Run(tc.desc, func(t *testing.T) { + if tc.cache != nil { + defer SetCache(GetCache()) + SetCache(tc.cache) + } ctx := metadata.NewIncomingContext(createContext(), metadata.New(tc.headers)) var ff FeatureFlag @@ -184,3 +208,10 @@ func TestFeatureFlag_enabled(t *testing.T) { }) } } + +type staticCache map[string]bool + +func (sc staticCache) Get(_ context.Context, name string) (bool, bool) { + val, found := sc[name] + return val, found +} |