Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer (GitLab) <jacob@gitlab.com>2017-09-11 15:46:15 +0300
committerJacob Vosmaer (GitLab) <jacob@gitlab.com>2017-09-11 15:46:15 +0300
commitda0ab569f1eec7032afde72c6b8cce5d1b2a22de (patch)
tree189e9dfeeda280124a7e0acd64bc85cd29653b88
parentaea5e82175047b514dbd0e0a75053b1e014c7612 (diff)
parent5776a29cb41046ebbe4e0741b5ed50862980cdb1 (diff)
Merge branch 'deprecate-command-close' into 'master'
Use context cancellation instead of command.Close See merge request !332
-rw-r--r--CHANGELOG.md5
-rw-r--r--internal/command/command.go100
-rw-r--r--internal/git/catfile/catfile.go1
-rw-r--r--internal/git/log/commit.go1
-rw-r--r--internal/helper/repo.go2
-rw-r--r--internal/linguist/linguist.go1
-rw-r--r--internal/rubyserver/rubyserver_test.go11
-rw-r--r--internal/server/auth_test.go6
-rw-r--r--internal/service/blob/get_blob.go1
-rw-r--r--internal/service/blob/get_blob_test.go16
-rw-r--r--internal/service/commit/commits_helper.go1
-rw-r--r--internal/service/commit/count_commits.go1
-rw-r--r--internal/service/commit/find_commits_test.go6
-rw-r--r--internal/service/commit/isancestor.go7
-rw-r--r--internal/service/commit/languages.go1
-rw-r--r--internal/service/commit/list_files.go1
-rw-r--r--internal/service/commit/raw_blame.go1
-rw-r--r--internal/service/diff/commit.go3
-rw-r--r--internal/service/diff/patch_test.go43
-rw-r--r--internal/service/ref/branches_test.go14
-rw-r--r--internal/service/ref/refexists.go6
-rw-r--r--internal/service/ref/refexists_test.go8
-rw-r--r--internal/service/ref/refname.go1
-rw-r--r--internal/service/ref/refs.go3
-rw-r--r--internal/service/repository/apply_gitattributes_test.go19
-rw-r--r--internal/service/repository/fetch_remote.go6
-rw-r--r--internal/service/repository/fetch_remote_test.go11
-rw-r--r--internal/service/repository/gc.go7
-rw-r--r--internal/service/repository/repack.go7
-rw-r--r--internal/service/smarthttp/inforefs.go5
-rw-r--r--internal/service/smarthttp/receive_pack.go5
-rw-r--r--internal/service/smarthttp/upload_pack.go3
-rw-r--r--internal/service/ssh/receive_pack.go3
-rw-r--r--internal/service/ssh/upload_pack.go3
-rw-r--r--internal/testhelper/testhelper.go31
35 files changed, 192 insertions, 148 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef1864d77..6d6832714 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,10 @@
# Gitaly changelog
-v0.39.0
+UNRELEASED
+- Use context cancellation instead of command.Close
+ https://gitlab.com/gitlab-org/gitaly/merge_requests/332
+v0.39.0
- Reimplement FindAllTags RPC in Ruby
https://gitlab.com/gitlab-org/gitaly/merge_requests/334
- Re-use gitaly-ruby client connection
diff --git a/internal/command/command.go b/internal/command/command.go
index 21aa5fec9..f81320cd2 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
+ "io/ioutil"
"os"
"os/exec"
"path"
@@ -36,15 +37,36 @@ var exportedEnvVars = []string{
"GIT_TRACE_SETUP",
}
-// Command encapsulates operations with commands creates with NewCommand
+// Command encapsulates a running exec.Cmd. The embedded exec.Cmd is
+// terminated and reaped automatically when the context.Context that
+// created it is canceled.
type Command struct {
- io.Reader
- *exec.Cmd
+ reader io.Reader
+ cmd *exec.Cmd
context context.Context
startTime time.Time
- done chan struct{}
- closeOnce sync.Once
- closeErr error
+
+ waitError error
+ waitOnce sync.Once
+}
+
+// Read calls Read() on the stdout pipe of the command.
+func (c *Command) Read(p []byte) (int, error) {
+ if c.reader == nil {
+ panic("command has no reader")
+ }
+
+ return c.reader.Read(p)
+}
+
+// Wait calls Wait() on the exec.Cmd instance inside the command. This
+// blocks until the command has finished and reports the command exit
+// status via the error return value. Use ExitStatus to get the integer
+// exit status from the error returned by Wait().
+func (c *Command) Wait() error {
+ c.waitOnce.Do(c.wait)
+
+ return c.waitError
}
// GitPath returns the path to the `git` binary. See `SetGitPath` for details
@@ -76,18 +98,32 @@ func GitlabShell(ctx context.Context, envs []string, executable string, args ...
return New(ctx, exec.Command(path.Join(shellPath, executable), args...), nil, nil, nil, envs...)
}
-// New creates a Command from an exec.Cmd
+var wg = &sync.WaitGroup{}
+
+// WaitAllDone waits for all commands started by the command package to
+// finish. This can only be called once in the lifecycle of the current
+// Go process.
+func WaitAllDone() {
+ wg.Wait()
+}
+
+// New creates a Command from an exec.Cmd. On success, the Command
+// contains a running subprocess. When ctx is canceled the embedded
+// process will be terminated and reaped automatically.
func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.Writer, env ...string) (*Command, error) {
- grpc_logrus.Extract(ctx).WithFields(log.Fields{
- "path": cmd.Path,
- "args": cmd.Args,
- }).Info("spawn")
+ logPid := -1
+ defer func() {
+ grpc_logrus.Extract(ctx).WithFields(log.Fields{
+ "pid": logPid,
+ "path": cmd.Path,
+ "args": cmd.Args,
+ }).Info("spawn")
+ }()
command := &Command{
- Cmd: cmd,
+ cmd: cmd,
startTime: time.Now(),
context: ctx,
- done: make(chan struct{}),
}
// Explicitly set the environment for the command
@@ -120,7 +156,7 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.
if err != nil {
return nil, fmt.Errorf("GitCommand: stdout: %v", err)
}
- command.Reader = pipe
+ command.reader = pipe
}
if stderr != nil {
@@ -134,18 +170,22 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.
return nil, fmt.Errorf("GitCommand: start %v: %v", cmd.Args, err)
}
+ // The goroutine below is responsible for terminating and reaping the
+ // process when ctx is canceled.
+ wg.Add(1)
go func() {
- select {
- case <-command.done:
- case <-ctx.Done():
- }
+ <-ctx.Done()
if process := cmd.Process; process != nil && process.Pid > 0 {
// Send SIGTERM to the process group of cmd
syscall.Kill(-process.Pid, syscall.SIGTERM)
}
+ command.Wait()
+ wg.Done()
}()
+ logPid = cmd.Process.Pid
+
return command, nil
}
@@ -159,20 +199,18 @@ func exportEnvironment(env []string) []string {
return env
}
-// Close will send a SIGTERM signal to the process group
-// belonging to the `cmd` process
-func (c *Command) Close() error {
- c.closeOnce.Do(c.close)
- return c.closeErr
-}
+// This function should never be called directly, use Wait().
+func (c *Command) wait() {
+ if c.reader != nil {
+ // Prevent the command from blocking on writing to its stdout.
+ io.Copy(ioutil.Discard, c.reader)
+ }
-func (c *Command) close() {
- close(c.done)
- c.closeErr = c.Cmd.Wait()
+ c.waitError = c.cmd.Wait()
exitCode := 0
- if c.closeErr != nil {
- if exitStatus, ok := ExitStatus(c.closeErr); ok {
+ if c.waitError != nil {
+ if exitStatus, ok := ExitStatus(c.waitError); ok {
exitCode = exitStatus
}
}
@@ -180,7 +218,7 @@ func (c *Command) close() {
c.logProcessComplete(c.context, exitCode)
}
-// ExitStatus will return the exit-code from an error
+// ExitStatus will return the exit-code from an error returned by Wait().
func ExitStatus(err error) (int, bool) {
exitError, ok := err.(*exec.ExitError)
if !ok {
@@ -196,7 +234,7 @@ func ExitStatus(err error) (int, bool) {
}
func (c *Command) logProcessComplete(ctx context.Context, exitCode int) {
- cmd := c.Cmd
+ cmd := c.cmd
systemTime := cmd.ProcessState.SystemTime()
userTime := cmd.ProcessState.UserTime()
diff --git a/internal/git/catfile/catfile.go b/internal/git/catfile/catfile.go
index 5fbae763c..144595eda 100644
--- a/internal/git/catfile/catfile.go
+++ b/internal/git/catfile/catfile.go
@@ -36,7 +36,6 @@ func CatFile(ctx context.Context, repoPath string, handler Handler) error {
if err != nil {
return grpc.Errorf(codes.Internal, "CatFile: cmd: %v", err)
}
- defer cmd.Close()
defer stdinWriter.Close()
defer stdinReader.Close()
diff --git a/internal/git/log/commit.go b/internal/git/log/commit.go
index 62ec6c65c..574511030 100644
--- a/internal/git/log/commit.go
+++ b/internal/git/log/commit.go
@@ -39,7 +39,6 @@ func GetCommit(ctx context.Context, repo *pb.Repository, revision string, path s
if err != nil {
return nil, err
}
- defer cmd.Close()
logParser := NewLogParser(cmd)
if ok := logParser.Parse(); !ok {
diff --git a/internal/helper/repo.go b/internal/helper/repo.go
index ed392912b..9e4999248 100644
--- a/internal/helper/repo.go
+++ b/internal/helper/repo.go
@@ -92,8 +92,6 @@ func IsValidRef(ctx context.Context, path, ref string) bool {
if err != nil {
return false
}
- defer cmd.Close()
- cmd.Stdout, cmd.Stderr, cmd.Stdin = nil, nil, nil
return cmd.Wait() == nil
}
diff --git a/internal/linguist/linguist.go b/internal/linguist/linguist.go
index b8ab27269..0a0359626 100644
--- a/internal/linguist/linguist.go
+++ b/internal/linguist/linguist.go
@@ -32,7 +32,6 @@ func Stats(ctx context.Context, repoPath string, commitID string) (map[string]in
if err != nil {
return nil, err
}
- defer reader.Close()
data, err := ioutil.ReadAll(reader)
if err != nil {
diff --git a/internal/rubyserver/rubyserver_test.go b/internal/rubyserver/rubyserver_test.go
index fbc9b033a..cb3dfd1e2 100644
--- a/internal/rubyserver/rubyserver_test.go
+++ b/internal/rubyserver/rubyserver_test.go
@@ -1,7 +1,6 @@
package rubyserver
import (
- "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -45,13 +44,17 @@ func TestSetHeaders(t *testing.T) {
}
for _, tc := range testCases {
- ctx, err := SetHeaders(context.Background(), tc.repo)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ clientCtx, err := SetHeaders(ctx, tc.repo)
+
if tc.errType != codes.OK {
testhelper.AssertGrpcError(t, err, tc.errType, "")
- assert.Nil(t, ctx)
+ assert.Nil(t, clientCtx)
} else {
assert.NoError(t, err)
- assert.NotNil(t, ctx)
+ assert.NotNil(t, clientCtx)
}
}
}
diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go
index f5c8ae197..a92d61188 100644
--- a/internal/server/auth_test.go
+++ b/internal/server/auth_test.go
@@ -1,7 +1,6 @@
package server
import (
- "context"
"net"
"testing"
"time"
@@ -130,8 +129,11 @@ func dial(opts []grpc.DialOption) (*grpc.ClientConn, error) {
}
func healthCheck(conn *grpc.ClientConn) error {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
client := healthpb.NewHealthClient(conn)
- _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
+ _, err := client.Check(ctx, &healthpb.HealthCheckRequest{})
return err
}
diff --git a/internal/service/blob/get_blob.go b/internal/service/blob/get_blob.go
index 2cd1344e5..573292e66 100644
--- a/internal/service/blob/get_blob.go
+++ b/internal/service/blob/get_blob.go
@@ -34,7 +34,6 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer
if err != nil {
return grpc.Errorf(codes.Internal, "GetBlob: cmd: %v", err)
}
- defer cmd.Close()
defer stdinWriter.Close()
defer stdinReader.Close()
diff --git a/internal/service/blob/get_blob_test.go b/internal/service/blob/get_blob_test.go
index 2440fa23d..464000a36 100644
--- a/internal/service/blob/get_blob_test.go
+++ b/internal/service/blob/get_blob_test.go
@@ -2,7 +2,6 @@ package blob
import (
"bytes"
- "context"
"fmt"
"io"
"testing"
@@ -71,7 +70,10 @@ func TestSuccessfulGetBlob(t *testing.T) {
Limit: int64(tc.limit),
}
- stream, err := client.GetBlob(context.Background(), request)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ stream, err := client.GetBlob(ctx, request)
require.NoError(t, err, "initiate RPC")
reportedSize, reportedOid, data, err := getBlob(stream)
@@ -97,7 +99,10 @@ func TestGetBlobNotFound(t *testing.T) {
Oid: "doesnotexist",
}
- stream, err := client.GetBlob(context.Background(), request)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ stream, err := client.GetBlob(ctx, request)
require.NoError(t, err)
reportedSize, reportedOid, data, err := getBlob(stream)
@@ -150,7 +155,10 @@ func TestFailedGetBlobRequestDueToValidationError(t *testing.T) {
}
for _, rpcRequest := range rpcRequests {
- stream, err := client.GetBlob(context.Background(), &rpcRequest)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ stream, err := client.GetBlob(ctx, &rpcRequest)
require.NoError(t, err, rpcRequest)
_, err = stream.Recv()
require.NotEqual(t, io.EOF, err, rpcRequest)
diff --git a/internal/service/commit/commits_helper.go b/internal/service/commit/commits_helper.go
index efd6a849a..9042cb3b3 100644
--- a/internal/service/commit/commits_helper.go
+++ b/internal/service/commit/commits_helper.go
@@ -19,7 +19,6 @@ func sendCommits(ctx context.Context, sender commitsSender, repo *pb.Repository,
if err != nil {
return err
}
- defer cmd.Close()
logParser := log.NewLogParser(cmd)
diff --git a/internal/service/commit/count_commits.go b/internal/service/commit/count_commits.go
index 6663acb13..a0c519a4a 100644
--- a/internal/service/commit/count_commits.go
+++ b/internal/service/commit/count_commits.go
@@ -44,7 +44,6 @@ func (s *server) CountCommits(ctx context.Context, in *pb.CountCommitsRequest) (
if err != nil {
return nil, grpc.Errorf(codes.Internal, "CountCommits: cmd: %v", err)
}
- defer cmd.Close()
var count int64
countStr, readAllErr := ioutil.ReadAll(cmd)
diff --git a/internal/service/commit/find_commits_test.go b/internal/service/commit/find_commits_test.go
index 963b52bc0..54f8e83e3 100644
--- a/internal/service/commit/find_commits_test.go
+++ b/internal/service/commit/find_commits_test.go
@@ -6,6 +6,7 @@ import (
"testing"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/stretchr/testify/require"
@@ -186,7 +187,10 @@ func TestSuccessfulFindCommitsRequest(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- stream, err := client.FindCommits(context.Background(), tc.request)
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ stream, err := client.FindCommits(ctx, tc.request)
require.NoError(t, err)
var ids []string
diff --git a/internal/service/commit/isancestor.go b/internal/service/commit/isancestor.go
index 53990c900..cafa08edb 100644
--- a/internal/service/commit/isancestor.go
+++ b/internal/service/commit/isancestor.go
@@ -1,9 +1,6 @@
package commit
import (
- "io/ioutil"
- "os/exec"
-
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -40,12 +37,10 @@ func commitIsAncestorName(ctx context.Context, path, ancestorID, childID string)
"childSha": childID,
}).Debug("commitIsAncestor")
- osCommand := exec.Command(command.GitPath(), "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID)
- cmd, err := command.New(ctx, osCommand, nil, ioutil.Discard, nil)
+ cmd, err := command.Git(ctx, "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID)
if err != nil {
return false, grpc.Errorf(codes.Internal, err.Error())
}
- defer cmd.Close()
return cmd.Wait() == nil, nil
}
diff --git a/internal/service/commit/languages.go b/internal/service/commit/languages.go
index 8618f3f46..dddc84235 100644
--- a/internal/service/commit/languages.go
+++ b/internal/service/commit/languages.go
@@ -82,7 +82,6 @@ func lookupRevision(ctx context.Context, repoPath string, revision string) (stri
if err != nil {
return "", err
}
- defer revParse.Close()
revParseBytes, err := ioutil.ReadAll(revParse)
if err != nil {
diff --git a/internal/service/commit/list_files.go b/internal/service/commit/list_files.go
index b994f12e2..1c9488385 100644
--- a/internal/service/commit/list_files.go
+++ b/internal/service/commit/list_files.go
@@ -39,7 +39,6 @@ func (s *server) ListFiles(in *pb.ListFilesRequest, stream pb.CommitService_List
if err != nil {
return grpc.Errorf(codes.Internal, err.Error())
}
- defer cmd.Close()
return lines.Send(cmd, listFilesWriter(stream), []byte{'\x00'})
}
diff --git a/internal/service/commit/raw_blame.go b/internal/service/commit/raw_blame.go
index a0342e5bf..1bfca2647 100644
--- a/internal/service/commit/raw_blame.go
+++ b/internal/service/commit/raw_blame.go
@@ -33,7 +33,6 @@ func (s *server) RawBlame(in *pb.RawBlameRequest, stream pb.CommitService_RawBla
if err != nil {
return grpc.Errorf(codes.Internal, "RawBlame: cmd: %v", err)
}
- defer cmd.Close()
sw := streamio.NewWriter(func(p []byte) error {
return stream.Send(&pb.RawBlameResponse{Data: p})
diff --git a/internal/service/diff/commit.go b/internal/service/diff/commit.go
index 974e528cd..e9455dca4 100644
--- a/internal/service/diff/commit.go
+++ b/internal/service/diff/commit.go
@@ -219,7 +219,6 @@ func eachDiff(ctx context.Context, rpc string, cmdArgs []string, limits diff.Lim
if err != nil {
return grpc.Errorf(codes.Internal, "%s: cmd: %v", rpc, err)
}
- defer cmd.Close()
diffParser := diff.NewDiffParser(cmd, limits)
@@ -234,7 +233,7 @@ func eachDiff(ctx context.Context, rpc string, cmdArgs []string, limits diff.Lim
}
if err := cmd.Wait(); err != nil {
- return grpc.Errorf(codes.Unavailable, "%s: cmd wait for %v: %v", rpc, cmd.Args, err)
+ return grpc.Errorf(codes.Unavailable, "%s: %v", rpc, err)
}
return nil
diff --git a/internal/service/diff/patch_test.go b/internal/service/diff/patch_test.go
index 52bba4a8c..32d9415d2 100644
--- a/internal/service/diff/patch_test.go
+++ b/internal/service/diff/patch_test.go
@@ -6,8 +6,6 @@ import (
"github.com/stretchr/testify/assert"
- "golang.org/x/net/context"
-
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
)
@@ -37,30 +35,33 @@ func TestSuccessfulCommitPatchRequest(t *testing.T) {
}
for _, testCase := range testCases {
- t.Log(testCase.desc)
-
- request := &pb.CommitPatchRequest{
- Repository: testRepo,
- Revision: testCase.revision,
- }
+ t.Run(testCase.desc, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
- c, err := client.CommitPatch(context.Background(), request)
- if err != nil {
- t.Fatal(err)
- }
+ request := &pb.CommitPatchRequest{
+ Repository: testRepo,
+ Revision: testCase.revision,
+ }
- data := []byte{}
- for {
- r, err := c.Recv()
- if err == io.EOF {
- break
- } else if err != nil {
+ c, err := client.CommitPatch(ctx, request)
+ if err != nil {
t.Fatal(err)
}
- data = append(data, r.GetData()...)
- }
+ data := []byte{}
+ for {
+ r, err := c.Recv()
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ t.Fatal(err)
+ }
+
+ data = append(data, r.GetData()...)
+ }
- assert.Equal(t, testCase.diff, data)
+ assert.Equal(t, testCase.diff, data)
+ })
}
}
diff --git a/internal/service/ref/branches_test.go b/internal/service/ref/branches_test.go
index 585f3b467..31620c07a 100644
--- a/internal/service/ref/branches_test.go
+++ b/internal/service/ref/branches_test.go
@@ -1,6 +1,7 @@
package ref
import (
+ "context"
"os/exec"
"testing"
@@ -10,22 +11,24 @@ import (
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"github.com/stretchr/testify/require"
- "golang.org/x/net/context"
"google.golang.org/grpc/codes"
)
func TestSuccessfulCreateBranchRequest(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
server := runRefServiceServer(t)
defer server.Stop()
client, conn := newRefClient(t)
defer conn.Close()
- headCommit, err := log.GetCommit(context.Background(), testRepo, "HEAD", "")
+ headCommit, err := log.GetCommit(ctx, testRepo, "HEAD", "")
require.NoError(t, err)
startPoint := "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd"
- startPointCommit, err := log.GetCommit(context.Background(), testRepo, startPoint, "")
+ startPointCommit, err := log.GetCommit(ctx, testRepo, startPoint, "")
require.NoError(t, err)
testCases := []struct {
@@ -210,6 +213,9 @@ func TestFailedDeleteBranchRequest(t *testing.T) {
}
func TestSuccessfulFindBranchRequest(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
server := runRefServiceServer(t)
defer server.Stop()
@@ -217,7 +223,7 @@ func TestSuccessfulFindBranchRequest(t *testing.T) {
defer conn.Close()
branchNameInput := "master"
- branchTarget, err := log.GetCommit(context.Background(), testRepo, branchNameInput, "")
+ branchTarget, err := log.GetCommit(ctx, testRepo, branchNameInput, "")
require.NoError(t, err)
branch := &pb.Branch{
diff --git a/internal/service/ref/refexists.go b/internal/service/ref/refexists.go
index 6f88522cb..be5e61ccc 100644
--- a/internal/service/ref/refexists.go
+++ b/internal/service/ref/refexists.go
@@ -1,8 +1,6 @@
package ref
import (
- "io/ioutil"
- "os/exec"
"strings"
log "github.com/Sirupsen/logrus"
@@ -41,12 +39,10 @@ func refExists(ctx context.Context, repoPath string, ref string) (bool, error) {
return false, grpc.Errorf(codes.InvalidArgument, "invalid refname")
}
- osCommand := exec.Command(command.GitPath(), "--git-dir", repoPath, "show-ref", "--verify", "--quiet", ref)
- cmd, err := command.New(ctx, osCommand, nil, ioutil.Discard, nil)
+ cmd, err := command.Git(ctx, "--git-dir", repoPath, "show-ref", "--verify", "--quiet", ref)
if err != nil {
return false, grpc.Errorf(codes.Internal, err.Error())
}
- defer cmd.Close()
err = cmd.Wait()
if err == nil {
diff --git a/internal/service/ref/refexists_test.go b/internal/service/ref/refexists_test.go
index 2204b39ee..c0275bbe5 100644
--- a/internal/service/ref/refexists_test.go
+++ b/internal/service/ref/refexists_test.go
@@ -3,12 +3,11 @@ package ref
import (
"testing"
- "golang.org/x/net/context"
-
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
)
func TestRefExists(t *testing.T) {
@@ -37,6 +36,9 @@ func TestRefExists(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
server := runRefServiceServer(t)
defer server.Stop()
@@ -45,7 +47,7 @@ func TestRefExists(t *testing.T) {
req := &pb.RefExistsRequest{Repository: tt.repo, Ref: []byte(tt.ref)}
- got, err := client.RefExists(context.Background(), req)
+ got, err := client.RefExists(ctx, req)
if grpc.Code(err) != tt.wantErr {
t.Errorf("server.RefExists() error = %v, wantErr %v", err, tt.wantErr)
diff --git a/internal/service/ref/refname.go b/internal/service/ref/refname.go
index 1b16a34c8..06a2d537f 100644
--- a/internal/service/ref/refname.go
+++ b/internal/service/ref/refname.go
@@ -46,7 +46,6 @@ func findRefName(ctx context.Context, path, commitID, prefix string) (string, er
if err != nil {
return "", err
}
- defer cmd.Close()
scanner := bufio.NewScanner(cmd)
scanner.Scan()
diff --git a/internal/service/ref/refs.go b/internal/service/ref/refs.go
index 30fefce72..d50b388a4 100644
--- a/internal/service/ref/refs.go
+++ b/internal/service/ref/refs.go
@@ -57,7 +57,6 @@ func findRefs(ctx context.Context, writer lines.Sender, repo *pb.Repository, pat
if err != nil {
return err
}
- defer cmd.Close()
if err := lines.Send(cmd, writer, opts.delim); err != nil {
return err
@@ -110,7 +109,6 @@ func _findBranchNames(ctx context.Context, repoPath string) ([][]byte, error) {
if err != nil {
return nil, err
}
- defer cmd.Close()
scanner := bufio.NewScanner(cmd)
for scanner.Scan() {
@@ -134,7 +132,6 @@ func _headReference(ctx context.Context, repoPath string) ([]byte, error) {
if err != nil {
return nil, err
}
- defer cmd.Close()
scanner := bufio.NewScanner(cmd)
scanner.Scan()
diff --git a/internal/service/repository/apply_gitattributes_test.go b/internal/service/repository/apply_gitattributes_test.go
index fbbf4d6f3..50ea86751 100644
--- a/internal/service/repository/apply_gitattributes_test.go
+++ b/internal/service/repository/apply_gitattributes_test.go
@@ -1,6 +1,7 @@
package repository
import (
+ "fmt"
"io/ioutil"
"os"
"path"
@@ -11,7 +12,6 @@ import (
"github.com/stretchr/testify/assert"
pb "gitlab.com/gitlab-org/gitaly-proto/go"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
- "golang.org/x/net/context"
)
func TestApplyGitattributesSuccess(t *testing.T) {
@@ -108,16 +108,23 @@ func TestApplyGitattributesFailure(t *testing.T) {
}
for _, test := range tests {
- req := &pb.ApplyGitattributesRequest{Repository: test.repo, Revision: test.revision}
- _, err := client.ApplyGitattributes(context.Background(), req)
- testhelper.AssertGrpcError(t, err, test.code, "")
- }
+ t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+ req := &pb.ApplyGitattributesRequest{Repository: test.repo, Revision: test.revision}
+ _, err := client.ApplyGitattributes(ctx, req)
+ testhelper.AssertGrpcError(t, err, test.code, "")
+ })
+ }
}
func assertGitattributesApplied(t *testing.T, client pb.RepositoryServiceClient, attributesPath string, revision, expectedContents []byte) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
req := &pb.ApplyGitattributesRequest{Repository: testRepo, Revision: revision}
- c, err := client.ApplyGitattributes(context.Background(), req)
+ c, err := client.ApplyGitattributes(ctx, req)
assert.NoError(t, err)
assert.NotNil(t, c)
diff --git a/internal/service/repository/fetch_remote.go b/internal/service/repository/fetch_remote.go
index f247b5e8d..e5438b4a4 100644
--- a/internal/service/repository/fetch_remote.go
+++ b/internal/service/repository/fetch_remote.go
@@ -2,8 +2,6 @@ package repository
import (
"fmt"
- "io"
- "io/ioutil"
log "github.com/Sirupsen/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
@@ -36,10 +34,6 @@ func (server) FetchRemote(ctx context.Context, in *pb.FetchRemoteRequest) (*pb.F
return nil, grpc.Errorf(codes.Internal, err.Error())
}
- if _, err = io.Copy(ioutil.Discard, cmd); err != nil {
- return nil, grpc.Errorf(codes.Internal, err.Error())
- }
-
if err = cmd.Wait(); err != nil {
return nil, grpc.Errorf(codes.Internal, err.Error())
}
diff --git a/internal/service/repository/fetch_remote_test.go b/internal/service/repository/fetch_remote_test.go
index 3da43409f..a4776ffe9 100644
--- a/internal/service/repository/fetch_remote_test.go
+++ b/internal/service/repository/fetch_remote_test.go
@@ -1,7 +1,6 @@
package repository
import (
- "context"
"io/ioutil"
"os"
"path"
@@ -108,6 +107,9 @@ func TestFetchRemoteArgsBuilder(t *testing.T) {
// NOTE: Only tests that `gitlab-shell` is being called, not what it does.
func TestFetchRemoteSuccess(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
dir, err := ioutil.TempDir("", "gitlab-shell.")
require.NoError(t, err)
defer func(dir string) {
@@ -135,7 +137,7 @@ func TestFetchRemoteSuccess(t *testing.T) {
os.RemoveAll(path)
}(cloneRepo)
- resp, err := client.FetchRemote(context.Background(), &pb.FetchRemoteRequest{
+ resp, err := client.FetchRemote(ctx, &pb.FetchRemoteRequest{
Repository: cloneRepo,
Remote: "my-remote",
})
@@ -172,6 +174,9 @@ func TestFetchRemoteFailure(t *testing.T) {
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
if len(tc.shellPath) == 0 {
defer func(oldPath string) {
config.Config.GitlabShell.Dir = oldPath
@@ -179,7 +184,7 @@ func TestFetchRemoteFailure(t *testing.T) {
config.Config.GitlabShell.Dir = ""
}
- resp, err := client.FetchRemote(context.Background(), tc.req)
+ resp, err := client.FetchRemote(ctx, tc.req)
testhelper.AssertGrpcError(t, err, tc.code, tc.err)
assert.Error(t, err)
assert.Nil(t, resp)
diff --git a/internal/service/repository/gc.go b/internal/service/repository/gc.go
index b7c162621..f8e7a155c 100644
--- a/internal/service/repository/gc.go
+++ b/internal/service/repository/gc.go
@@ -1,9 +1,6 @@
package repository
import (
- "io"
- "io/ioutil"
-
log "github.com/Sirupsen/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"golang.org/x/net/context"
@@ -37,10 +34,6 @@ func (server) GarbageCollect(ctx context.Context, in *pb.GarbageCollectRequest)
return nil, grpc.Errorf(codes.Internal, err.Error())
}
- if _, err := io.Copy(ioutil.Discard, cmd); err != nil {
- return nil, grpc.Errorf(codes.Internal, err.Error())
- }
-
if err := cmd.Wait(); err != nil {
return nil, grpc.Errorf(codes.Internal, err.Error())
}
diff --git a/internal/service/repository/repack.go b/internal/service/repository/repack.go
index 55fe576ab..43201bafc 100644
--- a/internal/service/repository/repack.go
+++ b/internal/service/repository/repack.go
@@ -1,9 +1,6 @@
package repository
import (
- "io"
- "io/ioutil"
-
log "github.com/Sirupsen/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"golang.org/x/net/context"
@@ -52,10 +49,6 @@ func repackCommand(ctx context.Context, rpcName string, repo *pb.Repository, bit
return grpc.Errorf(codes.Internal, err.Error())
}
- if _, err := io.Copy(ioutil.Discard, cmd); err != nil {
- return grpc.Errorf(codes.Internal, err.Error())
- }
-
if err := cmd.Wait(); err != nil {
return grpc.Errorf(codes.Internal, err.Error())
}
diff --git a/internal/service/smarthttp/inforefs.go b/internal/service/smarthttp/inforefs.go
index b01d3911d..ad71839cf 100644
--- a/internal/service/smarthttp/inforefs.go
+++ b/internal/service/smarthttp/inforefs.go
@@ -44,7 +44,6 @@ func handleInfoRefs(ctx context.Context, service string, repo *pb.Repository, w
if err != nil {
return grpc.Errorf(codes.Internal, "GetInfoRefs: cmd: %v", err)
}
- defer cmd.Close()
if err := pktLine(w, fmt.Sprintf("# service=git-%s\n", service)); err != nil {
return grpc.Errorf(codes.Internal, "GetInfoRefs: pktLine: %v", err)
@@ -55,11 +54,11 @@ func handleInfoRefs(ctx context.Context, service string, repo *pb.Repository, w
}
if _, err := io.Copy(w, cmd); err != nil {
- return grpc.Errorf(codes.Internal, "GetInfoRefs: copy output of %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err)
}
if err := cmd.Wait(); err != nil {
- return grpc.Errorf(codes.Internal, "GetInfoRefs: wait for %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err)
}
return nil
diff --git a/internal/service/smarthttp/receive_pack.go b/internal/service/smarthttp/receive_pack.go
index e0d4fc021..3009ff008 100644
--- a/internal/service/smarthttp/receive_pack.go
+++ b/internal/service/smarthttp/receive_pack.go
@@ -53,12 +53,11 @@ func (s *server) PostReceivePack(stream pb.SmartHTTPService_PostReceivePackServe
cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, nil, env...)
if err != nil {
- return grpc.Errorf(codes.Unavailable, "PostReceivePack: cmd: %v", err)
+ return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err)
}
- defer cmd.Close()
if err := cmd.Wait(); err != nil {
- return grpc.Errorf(codes.Unavailable, "PostReceivePack: cmd wait for %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err)
}
return nil
diff --git a/internal/service/smarthttp/upload_pack.go b/internal/service/smarthttp/upload_pack.go
index 8baf6d851..273aa9abb 100644
--- a/internal/service/smarthttp/upload_pack.go
+++ b/internal/service/smarthttp/upload_pack.go
@@ -66,7 +66,6 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer)
if err != nil {
return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd: %v", err)
}
- defer cmd.Close()
if err := cmd.Wait(); err != nil {
pw.Close() // ensure scanDeepen returns
@@ -77,7 +76,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer)
deepenCount.Inc()
return nil
}
- return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd wait for %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Unavailable, "PostUploadPack: %v", err)
}
return nil
diff --git a/internal/service/ssh/receive_pack.go b/internal/service/ssh/receive_pack.go
index 0fa80d6b7..8dd456af4 100644
--- a/internal/service/ssh/receive_pack.go
+++ b/internal/service/ssh/receive_pack.go
@@ -59,7 +59,6 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error
if err != nil {
return grpc.Errorf(codes.Unavailable, "SSHReceivePack: cmd: %v", err)
}
- defer cmd.Close()
if err := cmd.Wait(); err != nil {
if status, ok := command.ExitStatus(err); ok {
@@ -68,7 +67,7 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error
stream.Send(&pb.SSHReceivePackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}),
)
}
- return grpc.Errorf(codes.Unavailable, "SSHReceivePack: cmd wait for %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Unavailable, "SSHReceivePack: %v", err)
}
return nil
diff --git a/internal/service/ssh/upload_pack.go b/internal/service/ssh/upload_pack.go
index 811d4eb45..1fe56c568 100644
--- a/internal/service/ssh/upload_pack.go
+++ b/internal/service/ssh/upload_pack.go
@@ -53,7 +53,6 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
if err != nil {
return grpc.Errorf(codes.Unavailable, "SSHUploadPack: cmd: %v", err)
}
- defer cmd.Close()
if err := cmd.Wait(); err != nil {
if status, ok := command.ExitStatus(err); ok {
@@ -62,7 +61,7 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error {
stream.Send(&pb.SSHUploadPackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}),
)
}
- return grpc.Errorf(codes.Unavailable, "SSHUploadPack: cmd wait for %v: %v", cmd.Args, err)
+ return grpc.Errorf(codes.Unavailable, "SSHUploadPack: %v", err)
}
return nil
diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go
index 69e47ea15..4c4807430 100644
--- a/internal/testhelper/testhelper.go
+++ b/internal/testhelper/testhelper.go
@@ -2,6 +2,7 @@ package testhelper
import (
"bytes"
+ "context"
"fmt"
"io"
"io/ioutil"
@@ -13,6 +14,7 @@ import (
"strings"
"syscall"
"testing"
+ "time"
log "github.com/Sirupsen/logrus"
@@ -196,8 +198,20 @@ func NewTestGrpcServer(t *testing.T, streamInterceptors []grpc.StreamServerInter
}
// MustHaveNoChildProcess panics if it finds a running or finished child
-// process.
+// process. It waits for 2 seconds for processes to be cleaned up by other
+// goroutines.
func MustHaveNoChildProcess() {
+ waitDone := make(chan struct{})
+ go func() {
+ command.WaitAllDone()
+ close(waitDone)
+ }()
+
+ select {
+ case <-waitDone:
+ case <-time.After(2 * time.Second):
+ }
+
mustFindNoFinishedChildProcess()
mustFindNoRunningChildProcess()
}
@@ -209,11 +223,7 @@ func mustFindNoFinishedChildProcess() {
// rusage. Use WNOHANG to return immediately if there is no child waiting
// to be reaped.
wpid, err := syscall.Wait4(-1, nil, syscall.WNOHANG, nil)
- if err != nil {
- return
- }
-
- if wpid > 0 {
+ if err == nil && wpid > 0 {
panic(fmt.Errorf("wait4 found child process %d", wpid))
}
}
@@ -226,7 +236,7 @@ func mustFindNoRunningChildProcess() {
if err == nil {
pidsComma := strings.Replace(strings.TrimSpace(string(out)), ",", "\n", -1)
psOut, _ := exec.Command("ps", "-o", "pid,args", "-p", pidsComma).Output()
- panic(fmt.Sprintf("found running child processes %s:\n%s", pidsComma, psOut))
+ panic(fmt.Errorf("found running child processes %s:\n%s", pidsComma, psOut))
}
if status, ok := command.ExitStatus(err); ok && status == 1 {
@@ -234,5 +244,10 @@ func mustFindNoRunningChildProcess() {
return
}
- panic(fmt.Sprintf("%s: %v", desc, err))
+ panic(fmt.Errorf("%s: %v", desc, err))
+}
+
+// Context returns a cancellable context.
+func Context() (context.Context, func()) {
+ return context.WithCancel(context.Background())
}