diff options
-rw-r--r-- | internal/gitaly/service/repository/write_ref.go | 51 |
1 files changed, 46 insertions, 5 deletions
diff --git a/internal/gitaly/service/repository/write_ref.go b/internal/gitaly/service/repository/write_ref.go index eeb32561c..b37e8d1a3 100644 --- a/internal/gitaly/service/repository/write_ref.go +++ b/internal/gitaly/service/repository/write_ref.go @@ -8,7 +8,9 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/internal/git" "gitlab.com/gitlab-org/gitaly/v16/internal/git/localrepo" "gitlab.com/gitlab-org/gitaly/v16/internal/git/updateref" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/service" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/txutil" "gitlab.com/gitlab-org/gitaly/v16/internal/structerr" "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" ) @@ -17,8 +19,15 @@ func (s *server) WriteRef(ctx context.Context, req *gitalypb.WriteRefRequest) (* if err := validateWriteRefRequest(req); err != nil { return nil, structerr.NewInvalidArgument("%w", err) } - if err := s.writeRef(ctx, req); err != nil { - return nil, structerr.NewInternal("%w", err) + + if s.partitionManager != nil { + if err := s.writeRefWAL(ctx, req); err != nil { + return nil, err + } + } else { + if err := s.writeRef(ctx, req); err != nil { + return nil, structerr.NewInternal("%w", err) + } } return &gitalypb.WriteRefResponse{}, nil @@ -38,7 +47,30 @@ func (s *server) writeRef(ctx context.Context, req *gitalypb.WriteRefRequest) er return updateRef(ctx, repo, req) } -func updateRef(ctx context.Context, repo *localrepo.Repo, req *gitalypb.WriteRefRequest) (returnedErr error) { +func (s *server) writeRefWAL(ctx context.Context, req *gitalypb.WriteRefRequest) error { + tx, err := s.partitionManager.Begin(ctx, req.GetRepository()) + if err != nil { + return fmt.Errorf("begin: %w", err) + } + defer txutil.LogRollback(ctx, tx) + + if string(req.Ref) == "HEAD" { + tx.SetDefaultBranch(git.ReferenceName(req.GetRevision())) + } else { + oldOID, newOID, err := resolveObjectIDs(ctx, s.localrepo(req.GetRepository()), req) + if err != nil { + return err + } + + tx.UpdateReferences(gitaly.ReferenceUpdates{ + git.ReferenceName(req.GetRef()): {Force: oldOID == "", OldOID: oldOID, NewOID: newOID}, + }) + } + + return tx.Commit(ctx) +} + +func resolveObjectIDs(ctx context.Context, repo *localrepo.Repo, req *gitalypb.WriteRefRequest) (git.ObjectID, git.ObjectID, error) { var newObjectID git.ObjectID if git.ObjectHashSHA1.IsZeroOID(git.ObjectID(req.GetRevision())) { // Passing the all-zeroes object ID as new value means that we should delete the @@ -52,7 +84,7 @@ func updateRef(ctx context.Context, repo *localrepo.Repo, req *gitalypb.WriteRef var err error newObjectID, err = repo.ResolveRevision(ctx, git.Revision(req.GetRevision())+"^{object}") if err != nil { - return fmt.Errorf("resolving new revision: %w", err) + return "", "", fmt.Errorf("resolving new revision: %w", err) } } @@ -66,11 +98,20 @@ func updateRef(ctx context.Context, repo *localrepo.Repo, req *gitalypb.WriteRef var err error oldObjectID, err = repo.ResolveRevision(ctx, git.Revision(req.GetOldRevision())+"^{object}") if err != nil { - return fmt.Errorf("resolving old revision: %w", err) + return "", "", fmt.Errorf("resolving old revision: %w", err) } } } + return oldObjectID, newObjectID, nil +} + +func updateRef(ctx context.Context, repo *localrepo.Repo, req *gitalypb.WriteRefRequest) (returnedErr error) { + oldObjectID, newObjectID, err := resolveObjectIDs(ctx, repo, req) + if err != nil { + return err + } + u, err := updateref.New(ctx, repo) if err != nil { return fmt.Errorf("error when running creating new updater: %w", err) |