Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavlo Strokov <pstrokov@gitlab.com>2021-12-14 11:51:58 +0300
committerPavlo Strokov <pstrokov@gitlab.com>2021-12-14 11:51:58 +0300
commit0e0aeb5ca4488903f41acd392911bd89d1ad3d6d (patch)
tree3a5dd5ffa5521f4b16f6f332d4aac53323ea0040
parent125f0fc0e49db4dd46f2e905f34178ce880dd79e (diff)
parentf7071f97997fb7734cf6903285b007336186b22a (diff)
Merge branch 'pks-linting-disallow-standard-context' into 'master'
lint: Disallow use of "normal" contexts See merge request gitlab-org/gitaly!4190
-rw-r--r--.golangci.yml5
-rw-r--r--auth/extract_test.go7
-rw-r--r--cmd/gitaly-hooks/hooks_test.go31
-rw-r--r--internal/backchannel/backchannel_test.go4
-rw-r--r--internal/cache/diskcache_test.go196
-rw-r--r--internal/cgroups/v1_linux_test.go3
-rw-r--r--internal/command/command_test.go3
-rw-r--r--internal/git/catfile/cache_test.go5
-rw-r--r--internal/git/objectpool/clone_test.go4
-rw-r--r--internal/git/objectpool/fetch_test.go12
-rw-r--r--internal/git/objectpool/link_test.go8
-rw-r--r--internal/git/objectpool/pool_test.go13
-rw-r--r--internal/git/objectpool/testhelper_test.go4
-rw-r--r--internal/gitaly/hook/sidechannel_test.go12
-rw-r--r--internal/gitaly/maintenance/daily_test.go5
-rw-r--r--internal/gitaly/storage/servers_test.go4
-rw-r--r--internal/gitlab/http_client_test.go27
-rw-r--r--internal/listenmux/mux_test.go25
-rw-r--r--internal/log/log_test.go32
-rw-r--r--internal/metadata/featureflag/context.go2
-rw-r--r--internal/metadata/featureflag/context_test.go78
-rw-r--r--internal/metadata/featureflag/featureflag_test.go3
-rw-r--r--internal/metadata/metadata_test.go9
-rw-r--r--internal/middleware/featureflag/featureflag_handler_test.go2
-rw-r--r--internal/middleware/limithandler/concurrency_limiter_test.go6
-rw-r--r--internal/middleware/metadatahandler/metadatahandler_test.go7
-rw-r--r--internal/middleware/sentryhandler/sentryhandler_test.go5
-rw-r--r--internal/praefect/coordinator_test.go8
-rw-r--r--internal/praefect/repocleaner/action_log_test.go7
-rw-r--r--internal/sidechannel/sidechannel_test.go11
-rw-r--r--internal/streamcache/cache_test.go53
-rw-r--r--internal/testhelper/featureset_test.go45
-rw-r--r--internal/testhelper/testhelper.go26
33 files changed, 423 insertions, 239 deletions
diff --git a/.golangci.yml b/.golangci.yml
index dce1c42d7..b1359a125 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -53,6 +53,11 @@ linters-settings:
# following functions is thus disallowed. and a code smell.
- ^context.WithDeadline$
- ^context.WithTimeout$
+ # Tests should always use `testhelper.Context()`: this context has
+ # special handling for feature flags which allows us to assert that
+ # they're tested as expected.
+ - ^context.Background$
+ - ^context.TODO$
stylecheck:
# ST1000 checks for missing package comments. We don't use these for most
# packages, so let's disable this check.
diff --git a/auth/extract_test.go b/auth/extract_test.go
index cf70aa8eb..7aa14ca7e 100644
--- a/auth/extract_test.go
+++ b/auth/extract_test.go
@@ -1,12 +1,12 @@
package gitalyauth
import (
- "context"
"testing"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
func TestCheckTokenV2(t *testing.T) {
@@ -77,9 +77,12 @@ func TestCheckTokenV2(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
md := metautils.NiceMD{}
md.Set("authorization", "Bearer "+tc.token)
- result := CheckToken(md.ToIncoming(context.Background()), string(secret), targetTime)
+ result := CheckToken(md.ToIncoming(ctx), string(secret), targetTime)
require.Equal(t, tc.result, result)
})
diff --git a/cmd/gitaly-hooks/hooks_test.go b/cmd/gitaly-hooks/hooks_test.go
index 856730900..399e6e538 100644
--- a/cmd/gitaly-hooks/hooks_test.go
+++ b/cmd/gitaly-hooks/hooks_test.go
@@ -122,6 +122,9 @@ func TestHooksPrePostReceive(t *testing.T) {
}
func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Repository, repoPath string) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
secretToken := "secret token"
glID := "key-1234"
glUsername := "iamgitlab"
@@ -184,7 +187,7 @@ func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Reposi
cmd.Stdin = stdin
cmd.Env = envForHooks(
t,
- context.Background(),
+ ctx,
cfg,
repo,
glHookValues{
@@ -242,6 +245,9 @@ func testHooksPrePostReceive(t *testing.T, cfg config.Cfg, repo *gitalypb.Reposi
}
func TestHooksUpdate(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
glID := "key-1234"
glUsername := "iamgitlab"
glProtocol := "ssh"
@@ -261,21 +267,21 @@ func TestHooksUpdate(t *testing.T) {
runHookServiceServer(t, cfg)
- testHooksUpdate(t, cfg, glHookValues{
+ testHooksUpdate(t, ctx, cfg, glHookValues{
GLID: glID,
GLUsername: glUsername,
GLProtocol: glProtocol,
})
}
-func testHooksUpdate(t *testing.T, cfg config.Cfg, glValues glHookValues) {
+func testHooksUpdate(t *testing.T, ctx context.Context, cfg config.Cfg, glValues glHookValues) {
repo, repoPath := gittest.CloneRepo(t, cfg, cfg.Storages[0])
refval, oldval, newval := "refval", strings.Repeat("a", 40), strings.Repeat("b", 40)
updateHookPath, err := filepath.Abs("../../ruby/git-hooks/update")
require.NoError(t, err)
cmd := exec.Command(updateHookPath, refval, oldval, newval)
- cmd.Env = envForHooks(t, context.Background(), cfg, repo, glValues, proxyValues{})
+ cmd.Env = envForHooks(t, ctx, cfg, repo, glValues, proxyValues{})
cmd.Dir = repoPath
tempDir := testhelper.TempDir(t)
@@ -395,6 +401,9 @@ func TestHooksPostReceiveFailed(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.desc, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
hooksPayload, err := git.NewHooksPayload(
cfg,
repo,
@@ -409,11 +418,11 @@ func TestHooksPostReceiveFailed(t *testing.T) {
Protocol: glProtocol,
},
git.PostReceiveHook,
- rawFeatureFlags(context.Background()),
+ rawFeatureFlags(ctx),
).Env()
require.NoError(t, err)
- env := envForHooks(t, context.Background(), cfg, repo, glHookValues{}, proxyValues{})
+ env := envForHooks(t, ctx, cfg, repo, glHookValues{}, proxyValues{})
env = append(env, hooksPayload)
cmd := exec.Command(postReceiveHookPath)
@@ -466,13 +475,16 @@ func TestHooksNotAllowed(t *testing.T) {
var stderr, stdout bytes.Buffer
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
preReceiveHookPath, err := filepath.Abs("../../ruby/git-hooks/pre-receive")
require.NoError(t, err)
cmd := exec.Command(preReceiveHookPath)
cmd.Stderr = &stderr
cmd.Stdout = &stdout
cmd.Stdin = strings.NewReader(changes)
- cmd.Env = envForHooks(t, context.Background(), cfg, repo,
+ cmd.Env = envForHooks(t, ctx, cfg, repo,
glHookValues{
GLID: glID,
GLUsername: glUsername,
@@ -660,6 +672,9 @@ func TestGitalyHooksPackObjects(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
hook.Reset()
tempDir := testhelper.TempDir(t)
@@ -668,7 +683,7 @@ func TestGitalyHooksPackObjects(t *testing.T) {
args = append(args, repoPath, tempDir)
gittest.ExecOpts(t, cfg, gittest.ExecConfig{
- Env: (envForHooks(t, context.Background(), cfg, repo, glHookValues{}, proxyValues{})),
+ Env: (envForHooks(t, ctx, cfg, repo, glHookValues{}, proxyValues{})),
}, args...)
})
}
diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go
index 474aa86d2..7e001bf22 100644
--- a/internal/backchannel/backchannel_test.go
+++ b/internal/backchannel/backchannel_test.go
@@ -77,7 +77,7 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
defer srv.Stop()
go srv.Serve(ln)
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := testhelper.Context()
defer cancel()
start := make(chan struct{})
@@ -208,7 +208,7 @@ func Benchmark(b *testing.B) {
defer srv.Stop()
go srv.Serve(ln)
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := testhelper.Context()
defer cancel()
opts := []grpc.DialOption{grpc.WithBlock(), grpc.WithInsecure()}
diff --git a/internal/cache/diskcache_test.go b/internal/cache/diskcache_test.go
index 497b89dee..d52dbbc7f 100644
--- a/internal/cache/diskcache_test.go
+++ b/internal/cache/diskcache_test.go
@@ -18,33 +18,12 @@ import (
)
func TestStreamDBNaiveKeyer(t *testing.T) {
- cfg := testcfg.Build(t)
-
- testRepo1, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0])
- testRepo2, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0])
-
- locator := config.NewLocator(cfg)
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack")
-
- cache := New(cfg, locator)
-
- req1 := &gitalypb.InfoRefsRequest{
- Repository: testRepo1,
- }
- req2 := &gitalypb.InfoRefsRequest{
- Repository: testRepo2,
- }
-
- expectGetMiss := func(req *gitalypb.InfoRefsRequest) {
+ expectGetMiss := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest) {
_, err := cache.GetStream(ctx, req.Repository, req)
require.Equal(t, ErrReqNotFound, err)
}
- expectGetHit := func(expectStr string, req *gitalypb.InfoRefsRequest) {
+ expectGetHit := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest, expectStr string) {
actualStream, err := cache.GetStream(ctx, req.Repository, req)
require.NoError(t, err)
actualBytes, err := io.ReadAll(actualStream)
@@ -52,70 +31,126 @@ func TestStreamDBNaiveKeyer(t *testing.T) {
require.Equal(t, expectStr, string(actualBytes))
}
- invalidationEvent := func(repo *gitalypb.Repository) {
+ invalidationEvent := func(ctx context.Context, cache Cache, repo *gitalypb.Repository) {
lease, err := cache.StartLease(repo)
require.NoError(t, err)
// imagine repo being modified here
require.NoError(t, lease.EndLease(ctx))
}
- storeAndRetrieve := func(expectStr string, req *gitalypb.InfoRefsRequest) {
+ storeAndRetrieve := func(ctx context.Context, cache Cache, req *gitalypb.InfoRefsRequest, expectStr string) {
require.NoError(t, cache.PutStream(ctx, req.Repository, req, strings.NewReader(expectStr)))
- expectGetHit(expectStr, req)
+ expectGetHit(ctx, cache, req, expectStr)
+ }
+
+ cfg := testcfg.Build(t)
+
+ repo1, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0])
+ repo2, _ := gittest.CloneRepo(t, cfg, cfg.Storages[0])
+
+ locator := config.NewLocator(cfg)
+
+ req1 := &gitalypb.InfoRefsRequest{
+ Repository: repo1,
+ }
+ req2 := &gitalypb.InfoRefsRequest{
+ Repository: repo2,
}
- // cache is initially empty
- expectGetMiss(req1)
- expectGetMiss(req2)
-
- // populate cache
- repo1contents := "store and retrieve value in repo 1"
- storeAndRetrieve(repo1contents, req1)
- repo2contents := "store and retrieve value in repo 2"
- storeAndRetrieve(repo2contents, req2)
-
- // invalidation makes previous value stale and unreachable
- invalidationEvent(req1.Repository)
- expectGetMiss(req1)
- expectGetHit(repo2contents, req2) // repo1 invalidation doesn't affect repo2
-
- // store new value for same cache value but at new generation
- expectStream2 := "not what you were looking for"
- require.NoError(t, cache.PutStream(ctx, req1.Repository, req1, strings.NewReader(expectStream2)))
- expectGetHit(expectStream2, req1)
-
- // enabled feature flags affect caching
- oldCtx := ctx
- ctx = featureflag.IncomingCtxWithFeatureFlag(ctx, featureflag.FeatureFlag{Name: "meow", OnByDefault: false}, true)
- expectGetMiss(req1)
- ctx = oldCtx
- expectGetHit(expectStream2, req1)
-
- // start critical section without closing
- repo1Lease, err := cache.StartLease(req1.Repository)
- require.NoError(t, err)
-
- // accessing repo cache with open critical section should fail
- _, err = cache.GetStream(ctx, req1.Repository, req1)
- require.Equal(t, err, ErrPendingExists)
- err = cache.PutStream(ctx, req1.Repository, req1, strings.NewReader(repo1contents))
- require.Equal(t, err, ErrPendingExists)
-
- expectGetHit(repo2contents, req2) // other repo caches should be unaffected
-
- // opening and closing a new critical zone doesn't resolve the issue
- invalidationEvent(req1.Repository)
- _, err = cache.GetStream(ctx, req1.Repository, req1)
- require.Equal(t, err, ErrPendingExists)
-
- // only completing/removing the pending generation file will allow access
- require.NoError(t, repo1Lease.EndLease(ctx))
- expectGetMiss(req1)
-
- // creating a lease on a repo that doesn't exist yet should succeed
- req1.Repository.RelativePath += "-does-not-exist"
- _, err = cache.StartLease(req1.Repository)
- require.NoError(t, err)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+ ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack")
+
+ t.Run("empty cache", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ expectGetMiss(ctx, cache, req1)
+ expectGetMiss(ctx, cache, req2)
+ })
+
+ t.Run("store and retrieve", func(t *testing.T) {
+ cache := New(cfg, locator)
+ storeAndRetrieve(ctx, cache, req1, "content-1")
+ storeAndRetrieve(ctx, cache, req2, "content-2")
+ })
+
+ t.Run("invalidation", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ storeAndRetrieve(ctx, cache, req1, "content-1")
+ storeAndRetrieve(ctx, cache, req2, "content-2")
+
+ invalidationEvent(ctx, cache, req1.Repository)
+
+ expectGetMiss(ctx, cache, req1)
+ expectGetHit(ctx, cache, req2, "content-2")
+ })
+
+ t.Run("overwrite existing entry", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ storeAndRetrieve(ctx, cache, req1, "content-1")
+
+ require.NoError(t, cache.PutStream(ctx, req1.Repository, req1, strings.NewReader("not what you were looking for")))
+ expectGetHit(ctx, cache, req1, "not what you were looking for")
+ })
+
+ t.Run("feature flags affect caching", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ ctxWithFF := featureflag.IncomingCtxWithFeatureFlag(ctx, featureflag.FeatureFlag{
+ Name: "meow",
+ }, true)
+
+ storeAndRetrieve(ctx, cache, req1, "default")
+ expectGetHit(ctx, cache, req1, "default")
+ expectGetMiss(ctxWithFF, cache, req1)
+
+ storeAndRetrieve(ctxWithFF, cache, req1, "flagged")
+ expectGetHit(ctxWithFF, cache, req1, "flagged")
+ expectGetHit(ctx, cache, req1, "default")
+ })
+
+ t.Run("critical section", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ storeAndRetrieve(ctx, cache, req2, "unrelated")
+
+ // Start critical section without closing it.
+ repo1Lease, err := cache.StartLease(req1.Repository)
+ require.NoError(t, err)
+
+ // Accessing repo cache with open critical section should fail.
+ _, err = cache.GetStream(ctx, req1.Repository, req1)
+ require.Equal(t, err, ErrPendingExists)
+ err = cache.PutStream(ctx, req1.Repository, req1, strings.NewReader("conflict"))
+ require.Equal(t, err, ErrPendingExists)
+
+ // Other repo caches should be unaffected.
+ expectGetHit(ctx, cache, req2, "unrelated")
+
+ // Opening and closing a new critical section doesn't resolve the issue.
+ invalidationEvent(ctx, cache, req1.Repository)
+ _, err = cache.GetStream(ctx, req1.Repository, req1)
+ require.Equal(t, err, ErrPendingExists)
+
+ // Only completing/removing the pending generation file will allow access.
+ require.NoError(t, repo1Lease.EndLease(ctx))
+ expectGetMiss(ctx, cache, req1)
+ })
+
+ t.Run("nonexisteng repository", func(t *testing.T) {
+ cache := New(cfg, locator)
+
+ nonexistentRepo := &gitalypb.Repository{
+ StorageName: repo1.StorageName,
+ RelativePath: "does-not-exist",
+ }
+
+ // Creating a lease on a repo that doesn't exist yet should succeed.
+ _, err := cache.StartLease(nonexistentRepo)
+ require.NoError(t, err)
+ })
}
func TestLoserCount(t *testing.T) {
@@ -133,7 +168,10 @@ func TestLoserCount(t *testing.T) {
StorageName: "storage-1",
},
}
- ctx := testhelper.SetCtxGrpcMethod(context.Background(), "InfoRefsUploadPack")
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+ ctx = testhelper.SetCtxGrpcMethod(ctx, "InfoRefsUploadPack")
leashes := []chan struct{}{make(chan struct{}), make(chan struct{}), make(chan struct{})}
errQ := make(chan error)
diff --git a/internal/cgroups/v1_linux_test.go b/internal/cgroups/v1_linux_test.go
index 252aa9287..0cb522c04 100644
--- a/internal/cgroups/v1_linux_test.go
+++ b/internal/cgroups/v1_linux_test.go
@@ -2,7 +2,6 @@ package cgroups
import (
"bytes"
- "context"
"fmt"
"hash/crc32"
"os"
@@ -72,7 +71,7 @@ func TestAddCommand(t *testing.T) {
}
require.NoError(t, v1Manager1.Setup())
- ctx, cancel := context.WithCancel(context.Background())
+ ctx, cancel := testhelper.Context()
defer cancel()
cmd1 := exec.Command("ls", "-hal", ".")
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index d3d71f1a2..72be22895 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -2,7 +2,6 @@ package command
import (
"bytes"
- "context"
"fmt"
"io"
"os/exec"
@@ -161,7 +160,7 @@ func TestRejectEmptyContextDone(t *testing.T) {
}
}()
- _, err := New(context.Background(), exec.Command("true"), nil, nil, nil)
+ _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true"), nil, nil, nil)
require.NoError(t, err)
}
diff --git a/internal/git/catfile/cache_test.go b/internal/git/catfile/cache_test.go
index b2e3cedfb..0344aec83 100644
--- a/internal/git/catfile/cache_test.go
+++ b/internal/git/catfile/cache_test.go
@@ -1,7 +1,6 @@
package catfile
import (
- "context"
"errors"
"io"
"os"
@@ -186,7 +185,7 @@ func TestCache_ObjectReader(t *testing.T) {
cache.cachedProcessDone = sync.NewCond(&sync.Mutex{})
t.Run("uncancellable", func(t *testing.T) {
- ctx := context.Background()
+ ctx := testhelper.ContextWithoutCancel()
require.PanicsWithValue(t, "empty ctx.Done() in catfile.Batch.New()", func() {
_, _ = cache.ObjectReader(ctx, repoExecutor)
@@ -322,7 +321,7 @@ func TestCache_ObjectInfoReader(t *testing.T) {
cache.cachedProcessDone = sync.NewCond(&sync.Mutex{})
t.Run("uncancellable", func(t *testing.T) {
- ctx := context.Background()
+ ctx := testhelper.ContextWithoutCancel()
require.PanicsWithValue(t, "empty ctx.Done() in catfile.Batch.New()", func() {
_, _ = cache.ObjectInfoReader(ctx, repoExecutor)
diff --git a/internal/git/objectpool/clone_test.go b/internal/git/objectpool/clone_test.go
index a705c5832..3bf1e7584 100644
--- a/internal/git/objectpool/clone_test.go
+++ b/internal/git/objectpool/clone_test.go
@@ -12,7 +12,7 @@ func TestClone(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.clone(ctx, testRepo))
defer func() {
@@ -27,7 +27,7 @@ func TestCloneExistingPool(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.clone(ctx, testRepo))
defer func() {
diff --git a/internal/git/objectpool/fetch_test.go b/internal/git/objectpool/fetch_test.go
index eaab4dcac..8141dbdee 100644
--- a/internal/git/objectpool/fetch_test.go
+++ b/internal/git/objectpool/fetch_test.go
@@ -19,7 +19,7 @@ func TestFetchFromOriginDangling(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool")
@@ -86,7 +86,7 @@ func TestFetchFromOriginFsck(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, repo := setupObjectPool(t)
+ pool, repo := setupObjectPool(t, ctx)
repoPath := filepath.Join(pool.cfg.Storages[0].Path, repo.RelativePath)
require.NoError(t, pool.FetchFromOrigin(ctx, repo), "seed pool")
@@ -110,7 +110,7 @@ func TestFetchFromOriginDeltaIslands(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool")
@@ -133,7 +133,7 @@ func TestFetchFromOriginBitmapHashCache(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.FetchFromOrigin(ctx, testRepo), "seed pool")
@@ -158,7 +158,7 @@ func TestFetchFromOriginRefUpdates(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
poolPath := pool.FullPath()
@@ -203,7 +203,7 @@ func TestFetchFromOrigin_refs(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, _ := setupObjectPool(t)
+ pool, _ := setupObjectPool(t, ctx)
poolPath := pool.FullPath()
// Init the source repo with a bunch of refs.
diff --git a/internal/git/objectpool/link_test.go b/internal/git/objectpool/link_test.go
index c960f48db..40d07f68c 100644
--- a/internal/git/objectpool/link_test.go
+++ b/internal/git/objectpool/link_test.go
@@ -19,7 +19,7 @@ func TestLink(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.Remove(ctx), "make sure pool does not exist prior to creation")
require.NoError(t, pool.Create(ctx, testRepo), "create pool")
@@ -49,7 +49,7 @@ func TestLink_transactional(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, poolMember := setupObjectPool(t)
+ pool, poolMember := setupObjectPool(t, ctx)
require.NoError(t, pool.Create(ctx, poolMember))
txManager := transaction.NewTrackingManager()
@@ -74,7 +74,7 @@ func TestLinkRemoveBitmap(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.Init(ctx))
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
@@ -119,7 +119,7 @@ func TestLinkAbsoluteLinkExists(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
diff --git a/internal/git/objectpool/pool_test.go b/internal/git/objectpool/pool_test.go
index 94756ee0c..c3e57217d 100644
--- a/internal/git/objectpool/pool_test.go
+++ b/internal/git/objectpool/pool_test.go
@@ -30,7 +30,7 @@ func TestNewFromRepoSuccess(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
require.NoError(t, pool.Create(ctx, testRepo))
require.NoError(t, pool.Link(ctx, testRepo))
@@ -42,7 +42,10 @@ func TestNewFromRepoSuccess(t *testing.T) {
}
func TestNewFromRepoNoObjectPool(t *testing.T) {
- pool, testRepo := setupObjectPool(t)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ pool, testRepo := setupObjectPool(t, ctx)
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
@@ -93,7 +96,7 @@ func TestCreate(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
testRepoPath := filepath.Join(pool.cfg.Storages[0].Path, testRepo.RelativePath)
@@ -127,7 +130,7 @@ func TestCreateSubDirsExist(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
err := pool.Create(ctx, testRepo)
require.NoError(t, err)
@@ -143,7 +146,7 @@ func TestRemove(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- pool, testRepo := setupObjectPool(t)
+ pool, testRepo := setupObjectPool(t, ctx)
err := pool.Create(ctx, testRepo)
require.NoError(t, err)
diff --git a/internal/git/objectpool/testhelper_test.go b/internal/git/objectpool/testhelper_test.go
index 722e892df..62f89eb57 100644
--- a/internal/git/objectpool/testhelper_test.go
+++ b/internal/git/objectpool/testhelper_test.go
@@ -24,7 +24,7 @@ func TestMain(m *testing.M) {
}))
}
-func setupObjectPool(t *testing.T) (*ObjectPool, *gitalypb.Repository) {
+func setupObjectPool(t *testing.T, ctx context.Context) (*ObjectPool, *gitalypb.Repository) {
t.Helper()
cfg, repo, _ := testcfg.BuildWithRepo(t)
@@ -45,7 +45,7 @@ func setupObjectPool(t *testing.T) (*ObjectPool, *gitalypb.Repository) {
require.NoError(t, err)
t.Cleanup(func() {
- if err := pool.Remove(context.TODO()); err != nil {
+ if err := pool.Remove(ctx); err != nil {
panic(err)
}
})
diff --git a/internal/gitaly/hook/sidechannel_test.go b/internal/gitaly/hook/sidechannel_test.go
index 8571d8a2b..a9ba8d344 100644
--- a/internal/gitaly/hook/sidechannel_test.go
+++ b/internal/gitaly/hook/sidechannel_test.go
@@ -1,20 +1,23 @@
package hook
import (
- "context"
"io"
"net"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
grpc_metadata "google.golang.org/grpc/metadata"
)
func TestSidechannel(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
// Client side
ctxOut, wt, err := SetupSidechannel(
- context.Background(),
+ ctx,
func(c *net.UnixConn) error {
_, err := io.WriteString(c, "ping")
return err
@@ -38,6 +41,9 @@ func TestSidechannel(t *testing.T) {
}
func TestGetSidechannel(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
testCases := []string{
"foobar",
"sc.foo/../../bar",
@@ -48,7 +54,7 @@ func TestGetSidechannel(t *testing.T) {
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
ctx := grpc_metadata.NewIncomingContext(
- context.Background(),
+ ctx,
map[string][]string{sidechannelHeader: {tc}},
)
_, err := GetSidechannel(ctx)
diff --git a/internal/gitaly/maintenance/daily_test.go b/internal/gitaly/maintenance/daily_test.go
index 0edeb854c..151b61630 100644
--- a/internal/gitaly/maintenance/daily_test.go
+++ b/internal/gitaly/maintenance/daily_test.go
@@ -38,7 +38,10 @@ func TestStartDaily(t *testing.T) {
Duration: config.Duration(time.Hour),
Storages: []string{"meow"},
}
- ctx, cancel := context.WithCancel(context.Background())
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
go func() { errQ <- dw.StartDaily(ctx, testhelper.NewDiscardingLogEntry(t), s, fn) }()
startTime := time.Date(1999, 3, 31, 0, 0, 0, 0, time.Local)
diff --git a/internal/gitaly/storage/servers_test.go b/internal/gitaly/storage/servers_test.go
index 5c6def8da..f10666828 100644
--- a/internal/gitaly/storage/servers_test.go
+++ b/internal/gitaly/storage/servers_test.go
@@ -81,6 +81,8 @@ func TestInjectGitalyServers(t *testing.T) {
}
t.Run("brand new context", func(t *testing.T) {
+ //nolint:forbidigo // We need to check for metadata and thus cannot use the
+ // testhelper context, which injects feature flags.
ctx := context.Background()
check(t, ctx)
@@ -89,6 +91,8 @@ func TestInjectGitalyServers(t *testing.T) {
t.Run("context with existing outgoing metadata should not be re-written", func(t *testing.T) {
existing := metadata.New(map[string]string{"foo": "bar"})
+ //nolint:forbidigo // We need to check for metadata and thus cannot use the
+ // testhelper context, which injects feature flags.
ctx := metadata.NewOutgoingContext(context.Background(), existing)
check(t, ctx)
diff --git a/internal/gitlab/http_client_test.go b/internal/gitlab/http_client_test.go
index 7d2c6bfc9..90f9e96f0 100644
--- a/internal/gitlab/http_client_test.go
+++ b/internal/gitlab/http_client_test.go
@@ -1,7 +1,6 @@
package gitlab
import (
- "context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -120,8 +119,11 @@ func TestAccess_verifyParams(t *testing.T) {
},
}
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
for _, tc := range testCases {
- allowed, _, err := c.Allowed(context.Background(), AllowedParams{
+ allowed, _, err := c.Allowed(ctx, AllowedParams{
RepoPath: tc.repo.RelativePath,
GitObjectDirectory: tc.repo.GitObjectDirectory,
GitAlternateObjectDirectories: tc.repo.GitAlternateObjectDirectories,
@@ -227,7 +229,11 @@ func TestAccess_escapedAndRelativeURLs(t *testing.T) {
prometheus.Config{},
)
require.NoError(t, err)
- allowed, _, err := c.Allowed(context.Background(), AllowedParams{
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ allowed, _, err := c.Allowed(ctx, AllowedParams{
RepoPath: repo.RelativePath,
GitObjectDirectory: repo.GitObjectDirectory,
GitAlternateObjectDirectories: repo.GitAlternateObjectDirectories,
@@ -379,7 +385,10 @@ func TestAccess_allowedResponseHandling(t *testing.T) {
mockHistogramVec := promtest.NewMockHistogramVec()
c.latencyMetric = mockHistogramVec
- allowed, message, err := c.Allowed(context.Background(), AllowedParams{
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ allowed, message, err := c.Allowed(ctx, AllowedParams{
RepoPath: repo.RelativePath,
GitObjectDirectory: repo.GitObjectDirectory,
GitAlternateObjectDirectories: repo.GitAlternateObjectDirectories,
@@ -489,7 +498,10 @@ func TestAccess_preReceive(t *testing.T) {
mockHistogramVec := promtest.NewMockHistogramVec()
c.latencyMetric = mockHistogramVec
- success, err := c.PreReceive(context.Background(), "key-123")
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ success, err := c.PreReceive(ctx, "key-123")
require.Equal(t, tc.success, success)
if err != nil {
require.Contains(t, err.Error(), tc.errMsg)
@@ -577,10 +589,13 @@ func TestAccess_postReceive(t *testing.T) {
mockHistogramVec := promtest.NewMockHistogramVec()
c.latencyMetric = mockHistogramVec
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
repositoryID := "project-123"
identifier := "key-123"
changes := "000 000 refs/heads/master"
- success, _, err := c.PostReceive(context.Background(), repositoryID, identifier, changes, tc.pushOptions...)
+ success, _, err := c.PostReceive(ctx, repositoryID, identifier, changes, tc.pushOptions...)
require.Equal(t, tc.success, success)
if err != nil {
require.Contains(t, err.Error(), tc.errMsg)
diff --git a/internal/listenmux/mux_test.go b/internal/listenmux/mux_test.go
index 4ab8c488a..00c63d3c0 100644
--- a/internal/listenmux/mux_test.go
+++ b/internal/listenmux/mux_test.go
@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
@@ -55,23 +56,29 @@ func serverWithHandshaker(t *testing.T, h Handshaker) string {
return l.Addr().String()
}
-func checkHealth(t *testing.T, cc *grpc.ClientConn) {
+func checkHealth(t *testing.T, ctx context.Context, cc *grpc.ClientConn) {
t.Helper()
- _, err := healthgrpc.NewHealthClient(cc).Check(context.Background(), &healthgrpc.HealthCheckRequest{})
+ _, err := healthgrpc.NewHealthClient(cc).Check(ctx, &healthgrpc.HealthCheckRequest{})
require.NoError(t, err)
}
func TestMux_normalClientNoMux(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
addr := serverWithHandshaker(t, nil)
cc, err := grpc.Dial(addr, grpc.WithInsecure())
require.NoError(t, err)
defer cc.Close()
- checkHealth(t, cc)
+ checkHealth(t, ctx, cc)
}
func TestMux_normalClientMuxIgnored(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
addr := serverWithHandshaker(t,
handshakeFunc(func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
t.Error("never called")
@@ -83,10 +90,13 @@ func TestMux_normalClientMuxIgnored(t *testing.T) {
require.NoError(t, err)
defer cc.Close()
- checkHealth(t, cc)
+ checkHealth(t, ctx, cc)
}
func TestMux_muxClientPassesThrough(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
handshakerCalled := false
addr := serverWithHandshaker(t,
@@ -115,7 +125,7 @@ func TestMux_muxClientPassesThrough(t *testing.T) {
require.NoError(t, err)
defer cc.Close()
- checkHealth(t, cc)
+ checkHealth(t, ctx, cc)
require.True(t, handshakerCalled)
}
@@ -221,6 +231,9 @@ func TestMux_concurrency(t *testing.T) {
streamClientErrors := make(chan error, N)
grpcHealthErrors := make(chan error, N)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
for i := 0; i < N; i++ {
go func() {
<-start
@@ -274,7 +287,7 @@ func TestMux_concurrency(t *testing.T) {
defer cc.Close()
client := healthgrpc.NewHealthClient(cc)
- _, err = client.Check(context.Background(), &healthgrpc.HealthCheckRequest{})
+ _, err = client.Check(ctx, &healthgrpc.HealthCheckRequest{})
return err
}()
}()
diff --git a/internal/log/log_test.go b/internal/log/log_test.go
index 4c508e057..93f4b1325 100644
--- a/internal/log/log_test.go
+++ b/internal/log/log_test.go
@@ -27,7 +27,7 @@ import (
)
func TestPayloadBytes(t *testing.T) {
- ctx := context.Background()
+ ctx := createContext()
logger, hook := test.NewNullLogger()
@@ -261,7 +261,7 @@ func TestConfigure(t *testing.T) {
func TestMessageProducer(t *testing.T) {
triggered := false
MessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
- require.Equal(t, context.Background(), ctx)
+ require.Equal(t, createContext(), ctx)
require.Equal(t, "format-stub", format)
require.Equal(t, logrus.DebugLevel, level)
require.Equal(t, codes.OutOfRange, code)
@@ -272,20 +272,20 @@ func TestMessageProducer(t *testing.T) {
return logrus.Fields{"a": 1}
}, func(context.Context) logrus.Fields {
return logrus.Fields{"b": "test"}
- })(context.Background(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"})
+ })(createContext(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, logrus.Fields{"c": "stub"})
require.True(t, triggered)
}
func TestPropagationMessageProducer(t *testing.T) {
t.Run("empty context", func(t *testing.T) {
- ctx := context.Background()
+ ctx := createContext()
mp := PropagationMessageProducer(func(context.Context, string, logrus.Level, codes.Code, error, logrus.Fields) {})
mp(ctx, "", logrus.DebugLevel, codes.OK, nil, nil)
})
t.Run("context with holder", func(t *testing.T) {
holder := new(messageProducerHolder)
- ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, holder)
+ ctx := context.WithValue(createContext(), messageProducerHolderKey{}, holder)
triggered := false
mp := PropagationMessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) {
triggered = true
@@ -313,7 +313,7 @@ func TestPerRPCLogHandler(t *testing.T) {
}
t.Run("check propagation", func(t *testing.T) {
- ctx := context.Background()
+ ctx := createContext()
ctx = lh.TagConn(ctx, &stats.ConnTagInfo{})
lh.HandleConn(ctx, &stats.ConnBegin{})
ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{})
@@ -334,7 +334,7 @@ func TestPerRPCLogHandler(t *testing.T) {
})
t.Run("log handling", func(t *testing.T) {
- ctx := ctxlogrus.ToContext(context.Background(), logrus.NewEntry(logrus.New()))
+ ctx := ctxlogrus.ToContext(createContext(), logrus.NewEntry(logrus.New()))
ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{})
mpp := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder)
mpp.format = "message"
@@ -381,7 +381,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("propagates call", func(t *testing.T) {
interceptor := UnaryLogDataCatcherServerInterceptor()
- resp, err := interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
+ resp, err := interceptor(createContext(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
return 42, assert.AnError
})
@@ -391,7 +391,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("no logger", func(t *testing.T) {
mpp := &messageProducerHolder{}
- ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp)
interceptor := UnaryLogDataCatcherServerInterceptor()
_, _ = interceptor(ctx, nil, nil, handlerStub)
@@ -400,7 +400,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("caught", func(t *testing.T) {
mpp := &messageProducerHolder{}
- ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp)
ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1))
interceptor := UnaryLogDataCatcherServerInterceptor()
_, _ = interceptor(ctx, nil, nil, handlerStub)
@@ -411,7 +411,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) {
func TestStreamLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("propagates call", func(t *testing.T) {
interceptor := StreamLogDataCatcherServerInterceptor()
- ss := &grpcmw.WrappedServerStream{WrappedContext: context.Background()}
+ ss := &grpcmw.WrappedServerStream{WrappedContext: createContext()}
err := interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error {
return assert.AnError
})
@@ -421,7 +421,7 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("no logger", func(t *testing.T) {
mpp := &messageProducerHolder{}
- ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp)
interceptor := StreamLogDataCatcherServerInterceptor()
ss := &grpcmw.WrappedServerStream{WrappedContext: ctx}
@@ -430,7 +430,7 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) {
t.Run("caught", func(t *testing.T) {
mpp := &messageProducerHolder{}
- ctx := context.WithValue(context.Background(), messageProducerHolderKey{}, mpp)
+ ctx := context.WithValue(createContext(), messageProducerHolderKey{}, mpp)
ctx = ctxlogrus.ToContext(ctx, logrus.New().WithField("a", 1))
interceptor := StreamLogDataCatcherServerInterceptor()
@@ -439,3 +439,9 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) {
assert.Equal(t, logrus.Fields{"a": 1}, mpp.fields)
})
}
+
+//nolint:forbidigo // We cannot use `testhelper.Context()` because of a cyclic dependency between
+// this package and the `testhelper` package.
+func createContext() context.Context {
+ return context.Background()
+}
diff --git a/internal/metadata/featureflag/context.go b/internal/metadata/featureflag/context.go
index 16caa2837..9d43bd744 100644
--- a/internal/metadata/featureflag/context.go
+++ b/internal/metadata/featureflag/context.go
@@ -33,6 +33,7 @@ func outgoingCtxWithFeatureFlag(ctx context.Context, key string, enabled bool) c
md = metadata.New(map[string]string{})
}
+ md = md.Copy()
md.Set(key, strconv.FormatBool(enabled))
return metadata.NewOutgoingContext(ctx, md)
@@ -56,6 +57,7 @@ func incomingCtxWithFeatureFlag(ctx context.Context, key string, enabled bool) c
md = metadata.New(map[string]string{})
}
+ md = md.Copy()
md.Set(key, strconv.FormatBool(enabled))
return metadata.NewIncomingContext(ctx, md)
diff --git a/internal/metadata/featureflag/context_test.go b/internal/metadata/featureflag/context_test.go
index 5a40bd834..998c0f259 100644
--- a/internal/metadata/featureflag/context_test.go
+++ b/internal/metadata/featureflag/context_test.go
@@ -5,53 +5,89 @@ import (
"testing"
"github.com/stretchr/testify/require"
+ gitaly_metadata "gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
"google.golang.org/grpc/metadata"
)
-var mockFeatureFlag = FeatureFlag{"turn meow on", false}
+var (
+ ffA = FeatureFlag{"feature-a", false}
+ ffB = FeatureFlag{"feature-b", false}
+)
+
+//nolint:forbidigo // We cannot use `testhelper.Context()` given that it would inject feature flags
+// already.
+func createContext() context.Context {
+ return context.Background()
+}
func TestIncomingCtxWithFeatureFlag(t *testing.T) {
- ctx := context.Background()
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx := createContext()
+ require.False(t, ffA.IsEnabled(ctx))
+ require.False(t, ffB.IsEnabled(ctx))
t.Run("enabled", func(t *testing.T) {
- ctx := IncomingCtxWithFeatureFlag(ctx, mockFeatureFlag, true)
- require.True(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx := IncomingCtxWithFeatureFlag(ctx, ffA, true)
+ require.True(t, ffA.IsEnabled(ctx))
})
t.Run("disabled", func(t *testing.T) {
- ctx := IncomingCtxWithFeatureFlag(ctx, mockFeatureFlag, false)
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx := IncomingCtxWithFeatureFlag(ctx, ffA, false)
+ require.False(t, ffA.IsEnabled(ctx))
+ })
+
+ t.Run("set multiple flags", func(t *testing.T) {
+ ctxA := IncomingCtxWithFeatureFlag(ctx, ffA, true)
+ ctxB := IncomingCtxWithFeatureFlag(ctxA, ffB, true)
+
+ require.True(t, ffA.IsEnabled(ctxA))
+ require.False(t, ffB.IsEnabled(ctxA))
+
+ require.True(t, ffA.IsEnabled(ctxB))
+ require.True(t, ffB.IsEnabled(ctxB))
})
}
func TestOutgoingCtxWithFeatureFlag(t *testing.T) {
- ctx := context.Background()
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx := createContext()
+ require.False(t, ffA.IsEnabled(ctx))
+ require.False(t, ffB.IsEnabled(ctx))
t.Run("enabled", func(t *testing.T) {
- ctx := OutgoingCtxWithFeatureFlag(ctx, mockFeatureFlag, true)
+ ctx := OutgoingCtxWithFeatureFlag(ctx, ffA, true)
// The feature flag is only checked for incoming contexts, so it's not expected to
// be enabled yet.
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ require.False(t, ffA.IsEnabled(ctx))
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
// It should be enabled after converting it to an incoming context though.
- ctx = metadata.NewIncomingContext(context.Background(), md)
- require.True(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx = metadata.NewIncomingContext(createContext(), md)
+ require.True(t, ffA.IsEnabled(ctx))
})
t.Run("disabled", func(t *testing.T) {
- ctx = OutgoingCtxWithFeatureFlag(ctx, mockFeatureFlag, false)
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx = OutgoingCtxWithFeatureFlag(ctx, ffA, false)
+ require.False(t, ffA.IsEnabled(ctx))
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
- ctx = metadata.NewIncomingContext(context.Background(), md)
- require.False(t, mockFeatureFlag.IsEnabled(ctx))
+ ctx = metadata.NewIncomingContext(createContext(), md)
+ require.False(t, ffA.IsEnabled(ctx))
+ })
+
+ t.Run("set multiple flags", func(t *testing.T) {
+ ctxA := OutgoingCtxWithFeatureFlag(ctx, ffA, true)
+ ctxB := OutgoingCtxWithFeatureFlag(ctxA, ffB, true)
+
+ ctxA = gitaly_metadata.OutgoingToIncoming(ctxA)
+ require.True(t, ffA.IsEnabled(ctxA))
+ require.False(t, ffB.IsEnabled(ctxA))
+
+ ctxB = gitaly_metadata.OutgoingToIncoming(ctxB)
+ require.True(t, ffA.IsEnabled(ctxB))
+ require.True(t, ffB.IsEnabled(ctxB))
})
}
@@ -76,7 +112,7 @@ func TestGRPCMetadataFeatureFlag(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
md := metadata.New(tc.headers)
- ctx := metadata.NewIncomingContext(context.Background(), md)
+ ctx := metadata.NewIncomingContext(createContext(), md)
require.Equal(t, tc.enabled, FeatureFlag{tc.flag, tc.onByDefault}.IsEnabled(ctx))
})
@@ -91,7 +127,7 @@ func TestAllEnabledFlags(t *testing.T) {
ffPrefix + "bar": "TRUE", // not enabled
}
- ctx := metadata.NewIncomingContext(context.Background(), metadata.New(flags))
+ ctx := metadata.NewIncomingContext(createContext(), metadata.New(flags))
require.ElementsMatch(t, AllFlags(ctx), []string{"meow:true", "foo:true", "woof:false", "bar:TRUE"})
}
@@ -105,7 +141,7 @@ func TestRaw(t *testing.T) {
}
t.Run("RawFromContext", func(t *testing.T) {
- ctx := context.Background()
+ ctx := createContext()
ctx = IncomingCtxWithFeatureFlag(ctx, enabledFlag, true)
ctx = IncomingCtxWithFeatureFlag(ctx, disabledFlag, false)
@@ -113,7 +149,7 @@ func TestRaw(t *testing.T) {
})
t.Run("OutgoingWithRaw", func(t *testing.T) {
- outgoingMD, ok := metadata.FromOutgoingContext(OutgoingWithRaw(context.Background(), raw))
+ outgoingMD, ok := metadata.FromOutgoingContext(OutgoingWithRaw(createContext(), raw))
require.True(t, ok)
require.Equal(t, metadata.MD{
ffPrefix + enabledFlag.Name: {"true"},
diff --git a/internal/metadata/featureflag/featureflag_test.go b/internal/metadata/featureflag/featureflag_test.go
index ef9dc52ca..8944e21fd 100644
--- a/internal/metadata/featureflag/featureflag_test.go
+++ b/internal/metadata/featureflag/featureflag_test.go
@@ -1,7 +1,6 @@
package featureflag
import (
- "context"
"testing"
"github.com/stretchr/testify/require"
@@ -74,7 +73,7 @@ func TestFeatureFlag_enabled(t *testing.T) {
},
} {
t.Run(tc.desc, func(t *testing.T) {
- ctx := metadata.NewIncomingContext(context.Background(), metadata.New(tc.headers))
+ ctx := metadata.NewIncomingContext(createContext(), metadata.New(tc.headers))
ff := FeatureFlag{tc.flag, tc.onByDefault}
require.Equal(t, tc.enabled, ff.IsEnabled(ctx))
diff --git a/internal/metadata/metadata_test.go b/internal/metadata/metadata_test.go
index 71413549e..e361b1c47 100644
--- a/internal/metadata/metadata_test.go
+++ b/internal/metadata/metadata_test.go
@@ -1,20 +1,23 @@
package metadata
import (
- "context"
+ "errors"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/storage"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
func TestOutgoingToIncoming(t *testing.T) {
- ctx := context.Background()
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
ctx, err := storage.InjectGitalyServers(ctx, "a", "b", "c")
require.NoError(t, err)
_, err = storage.ExtractGitalyServer(ctx, "a")
- require.Equal(t, storage.ErrEmptyMetadata, err,
+ require.Equal(t, errors.New("empty gitaly-servers metadata"), err,
"server should not be found in the incoming context")
ctx = OutgoingToIncoming(ctx)
diff --git a/internal/middleware/featureflag/featureflag_handler_test.go b/internal/middleware/featureflag/featureflag_handler_test.go
index be0ccfb43..f3b68b017 100644
--- a/internal/middleware/featureflag/featureflag_handler_test.go
+++ b/internal/middleware/featureflag/featureflag_handler_test.go
@@ -71,6 +71,8 @@ func setup() (context.Context, *test.Hook) {
}
func setupContext() (context.Context, *test.Hook) {
+ //nolint:forbidigo // We don't want to set up the feature flags which `testhelper.Context()`
+ // would inject here.
ctx := context.Background()
logger := logrus.New()
logger.SetOutput(io.Discard)
diff --git a/internal/middleware/limithandler/concurrency_limiter_test.go b/internal/middleware/limithandler/concurrency_limiter_test.go
index c5070fe39..bff19c688 100644
--- a/internal/middleware/limithandler/concurrency_limiter_test.go
+++ b/internal/middleware/limithandler/concurrency_limiter_test.go
@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
type counter struct {
@@ -123,6 +124,9 @@ func TestLimiter(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
expectedGaugeMax := tt.maxConcurrency * tt.buckets
if tt.maxConcurrency <= 0 {
expectedGaugeMax = tt.concurrency
@@ -161,7 +165,7 @@ func TestLimiter(t *testing.T) {
for i := 0; i < tt.iterations; i++ {
lockKey := strconv.Itoa((i ^ counter) % tt.buckets)
- _, err := limiter.Limit(context.Background(), lockKey, func() (interface{}, error) {
+ _, err := limiter.Limit(ctx, lockKey, func() (interface{}, error) {
primePump()
current := gauge.currentVal()
diff --git a/internal/middleware/metadatahandler/metadatahandler_test.go b/internal/middleware/metadatahandler/metadatahandler_test.go
index dc89037be..6aa12b64b 100644
--- a/internal/middleware/metadatahandler/metadatahandler_test.go
+++ b/internal/middleware/metadatahandler/metadatahandler_test.go
@@ -129,12 +129,15 @@ func verifyHandler(ctx context.Context, req interface{}) (interface{}, error) {
}
func TestGRPCTags(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
require := require.New(t)
- ctx := metadata.NewIncomingContext(
+ ctx = metadata.NewIncomingContext(
correlation.ContextWithCorrelation(
correlation.ContextWithClientName(
- context.Background(),
+ ctx,
clientName,
),
correlationID,
diff --git a/internal/middleware/sentryhandler/sentryhandler_test.go b/internal/middleware/sentryhandler/sentryhandler_test.go
index f2643e3a7..fcc75d6b0 100644
--- a/internal/middleware/sentryhandler/sentryhandler_test.go
+++ b/internal/middleware/sentryhandler/sentryhandler_test.go
@@ -83,7 +83,10 @@ func Test_generateSentryEvent(t *testing.T) {
name: "marked to skip",
ctx: func() context.Context {
var result context.Context
- ctx := context.Background()
+
+ ctx, cancel := testhelper.Context()
+ t.Cleanup(cancel)
+
// this is the only way how we could populate context with `tags` assembler
_, err := grpcmwtags.UnaryServerInterceptor()(ctx, nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) {
result = ctx
diff --git a/internal/praefect/coordinator_test.go b/internal/praefect/coordinator_test.go
index 7da00dbb7..70dc062aa 100644
--- a/internal/praefect/coordinator_test.go
+++ b/internal/praefect/coordinator_test.go
@@ -2090,9 +2090,12 @@ func TestNewRequestFinalizer_contextIsDisjointedFromTheRPC(t *testing.T) {
parentDeadline := time.Now()
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
//nolint:forbidigo // We explicitly want to test that the deadline does not propagate into
// the request's context.
- ctx, cancel := context.WithDeadline(context.WithValue(context.Background(), ctxKey{}, "value"), parentDeadline)
+ ctx, cancel = context.WithDeadline(context.WithValue(ctx, ctxKey{}, "value"), parentDeadline)
cancel()
requireSuppressedCancellation := func(t testing.TB, ctx context.Context) {
@@ -2210,6 +2213,9 @@ func TestStreamParametersContext(t *testing.T) {
return metadata.Pairs(pairs...)
}
+ //nolint:forbidigo // We explicitly test context values, so we cannot use the testhelper
+ // context here given that it would contain unrelated data and thus change the system under
+ // test.
for _, tc := range []struct {
desc string
setupContext func() context.Context
diff --git a/internal/praefect/repocleaner/action_log_test.go b/internal/praefect/repocleaner/action_log_test.go
index b8073d91f..1afada69c 100644
--- a/internal/praefect/repocleaner/action_log_test.go
+++ b/internal/praefect/repocleaner/action_log_test.go
@@ -1,18 +1,21 @@
package repocleaner
import (
- "context"
"testing"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
func TestLogWarnAction_Perform(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
logger, hook := test.NewNullLogger()
action := NewLogWarnAction(logger)
- err := action.Perform(context.TODO(), "vs1", "g1", []string{"p/1", "p/2"})
+ err := action.Perform(ctx, "vs1", "g1", []string{"p/1", "p/2"})
require.NoError(t, err)
require.Len(t, hook.AllEntries(), 2)
diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go
index 285765b51..e5c153f1a 100644
--- a/internal/sidechannel/sidechannel_test.go
+++ b/internal/sidechannel/sidechannel_test.go
@@ -12,12 +12,16 @@ import (
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
"gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
func TestSidechannel(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
const blobSize = 1024 * 1024
in := make([]byte, blobSize)
@@ -45,7 +49,7 @@ func TestSidechannel(t *testing.T) {
conn, registry := dial(t, addr)
err = call(
- context.Background(), conn, registry,
+ ctx, conn, registry,
func(conn *ClientConn) error {
errC := make(chan error, 1)
go func() {
@@ -68,6 +72,9 @@ func TestSidechannel(t *testing.T) {
// Conduct multiple requests with sidechannel included on the same grpc
// connection.
func TestSidechannelConcurrency(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
const concurrency = 10
const blobSize = 1024 * 1024
@@ -111,7 +118,7 @@ func TestSidechannelConcurrency(t *testing.T) {
defer wg.Done()
err := call(
- context.Background(), conn, registry,
+ ctx, conn, registry,
func(conn *ClientConn) error {
errC := make(chan error, 1)
go func() {
diff --git a/internal/streamcache/cache_test.go b/internal/streamcache/cache_test.go
index da8b5e157..1bb9009ea 100644
--- a/internal/streamcache/cache_test.go
+++ b/internal/streamcache/cache_test.go
@@ -29,6 +29,9 @@ func newCache(dir string) Cache {
}
func TestCache_writeOneReadMultiple(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
c := newCache(tmp)
@@ -50,7 +53,7 @@ func TestCache_writeOneReadMultiple(t *testing.T) {
out, err := io.ReadAll(r)
require.NoError(t, err)
- require.NoError(t, r.Wait(context.Background()))
+ require.NoError(t, r.Wait(ctx))
require.Equal(t, content(0), string(out), "expect cache hits for all i > 0")
})
}
@@ -59,6 +62,9 @@ func TestCache_writeOneReadMultiple(t *testing.T) {
}
func TestCache_manyConcurrentWrites(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
c := newCache(tmp)
@@ -94,7 +100,7 @@ func TestCache_manyConcurrentWrites(t *testing.T) {
}
output[i] = string(out)
- return r.Wait(context.Background())
+ return r.Wait(ctx)
}()
}(i)
}
@@ -175,6 +181,9 @@ func TestCache_deletedFile(t *testing.T) {
}
func TestCache_scope(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
const (
@@ -215,13 +224,16 @@ func TestCache_scope(t *testing.T) {
out, err := io.ReadAll(r)
require.NoError(t, err)
- require.NoError(t, r.Wait(context.Background()))
+ require.NoError(t, r.Wait(ctx))
require.Equal(t, content, string(out))
}
}
func TestCache_diskCleanup(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
const (
@@ -254,7 +266,7 @@ func TestCache_diskCleanup(t *testing.T) {
out1, err := io.ReadAll(r1)
require.NoError(t, err)
require.Equal(t, content(1), string(out1))
- require.NoError(t, r1.Wait(context.Background()))
+ require.NoError(t, r1.Wait(ctx))
// File and index entry should still exist because cleanup goroutines are blocked.
requireCacheFiles(t, tmp, 1)
@@ -292,13 +304,16 @@ func TestCache_diskCleanup(t *testing.T) {
out2, err := io.ReadAll(r2)
require.NoError(t, err)
- require.NoError(t, r2.Wait(context.Background()))
+ require.NoError(t, r2.Wait(ctx))
// Sanity check: no stale value returned by the cache
require.Equal(t, content(2), string(out2))
}
func TestCache_failedWrite(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
c := newCache(tmp)
@@ -327,7 +342,7 @@ func TestCache_failedWrite(t *testing.T) {
_, err = io.Copy(io.Discard, r1)
require.NoError(t, err, "errors on the write end are not propagated via Read()")
require.NoError(t, r1.Close(), "errors on the write end are not propagated via Close()")
- require.Error(t, r1.Wait(context.Background()), "error propagation happens via Wait()")
+ require.Error(t, r1.Wait(ctx), "error propagation happens via Wait()")
time.Sleep(10 * time.Millisecond)
@@ -339,7 +354,7 @@ func TestCache_failedWrite(t *testing.T) {
out, err := io.ReadAll(r2)
require.NoError(t, err)
- require.NoError(t, r2.Wait(context.Background()))
+ require.NoError(t, r2.Wait(ctx))
require.Equal(t, happy, string(out))
})
}
@@ -359,6 +374,9 @@ func TestCache_failCreateFile(t *testing.T) {
}
func TestCache_unWriteableFile(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
c := newCache(tmp)
@@ -378,12 +396,15 @@ func TestCache_unWriteableFile(t *testing.T) {
_, err = io.ReadAll(r)
require.NoError(t, err)
- err = r.Wait(context.Background())
+ err = r.Wait(ctx)
require.IsType(t, &os.PathError{}, err)
require.Equal(t, "write", err.(*os.PathError).Op)
}
func TestCache_unCloseableFile(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
tmp := testhelper.TempDir(t)
c := newCache(tmp)
@@ -404,7 +425,7 @@ func TestCache_unCloseableFile(t *testing.T) {
_, err = io.ReadAll(r)
require.NoError(t, err)
- err = r.Wait(context.Background())
+ err = r.Wait(ctx)
require.IsType(t, &os.PathError{}, err)
require.Equal(t, "close", err.(*os.PathError).Op)
}
@@ -430,16 +451,21 @@ func TestCache_cannotOpenFileForReading(t *testing.T) {
}
func TestWaiter(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
w := newWaiter()
err := errors.New("test error")
w.SetError(err)
- require.Equal(t, err, w.Wait(context.Background()))
+ require.Equal(t, err, w.Wait(ctx))
}
func TestWaiter_cancel(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
w := newWaiter()
errc := make(chan error, 1)
- ctx, cancel := context.WithCancel(context.Background())
go func() { errc <- w.Wait(ctx) }()
cancel()
@@ -447,6 +473,9 @@ func TestWaiter_cancel(t *testing.T) {
}
func TestNullCache(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
const (
N = 1000
inputSize = 4096
@@ -500,7 +529,7 @@ func TestNullCache(t *testing.T) {
return errors.New("output does not match input")
}
- return s.Wait(context.Background())
+ return s.Wait(ctx)
}()
}()
}
diff --git a/internal/testhelper/featureset_test.go b/internal/testhelper/featureset_test.go
index c6a4e75a8..fdbbba5ed 100644
--- a/internal/testhelper/featureset_test.go
+++ b/internal/testhelper/featureset_test.go
@@ -5,9 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
ff "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
- grpc_metadata "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/metadata"
)
var (
@@ -172,9 +171,6 @@ func TestNewFeatureSetsWithRubyFlags(t *testing.T) {
}
func TestFeatureSets_Run(t *testing.T) {
- var incomingFlags [][2]bool
- var outgoingFlags [][2]bool
-
// This test depends on feature flags being default-enabled in the test
// context, which requires those flags to exist in the ff.All slice. So
// let's just append them here so we do not need to use a "real"
@@ -185,40 +181,25 @@ func TestFeatureSets_Run(t *testing.T) {
}(ff.All)
ff.All = append(ff.All, featureFlagA, featureFlagB)
+ var featureFlags [][2]bool
NewFeatureSets(featureFlagB, featureFlagA).Run(t, func(t *testing.T, ctx context.Context) {
- incomingMD, ok := grpc_metadata.FromIncomingContext(ctx)
+ incomingMD, ok := metadata.FromIncomingContext(ctx)
require.True(t, ok)
- outgoingMD, ok := grpc_metadata.FromOutgoingContext(ctx)
+ outgoingMD, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
- incomingCtx := grpc_metadata.NewIncomingContext(context.Background(), incomingMD)
- outgoingCtx := metadata.OutgoingToIncoming(grpc_metadata.NewOutgoingContext(context.Background(), outgoingMD))
+ require.Equal(t, incomingMD, outgoingMD)
- incomingFlags = append(incomingFlags, [2]bool{
- featureFlagB.IsDisabled(incomingCtx),
- featureFlagA.IsDisabled(incomingCtx),
- })
- outgoingFlags = append(outgoingFlags, [2]bool{
- featureFlagB.IsDisabled(outgoingCtx),
- featureFlagA.IsDisabled(outgoingCtx),
+ featureFlags = append(featureFlags, [2]bool{
+ featureFlagA.IsEnabled(ctx), featureFlagB.IsEnabled(ctx),
})
})
- for _, tc := range []struct {
- desc string
- flags [][2]bool
- }{
- {desc: "incoming context", flags: incomingFlags},
- {desc: "outgoing context", flags: outgoingFlags},
- } {
- t.Run(tc.desc, func(t *testing.T) {
- require.ElementsMatch(t, tc.flags, [][2]bool{
- {false, false},
- {true, false},
- {false, true},
- {true, true},
- })
- })
- }
+ require.Equal(t, [][2]bool{
+ {false, false},
+ {false, true},
+ {true, false},
+ {true, true},
+ }, featureFlags)
}
diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go
index b5b008ff8..b98a7f4d8 100644
--- a/internal/testhelper/testhelper.go
+++ b/internal/testhelper/testhelper.go
@@ -157,18 +157,23 @@ func GetLocalhostListener(t testing.TB) (net.Listener, string) {
}
// ContextOpt returns a new context instance with the new additions to it.
-type ContextOpt func(context.Context) (context.Context, func())
+type ContextOpt func(context.Context) context.Context
// ContextWithLogger allows to inject provided logger into the context.
func ContextWithLogger(logger *log.Entry) ContextOpt {
- return func(ctx context.Context) (context.Context, func()) {
- return ctxlogrus.ToContext(ctx, logger), func() {}
+ return func(ctx context.Context) context.Context {
+ return ctxlogrus.ToContext(ctx, logger)
}
}
// Context returns a cancellable context.
func Context(opts ...ContextOpt) (context.Context, func()) {
- ctx, cancel := context.WithCancel(context.Background())
+ return context.WithCancel(ContextWithoutCancel(opts...))
+}
+
+// ContextWithoutCancel returns a non-cancellable context.
+func ContextWithoutCancel(opts ...ContextOpt) context.Context {
+ ctx := context.Background()
// Enable use of explicit feature flags. Each feature flag which is checked must have been
// explicitly injected into the context, or otherwise we panic. This is a sanity check to
@@ -180,18 +185,11 @@ func Context(opts ...ContextOpt) (context.Context, func()) {
// context.
ctx = featureflag.ContextWithFeatureFlags(ctx, featureflag.RunCommandsInCGroup)
- cancels := make([]func(), len(opts)+1)
- cancels[0] = cancel
- for i, opt := range opts {
- ctx, cancel = opt(ctx)
- cancels[i+1] = cancel
+ for _, opt := range opts {
+ ctx = opt(ctx)
}
- return ctx, func() {
- for i := len(cancels) - 1; i >= 0; i-- {
- cancels[i]()
- }
- }
+ return ctx
}
// CreateGlobalDirectory creates a directory in the test directory that is shared across all