diff options
author | Pavlo Strokov <pstrokov@gitlab.com> | 2021-12-14 11:51:58 +0300 |
---|---|---|
committer | Pavlo Strokov <pstrokov@gitlab.com> | 2021-12-14 11:51:58 +0300 |
commit | 0e0aeb5ca4488903f41acd392911bd89d1ad3d6d (patch) | |
tree | 3a5dd5ffa5521f4b16f6f332d4aac53323ea0040 | |
parent | 125f0fc0e49db4dd46f2e905f34178ce880dd79e (diff) | |
parent | f7071f97997fb7734cf6903285b007336186b22a (diff) |
Merge branch 'pks-linting-disallow-standard-context' into 'master'
lint: Disallow use of "normal" contexts
See merge request gitlab-org/gitaly!4190
33 files changed, 423 insertions, 239 deletions
diff --git a/.golangci.yml b/.golangci.yml index dce1c42d7..b1359a125 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -53,6 +53,11 @@ linters-settings: # following functions is thus disallowed. and a code smell. - ^context.WithDeadline$ - ^context.WithTimeout$ + # Tests should always use `testhelper.Context()`: this context has + # special handling for feature flags which allows us to assert that + # they're tested as expected. + - ^context.Background$ + - ^context.TODO$ stylecheck: # ST1000 checks for missing package comments. We don't use these for most # packages, so let's disable this check. diff --git a/auth/extract_test.go b/auth/extract_test.go index cf70aa8eb..7aa14ca7e 100644 --- a/auth/extract_test.go +++ b/auth/extract_test.go @@ -1,12 +1,12 @@ package gitalyauth import ( - "context" "testing" "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) func TestCheckTokenV2(t *testing.T) { @@ -77,9 +77,12 @@ func TestCheckTokenV2(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + md := metautils.NiceMD{} md.Set("authorization", "Bearer "+tc.token) - result := CheckToken(md.ToIncoming(context.Background()), string(secret), targetTime) + result := CheckToken(md.ToIncoming(ctx), string(secret), targetTime) require.Equal(t, tc.result, result) }) diff --git a/cmd/gitaly-hooks/hooks_test.go b/cmd/gitaly-hooks/hooks_test.go index 856730900..399e6e538 100644 --- a/cmd/gitaly-hooks/hooks_test.go +++ b/cmd/gitaly-hooks/hooks_test.go @@ -122,6 +122,9 @@ func TestHooksPrePostReceive(t *testing.T) { } func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Repository, repoPath string) { + ctx, cancel := testhelper.Context() + defer cancel() + secretToken := "secret token" glID := "key-1234" glUsername := "iamgitlab" @@ -184,7 +187,7 @@ func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Reposi cmd.Stdin = stdin cmd.Env = envForHooks( t, - context.Background(), + ctx, cfg, repo, glHookValues{ @@ -242,6 +245,9 @@ func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Reposi } func TestHooksUpdate(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + glID := "key-1234" glUsername := "iamgitlab" glProtocol := "ssh" @@ -261,21 +267,21 @@ func TestHooksUpdate(t *testing.T) { runHookServiceServer(t, cfg) - testHooksUpdate(t, cfg, glHookValues{ + testHooksUpdate(t, ctx, cfg, glHookValues{ GLID: glID, GLUsername: glUsername, GLProtocol: glProtocol, }) } -func testHooksUpdate(t *testing.T, cfg config.Cfg, glValues glHookValues) { +func testHooksUpdate(t *testing.T, ctx context.Context, cfg config.Cfg, glValues glHookValues) { repo, repoPath := gittest.CloneRepo(t, cfg, cfg.Storages[0]) refval, oldval, newval := "refval", strings.Repeat("a", 40), strings.Repeat("b", 40) updateHookPath, err := filepath.Abs("../../ruby/git-hooks/update") require.NoError(t, err) cmd := exec.Command(updateHookPath, refval, oldval, newval) - cmd.Env = envForHooks(t, context.Background(), cfg, repo, glValues, proxyValues{}) + cmd.Env = envForHooks(t, ctx, cfg, repo, glValues, proxyValues{}) cmd.Dir = repoPath tempDir := testhelper.TempDir(t) @@ -395,6 +401,9 @@ func TestHooksPostReceiveFailed(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + hooksPayload, err := git.NewHooksPayload( cfg, repo, @@ -409,11 +418,11 @@ func TestHooksPostReceiveFailed(t *testing.T) { Protocol: glProtocol, }, git.PostReceiveHook, - rawFeatureFlags(context.Background()), + rawFeatureFlags(ctx), ).Env() require.NoError(t, err) - env := envForHooks(t, context.Background(), cfg, repo, glHookValues{}, proxyValues{}) + env := envForHooks(t, ctx, cfg, repo, glHookValues{}, proxyValues{}) env = append(env, hooksPayload) cmd := exec.Command(postReceiveHookPath) @@ -466,13 +475,16 @@ func TestHooksNotAllowed(t *testing.T) { var stderr, stdout bytes.Buffer + ctx, cancel := testhelper.Context() + defer cancel() + preReceiveHookPath, err := filepath.Abs("../../ruby/git-hooks/pre-receive") require.NoError(t, err) cmd := exec.Command(preReceiveHookPath) cmd.Stderr = &stderr cmd.Stdout = &stdout cmd.Stdin = strings.NewReader(changes) - cmd.Env = envForHooks(t, context.Background(), cfg, repo, + cmd.Env = envForHooks(t, ctx, cfg, repo, glHookValues{ GLID: glID, GLUsername: glUsername, @@ -660,6 +672,9 @@ func TestGitalyHooksPackObjects(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + hook.Reset() tempDir := testhelper.TempDir(t) @@ -668,7 +683,7 @@ func TestGitalyHooksPackObjects(t *testing.T) { args = append(args, repoPath, tempDir) gittest.ExecOpts(t, cfg, gittest.ExecConfig{ - Env: (envForHooks(t, context.Background(), cfg, repo, glHookValues{}, proxyValues{})), + Env: (envForHooks(t, ctx, cfg, repo, glHookValues{}, proxyValues{})), }, args...) }) } diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go index 474aa86d2..7e001bf22 100644 --- a/internal/backchannel/backchannel_test.go +++ b/internal/backchannel/backchannel_test.go @@ -77,7 +77,7 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) { defer srv.Stop() go srv.Serve(ln) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := testhelper.Context() defer cancel() start := make(chan struct{}) @@ -208,7 +208,7 @@ func Benchmark(b *testing.B) { defer srv.Stop() go srv.Serve(ln) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := testhelper.Context() defer cancel() opts := []grpc.DialOption{grpc.WithBlock(), grpc.WithInsecure()} diff --git a/internal/cache/diskcache_test.go b/internal/cache/diskcache_test.go index 497b89dee..d52dbbc7f 100644 --- a/internal/cache/diskcache_test.go +++ b/internal/cache/diskcache_test.go @@ -18,33 +18,12 @@ import ( ) func TestStreamDBNaiveKeyer(t *testing.T) { - cfg := testcfg.Build(t) - - testRepo1, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0]) - testRepo2, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0]) - - locator := config.NewLocator(cfg) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack") - - cache := New(cfg, locator) - - req1 := &gitalypb.InfoRefsRequest{ - Repository: testRepo1, - } - req2 := &gitalypb.InfoRefsRequest{ - Repository: testRepo2, - } - - expectGetMiss := func(req *gitalypb.InfoRefsRequest) { + expectGetMiss := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest) { _, err := cache.GetStream(ctx, req.Repository, req) require.Equal(t, ErrReqNotFound, err) } - expectGetHit := func(expectStr string, req *gitalypb.InfoRefsRequest) { + expectGetHit := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest, expectStr string) { actualStream, err := cache.GetStream(ctx, req.Repository, req) require.NoError(t, err) actualBytes, err := io.ReadAll(actualStream) @@ -52,70 +31,126 @@ func TestStreamDBNaiveKeyer(t *testing.T) { require.Equal(t, expectStr, string(actualBytes)) } - invalidationEvent := func(repo *gitalypb.Repository) { + invalidationEvent := func(ctx context.Context, cache Cache, repo *gitalypb.Repository) { lease, err := cache.StartLease(repo) require.NoError(t, err) // imagine repo being modified here require.NoError(t, lease.EndLease(ctx)) } - storeAndRetrieve := func(expectStr string, req *gitalypb.InfoRefsRequest) { + storeAndRetrieve := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest, expectStr string) { require.NoError(t, cache.PutStream(ctx, req.Repository, req, strings.NewReader(expectStr))) - expectGetHit(expectStr, req) + expectGetHit(ctx, cache, req, expectStr) + } + + cfg := testcfg.Build(t) + + repo1, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0]) + repo2, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0]) + + locator := config.NewLocator(cfg) + + req1 := &gitalypb.InfoRefsRequest{ + Repository: repo1, + } + req2 := &gitalypb.InfoRefsRequest{ + Repository: repo2, } - // cache is initially empty - expectGetMiss(req1) - expectGetMiss(req2) - - // populate cache - repo1contents := "store and retrieve value in repo 1" - storeAndRetrieve(repo1contents, req1) - repo2contents := "store and retrieve value in repo 2" - storeAndRetrieve(repo2contents, req2) - - // invalidation makes previous value stale and unreachable - invalidationEvent(req1.Repository) - expectGetMiss(req1) - expectGetHit(repo2contents, req2) // repo1 invalidation doesn't affect repo2 - - // store new value for same cache value but at new generation - expectStream2 := "not what you were looking for" - require.NoError(t, cache.PutStream(ctx, req1.Repository, req1, strings.NewReader(expectStream2))) - expectGetHit(expectStream2, req1) - - // enabled feature flags affect caching - oldCtx := ctx - ctx = featureflag.IncomingCtxWithFeatureFlag(ctx, featureflag.FeatureFlag{Name: "meow", OnByDefault: false}, true) - expectGetMiss(req1) - ctx = oldCtx - expectGetHit(expectStream2, req1) - - // start critical section without closing - repo1Lease, err := cache.StartLease(req1.Repository) - require.NoError(t, err) - - // accessing repo cache with open critical section should fail - _, err = cache.GetStream(ctx, req1.Repository, req1) - require.Equal(t, err, ErrPendingExists) - err = cache.PutStream(ctx, req1.Repository, req1, strings.NewReader(repo1contents)) - require.Equal(t, err, ErrPendingExists) - - expectGetHit(repo2contents, req2) // other repo caches should be unaffected - - // opening and closing a new critical zone doesn't resolve the issue - invalidationEvent(req1.Repository) - _, err = cache.GetStream(ctx, req1.Repository, req1) - require.Equal(t, err, ErrPendingExists) - - // only completing/removing the pending generation file will allow access - require.NoError(t, repo1Lease.EndLease(ctx)) - expectGetMiss(req1) - - // creating a lease on a repo that doesn't exist yet should succeed - req1.Repository.RelativePath += "-does-not-exist" - _, err = cache.StartLease(req1.Repository) - require.NoError(t, err) + ctx, cancel := testhelper.Context() + defer cancel() + ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack") + + t.Run("empty cache", func(t *testing.T) { + cache := New(cfg, locator) + + expectGetMiss(ctx, cache, req1) + expectGetMiss(ctx, cache, req2) + }) + + t.Run("store and retrieve", func(t *testing.T) { + cache := New(cfg, locator) + storeAndRetrieve(ctx, cache, req1, "content-1") + storeAndRetrieve(ctx, cache, req2, "content-2") + }) + + t.Run("invalidation", func(t *testing.T) { + cache := New(cfg, locator) + + storeAndRetrieve(ctx, cache, req1, "content-1") + storeAndRetrieve(ctx, cache, req2, "content-2") + + invalidationEvent(ctx, cache, req1.Repository) + + expectGetMiss(ctx, cache, req1) + expectGetHit(ctx, cache, req2, "content-2") + }) + + t.Run("overwrite existing entry", func(t *testing.T) { + cache := New(cfg, locator) + + storeAndRetrieve(ctx, cache, req1, "content-1") + + require.NoError(t, cache.PutStream(ctx, req1.Repository, req1, strings.NewReader("not what you were looking for"))) + expectGetHit(ctx, cache, req1, "not what you were looking for") + }) + + t.Run("feature flags affect caching", func(t *testing.T) { + cache := New(cfg, locator) + + ctxWithFF := featureflag.IncomingCtxWithFeatureFlag(ctx, featureflag.FeatureFlag{ + Name: "meow", + }, true) + + storeAndRetrieve(ctx, cache, req1, "default") + expectGetHit(ctx, cache, req1, "default") + expectGetMiss(ctxWithFF, cache, req1) + + storeAndRetrieve(ctxWithFF, cache, req1, "flagged") + expectGetHit(ctxWithFF, cache, req1, "flagged") + expectGetHit(ctx, cache, req1, "default") + }) + + t.Run("critical section", func(t *testing.T) { + cache := New(cfg, locator) + + storeAndRetrieve(ctx, cache, req2, "unrelated") + + // Start critical section without closing it. + repo1Lease, err := cache.StartLease(req1.Repository) + require.NoError(t, err) + + // Accessing repo cache with open critical section should fail. + _, err = cache.GetStream(ctx, req1.Repository, req1) + require.Equal(t, err, ErrPendingExists) + err = cache.PutStream(ctx, req1.Repository, req1, strings.NewReader("conflict")) + require.Equal(t, err, ErrPendingExists) + + // Other repo caches should be unaffected. + expectGetHit(ctx, cache, req2, "unrelated") + + // Opening and closing a new critical section doesn't resolve the issue. + invalidationEvent(ctx, cache, req1.Repository) + _, err = cache.GetStream(ctx, req1.Repository, req1) + require.Equal(t, err, ErrPendingExists) + + // Only completing/removing the pending generation file will allow access. + require.NoError(t, repo1Lease.EndLease(ctx)) + expectGetMiss(ctx, cache, req1) + }) + + t.Run("nonexisteng repository", func(t *testing.T) { + cache := New(cfg, locator) + + nonexistentRepo := &gitalypb.Repository{ + StorageName: repo1.StorageName, + RelativePath: "does-not-exist", + } + + // Creating a lease on a repo that doesn't exist yet should succeed. + _, err := cache.StartLease(nonexistentRepo) + require.NoError(t, err) + }) } func TestLoserCount(t *testing.T) { @@ -133,7 +168,10 @@ func TestLoserCount(t *testing.T) { StorageName: "storage-1", }, } - ctx := testhelper.SetCtxGrpcMethod(context.Background(), "InfoRefsUploadPack") + + ctx, cancel := testhelper.Context() + defer cancel() + ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack") leashes := []chan struct{}{make(chan struct{}), make(chan struct{}), make(chan struct{})} errQ := make(chan error) diff --git a/internal/cgroups/v1_linux_test.go b/internal/cgroups/v1_linux_test.go index 252aa9287..0cb522c04 100644 --- a/internal/cgroups/v1_linux_test.go +++ b/internal/cgroups/v1_linux_test.go @@ -2,7 +2,6 @@ package cgroups import ( "bytes" - "context" "fmt" "hash/crc32" "os" @@ -72,7 +71,7 @@ func TestAddCommand(t *testing.T) { } require.NoError(t, v1Manager1.Setup()) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := testhelper.Context() defer cancel() cmd1 := exec.Command("ls", "-hal", ".") diff --git a/internal/command/command_test.go b/internal/command/command_test.go index d3d71f1a2..72be22895 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,7 +2,6 @@ package command import ( "bytes" - "context" "fmt" "io" "os/exec" @@ -161,7 +160,7 @@ func TestRejectEmptyContextDone(t *testing.T) { } }() - _, err := New(context.Background(), exec.Command("true"), nil, nil, nil) + _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true"), nil, nil, nil) require.NoError(t, err) } diff --git a/internal/git/catfile/cache_test.go b/internal/git/catfile/cache_test.go index b2e3cedfb..0344aec83 100644 --- a/internal/git/catfile/cache_test.go +++ b/internal/git/catfile/cache_test.go @@ -1,7 +1,6 @@ package catfile import ( - "context" "errors" "io" "os" @@ -186,7 +185,7 @@ func TestCache_ObjectReader(t *testing.T) { cache.cachedProcessDone = sync.NewCond(&sync.Mutex{}) t.Run("uncancellable", func(t *testing.T) { - ctx := context.Background() + ctx := testhelper.ContextWithoutCancel() require.PanicsWithValue(t, "empty ctx.Done() in catfile.Batch.New()", func() { _, _ = cache.ObjectReader(ctx, repoExecutor) @@ -322,7 +321,7 @@ func TestCache_ObjectInfoReader(t *testing.T) { cache.cachedProcessDone = sync.NewCond(&sync.Mutex{}) t.Run("uncancellable", func(t *testing.T) { - ctx := context.Background() + ctx := testhelper.ContextWithoutCancel() require.PanicsWithValue(t, "empty ctx.Done() in catfile.Batch.New()", func() { _, _ = cache.ObjectInfoReader(ctx, repoExecutor) diff --git a/internal/git/objectpool/clone_test.go b/internal/git/objectpool/clone_test.go index a705c5832..3bf1e7584 100644 --- a/internal/git/objectpool/clone_test.go +++ b/internal/git/objectpool/clone_test.go @@ -12,7 +12,7 @@ func TestClone(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.clone(ctx, testRepo)) defer func() { @@ -27,7 +27,7 @@ func TestCloneExistingPool(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.clone(ctx, testRepo)) defer func() { diff --git a/internal/git/objectpool/fetch_test.go b/internal/git/objectpool/fetch_test.go index eaab4dcac..8141dbdee 100644 --- a/internal/git/objectpool/fetch_test.go +++ b/internal/git/objectpool/fetch_test.go @@ -19,7 +19,7 @@ func TestFetchFromOriginDangling(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool") @@ -86,7 +86,7 @@ func TestFetchFromOriginFsck(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, repo := setupObjectPool(t) + pool, repo := setupObjectPool(t, ctx) repoPath := filepath.Join(pool.cfg.Storages[0].Path, repo.RelativePath) require.NoError(t, pool.FetchFromOrigin(ctx, repo), "seed pool") @@ -110,7 +110,7 @@ func TestFetchFromOriginDeltaIslands(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool") @@ -133,7 +133,7 @@ func TestFetchFromOriginBitmapHashCache(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool") @@ -158,7 +158,7 @@ func TestFetchFromOriginRefUpdates(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) poolPath := pool.FullPath() @@ -203,7 +203,7 @@ func TestFetchFromOrigin_refs(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, _ := setupObjectPool(t) + pool, _ := setupObjectPool(t, ctx) poolPath := pool.FullPath() // Init the source repo with a bunch of refs. diff --git a/internal/git/objectpool/link_test.go b/internal/git/objectpool/link_test.go index c960f48db..40d07f68c 100644 --- a/internal/git/objectpool/link_test.go +++ b/internal/git/objectpool/link_test.go @@ -19,7 +19,7 @@ func TestLink(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.Remove(ctx), "make sure pool does not exist prior to creation") require.NoError(t, pool.Create(ctx, testRepo), "create pool") @@ -49,7 +49,7 @@ func TestLink_transactional(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, poolMember := setupObjectPool(t) + pool, poolMember := setupObjectPool(t, ctx) require.NoError(t, pool.Create(ctx, poolMember)) txManager := transaction.NewTrackingManager() @@ -74,7 +74,7 @@ func TestLinkRemoveBitmap(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.Init(ctx)) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) @@ -119,7 +119,7 @@ func TestLinkAbsoluteLinkExists(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) diff --git a/internal/git/objectpool/pool_test.go b/internal/git/objectpool/pool_test.go index 94756ee0c..c3e57217d 100644 --- a/internal/git/objectpool/pool_test.go +++ b/internal/git/objectpool/pool_test.go @@ -30,7 +30,7 @@ func TestNewFromRepoSuccess(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) require.NoError(t, pool.Create(ctx, testRepo)) require.NoError(t, pool.Link(ctx, testRepo)) @@ -42,7 +42,10 @@ func TestNewFromRepoSuccess(t *testing.T) { } func TestNewFromRepoNoObjectPool(t *testing.T) { - pool, testRepo := setupObjectPool(t) + ctx, cancel := testhelper.Context() + defer cancel() + + pool, testRepo := setupObjectPool(t, ctx) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) @@ -93,7 +96,7 @@ func TestCreate(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath) @@ -127,7 +130,7 @@ func TestCreateSubDirsExist(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) err := pool.Create(ctx, testRepo) require.NoError(t, err) @@ -143,7 +146,7 @@ func TestRemove(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - pool, testRepo := setupObjectPool(t) + pool, testRepo := setupObjectPool(t, ctx) err := pool.Create(ctx, testRepo) require.NoError(t, err) diff --git a/internal/git/objectpool/testhelper_test.go b/internal/git/objectpool/testhelper_test.go index 722e892df..62f89eb57 100644 --- a/internal/git/objectpool/testhelper_test.go +++ b/internal/git/objectpool/testhelper_test.go @@ -24,7 +24,7 @@ func TestMain(m *testing.M) { })) } -func setupObjectPool(t *testing.T) (*ObjectPool, *gitalypb.Repository) { +func setupObjectPool(t *testing.T, ctx context.Context) (*ObjectPool, *gitalypb.Repository) { t.Helper() cfg, repo, _ := testcfg.BuildWithRepo(t) @@ -45,7 +45,7 @@ func setupObjectPool(t *testing.T) (*ObjectPool, *gitalypb.Repository) { require.NoError(t, err) t.Cleanup(func() { - if err := pool.Remove(context.TODO()); err != nil { + if err := pool.Remove(ctx); err != nil { panic(err) } }) diff --git a/internal/gitaly/hook/sidechannel_test.go b/internal/gitaly/hook/sidechannel_test.go index 8571d8a2b..a9ba8d344 100644 --- a/internal/gitaly/hook/sidechannel_test.go +++ b/internal/gitaly/hook/sidechannel_test.go @@ -1,20 +1,23 @@ package hook import ( - "context" "io" "net" "testing" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/metadata" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" grpc_metadata "google.golang.org/grpc/metadata" ) func TestSidechannel(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + // Client side ctxOut, wt, err := SetupSidechannel( - context.Background(), + ctx, func(c *net.UnixConn) error { _, err := io.WriteString(c, "ping") return err @@ -38,6 +41,9 @@ func TestSidechannel(t *testing.T) { } func TestGetSidechannel(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + testCases := []string{ "foobar", "sc.foo/../../bar", @@ -48,7 +54,7 @@ func TestGetSidechannel(t *testing.T) { for _, tc := range testCases { t.Run(tc, func(t *testing.T) { ctx := grpc_metadata.NewIncomingContext( - context.Background(), + ctx, map[string][]string{sidechannelHeader: {tc}}, ) _, err := GetSidechannel(ctx) diff --git a/internal/gitaly/maintenance/daily_test.go b/internal/gitaly/maintenance/daily_test.go index 0edeb854c..151b61630 100644 --- a/internal/gitaly/maintenance/daily_test.go +++ b/internal/gitaly/maintenance/daily_test.go @@ -38,7 +38,10 @@ func TestStartDaily(t *testing.T) { Duration: config.Duration(time.Hour), Storages: []string{"meow"}, } - ctx, cancel := context.WithCancel(context.Background()) + + ctx, cancel := testhelper.Context() + defer cancel() + go func() { errQ <- dw.StartDaily(ctx, testhelper.NewDiscardingLogEntry(t), s, fn) }() startTime := time.Date(1999, 3, 31, 0, 0, 0, 0, time.Local) diff --git a/internal/gitaly/storage/servers_test.go b/internal/gitaly/storage/servers_test.go index 5c6def8da..f10666828 100644 --- a/internal/gitaly/storage/servers_test.go +++ b/internal/gitaly/storage/servers_test.go @@ -81,6 +81,8 @@ func TestInjectGitalyServers(t *testing.T) { } t.Run("brand new context", func(t *testing.T) { + //nolint:forbidigo // We need to check for metadata and thus cannot use the + // testhelper context, which injects feature flags. ctx := context.Background() check(t, ctx) @@ -89,6 +91,8 @@ func TestInjectGitalyServers(t *testing.T) { t.Run("context with existing outgoing metadata should not be re-written", func(t *testing.T) { existing := metadata.New(map[string]string{"foo": "bar"}) + //nolint:forbidigo // We need to check for metadata and thus cannot use the + // testhelper context, which injects feature flags. ctx := metadata.NewOutgoingContext(context.Background(), existing) check(t, ctx) diff --git a/internal/gitlab/http_client_test.go b/internal/gitlab/http_client_test.go index 7d2c6bfc9..90f9e96f0 100644 --- a/internal/gitlab/http_client_test.go +++ b/internal/gitlab/http_client_test.go @@ -1,7 +1,6 @@ package gitlab import ( - "context" "encoding/json" "net/http" "net/http/httptest" @@ -120,8 +119,11 @@ func TestAccess_verifyParams(t *testing.T) { }, } + ctx, cancel := testhelper.Context() + defer cancel() + for _, tc := range testCases { - allowed, _, err := c.Allowed(context.Background(), AllowedParams{ + allowed, _, err := c.Allowed(ctx, AllowedParams{ RepoPath: tc.repo.RelativePath, GitObjectDirectory: tc.repo.GitObjectDirectory, GitAlternateObjectDirectories: tc.repo.GitAlternateObjectDirectories, @@ -227,7 +229,11 @@ func TestAccess_escapedAndRelativeURLs(t *testing.T) { prometheus.Config{}, ) require.NoError(t, err) - allowed, _, err := c.Allowed(context.Background(), AllowedParams{ + + ctx, cancel := testhelper.Context() + defer cancel() + + allowed, _, err := c.Allowed(ctx, AllowedParams{ RepoPath: repo.RelativePath, GitObjectDirectory: repo.GitObjectDirectory, GitAlternateObjectDirectories: repo.GitAlternateObjectDirectories, @@ -379,7 +385,10 @@ func TestAccess_allowedResponseHandling(t *testing.T) { mockHistogramVec := promtest.NewMockHistogramVec() c.latencyMetric = mockHistogramVec - allowed, message, err := c.Allowed(context.Background(), AllowedParams{ + ctx, cancel := testhelper.Context() + defer cancel() + + allowed, message, err := c.Allowed(ctx, AllowedParams{ RepoPath: repo.RelativePath, GitObjectDirectory: repo.GitObjectDirectory, GitAlternateObjectDirectories: repo.GitAlternateObjectDirectories, @@ -489,7 +498,10 @@ func TestAccess_preReceive(t *testing.T) { mockHistogramVec := promtest.NewMockHistogramVec() c.latencyMetric = mockHistogramVec - success, err := c.PreReceive(context.Background(), "key-123") + ctx, cancel := testhelper.Context() + defer cancel() + + success, err := c.PreReceive(ctx, "key-123") require.Equal(t, tc.success, success) if err != nil { require.Contains(t, err.Error(), tc.errMsg) @@ -577,10 +589,13 @@ func TestAccess_postReceive(t *testing.T) { mockHistogramVec := promtest.NewMockHistogramVec() c.latencyMetric = mockHistogramVec + ctx, cancel := testhelper.Context() + defer cancel() + repositoryID := "project-123" identifier := "key-123" changes := "000 000 refs/heads/master" - success, _, err := c.PostReceive(context.Background(), repositoryID, identifier, changes, tc.pushOptions...) + success, _, err := c.PostReceive(ctx, repositoryID, identifier, changes, tc.pushOptions...) require.Equal(t, tc.success, success) if err != nil { require.Contains(t, err.Error(), tc.errMsg) diff --git a/internal/listenmux/mux_test.go b/internal/listenmux/mux_test.go index 4ab8c488a..00c63d3c0 100644 --- a/internal/listenmux/mux_test.go +++ b/internal/listenmux/mux_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -55,23 +56,29 @@ func serverWithHandshaker(t *testing.T, h Handshaker) string { return l.Addr().String() } -func checkHealth(t *testing.T, cc *grpc.ClientConn) { +func checkHealth(t *testing.T, ctx context.Context, cc *grpc.ClientConn) { t.Helper() - _, err := healthgrpc.NewHealthClient(cc).Check(context.Background(), &healthgrpc.HealthCheckRequest{}) + _, err := healthgrpc.NewHealthClient(cc).Check(ctx, &healthgrpc.HealthCheckRequest{}) require.NoError(t, err) } func TestMux_normalClientNoMux(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + addr := serverWithHandshaker(t, nil) cc, err := grpc.Dial(addr, grpc.WithInsecure()) require.NoError(t, err) defer cc.Close() - checkHealth(t, cc) + checkHealth(t, ctx, cc) } func TestMux_normalClientMuxIgnored(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + addr := serverWithHandshaker(t, handshakeFunc(func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { t.Error("never called") @@ -83,10 +90,13 @@ func TestMux_normalClientMuxIgnored(t *testing.T) { require.NoError(t, err) defer cc.Close() - checkHealth(t, cc) + checkHealth(t, ctx, cc) } func TestMux_muxClientPassesThrough(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + handshakerCalled := false addr := serverWithHandshaker(t, @@ -115,7 +125,7 @@ func TestMux_muxClientPassesThrough(t *testing.T) { require.NoError(t, err) defer cc.Close() - checkHealth(t, cc) + checkHealth(t, ctx, cc) require.True(t, handshakerCalled) } @@ -221,6 +231,9 @@ func TestMux_concurrency(t *testing.T) { streamClientErrors := make(chan error, N) grpcHealthErrors := make(chan error, N) + ctx, cancel := testhelper.Context() + defer cancel() + for i := 0; i < N; i++ { go func() { <-start @@ -274,7 +287,7 @@ func TestMux_concurrency(t *testing.T) { defer cc.Close() client := healthgrpc.NewHealthClient(cc) - _, err = client.Check(context.Background(), &healthgrpc.HealthCheckRequest{}) + _, err = client.Check(ctx, &healthgrpc.HealthCheckRequest{}) return err }() }() diff --git a/internal/log/log_test.go b/internal/log/log_test.go index 4c508e057..93f4b1325 100644 --- a/internal/log/log_test.go +++ b/internal/log/log_test.go @@ -27,7 +27,7 @@ import ( ) func TestPayloadBytes(t *testing.T) { - ctx := context.Background() + ctx := createContext() logger, hook := test.NewNullLogger() @@ -261,7 +261,7 @@ func TestConfigure(t *testing.T) { func TestMessageProducer(t *testing.T) { triggered := false MessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { - require.Equal(t, context.Background(), ctx) + require.Equal(t, createContext(), ctx) require.Equal(t, "format-stub", format) require.Equal(t, logrus.DebugLevel, level) require.Equal(t, codes.OutOfRange, code) @@ -272,20 +272,20 @@ func TestMessageProducer(t *testing.T) { return logrus.Fields{"a": 1} }, func(context.Context) logrus.Fields { return logrus.Fields{"b": "test"} - })(context.Background(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"}) + })(createContext(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"}) require.True(t, triggered) } func TestPropagationMessageProducer(t *testing.T) { t.Run("empty context", func(t *testing.T) { - ctx := context.Background() + ctx := createContext() mp := PropagationMessageProducer(func(context.Context, string, logrus.Level, codes.Code, error, logrus.Fields) {}) mp(ctx, "", logrus.DebugLevel, codes.OK, nil, nil) }) t.Run("context with holder", func(t *testing.T) { holder := new(messageProducerHolder) - ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, holder) + ctx := context.WithValue(createContext(), messageProducerHolderKey{}, holder) triggered := false mp := PropagationMessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { triggered = true @@ -313,7 +313,7 @@ func TestPerRPCLogHandler(t *testing.T) { } t.Run("check propagation", func(t *testing.T) { - ctx := context.Background() + ctx := createContext() ctx = lh.TagConn(ctx, &stats.ConnTagInfo{}) lh.HandleConn(ctx, &stats.ConnBegin{}) ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{}) @@ -334,7 +334,7 @@ func TestPerRPCLogHandler(t *testing.T) { }) t.Run("log handling", func(t *testing.T) { - ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New())) + ctx := ctxlogrus.ToContext(createContext(), logrus.NewEntry(logrus.New())) ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{}) mpp := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder) mpp.format = "message" @@ -381,7 +381,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { t.Run("propagates call", func(t *testing.T) { interceptor := UnaryLogDataCatcherServerInterceptor() - resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + resp, err := interceptor(createContext(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { return 42, assert.AnError }) @@ -391,7 +391,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { t.Run("no logger", func(t *testing.T) { mpp := &messageProducerHolder{} - ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp) interceptor := UnaryLogDataCatcherServerInterceptor() _, _ = interceptor(ctx, nil, nil, handlerStub) @@ -400,7 +400,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { t.Run("caught", func(t *testing.T) { mpp := &messageProducerHolder{} - ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp) ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1)) interceptor := UnaryLogDataCatcherServerInterceptor() _, _ = interceptor(ctx, nil, nil, handlerStub) @@ -411,7 +411,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { t.Run("propagates call", func(t *testing.T) { interceptor := StreamLogDataCatcherServerInterceptor() - ss := &grpcmw.WrappedServerStream{WrappedContext: context.Background()} + ss := &grpcmw.WrappedServerStream{WrappedContext: createContext()} err := interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return assert.AnError }) @@ -421,7 +421,7 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { t.Run("no logger", func(t *testing.T) { mpp := &messageProducerHolder{} - ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp) interceptor := StreamLogDataCatcherServerInterceptor() ss := &grpcmw.WrappedServerStream{WrappedContext: ctx} @@ -430,7 +430,7 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { t.Run("caught", func(t *testing.T) { mpp := &messageProducerHolder{} - ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp) + ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp) ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1)) interceptor := StreamLogDataCatcherServerInterceptor() @@ -439,3 +439,9 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields) }) } + +//nolint:forbidigo // We cannot use `testhelper.Context()` because of a cyclic dependency between +// this package and the `testhelper` package. +func createContext() context.Context { + return context.Background() +} diff --git a/internal/metadata/featureflag/context.go b/internal/metadata/featureflag/context.go index 16caa2837..9d43bd744 100644 --- a/internal/metadata/featureflag/context.go +++ b/internal/metadata/featureflag/context.go @@ -33,6 +33,7 @@ func outgoingCtxWithFeatureFlag(ctx context.Context, key string, enabled bool) c md = metadata.New(map[string]string{}) } + md = md.Copy() md.Set(key, strconv.FormatBool(enabled)) return metadata.NewOutgoingContext(ctx, md) @@ -56,6 +57,7 @@ func incomingCtxWithFeatureFlag(ctx context.Context, key string, enabled bool) c md = metadata.New(map[string]string{}) } + md = md.Copy() md.Set(key, strconv.FormatBool(enabled)) return metadata.NewIncomingContext(ctx, md) diff --git a/internal/metadata/featureflag/context_test.go b/internal/metadata/featureflag/context_test.go index 5a40bd834..998c0f259 100644 --- a/internal/metadata/featureflag/context_test.go +++ b/internal/metadata/featureflag/context_test.go @@ -5,53 +5,89 @@ import ( "testing" "github.com/stretchr/testify/require" + gitaly_metadata "gitlab.com/gitlab-org/gitaly/v14/internal/metadata" "google.golang.org/grpc/metadata" ) -var mockFeatureFlag = FeatureFlag{"turn meow on", false} +var ( + ffA = FeatureFlag{"feature-a", false} + ffB = FeatureFlag{"feature-b", false} +) + +//nolint:forbidigo // We cannot use `testhelper.Context()` given that it would inject feature flags +// already. +func createContext() context.Context { + return context.Background() +} func TestIncomingCtxWithFeatureFlag(t *testing.T) { - ctx := context.Background() - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + ctx := createContext() + require.False(t, ffA.IsEnabled(ctx)) + require.False(t, ffB.IsEnabled(ctx)) t.Run("enabled", func(t *testing.T) { - ctx := IncomingCtxWithFeatureFlag(ctx, mockFeatureFlag, true) - require.True(t, mockFeatureFlag.IsEnabled(ctx)) + ctx := IncomingCtxWithFeatureFlag(ctx, ffA, true) + require.True(t, ffA.IsEnabled(ctx)) }) t.Run("disabled", func(t *testing.T) { - ctx := IncomingCtxWithFeatureFlag(ctx, mockFeatureFlag, false) - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + ctx := IncomingCtxWithFeatureFlag(ctx, ffA, false) + require.False(t, ffA.IsEnabled(ctx)) + }) + + t.Run("set multiple flags", func(t *testing.T) { + ctxA := IncomingCtxWithFeatureFlag(ctx, ffA, true) + ctxB := IncomingCtxWithFeatureFlag(ctxA, ffB, true) + + require.True(t, ffA.IsEnabled(ctxA)) + require.False(t, ffB.IsEnabled(ctxA)) + + require.True(t, ffA.IsEnabled(ctxB)) + require.True(t, ffB.IsEnabled(ctxB)) }) } func TestOutgoingCtxWithFeatureFlag(t *testing.T) { - ctx := context.Background() - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + ctx := createContext() + require.False(t, ffA.IsEnabled(ctx)) + require.False(t, ffB.IsEnabled(ctx)) t.Run("enabled", func(t *testing.T) { - ctx := OutgoingCtxWithFeatureFlag(ctx, mockFeatureFlag, true) + ctx := OutgoingCtxWithFeatureFlag(ctx, ffA, true) // The feature flag is only checked for incoming contexts, so it's not expected to // be enabled yet. - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + require.False(t, ffA.IsEnabled(ctx)) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) // It should be enabled after converting it to an incoming context though. - ctx = metadata.NewIncomingContext(context.Background(), md) - require.True(t, mockFeatureFlag.IsEnabled(ctx)) + ctx = metadata.NewIncomingContext(createContext(), md) + require.True(t, ffA.IsEnabled(ctx)) }) t.Run("disabled", func(t *testing.T) { - ctx = OutgoingCtxWithFeatureFlag(ctx, mockFeatureFlag, false) - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + ctx = OutgoingCtxWithFeatureFlag(ctx, ffA, false) + require.False(t, ffA.IsEnabled(ctx)) md, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) - ctx = metadata.NewIncomingContext(context.Background(), md) - require.False(t, mockFeatureFlag.IsEnabled(ctx)) + ctx = metadata.NewIncomingContext(createContext(), md) + require.False(t, ffA.IsEnabled(ctx)) + }) + + t.Run("set multiple flags", func(t *testing.T) { + ctxA := OutgoingCtxWithFeatureFlag(ctx, ffA, true) + ctxB := OutgoingCtxWithFeatureFlag(ctxA, ffB, true) + + ctxA = gitaly_metadata.OutgoingToIncoming(ctxA) + require.True(t, ffA.IsEnabled(ctxA)) + require.False(t, ffB.IsEnabled(ctxA)) + + ctxB = gitaly_metadata.OutgoingToIncoming(ctxB) + require.True(t, ffA.IsEnabled(ctxB)) + require.True(t, ffB.IsEnabled(ctxB)) }) } @@ -76,7 +112,7 @@ func TestGRPCMetadataFeatureFlag(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { md := metadata.New(tc.headers) - ctx := metadata.NewIncomingContext(context.Background(), md) + ctx := metadata.NewIncomingContext(createContext(), md) require.Equal(t, tc.enabled, FeatureFlag{tc.flag, tc.onByDefault}.IsEnabled(ctx)) }) @@ -91,7 +127,7 @@ func TestAllEnabledFlags(t *testing.T) { ffPrefix + "bar": "TRUE", // not enabled } - ctx := metadata.NewIncomingContext(context.Background(), metadata.New(flags)) + ctx := metadata.NewIncomingContext(createContext(), metadata.New(flags)) require.ElementsMatch(t, AllFlags(ctx), []string{"meow:true", "foo:true", "woof:false", "bar:TRUE"}) } @@ -105,7 +141,7 @@ func TestRaw(t *testing.T) { } t.Run("RawFromContext", func(t *testing.T) { - ctx := context.Background() + ctx := createContext() ctx = IncomingCtxWithFeatureFlag(ctx, enabledFlag, true) ctx = IncomingCtxWithFeatureFlag(ctx, disabledFlag, false) @@ -113,7 +149,7 @@ func TestRaw(t *testing.T) { }) t.Run("OutgoingWithRaw", func(t *testing.T) { - outgoingMD, ok := metadata.FromOutgoingContext(OutgoingWithRaw(context.Background(), raw)) + outgoingMD, ok := metadata.FromOutgoingContext(OutgoingWithRaw(createContext(), raw)) require.True(t, ok) require.Equal(t, metadata.MD{ ffPrefix + enabledFlag.Name: {"true"}, diff --git a/internal/metadata/featureflag/featureflag_test.go b/internal/metadata/featureflag/featureflag_test.go index ef9dc52ca..8944e21fd 100644 --- a/internal/metadata/featureflag/featureflag_test.go +++ b/internal/metadata/featureflag/featureflag_test.go @@ -1,7 +1,6 @@ package featureflag import ( - "context" "testing" "github.com/stretchr/testify/require" @@ -74,7 +73,7 @@ func TestFeatureFlag_enabled(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - ctx := metadata.NewIncomingContext(context.Background(), metadata.New(tc.headers)) + ctx := metadata.NewIncomingContext(createContext(), metadata.New(tc.headers)) ff := FeatureFlag{tc.flag, tc.onByDefault} require.Equal(t, tc.enabled, ff.IsEnabled(ctx)) diff --git a/internal/metadata/metadata_test.go b/internal/metadata/metadata_test.go index 71413549e..e361b1c47 100644 --- a/internal/metadata/metadata_test.go +++ b/internal/metadata/metadata_test.go @@ -1,20 +1,23 @@ package metadata import ( - "context" + "errors" "testing" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/storage" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) func TestOutgoingToIncoming(t *testing.T) { - ctx := context.Background() + ctx, cancel := testhelper.Context() + defer cancel() + ctx, err := storage.InjectGitalyServers(ctx, "a", "b", "c") require.NoError(t, err) _, err = storage.ExtractGitalyServer(ctx, "a") - require.Equal(t, storage.ErrEmptyMetadata, err, + require.Equal(t, errors.New("empty gitaly-servers metadata"), err, "server should not be found in the incoming context") ctx = OutgoingToIncoming(ctx) diff --git a/internal/middleware/featureflag/featureflag_handler_test.go b/internal/middleware/featureflag/featureflag_handler_test.go index be0ccfb43..f3b68b017 100644 --- a/internal/middleware/featureflag/featureflag_handler_test.go +++ b/internal/middleware/featureflag/featureflag_handler_test.go @@ -71,6 +71,8 @@ func setup() (context.Context, *test.Hook) { } func setupContext() (context.Context, *test.Hook) { + //nolint:forbidigo // We don't want to set up the feature flags which `testhelper.Context()` + // would inject here. ctx := context.Background() logger := logrus.New() logger.SetOutput(io.Discard) diff --git a/internal/middleware/limithandler/concurrency_limiter_test.go b/internal/middleware/limithandler/concurrency_limiter_test.go index c5070fe39..bff19c688 100644 --- a/internal/middleware/limithandler/concurrency_limiter_test.go +++ b/internal/middleware/limithandler/concurrency_limiter_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) type counter struct { @@ -123,6 +124,9 @@ func TestLimiter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + expectedGaugeMax := tt.maxConcurrency * tt.buckets if tt.maxConcurrency <= 0 { expectedGaugeMax = tt.concurrency @@ -161,7 +165,7 @@ func TestLimiter(t *testing.T) { for i := 0; i < tt.iterations; i++ { lockKey := strconv.Itoa((i ^ counter) % tt.buckets) - _, err := limiter.Limit(context.Background(), lockKey, func() (interface{}, error) { + _, err := limiter.Limit(ctx, lockKey, func() (interface{}, error) { primePump() current := gauge.currentVal() diff --git a/internal/middleware/metadatahandler/metadatahandler_test.go b/internal/middleware/metadatahandler/metadatahandler_test.go index dc89037be..6aa12b64b 100644 --- a/internal/middleware/metadatahandler/metadatahandler_test.go +++ b/internal/middleware/metadatahandler/metadatahandler_test.go @@ -129,12 +129,15 @@ func verifyHandler(ctx context.Context, req interface{}) (interface{}, error) { } func TestGRPCTags(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + require := require.New(t) - ctx := metadata.NewIncomingContext( + ctx = metadata.NewIncomingContext( correlation.ContextWithCorrelation( correlation.ContextWithClientName( - context.Background(), + ctx, clientName, ), correlationID, diff --git a/internal/middleware/sentryhandler/sentryhandler_test.go b/internal/middleware/sentryhandler/sentryhandler_test.go index f2643e3a7..fcc75d6b0 100644 --- a/internal/middleware/sentryhandler/sentryhandler_test.go +++ b/internal/middleware/sentryhandler/sentryhandler_test.go @@ -83,7 +83,10 @@ func Test_generateSentryEvent(t *testing.T) { name: "marked to skip", ctx: func() context.Context { var result context.Context - ctx := context.Background() + + ctx, cancel := testhelper.Context() + t.Cleanup(cancel) + // this is the only way how we could populate context with `tags` assembler _, err := grpcmwtags.UnaryServerInterceptor()(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { result = ctx diff --git a/internal/praefect/coordinator_test.go b/internal/praefect/coordinator_test.go index 7da00dbb7..70dc062aa 100644 --- a/internal/praefect/coordinator_test.go +++ b/internal/praefect/coordinator_test.go @@ -2090,9 +2090,12 @@ func TestNewRequestFinalizer_contextIsDisjointedFromTheRPC(t *testing.T) { parentDeadline := time.Now() + ctx, cancel := testhelper.Context() + defer cancel() + //nolint:forbidigo // We explicitly want to test that the deadline does not propagate into // the request's context. - ctx, cancel := context.WithDeadline(context.WithValue(context.Background(), ctxKey{}, "value"), parentDeadline) + ctx, cancel = context.WithDeadline(context.WithValue(ctx, ctxKey{}, "value"), parentDeadline) cancel() requireSuppressedCancellation := func(t testing.TB, ctx context.Context) { @@ -2210,6 +2213,9 @@ func TestStreamParametersContext(t *testing.T) { return metadata.Pairs(pairs...) } + //nolint:forbidigo // We explicitly test context values, so we cannot use the testhelper + // context here given that it would contain unrelated data and thus change the system under + // test. for _, tc := range []struct { desc string setupContext func() context.Context diff --git a/internal/praefect/repocleaner/action_log_test.go b/internal/praefect/repocleaner/action_log_test.go index b8073d91f..1afada69c 100644 --- a/internal/praefect/repocleaner/action_log_test.go +++ b/internal/praefect/repocleaner/action_log_test.go @@ -1,18 +1,21 @@ package repocleaner import ( - "context" "testing" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) func TestLogWarnAction_Perform(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + logger, hook := test.NewNullLogger() action := NewLogWarnAction(logger) - err := action.Perform(context.TODO(), "vs1", "g1", []string{"p/1", "p/2"}) + err := action.Perform(ctx, "vs1", "g1", []string{"p/1", "p/2"}) require.NoError(t, err) require.Len(t, hook.AllEntries(), 2) diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index 285765b51..e5c153f1a 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -12,12 +12,16 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) func TestSidechannel(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + const blobSize = 1024 * 1024 in := make([]byte, blobSize) @@ -45,7 +49,7 @@ func TestSidechannel(t *testing.T) { conn, registry := dial(t, addr) err = call( - context.Background(), conn, registry, + ctx, conn, registry, func(conn *ClientConn) error { errC := make(chan error, 1) go func() { @@ -68,6 +72,9 @@ func TestSidechannel(t *testing.T) { // Conduct multiple requests with sidechannel included on the same grpc // connection. func TestSidechannelConcurrency(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + const concurrency = 10 const blobSize = 1024 * 1024 @@ -111,7 +118,7 @@ func TestSidechannelConcurrency(t *testing.T) { defer wg.Done() err := call( - context.Background(), conn, registry, + ctx, conn, registry, func(conn *ClientConn) error { errC := make(chan error, 1) go func() { diff --git a/internal/streamcache/cache_test.go b/internal/streamcache/cache_test.go index da8b5e157..1bb9009ea 100644 --- a/internal/streamcache/cache_test.go +++ b/internal/streamcache/cache_test.go @@ -29,6 +29,9 @@ func newCache(dir string) Cache { } func TestCache_writeOneReadMultiple(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) c := newCache(tmp) @@ -50,7 +53,7 @@ func TestCache_writeOneReadMultiple(t *testing.T) { out, err := io.ReadAll(r) require.NoError(t, err) - require.NoError(t, r.Wait(context.Background())) + require.NoError(t, r.Wait(ctx)) require.Equal(t, content(0), string(out), "expect cache hits for all i > 0") }) } @@ -59,6 +62,9 @@ func TestCache_writeOneReadMultiple(t *testing.T) { } func TestCache_manyConcurrentWrites(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) c := newCache(tmp) @@ -94,7 +100,7 @@ func TestCache_manyConcurrentWrites(t *testing.T) { } output[i] = string(out) - return r.Wait(context.Background()) + return r.Wait(ctx) }() }(i) } @@ -175,6 +181,9 @@ func TestCache_deletedFile(t *testing.T) { } func TestCache_scope(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) const ( @@ -215,13 +224,16 @@ func TestCache_scope(t *testing.T) { out, err := io.ReadAll(r) require.NoError(t, err) - require.NoError(t, r.Wait(context.Background())) + require.NoError(t, r.Wait(ctx)) require.Equal(t, content, string(out)) } } func TestCache_diskCleanup(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) const ( @@ -254,7 +266,7 @@ func TestCache_diskCleanup(t *testing.T) { out1, err := io.ReadAll(r1) require.NoError(t, err) require.Equal(t, content(1), string(out1)) - require.NoError(t, r1.Wait(context.Background())) + require.NoError(t, r1.Wait(ctx)) // File and index entry should still exist because cleanup goroutines are blocked. requireCacheFiles(t, tmp, 1) @@ -292,13 +304,16 @@ func TestCache_diskCleanup(t *testing.T) { out2, err := io.ReadAll(r2) require.NoError(t, err) - require.NoError(t, r2.Wait(context.Background())) + require.NoError(t, r2.Wait(ctx)) // Sanity check: no stale value returned by the cache require.Equal(t, content(2), string(out2)) } func TestCache_failedWrite(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) c := newCache(tmp) @@ -327,7 +342,7 @@ func TestCache_failedWrite(t *testing.T) { _, err = io.Copy(io.Discard, r1) require.NoError(t, err, "errors on the write end are not propagated via Read()") require.NoError(t, r1.Close(), "errors on the write end are not propagated via Close()") - require.Error(t, r1.Wait(context.Background()), "error propagation happens via Wait()") + require.Error(t, r1.Wait(ctx), "error propagation happens via Wait()") time.Sleep(10 * time.Millisecond) @@ -339,7 +354,7 @@ func TestCache_failedWrite(t *testing.T) { out, err := io.ReadAll(r2) require.NoError(t, err) - require.NoError(t, r2.Wait(context.Background())) + require.NoError(t, r2.Wait(ctx)) require.Equal(t, happy, string(out)) }) } @@ -359,6 +374,9 @@ func TestCache_failCreateFile(t *testing.T) { } func TestCache_unWriteableFile(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) c := newCache(tmp) @@ -378,12 +396,15 @@ func TestCache_unWriteableFile(t *testing.T) { _, err = io.ReadAll(r) require.NoError(t, err) - err = r.Wait(context.Background()) + err = r.Wait(ctx) require.IsType(t, &os.PathError{}, err) require.Equal(t, "write", err.(*os.PathError).Op) } func TestCache_unCloseableFile(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + tmp := testhelper.TempDir(t) c := newCache(tmp) @@ -404,7 +425,7 @@ func TestCache_unCloseableFile(t *testing.T) { _, err = io.ReadAll(r) require.NoError(t, err) - err = r.Wait(context.Background()) + err = r.Wait(ctx) require.IsType(t, &os.PathError{}, err) require.Equal(t, "close", err.(*os.PathError).Op) } @@ -430,16 +451,21 @@ func TestCache_cannotOpenFileForReading(t *testing.T) { } func TestWaiter(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + w := newWaiter() err := errors.New("test error") w.SetError(err) - require.Equal(t, err, w.Wait(context.Background())) + require.Equal(t, err, w.Wait(ctx)) } func TestWaiter_cancel(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + w := newWaiter() errc := make(chan error, 1) - ctx, cancel := context.WithCancel(context.Background()) go func() { errc <- w.Wait(ctx) }() cancel() @@ -447,6 +473,9 @@ func TestWaiter_cancel(t *testing.T) { } func TestNullCache(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + const ( N = 1000 inputSize = 4096 @@ -500,7 +529,7 @@ func TestNullCache(t *testing.T) { return errors.New("output does not match input") } - return s.Wait(context.Background()) + return s.Wait(ctx) }() }() } diff --git a/internal/testhelper/featureset_test.go b/internal/testhelper/featureset_test.go index c6a4e75a8..fdbbba5ed 100644 --- a/internal/testhelper/featureset_test.go +++ b/internal/testhelper/featureset_test.go @@ -5,9 +5,8 @@ import ( "testing" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitaly/v14/internal/metadata" ff "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag" - grpc_metadata "google.golang.org/grpc/metadata" + "google.golang.org/grpc/metadata" ) var ( @@ -172,9 +171,6 @@ func TestNewFeatureSetsWithRubyFlags(t *testing.T) { } func TestFeatureSets_Run(t *testing.T) { - var incomingFlags [][2]bool - var outgoingFlags [][2]bool - // This test depends on feature flags being default-enabled in the test // context, which requires those flags to exist in the ff.All slice. So // let's just append them here so we do not need to use a "real" @@ -185,40 +181,25 @@ func TestFeatureSets_Run(t *testing.T) { }(ff.All) ff.All = append(ff.All, featureFlagA, featureFlagB) + var featureFlags [][2]bool NewFeatureSets(featureFlagB, featureFlagA).Run(t, func(t *testing.T, ctx context.Context) { - incomingMD, ok := grpc_metadata.FromIncomingContext(ctx) + incomingMD, ok := metadata.FromIncomingContext(ctx) require.True(t, ok) - outgoingMD, ok := grpc_metadata.FromOutgoingContext(ctx) + outgoingMD, ok := metadata.FromOutgoingContext(ctx) require.True(t, ok) - incomingCtx := grpc_metadata.NewIncomingContext(context.Background(), incomingMD) - outgoingCtx := metadata.OutgoingToIncoming(grpc_metadata.NewOutgoingContext(context.Background(), outgoingMD)) + require.Equal(t, incomingMD, outgoingMD) - incomingFlags = append(incomingFlags, [2]bool{ - featureFlagB.IsDisabled(incomingCtx), - featureFlagA.IsDisabled(incomingCtx), - }) - outgoingFlags = append(outgoingFlags, [2]bool{ - featureFlagB.IsDisabled(outgoingCtx), - featureFlagA.IsDisabled(outgoingCtx), + featureFlags = append(featureFlags, [2]bool{ + featureFlagA.IsEnabled(ctx), featureFlagB.IsEnabled(ctx), }) }) - for _, tc := range []struct { - desc string - flags [][2]bool - }{ - {desc: "incoming context", flags: incomingFlags}, - {desc: "outgoing context", flags: outgoingFlags}, - } { - t.Run(tc.desc, func(t *testing.T) { - require.ElementsMatch(t, tc.flags, [][2]bool{ - {false, false}, - {true, false}, - {false, true}, - {true, true}, - }) - }) - } + require.Equal(t, [][2]bool{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + }, featureFlags) } diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index b5b008ff8..b98a7f4d8 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -157,18 +157,23 @@ func GetLocalhostListener(t testing.TB) (net.Listener, string) { } // ContextOpt returns a new context instance with the new additions to it. -type ContextOpt func(context.Context) (context.Context, func()) +type ContextOpt func(context.Context) context.Context // ContextWithLogger allows to inject provided logger into the context. func ContextWithLogger(logger *log.Entry) ContextOpt { - return func(ctx context.Context) (context.Context, func()) { - return ctxlogrus.ToContext(ctx, logger), func() {} + return func(ctx context.Context) context.Context { + return ctxlogrus.ToContext(ctx, logger) } } // Context returns a cancellable context. func Context(opts ...ContextOpt) (context.Context, func()) { - ctx, cancel := context.WithCancel(context.Background()) + return context.WithCancel(ContextWithoutCancel(opts...)) +} + +// ContextWithoutCancel returns a non-cancellable context. +func ContextWithoutCancel(opts ...ContextOpt) context.Context { + ctx := context.Background() // Enable use of explicit feature flags. Each feature flag which is checked must have been // explicitly injected into the context, or otherwise we panic. This is a sanity check to @@ -180,18 +185,11 @@ func Context(opts ...ContextOpt) (context.Context, func()) { // context. ctx = featureflag.ContextWithFeatureFlags(ctx, featureflag.RunCommandsInCGroup) - cancels := make([]func(), len(opts)+1) - cancels[0] = cancel - for i, opt := range opts { - ctx, cancel = opt(ctx) - cancels[i+1] = cancel + for _, opt := range opts { + ctx = opt(ctx) } - return ctx, func() { - for i := len(cancels) - 1; i >= 0; i-- { - cancels[i]() - } - } + return ctx } // CreateGlobalDirectory creates a directory in the test directory that is shared across all |