Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToon Claes <toon@gitlab.com>2021-09-08 17:19:48 +0300
committerToon Claes <toon@gitlab.com>2021-09-08 17:19:48 +0300
commitd63338d5a5e13cd36de6621ecbf2eff62ebefc7a (patch)
tree2d06cb75facca6ee55bb928255be69c7a402d763
parent719b070eb2faa9407482fbe3345331d30172d202 (diff)
parentb3ab3bfa5914d889c5c00e2c79fd03b4957a5a5a (diff)
Merge branch 'pks-tx-file-locking' into 'master'
Guard transactional file modifications against concurrent modifications See merge request gitlab-org/gitaly!3821
-rw-r--r--internal/cache/diskcache.go2
-rw-r--r--internal/cache/keyer.go2
-rw-r--r--internal/gitaly/service/repository/apply_gitattributes.go42
-rw-r--r--internal/gitaly/service/repository/apply_gitattributes_test.go72
-rw-r--r--internal/gitaly/service/repository/fullpath.go45
-rw-r--r--internal/gitaly/service/repository/fullpath_test.go17
-rw-r--r--internal/gitaly/service/repository/replicate.go2
-rw-r--r--internal/gitaly/storage/metadata.go2
-rw-r--r--internal/gitaly/transaction/manager.go19
-rw-r--r--internal/gitaly/transaction/voting.go77
-rw-r--r--internal/gitaly/transaction/voting_test.go210
-rw-r--r--internal/metadata/featureflag/ff_tx_file_locking.go5
-rw-r--r--internal/safe/file_writer.go26
-rw-r--r--internal/safe/file_writer_test.go40
-rw-r--r--internal/safe/locking_file_writer.go222
-rw-r--r--internal/safe/locking_file_writer_test.go304
16 files changed, 1015 insertions, 72 deletions
diff --git a/internal/cache/diskcache.go b/internal/cache/diskcache.go
index 98f839fbc..3fb5bf4bb 100644
--- a/internal/cache/diskcache.go
+++ b/internal/cache/diskcache.go
@@ -289,7 +289,7 @@ func (c *DiskCache) PutStream(ctx context.Context, repo *gitalypb.Repository, re
return err
}
- sf, err := safe.CreateFileWriter(reqPath)
+ sf, err := safe.NewFileWriter(reqPath)
if err != nil {
return err
}
diff --git a/internal/cache/keyer.go b/internal/cache/keyer.go
index a46fa22a7..dad5df90d 100644
--- a/internal/cache/keyer.go
+++ b/internal/cache/keyer.go
@@ -66,7 +66,7 @@ func (keyer leaseKeyer) updateLatest(ctx context.Context, repo *gitalypb.Reposit
return "", err
}
- latest, err := safe.CreateFileWriter(lPath)
+ latest, err := safe.NewFileWriter(lPath)
if err != nil {
return "", err
}
diff --git a/internal/gitaly/service/repository/apply_gitattributes.go b/internal/gitaly/service/repository/apply_gitattributes.go
index 9dbbdbc7c..a26b1cc14 100644
--- a/internal/gitaly/service/repository/apply_gitattributes.go
+++ b/internal/gitaly/service/repository/apply_gitattributes.go
@@ -12,11 +12,13 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"gitlab.com/gitlab-org/gitaly/v14/internal/git"
"gitlab.com/gitlab-org/gitaly/v14/internal/git/catfile"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/transaction"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/helper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/safe"
"gitlab.com/gitlab-org/gitaly/v14/internal/transaction/txinfo"
"gitlab.com/gitlab-org/gitaly/v14/internal/transaction/voting"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
)
const attributesFileMode os.FileMode = 0o644
@@ -28,7 +30,7 @@ func (s *server) applyGitattributes(ctx context.Context, c catfile.Batch, repoPa
_, err := c.Info(ctx, git.Revision(revision))
if err != nil {
if catfile.IsNotFound(err) {
- return status.Errorf(codes.InvalidArgument, "Revision doesn't exist")
+ return helper.ErrInvalidArgumentf("revision does not exist")
}
return err
@@ -59,9 +61,34 @@ func (s *server) applyGitattributes(ctx context.Context, c catfile.Batch, repoPa
return err
}
+ blobObj, err := c.Blob(ctx, git.Revision(blobInfo.Oid))
+ if err != nil {
+ return err
+ }
+
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ writer, err := safe.NewLockingFileWriter(attributesPath, safe.LockingFileWriterConfig{
+ FileWriterConfig: safe.FileWriterConfig{FileMode: attributesFileMode},
+ })
+ if err != nil {
+ return fmt.Errorf("creating gitattributes writer: %w", err)
+ }
+ defer writer.Close()
+
+ if _, err := io.CopyN(writer, blobObj.Reader, blobInfo.Size); err != nil {
+ return err
+ }
+
+ if err := transaction.CommitLockedFile(ctx, s.txManager, writer); err != nil {
+ return fmt.Errorf("committing gitattributes: %w", err)
+ }
+
+ return nil
+ }
+
tempFile, err := ioutil.TempFile(infoPath, "attributes")
if err != nil {
- return status.Errorf(codes.Internal, "ApplyGitAttributes: creating temp file: %v", err)
+ return helper.ErrInternalf("creating temporary gitattributes file: %w", err)
}
defer func() {
if err := os.Remove(tempFile.Name()); err != nil && !errors.Is(err, os.ErrNotExist) {
@@ -69,11 +96,6 @@ func (s *server) applyGitattributes(ctx context.Context, c catfile.Batch, repoPa
}
}()
- blobObj, err := c.Blob(ctx, git.Revision(blobInfo.Oid))
- if err != nil {
- return err
- }
-
// Write attributes to temp file
if _, err := io.CopyN(tempFile, blobObj.Reader, blobInfo.Size); err != nil {
return err
@@ -128,7 +150,7 @@ func (s *server) ApplyGitattributes(ctx context.Context, in *gitalypb.ApplyGitat
}
if err := git.ValidateRevision(in.GetRevision()); err != nil {
- return nil, status.Errorf(codes.InvalidArgument, "ApplyGitAttributes: revision: %v", err)
+ return nil, helper.ErrInvalidArgumentf("revision: %v", err)
}
c, err := s.catfileCache.BatchProcess(ctx, repo)
diff --git a/internal/gitaly/service/repository/apply_gitattributes_test.go b/internal/gitaly/service/repository/apply_gitattributes_test.go
index 03227950f..3cebc3969 100644
--- a/internal/gitaly/service/repository/apply_gitattributes_test.go
+++ b/internal/gitaly/service/repository/apply_gitattributes_test.go
@@ -16,11 +16,13 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/git"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testserver"
"gitlab.com/gitlab-org/gitaly/v14/internal/transaction/txinfo"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/transaction/voting"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -28,6 +30,12 @@ import (
)
func TestApplyGitattributesSuccess(t *testing.T) {
+ testhelper.NewFeatureSets([]featureflag.FeatureFlag{
+ featureflag.TxFileLocking,
+ }).Run(t, testApplyGitattributesSuccess)
+}
+
+func testApplyGitattributesSuccess(t *testing.T, ctx context.Context) {
t.Parallel()
cfg, repo, _, client := setupRepositoryService(t)
@@ -57,18 +65,18 @@ func TestApplyGitattributesSuccess(t *testing.T) {
if err := os.RemoveAll(infoPath); err != nil {
t.Fatal(err)
}
- assertGitattributesApplied(t, client, repo, attributesPath, test.revision, test.contents)
+ assertGitattributesApplied(t, ctx, client, repo, attributesPath, test.revision, test.contents)
// Test when no git attributes file exists
if err := os.Remove(attributesPath); err != nil && !os.IsNotExist(err) {
t.Fatal(err)
}
- assertGitattributesApplied(t, client, repo, attributesPath, test.revision, test.contents)
+ assertGitattributesApplied(t, ctx, client, repo, attributesPath, test.revision, test.contents)
// Test when a git attributes file already exists
require.NoError(t, os.MkdirAll(infoPath, 0o755))
require.NoError(t, ioutil.WriteFile(attributesPath, []byte("*.docx diff=word"), 0o644))
- assertGitattributesApplied(t, client, repo, attributesPath, test.revision, test.contents)
+ assertGitattributesApplied(t, ctx, client, repo, attributesPath, test.revision, test.contents)
})
}
}
@@ -86,6 +94,12 @@ func (s *testTransactionServer) VoteTransaction(ctx context.Context, in *gitalyp
}
func TestApplyGitattributesWithTransaction(t *testing.T) {
+ testhelper.NewFeatureSets([]featureflag.FeatureFlag{
+ featureflag.TxFileLocking,
+ }).Run(t, testApplyGitattributesWithTransaction)
+}
+
+func testApplyGitattributesWithTransaction(t *testing.T, ctx context.Context) {
t.Parallel()
cfg, repo, repoPath := testcfg.BuildWithRepo(t)
@@ -106,9 +120,6 @@ func TestApplyGitattributesWithTransaction(t *testing.T) {
// carefully crafted transaction and server information.
logger := testhelper.DiscardTestEntry(t)
- ctx, cancel := testhelper.Context()
- defer cancel()
-
client := newMuxedRepositoryClient(t, ctx, cfg, "unix://"+cfg.GitalyInternalSocketPath(),
backchannel.NewClientHandshaker(logger, func() backchannel.Server {
srv := grpc.NewServer()
@@ -128,12 +139,18 @@ func TestApplyGitattributesWithTransaction(t *testing.T) {
desc: "successful vote writes gitattributes",
revision: []byte("e63f41fe459e62e1228fcef60d7189127aeba95a"),
voteFn: func(t *testing.T, request *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
- oid, err := git.NewObjectIDFromHex("36814a3da051159a1683479e7a1487120309db8f")
- require.NoError(t, err)
- hash, err := oid.Bytes()
- require.NoError(t, err)
-
- require.Equal(t, hash, request.ReferenceUpdatesHash)
+ var expectedHash []byte
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ vote := voting.VoteFromData([]byte("/custom-highlighting/*.gitlab-custom gitlab-language=ruby\n"))
+ expectedHash = vote.Bytes()
+ } else {
+ oid, err := git.NewObjectIDFromHex("36814a3da051159a1683479e7a1487120309db8f")
+ require.NoError(t, err)
+ expectedHash, err = oid.Bytes()
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, expectedHash, request.ReferenceUpdatesHash)
return &gitalypb.VoteTransactionResponse{
State: gitalypb.VoteTransactionResponse_COMMIT,
}, nil
@@ -149,7 +166,13 @@ func TestApplyGitattributesWithTransaction(t *testing.T) {
}, nil
},
shouldExist: false,
- expectedErr: status.Error(codes.Unknown, "could not commit gitattributes: vote failed: transaction was aborted"),
+ expectedErr: func() error {
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ return status.Error(codes.Unknown, "committing gitattributes: voting on locked file: preimage vote: transaction was aborted")
+ }
+
+ return status.Error(codes.Unknown, "could not commit gitattributes: vote failed: transaction was aborted")
+ }(),
},
{
desc: "failing vote does not write gitattributes",
@@ -158,7 +181,14 @@ func TestApplyGitattributesWithTransaction(t *testing.T) {
return nil, errors.New("foobar")
},
shouldExist: false,
- expectedErr: status.Error(codes.Unknown, "could not commit gitattributes: vote failed: rpc error: code = Unknown desc = foobar"),
+
+ expectedErr: func() error {
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ return status.Error(codes.Unknown, "committing gitattributes: voting on locked file: preimage vote: rpc error: code = Unknown desc = foobar")
+ }
+
+ return status.Error(codes.Unknown, "could not commit gitattributes: vote failed: rpc error: code = Unknown desc = foobar")
+ }(),
},
{
desc: "commit without gitattributes performs vote",
@@ -203,6 +233,12 @@ func TestApplyGitattributesWithTransaction(t *testing.T) {
}
func TestApplyGitattributesFailure(t *testing.T) {
+ testhelper.NewFeatureSets([]featureflag.FeatureFlag{
+ featureflag.TxFileLocking,
+ }).Run(t, testApplyGitattributesFailure)
+}
+
+func testApplyGitattributesFailure(t *testing.T, ctx context.Context) {
t.Parallel()
_, repo, _, client := setupRepositoryService(t)
@@ -250,9 +286,6 @@ func TestApplyGitattributesFailure(t *testing.T) {
for _, test := range tests {
t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) {
- ctx, cancel := testhelper.Context()
- defer cancel()
-
req := &gitalypb.ApplyGitattributesRequest{Repository: test.repo, Revision: test.revision}
_, err := client.ApplyGitattributes(ctx, req)
testhelper.RequireGrpcError(t, err, test.code)
@@ -260,12 +293,9 @@ func TestApplyGitattributesFailure(t *testing.T) {
}
}
-func assertGitattributesApplied(t *testing.T, client gitalypb.RepositoryServiceClient, testRepo *gitalypb.Repository, attributesPath string, revision, expectedContents []byte) {
+func assertGitattributesApplied(t *testing.T, ctx context.Context, client gitalypb.RepositoryServiceClient, testRepo *gitalypb.Repository, attributesPath string, revision, expectedContents []byte) {
t.Helper()
- ctx, cancel := testhelper.Context()
- defer cancel()
-
req := &gitalypb.ApplyGitattributesRequest{Repository: testRepo, Revision: revision}
c, err := client.ApplyGitattributes(ctx, req)
diff --git a/internal/gitaly/service/repository/fullpath.go b/internal/gitaly/service/repository/fullpath.go
index 1c22f0815..86da226e4 100644
--- a/internal/gitaly/service/repository/fullpath.go
+++ b/internal/gitaly/service/repository/fullpath.go
@@ -2,9 +2,13 @@ package repository
import (
"context"
- "fmt"
+ "path/filepath"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/git"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/transaction"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/safe"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
)
@@ -26,16 +30,49 @@ func (s *server) SetFullPath(
repo := s.localrepo(request.GetRepository())
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ repoPath, err := repo.Path()
+ if err != nil {
+ return nil, helper.ErrInternalf("getting repository path: %w", err)
+ }
+ configPath := filepath.Join(repoPath, "config")
+
+ writer, err := safe.NewLockingFileWriter(configPath, safe.LockingFileWriterConfig{
+ SeedContents: true,
+ })
+ if err != nil {
+ return nil, helper.ErrInternalf("creating config writer: %w", err)
+ }
+ defer writer.Close()
+
+ if err := repo.ExecAndWait(ctx, git.SubCmd{
+ Name: "config",
+ Flags: []git.Option{
+ git.Flag{Name: "--replace-all"},
+ git.ValueFlag{Name: "--file", Value: writer.Path()},
+ },
+ Args: []string{fullPathKey, request.GetPath()},
+ }); err != nil {
+ return nil, helper.ErrInternalf("writing full path: %w", err)
+ }
+
+ if err := transaction.CommitLockedFile(ctx, s.txManager, writer); err != nil {
+ return nil, helper.ErrInternalf("committing config: %w", err)
+ }
+
+ return &gitalypb.SetFullPathResponse{}, nil
+ }
+
if err := s.voteOnConfig(ctx, request.GetRepository()); err != nil {
- return nil, helper.ErrInternal(fmt.Errorf("preimage vote on config: %w", err))
+ return nil, helper.ErrInternalf("preimage vote on config: %w", err)
}
if err := repo.Config().Set(ctx, fullPathKey, request.GetPath()); err != nil {
- return nil, helper.ErrInternal(fmt.Errorf("writing config: %w", err))
+ return nil, helper.ErrInternalf("writing config: %w", err)
}
if err := s.voteOnConfig(ctx, request.GetRepository()); err != nil {
- return nil, helper.ErrInternal(fmt.Errorf("postimage vote on config: %w", err))
+ return nil, helper.ErrInternalf("postimage vote on config: %w", err)
}
return &gitalypb.SetFullPathResponse{}, nil
diff --git a/internal/gitaly/service/repository/fullpath_test.go b/internal/gitaly/service/repository/fullpath_test.go
index 883e313fe..b04d6d09d 100644
--- a/internal/gitaly/service/repository/fullpath_test.go
+++ b/internal/gitaly/service/repository/fullpath_test.go
@@ -1,6 +1,7 @@
package repository
import (
+ "context"
"fmt"
"os"
"path/filepath"
@@ -11,6 +12,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper/text"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/metadata/featureflag"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
@@ -19,10 +21,13 @@ import (
func TestSetFullPath(t *testing.T) {
t.Parallel()
- cfg, client := setupRepositoryServiceWithoutRepo(t)
+ testhelper.NewFeatureSets([]featureflag.FeatureFlag{
+ featureflag.TxFileLocking,
+ }).Run(t, testSetFullPath)
+}
- ctx, cancel := testhelper.Context()
- defer cancel()
+func testSetFullPath(t *testing.T, ctx context.Context) {
+ cfg, client := setupRepositoryServiceWithoutRepo(t)
t.Run("missing repository", func(t *testing.T) {
response, err := client.SetFullPath(ctx, &gitalypb.SetFullPathRequest{
@@ -74,8 +79,12 @@ func TestSetFullPath(t *testing.T) {
require.Nil(t, response)
- expectedErr := fmt.Sprintf("rpc error: code = Internal desc = writing config: rpc "+
+ expectedErr := fmt.Sprintf("rpc error: code = NotFound desc = writing config: rpc "+
"error: code = NotFound desc = GetRepoPath: not a git repository: %q", repoPath)
+ if featureflag.TxFileLocking.IsEnabled(ctx) {
+ expectedErr = fmt.Sprintf("rpc error: code = NotFound desc = getting repository path: rpc "+
+ "error: code = NotFound desc = GetRepoPath: not a git repository: %q", repoPath)
+ }
require.EqualError(t, err, expectedErr)
})
diff --git a/internal/gitaly/service/repository/replicate.go b/internal/gitaly/service/repository/replicate.go
index ad3965747..b4bcaf676 100644
--- a/internal/gitaly/service/repository/replicate.go
+++ b/internal/gitaly/service/repository/replicate.go
@@ -276,7 +276,7 @@ func writeFile(path string, mode os.FileMode, reader io.Reader) error {
return err
}
- fw, err := safe.CreateFileWriter(path)
+ fw, err := safe.NewFileWriter(path)
if err != nil {
return err
}
diff --git a/internal/gitaly/storage/metadata.go b/internal/gitaly/storage/metadata.go
index 62381be7d..8368d93e7 100644
--- a/internal/gitaly/storage/metadata.go
+++ b/internal/gitaly/storage/metadata.go
@@ -28,7 +28,7 @@ func WriteMetadataFile(storagePath string) error {
return err
}
- fw, err := safe.CreateFileWriter(path)
+ fw, err := safe.NewFileWriter(path)
if err != nil {
return err
}
diff --git a/internal/gitaly/transaction/manager.go b/internal/gitaly/transaction/manager.go
index d3f3cf356..e53088115 100644
--- a/internal/gitaly/transaction/manager.go
+++ b/internal/gitaly/transaction/manager.go
@@ -168,22 +168,3 @@ func (m *PoolManager) Stop(ctx context.Context, tx txinfo.Transaction) error {
func (m *PoolManager) log(ctx context.Context) logrus.FieldLogger {
return ctxlogrus.Extract(ctx).WithField("component", "transaction.PoolManager")
}
-
-// RunOnContext runs the given function if the context identifies a transaction.
-func RunOnContext(ctx context.Context, fn func(txinfo.Transaction) error) error {
- transaction, err := txinfo.TransactionFromContext(ctx)
- if err != nil {
- if errors.Is(err, txinfo.ErrTransactionNotFound) {
- return nil
- }
- return err
- }
- return fn(transaction)
-}
-
-// VoteOnContext casts the vote on a transaction identified by the context, if there is any.
-func VoteOnContext(ctx context.Context, m Manager, vote voting.Vote) error {
- return RunOnContext(ctx, func(transaction txinfo.Transaction) error {
- return m.Vote(ctx, transaction, vote)
- })
-}
diff --git a/internal/gitaly/transaction/voting.go b/internal/gitaly/transaction/voting.go
new file mode 100644
index 000000000..071d944a8
--- /dev/null
+++ b/internal/gitaly/transaction/voting.go
@@ -0,0 +1,77 @@
+package transaction
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/safe"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/transaction/txinfo"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/transaction/voting"
+)
+
+// RunOnContext runs the given function if the context identifies a transaction.
+func RunOnContext(ctx context.Context, fn func(txinfo.Transaction) error) error {
+ transaction, err := txinfo.TransactionFromContext(ctx)
+ if err != nil {
+ if errors.Is(err, txinfo.ErrTransactionNotFound) {
+ return nil
+ }
+ return err
+ }
+ return fn(transaction)
+}
+
+// VoteOnContext casts the vote on a transaction identified by the context, if there is any.
+func VoteOnContext(ctx context.Context, m Manager, vote voting.Vote) error {
+ return RunOnContext(ctx, func(transaction txinfo.Transaction) error {
+ return m.Vote(ctx, transaction, vote)
+ })
+}
+
+// CommitLockedFile will lock, vote and commit the LockingFileWriter in a race-free manner.
+func CommitLockedFile(ctx context.Context, m Manager, writer *safe.LockingFileWriter) (returnedErr error) {
+ if err := writer.Lock(); err != nil {
+ return fmt.Errorf("locking file: %w", err)
+ }
+
+ var vote voting.Vote
+ if err := RunOnContext(ctx, func(tx txinfo.Transaction) error {
+ hasher := voting.NewVoteHash()
+
+ lockedFile, err := os.Open(writer.Path())
+ if err != nil {
+ return fmt.Errorf("opening locked file: %w", err)
+ }
+ defer lockedFile.Close()
+
+ if _, err := io.Copy(hasher, lockedFile); err != nil {
+ return fmt.Errorf("hashing locked file: %w", err)
+ }
+
+ vote, err = hasher.Vote()
+ if err != nil {
+ return fmt.Errorf("computing vote for locked file: %w", err)
+ }
+
+ if err := m.Vote(ctx, tx, vote); err != nil {
+ return fmt.Errorf("preimage vote: %w", err)
+ }
+
+ return nil
+ }); err != nil {
+ return fmt.Errorf("voting on locked file: %w", err)
+ }
+
+ if err := writer.Commit(); err != nil {
+ return fmt.Errorf("committing file: %w", err)
+ }
+
+ if err := VoteOnContext(ctx, m, vote); err != nil {
+ return fmt.Errorf("postimage vote: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/gitaly/transaction/voting_test.go b/internal/gitaly/transaction/voting_test.go
new file mode 100644
index 000000000..8912a3032
--- /dev/null
+++ b/internal/gitaly/transaction/voting_test.go
@@ -0,0 +1,210 @@
+package transaction
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/safe"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/transaction/txinfo"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/transaction/voting"
+ "google.golang.org/grpc/peer"
+)
+
+func TestRunOnContext(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ backchannelPeer := &peer.Peer{
+ AuthInfo: backchannel.WithID(nil, 1234),
+ }
+
+ t.Run("without transaction", func(t *testing.T) {
+ require.NoError(t, RunOnContext(ctx, func(tx txinfo.Transaction) error {
+ t.Fatal("this function should not be executed")
+ return nil
+ }))
+ })
+
+ t.Run("with transaction and no error", func(t *testing.T) {
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ ctx = peer.NewContext(ctx, backchannelPeer)
+
+ callbackExecuted := false
+ require.NoError(t, RunOnContext(ctx, func(tx txinfo.Transaction) error {
+ require.Equal(t, txinfo.Transaction{
+ ID: 5678,
+ Node: "node",
+ Primary: true,
+ BackchannelID: 1234,
+ }, tx)
+ callbackExecuted = true
+ return nil
+ }))
+ require.True(t, callbackExecuted, "callback should have been executed")
+ })
+
+ t.Run("with transaction and error", func(t *testing.T) {
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ ctx = peer.NewContext(ctx, backchannelPeer)
+
+ expectedErr := fmt.Errorf("any error")
+ require.Equal(t, expectedErr, RunOnContext(ctx, func(txinfo.Transaction) error {
+ return expectedErr
+ }))
+ })
+
+ t.Run("with transaction but missing peer", func(t *testing.T) {
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ require.EqualError(t, RunOnContext(ctx, nil), "get peer id: no peer info in context")
+ })
+}
+
+func TestVoteOnContext(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ backchannelPeer := &peer.Peer{
+ AuthInfo: backchannel.WithID(nil, 1234),
+ }
+
+ vote := voting.VoteFromData([]byte("1"))
+
+ t.Run("without transaction", func(t *testing.T) {
+ require.NoError(t, VoteOnContext(ctx, &MockManager{}, voting.Vote{}))
+ })
+
+ t.Run("successful vote", func(t *testing.T) {
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ ctx = peer.NewContext(ctx, backchannelPeer)
+
+ callbackExecuted := false
+ require.NoError(t, VoteOnContext(ctx, &MockManager{
+ VoteFn: func(ctx context.Context, tx txinfo.Transaction, vote voting.Vote) error {
+ require.Equal(t, txinfo.Transaction{
+ ID: 5678,
+ Node: "node",
+ Primary: true,
+ BackchannelID: 1234,
+ }, tx)
+ callbackExecuted = true
+ return nil
+ },
+ }, vote))
+ require.True(t, callbackExecuted, "callback should have been executed")
+ })
+
+ t.Run("failing vote", func(t *testing.T) {
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ ctx = peer.NewContext(ctx, backchannelPeer)
+
+ expectedErr := fmt.Errorf("any error")
+ require.Equal(t, expectedErr, VoteOnContext(ctx, &MockManager{
+ VoteFn: func(ctx context.Context, tx txinfo.Transaction, vote voting.Vote) error {
+ return expectedErr
+ },
+ }, vote))
+ })
+}
+
+func TestCommitLockedFile(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ backchannelPeer := &peer.Peer{
+ AuthInfo: backchannel.WithID(nil, 1234),
+ }
+
+ t.Run("without transaction", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("contents"))
+ require.NoError(t, err)
+
+ require.NoError(t, CommitLockedFile(ctx, &MockManager{}, writer))
+ require.Equal(t, []byte("contents"), testhelper.MustReadFile(t, file))
+ })
+
+ ctx, err := txinfo.InjectTransaction(ctx, 5678, "node", true)
+ require.NoError(t, err)
+ ctx = peer.NewContext(ctx, backchannelPeer)
+
+ t.Run("successful transaction", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("contents"))
+ require.NoError(t, err)
+
+ calls := 0
+ require.NoError(t, CommitLockedFile(ctx, &MockManager{
+ VoteFn: func(ctx context.Context, tx txinfo.Transaction, vote voting.Vote) error {
+ require.Equal(t, txinfo.Transaction{
+ ID: 5678,
+ Node: "node",
+ Primary: true,
+ BackchannelID: 1234,
+ }, tx)
+ require.Equal(t, voting.VoteFromData([]byte("contents")), vote)
+ calls++
+ return nil
+ },
+ }, writer))
+ require.Equal(t, 2, calls, "expected two votes")
+
+ require.Equal(t, []byte("contents"), testhelper.MustReadFile(t, file))
+ })
+
+ t.Run("failing transaction", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("contents"))
+ require.NoError(t, err)
+
+ err = CommitLockedFile(ctx, &MockManager{
+ VoteFn: func(context.Context, txinfo.Transaction, voting.Vote) error {
+ return fmt.Errorf("some error")
+ },
+ }, writer)
+ require.EqualError(t, err, "voting on locked file: preimage vote: some error")
+
+ require.NoFileExists(t, file)
+ })
+
+ t.Run("concurrent modification", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("contents"))
+ require.NoError(t, err)
+
+ err = CommitLockedFile(ctx, &MockManager{
+ VoteFn: func(context.Context, txinfo.Transaction, voting.Vote) error {
+ // This shouldn't typically happen given that the file is locked,
+ // but we concurrently update the file after our first vote.
+ require.NoError(t, ioutil.WriteFile(file, []byte("something"),
+ 0o666))
+ return nil
+ },
+ }, writer)
+ require.EqualError(t, err, "committing file: file concurrently created")
+
+ require.Equal(t, []byte("something"), testhelper.MustReadFile(t, file))
+ })
+}
diff --git a/internal/metadata/featureflag/ff_tx_file_locking.go b/internal/metadata/featureflag/ff_tx_file_locking.go
new file mode 100644
index 000000000..08f1c9f6c
--- /dev/null
+++ b/internal/metadata/featureflag/ff_tx_file_locking.go
@@ -0,0 +1,5 @@
+package featureflag
+
+// TxFileLocking enables two-phase voting on files with proper locking semantics such that no races
+// can exist anymore.
+var TxFileLocking = NewFeatureFlag("tx_file_locking", false)
diff --git a/internal/safe/file_writer.go b/internal/safe/file_writer.go
index 149ce75ac..49d7903d8 100644
--- a/internal/safe/file_writer.go
+++ b/internal/safe/file_writer.go
@@ -21,8 +21,23 @@ type FileWriter struct {
commitOrClose sync.Once
}
-// CreateFileWriter takes path as an absolute path of the target file and creates a new FileWriter by attempting to create a tempfile
-func CreateFileWriter(path string) (*FileWriter, error) {
+// FileWriterConfig contains configuration for the `NewFileWriter()` function.
+type FileWriterConfig struct {
+ // FileMode is the desired file mode of the committed target file. If left at its default
+ // value, then no file mode will be explicitly set for the file.
+ FileMode os.FileMode
+}
+
+// NewFileWriter takes path as an absolute path of the target file and creates a new FileWriter by
+// attempting to create a tempfile. This function either takes no FileWriterConfig or exactly one.
+func NewFileWriter(path string, optionalCfg ...FileWriterConfig) (*FileWriter, error) {
+ var cfg FileWriterConfig
+ if len(optionalCfg) == 1 {
+ cfg = optionalCfg[0]
+ } else if len(optionalCfg) > 1 {
+ return nil, fmt.Errorf("file writer created with more than one config")
+ }
+
writer := &FileWriter{path: path}
directory := filepath.Dir(path)
@@ -32,6 +47,13 @@ func CreateFileWriter(path string) (*FileWriter, error) {
return nil, err
}
+ if cfg.FileMode != 0 {
+ if err := tmpFile.Chmod(cfg.FileMode); err != nil {
+ _ = writer.Close()
+ return nil, err
+ }
+ }
+
writer.tmpFile = tmpFile
return writer, nil
diff --git a/internal/safe/file_writer_test.go b/internal/safe/file_writer_test.go
index 8be19dfd1..d6c83ab03 100644
--- a/internal/safe/file_writer_test.go
+++ b/internal/safe/file_writer_test.go
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/ioutil"
+ "os"
"path/filepath"
"sync"
"testing"
@@ -14,12 +15,12 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
-func TestFile(t *testing.T) {
+func TestFileWriter_successful(t *testing.T) {
dir := testhelper.TempDir(t)
filePath := filepath.Join(dir, "test_file_contents")
fileContents := "very important contents"
- file, err := safe.CreateFileWriter(filePath)
+ file, err := safe.NewFileWriter(filePath)
require.NoError(t, err)
_, err = io.Copy(file, bytes.NewBufferString(fileContents))
@@ -38,7 +39,30 @@ func TestFile(t *testing.T) {
require.Equal(t, filepath.Base(filePath), filesInTempDir[0].Name())
}
-func TestFileRace(t *testing.T) {
+func TestFileWriter_multipleConfigs(t *testing.T) {
+ _, err := safe.NewFileWriter("something", safe.FileWriterConfig{},
+ safe.FileWriterConfig{})
+ require.Equal(t, fmt.Errorf("file writer created with more than one config"), err)
+}
+
+func TestFileWriter_mode(t *testing.T) {
+ dir := testhelper.TempDir(t)
+
+ target := filepath.Join(dir, "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("contents"), 0o600))
+
+ writer, err := safe.NewFileWriter(target, safe.FileWriterConfig{
+ FileMode: 0o060,
+ })
+ require.NoError(t, err)
+ require.NoError(t, writer.Commit())
+
+ fi, err := os.Stat(target)
+ require.NoError(t, err)
+ require.Equal(t, os.FileMode(0o060), fi.Mode())
+}
+
+func TestFileWriter_race(t *testing.T) {
dir := testhelper.TempDir(t)
filePath := filepath.Join(dir, "test_file_contents")
@@ -48,7 +72,7 @@ func TestFileRace(t *testing.T) {
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
- w, err := safe.CreateFileWriter(filePath)
+ w, err := safe.NewFileWriter(filePath)
require.NoError(t, err)
_, err = w.Write([]byte(fmt.Sprintf("message # %d", i)))
require.NoError(t, err)
@@ -64,11 +88,11 @@ func TestFileRace(t *testing.T) {
require.Len(t, filesInTempDir, 1, "make sure no other files were written")
}
-func TestFileCloseBeforeCommit(t *testing.T) {
+func TestFileWriter_closeBeforeCommit(t *testing.T) {
dir := testhelper.TempDir(t)
dstPath := filepath.Join(dir, "safety_meow")
- sf, err := safe.CreateFileWriter(dstPath)
+ sf, err := safe.NewFileWriter(dstPath)
require.NoError(t, err)
require.True(t, !dirEmpty(t, dir), "should contain something")
@@ -82,11 +106,11 @@ func TestFileCloseBeforeCommit(t *testing.T) {
require.Equal(t, safe.ErrAlreadyDone, sf.Commit())
}
-func TestFileCommitBeforeClose(t *testing.T) {
+func TestFileWriter_commitBeforeClose(t *testing.T) {
dir := testhelper.TempDir(t)
dstPath := filepath.Join(dir, "safety_meow")
- sf, err := safe.CreateFileWriter(dstPath)
+ sf, err := safe.NewFileWriter(dstPath)
require.NoError(t, err)
require.False(t, dirEmpty(t, dir), "should contain something")
diff --git a/internal/safe/locking_file_writer.go b/internal/safe/locking_file_writer.go
new file mode 100644
index 000000000..b7e8405c2
--- /dev/null
+++ b/internal/safe/locking_file_writer.go
@@ -0,0 +1,222 @@
+package safe
+
+import (
+ "fmt"
+ "io"
+ "os"
+)
+
+type lockingFileWriterState int
+
+const (
+ lockingFileWriterStateOpen = lockingFileWriterState(iota)
+ lockingFileWriterStateLocked
+ lockingFileWriterStateClosed
+)
+
+// LockingFileWriter is a FileWriter which locks the target file on commit and checks whether it
+// has been modified since the LockingFileWriter has been created. The user must first create a new
+// LockingFileWriter via `NewLockingFileWriter()`, at which point it is open for writes. The writer
+// must be `Lock()`ed before `Commit()`ting changes.
+type LockingFileWriter struct {
+ writer *FileWriter
+ fi os.FileInfo
+ state lockingFileWriterState
+}
+
+// LockingFileWriterConfig contains configuration for the `NewLockingFileWriter()` function.
+type LockingFileWriterConfig struct {
+ // FileWriterConfig is the configuration for the embedded FileWriter.
+ FileWriterConfig
+ // SeedContents will seed the FileWriter's file with contents of the target file if
+ // set. If the target file does not exist, then the file remains empty.
+ SeedContents bool
+}
+
+// NewLockingFileWriter creates a new LockingFileWriter for the given path. At creation, it
+// stats the target file and caches its current size and last modification time such that it can
+// compare on commit whether the file has changed.
+func NewLockingFileWriter(path string, optionalCfg ...LockingFileWriterConfig) (*LockingFileWriter, error) {
+ var cfg LockingFileWriterConfig
+ if len(optionalCfg) == 1 {
+ cfg = optionalCfg[0]
+ } else if len(optionalCfg) > 1 {
+ return nil, fmt.Errorf("locking file writer created with more than one config")
+ }
+
+ targetFile, err := os.Open(path)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, fmt.Errorf("opening target file: %w", err)
+ }
+ defer targetFile.Close()
+
+ var targetFileInfo os.FileInfo
+ if targetFile != nil {
+ targetFileInfo, err = targetFile.Stat()
+ if err != nil {
+ return nil, fmt.Errorf("statting target file: %w", err)
+ }
+ }
+
+ writer, err := NewFileWriter(path, cfg.FileWriterConfig)
+ if err != nil {
+ return nil, fmt.Errorf("creating file writer: %w", err)
+ }
+
+ if targetFile != nil && cfg.SeedContents {
+ _, err := io.Copy(writer, targetFile)
+ if err != nil {
+ return nil, fmt.Errorf("seeding file writer: %w", err)
+ }
+
+ // We need to sync the file to disk such that it's possible to modify its contents
+ // via an external process. Otherwise, external processes may only see partially
+ // written files.
+ if err := writer.tmpFile.Sync(); err != nil {
+ return nil, fmt.Errorf("flushing seeded contents: %w", err)
+ }
+ }
+
+ return &LockingFileWriter{
+ writer: writer,
+ fi: targetFileInfo,
+ }, nil
+}
+
+// Write writes to the FileWriter. Must be called on an open LockingFileWriter.
+func (fw *LockingFileWriter) Write(p []byte) (int, error) {
+ if fw.state != lockingFileWriterStateOpen {
+ return 0, fmt.Errorf("file writer not accepting writes")
+ }
+
+ return fw.writer.Write(p)
+}
+
+// Close closes the FileWriter and removes any locks and temporary files without updating the target
+// file. Does nothing if the file has already been closed.
+func (fw *LockingFileWriter) Close() error {
+ var err error
+ switch fw.state {
+ case lockingFileWriterStateOpen:
+ // No lock has been taken yet, so we don't have to unlock.
+ case lockingFileWriterStateLocked:
+ err = fw.unlock()
+ case lockingFileWriterStateClosed:
+ return nil
+ default:
+ return fmt.Errorf("invalid state %d", fw.state)
+ }
+
+ if writerErr := fw.writer.Close(); writerErr != nil && err == nil {
+ err = fmt.Errorf("closing writer: %w", writerErr)
+ }
+
+ fw.state = lockingFileWriterStateClosed
+
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Lock locks the file writer such that no other process can concurrently update the same file. Must
+// be called on an open LockingFileWriter.
+func (fw *LockingFileWriter) Lock() error {
+ if fw.state != lockingFileWriterStateOpen {
+ return fmt.Errorf("file writer not lockable")
+ }
+
+ if err := fw.checkConcurrentModification(); err != nil {
+ return err
+ }
+
+ lock, err := os.OpenFile(fw.lockPath(), os.O_CREATE|os.O_EXCL|os.O_RDONLY, 0o400)
+ if err != nil {
+ if os.IsExist(err) {
+ return fmt.Errorf("file already locked")
+ }
+
+ return fmt.Errorf("creating lock file: %w", err)
+ }
+ _ = lock.Close()
+
+ fw.state = lockingFileWriterStateLocked
+
+ return nil
+}
+
+func (fw *LockingFileWriter) unlock() error {
+ // We only want to unlock in case we have locked this file ourselves. Otherwise, we risk
+ // removing the lock from another, concurrent locking file writer.
+ if fw.state != lockingFileWriterStateLocked {
+ return fmt.Errorf("file writer not locked")
+ }
+
+ if err := os.Remove(fw.lockPath()); err != nil {
+ return fmt.Errorf("removing lock file: %w", err)
+ }
+
+ fw.state = lockingFileWriterStateClosed
+
+ return nil
+}
+
+// Commit writes whatever has been written to the Filewriter to the target file if and only if the
+// target file has not been modified meanwhile. The writer must be `Lock()`ed first. The writer
+// will be closed after this call, with all locks and temporary files having been removed.
+func (fw *LockingFileWriter) Commit() (returnedErr error) {
+ if fw.state != lockingFileWriterStateLocked {
+ return fmt.Errorf("file writer not locked")
+ }
+
+ // While we have already checked that there was no concurrent modification when locking the
+ // file, we do so again here in order to verify that no other processes which are unaware of
+ // the locking semantics have changed the file. This may be overly cautious, but on the
+ // other hand the single stat(3P) call shouldn't be all that expensive in the first place.
+ if err := fw.checkConcurrentModification(); err != nil {
+ return err
+ }
+
+ if err := fw.writer.Commit(); err != nil {
+ return fmt.Errorf("committing file: %w", err)
+ }
+
+ if err := fw.unlock(); err != nil {
+ return fmt.Errorf("unlocking file: %w", err)
+ }
+
+ return nil
+}
+
+func (fw *LockingFileWriter) checkConcurrentModification() error {
+ fi, err := os.Stat(fw.writer.path)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("statting path: %w", err)
+ }
+
+ if fw.fi == nil && fi != nil {
+ return fmt.Errorf("file concurrently created")
+ }
+ if fw.fi != nil && fi == nil {
+ return fmt.Errorf("file concurrently deleted")
+ }
+ if fw.fi != nil && fi != nil {
+ if fw.fi.Size() != fi.Size() || fw.fi.ModTime() != fi.ModTime() || fw.fi.Mode() != fi.Mode() {
+ return fmt.Errorf("file concurrently modified")
+ }
+ }
+
+ return nil
+}
+
+// Path returns the path of the intermediate file the FileWriter is writing to. Exposing the path
+// allows an external process to write to the file directly. While it would be preferable to use the
+// io.Writer interface instead, this is not easily doable e.g. for Git processes.
+func (fw *LockingFileWriter) Path() string {
+ return fw.writer.tmpFile.Name()
+}
+
+func (fw *LockingFileWriter) lockPath() string {
+ return fw.writer.path + ".lock"
+}
diff --git a/internal/safe/locking_file_writer_test.go b/internal/safe/locking_file_writer_test.go
new file mode 100644
index 000000000..bf813a2d3
--- /dev/null
+++ b/internal/safe/locking_file_writer_test.go
@@ -0,0 +1,304 @@
+package safe_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/git/gittest"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/safe"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
+)
+
+func TestLockingFileWriter_lifecycle(t *testing.T) {
+ t.Parallel()
+
+ t.Run("normal lifecycle", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+ require.NoError(t, writer.Close())
+ })
+
+ t.Run("multiple locks fail", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.Equal(t, fmt.Errorf("file writer not lockable"), writer.Lock())
+ })
+
+ t.Run("commit without lock fails", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.Equal(t, fmt.Errorf("file writer not locked"), writer.Commit())
+ })
+
+ t.Run("multiple commits fail", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+ require.Equal(t, fmt.Errorf("file writer not locked"), writer.Commit())
+ })
+
+ t.Run("lock after close fails", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+ require.Equal(t, fmt.Errorf("file writer not lockable"), writer.Lock())
+ })
+
+ t.Run("multiple closes succeed", func(t *testing.T) {
+ writer, err := safe.NewLockingFileWriter(filepath.Join(testhelper.TempDir(t), "file"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Close())
+ require.NoError(t, writer.Close())
+ })
+}
+
+func TestLockingFileWriter_stateCleanup(t *testing.T) {
+ t.Parallel()
+
+ t.Run("commit", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+ lock := file + ".lock"
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ require.FileExists(t, writer.Path())
+ require.NoFileExists(t, lock)
+
+ require.NoError(t, writer.Lock())
+ require.FileExists(t, writer.Path())
+ require.FileExists(t, lock)
+
+ require.NoError(t, writer.Commit())
+ require.NoFileExists(t, writer.Path())
+ require.NoFileExists(t, lock)
+ })
+
+ t.Run("close", func(t *testing.T) {
+ file := filepath.Join(testhelper.TempDir(t), "file")
+ lock := file + ".lock"
+
+ writer, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ require.FileExists(t, writer.Path())
+ require.NoFileExists(t, lock)
+
+ require.NoError(t, writer.Lock())
+ require.FileExists(t, writer.Path())
+ require.FileExists(t, lock)
+
+ require.NoError(t, writer.Close())
+ require.NoFileExists(t, writer.Path())
+ require.NoFileExists(t, lock)
+ })
+}
+
+func TestLockingFileWriter_createsNewFiles(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("created"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte("created"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_createsEmptyFiles(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte{}, testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_seedingWithNonExistentTarget(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(target, safe.LockingFileWriterConfig{
+ SeedContents: true,
+ })
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte{}, testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_seedingWithExistingTarget(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("seed"), 0o644))
+
+ writer, err := safe.NewLockingFileWriter(target, safe.LockingFileWriterConfig{
+ SeedContents: true,
+ })
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("append"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte("seedappend"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_modifiesExistingFiles(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("preexisting"), 0o644))
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+ _, err = writer.Write([]byte("modified"))
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte("modified"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_modifiesExistingFilesWithMode(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("preexisting"), 0o644))
+
+ writer, err := safe.NewLockingFileWriter(target, safe.LockingFileWriterConfig{
+ FileWriterConfig: safe.FileWriterConfig{FileMode: 0o060},
+ })
+ require.NoError(t, err)
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ fi, err := os.Stat(target)
+ require.NoError(t, err)
+ require.Equal(t, os.FileMode(0o060), fi.Mode())
+}
+
+func TestLockingFileWriter_concurrentCreation(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+
+ // Create file concurrently.
+ require.NoError(t, ioutil.WriteFile(target, []byte("concurrent"), 0o644))
+
+ require.Equal(t, fmt.Errorf("file concurrently created"), writer.Lock())
+
+ require.Equal(t, []byte("concurrent"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_concurrentDeletion(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ require.NoError(t, ioutil.WriteFile(target, []byte("base"), 0o644))
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+
+ // Delete file concurrently.
+ require.NoError(t, os.Remove(target))
+
+ require.Equal(t, fmt.Errorf("file concurrently deleted"), writer.Lock())
+
+ require.NoFileExists(t, target)
+}
+
+func TestLockingFileWriter_concurrentModification(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+
+ require.NoError(t, ioutil.WriteFile(target, []byte("base"), 0o644))
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+
+ // Concurrently modify the file.
+ require.NoError(t, ioutil.WriteFile(target, []byte("concurrent"), 0o644))
+
+ require.Equal(t, fmt.Errorf("file concurrently modified"), writer.Lock())
+
+ require.Equal(t, []byte("concurrent"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_concurrentLocking(t *testing.T) {
+ t.Parallel()
+
+ file := filepath.Join(testhelper.TempDir(t), "file")
+
+ first, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = first.Write([]byte("first"))
+ require.NoError(t, err)
+
+ second, err := safe.NewLockingFileWriter(file)
+ require.NoError(t, err)
+ _, err = second.Write([]byte("second"))
+ require.NoError(t, err)
+
+ require.NoError(t, first.Lock())
+ require.Equal(t, fmt.Errorf("file already locked"), second.Lock())
+ require.NoError(t, first.Commit())
+
+ require.Equal(t, []byte("first"), testhelper.MustReadFile(t, file))
+}
+
+func TestLockingFileWriter_locked(t *testing.T) {
+ t.Parallel()
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("base"), 0o644))
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+
+ // Concurrently lock the file.
+ require.NoError(t, ioutil.WriteFile(target+".lock", nil, 0o644))
+
+ require.Equal(t, fmt.Errorf("file already locked"), writer.Lock())
+
+ require.Equal(t, []byte("base"), testhelper.MustReadFile(t, target))
+}
+
+func TestLockingFileWriter_externalProcess(t *testing.T) {
+ t.Parallel()
+
+ cfg := testcfg.Build(t)
+
+ target := filepath.Join(testhelper.TempDir(t), "file")
+ require.NoError(t, ioutil.WriteFile(target, []byte("base"), 0o644))
+
+ writer, err := safe.NewLockingFileWriter(target)
+ require.NoError(t, err)
+
+ gittest.Exec(t, cfg, "config", "-f", writer.Path(), "some.config", "true")
+ require.NoError(t, writer.Lock())
+ require.NoError(t, writer.Commit())
+
+ require.Equal(t, []byte("[some]\n\tconfig = true\n"), testhelper.MustReadFile(t, target))
+}