diff options
author | Jacob Vosmaer (GitLab) <jacob@gitlab.com> | 2017-09-11 15:46:15 +0300 |
---|---|---|
committer | Jacob Vosmaer (GitLab) <jacob@gitlab.com> | 2017-09-11 15:46:15 +0300 |
commit | da0ab569f1eec7032afde72c6b8cce5d1b2a22de (patch) | |
tree | 189e9dfeeda280124a7e0acd64bc85cd29653b88 | |
parent | aea5e82175047b514dbd0e0a75053b1e014c7612 (diff) | |
parent | 5776a29cb41046ebbe4e0741b5ed50862980cdb1 (diff) |
Merge branch 'deprecate-command-close' into 'master'
Use context cancellation instead of command.Close
See merge request !332
35 files changed, 192 insertions, 148 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index ef1864d77..6d6832714 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,10 @@ # Gitaly changelog -v0.39.0 +UNRELEASED +- Use context cancellation instead of command.Close + https://gitlab.com/gitlab-org/gitaly/merge_requests/332 +v0.39.0 - Reimplement FindAllTags RPC in Ruby https://gitlab.com/gitlab-org/gitaly/merge_requests/334 - Re-use gitaly-ruby client connection diff --git a/internal/command/command.go b/internal/command/command.go index 21aa5fec9..f81320cd2 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "io/ioutil" "os" "os/exec" "path" @@ -36,15 +37,36 @@ var exportedEnvVars = []string{ "GIT_TRACE_SETUP", } -// Command encapsulates operations with commands creates with NewCommand +// Command encapsulates a running exec.Cmd. The embedded exec.Cmd is +// terminated and reaped automatically when the context.Context that +// created it is canceled. type Command struct { - io.Reader - *exec.Cmd + reader io.Reader + cmd *exec.Cmd context context.Context startTime time.Time - done chan struct{} - closeOnce sync.Once - closeErr error + + waitError error + waitOnce sync.Once +} + +// Read calls Read() on the stdout pipe of the command. +func (c *Command) Read(p []byte) (int, error) { + if c.reader == nil { + panic("command has no reader") + } + + return c.reader.Read(p) +} + +// Wait calls Wait() on the exec.Cmd instance inside the command. This +// blocks until the command has finished and reports the command exit +// status via the error return value. Use ExitStatus to get the integer +// exit status from the error returned by Wait(). +func (c *Command) Wait() error { + c.waitOnce.Do(c.wait) + + return c.waitError } // GitPath returns the path to the `git` binary. See `SetGitPath` for details @@ -76,18 +98,32 @@ func GitlabShell(ctx context.Context, envs []string, executable string, args ... return New(ctx, exec.Command(path.Join(shellPath, executable), args...), nil, nil, nil, envs...) } -// New creates a Command from an exec.Cmd +var wg = &sync.WaitGroup{} + +// WaitAllDone waits for all commands started by the command package to +// finish. This can only be called once in the lifecycle of the current +// Go process. +func WaitAllDone() { + wg.Wait() +} + +// New creates a Command from an exec.Cmd. On success, the Command +// contains a running subprocess. When ctx is canceled the embedded +// process will be terminated and reaped automatically. func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.Writer, env ...string) (*Command, error) { - grpc_logrus.Extract(ctx).WithFields(log.Fields{ - "path": cmd.Path, - "args": cmd.Args, - }).Info("spawn") + logPid := -1 + defer func() { + grpc_logrus.Extract(ctx).WithFields(log.Fields{ + "pid": logPid, + "path": cmd.Path, + "args": cmd.Args, + }).Info("spawn") + }() command := &Command{ - Cmd: cmd, + cmd: cmd, startTime: time.Now(), context: ctx, - done: make(chan struct{}), } // Explicitly set the environment for the command @@ -120,7 +156,7 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io. if err != nil { return nil, fmt.Errorf("GitCommand: stdout: %v", err) } - command.Reader = pipe + command.reader = pipe } if stderr != nil { @@ -134,18 +170,22 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io. return nil, fmt.Errorf("GitCommand: start %v: %v", cmd.Args, err) } + // The goroutine below is responsible for terminating and reaping the + // process when ctx is canceled. + wg.Add(1) go func() { - select { - case <-command.done: - case <-ctx.Done(): - } + <-ctx.Done() if process := cmd.Process; process != nil && process.Pid > 0 { // Send SIGTERM to the process group of cmd syscall.Kill(-process.Pid, syscall.SIGTERM) } + command.Wait() + wg.Done() }() + logPid = cmd.Process.Pid + return command, nil } @@ -159,20 +199,18 @@ func exportEnvironment(env []string) []string { return env } -// Close will send a SIGTERM signal to the process group -// belonging to the `cmd` process -func (c *Command) Close() error { - c.closeOnce.Do(c.close) - return c.closeErr -} +// This function should never be called directly, use Wait(). +func (c *Command) wait() { + if c.reader != nil { + // Prevent the command from blocking on writing to its stdout. + io.Copy(ioutil.Discard, c.reader) + } -func (c *Command) close() { - close(c.done) - c.closeErr = c.Cmd.Wait() + c.waitError = c.cmd.Wait() exitCode := 0 - if c.closeErr != nil { - if exitStatus, ok := ExitStatus(c.closeErr); ok { + if c.waitError != nil { + if exitStatus, ok := ExitStatus(c.waitError); ok { exitCode = exitStatus } } @@ -180,7 +218,7 @@ func (c *Command) close() { c.logProcessComplete(c.context, exitCode) } -// ExitStatus will return the exit-code from an error +// ExitStatus will return the exit-code from an error returned by Wait(). func ExitStatus(err error) (int, bool) { exitError, ok := err.(*exec.ExitError) if !ok { @@ -196,7 +234,7 @@ func ExitStatus(err error) (int, bool) { } func (c *Command) logProcessComplete(ctx context.Context, exitCode int) { - cmd := c.Cmd + cmd := c.cmd systemTime := cmd.ProcessState.SystemTime() userTime := cmd.ProcessState.UserTime() diff --git a/internal/git/catfile/catfile.go b/internal/git/catfile/catfile.go index 5fbae763c..144595eda 100644 --- a/internal/git/catfile/catfile.go +++ b/internal/git/catfile/catfile.go @@ -36,7 +36,6 @@ func CatFile(ctx context.Context, repoPath string, handler Handler) error { if err != nil { return grpc.Errorf(codes.Internal, "CatFile: cmd: %v", err) } - defer cmd.Close() defer stdinWriter.Close() defer stdinReader.Close() diff --git a/internal/git/log/commit.go b/internal/git/log/commit.go index 62ec6c65c..574511030 100644 --- a/internal/git/log/commit.go +++ b/internal/git/log/commit.go @@ -39,7 +39,6 @@ func GetCommit(ctx context.Context, repo *pb.Repository, revision string, path s if err != nil { return nil, err } - defer cmd.Close() logParser := NewLogParser(cmd) if ok := logParser.Parse(); !ok { diff --git a/internal/helper/repo.go b/internal/helper/repo.go index ed392912b..9e4999248 100644 --- a/internal/helper/repo.go +++ b/internal/helper/repo.go @@ -92,8 +92,6 @@ func IsValidRef(ctx context.Context, path, ref string) bool { if err != nil { return false } - defer cmd.Close() - cmd.Stdout, cmd.Stderr, cmd.Stdin = nil, nil, nil return cmd.Wait() == nil } diff --git a/internal/linguist/linguist.go b/internal/linguist/linguist.go index b8ab27269..0a0359626 100644 --- a/internal/linguist/linguist.go +++ b/internal/linguist/linguist.go @@ -32,7 +32,6 @@ func Stats(ctx context.Context, repoPath string, commitID string) (map[string]in if err != nil { return nil, err } - defer reader.Close() data, err := ioutil.ReadAll(reader) if err != nil { diff --git a/internal/rubyserver/rubyserver_test.go b/internal/rubyserver/rubyserver_test.go index fbc9b033a..cb3dfd1e2 100644 --- a/internal/rubyserver/rubyserver_test.go +++ b/internal/rubyserver/rubyserver_test.go @@ -1,7 +1,6 @@ package rubyserver import ( - "context" "testing" "github.com/stretchr/testify/assert" @@ -45,13 +44,17 @@ func TestSetHeaders(t *testing.T) { } for _, tc := range testCases { - ctx, err := SetHeaders(context.Background(), tc.repo) + ctx, cancel := testhelper.Context() + defer cancel() + + clientCtx, err := SetHeaders(ctx, tc.repo) + if tc.errType != codes.OK { testhelper.AssertGrpcError(t, err, tc.errType, "") - assert.Nil(t, ctx) + assert.Nil(t, clientCtx) } else { assert.NoError(t, err) - assert.NotNil(t, ctx) + assert.NotNil(t, clientCtx) } } } diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index f5c8ae197..a92d61188 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "net" "testing" "time" @@ -130,8 +129,11 @@ func dial(opts []grpc.DialOption) (*grpc.ClientConn, error) { } func healthCheck(conn *grpc.ClientConn) error { + ctx, cancel := testhelper.Context() + defer cancel() + client := healthpb.NewHealthClient(conn) - _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + _, err := client.Check(ctx, &healthpb.HealthCheckRequest{}) return err } diff --git a/internal/service/blob/get_blob.go b/internal/service/blob/get_blob.go index 2cd1344e5..573292e66 100644 --- a/internal/service/blob/get_blob.go +++ b/internal/service/blob/get_blob.go @@ -34,7 +34,6 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer if err != nil { return grpc.Errorf(codes.Internal, "GetBlob: cmd: %v", err) } - defer cmd.Close() defer stdinWriter.Close() defer stdinReader.Close() diff --git a/internal/service/blob/get_blob_test.go b/internal/service/blob/get_blob_test.go index 2440fa23d..464000a36 100644 --- a/internal/service/blob/get_blob_test.go +++ b/internal/service/blob/get_blob_test.go @@ -2,7 +2,6 @@ package blob import ( "bytes" - "context" "fmt" "io" "testing" @@ -71,7 +70,10 @@ func TestSuccessfulGetBlob(t *testing.T) { Limit: int64(tc.limit), } - stream, err := client.GetBlob(context.Background(), request) + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlob(ctx, request) require.NoError(t, err, "initiate RPC") reportedSize, reportedOid, data, err := getBlob(stream) @@ -97,7 +99,10 @@ func TestGetBlobNotFound(t *testing.T) { Oid: "doesnotexist", } - stream, err := client.GetBlob(context.Background(), request) + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlob(ctx, request) require.NoError(t, err) reportedSize, reportedOid, data, err := getBlob(stream) @@ -150,7 +155,10 @@ func TestFailedGetBlobRequestDueToValidationError(t *testing.T) { } for _, rpcRequest := range rpcRequests { - stream, err := client.GetBlob(context.Background(), &rpcRequest) + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlob(ctx, &rpcRequest) require.NoError(t, err, rpcRequest) _, err = stream.Recv() require.NotEqual(t, io.EOF, err, rpcRequest) diff --git a/internal/service/commit/commits_helper.go b/internal/service/commit/commits_helper.go index efd6a849a..9042cb3b3 100644 --- a/internal/service/commit/commits_helper.go +++ b/internal/service/commit/commits_helper.go @@ -19,7 +19,6 @@ func sendCommits(ctx context.Context, sender commitsSender, repo *pb.Repository, if err != nil { return err } - defer cmd.Close() logParser := log.NewLogParser(cmd) diff --git a/internal/service/commit/count_commits.go b/internal/service/commit/count_commits.go index 6663acb13..a0c519a4a 100644 --- a/internal/service/commit/count_commits.go +++ b/internal/service/commit/count_commits.go @@ -44,7 +44,6 @@ func (s *server) CountCommits(ctx context.Context, in *pb.CountCommitsRequest) ( if err != nil { return nil, grpc.Errorf(codes.Internal, "CountCommits: cmd: %v", err) } - defer cmd.Close() var count int64 countStr, readAllErr := ioutil.ReadAll(cmd) diff --git a/internal/service/commit/find_commits_test.go b/internal/service/commit/find_commits_test.go index 963b52bc0..54f8e83e3 100644 --- a/internal/service/commit/find_commits_test.go +++ b/internal/service/commit/find_commits_test.go @@ -6,6 +6,7 @@ import ( "testing" pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" "github.com/golang/protobuf/ptypes/timestamp" "github.com/stretchr/testify/require" @@ -186,7 +187,10 @@ func TestSuccessfulFindCommitsRequest(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - stream, err := client.FindCommits(context.Background(), tc.request) + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.FindCommits(ctx, tc.request) require.NoError(t, err) var ids []string diff --git a/internal/service/commit/isancestor.go b/internal/service/commit/isancestor.go index 53990c900..cafa08edb 100644 --- a/internal/service/commit/isancestor.go +++ b/internal/service/commit/isancestor.go @@ -1,9 +1,6 @@ package commit import ( - "io/ioutil" - "os/exec" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -40,12 +37,10 @@ func commitIsAncestorName(ctx context.Context, path, ancestorID, childID string) "childSha": childID, }).Debug("commitIsAncestor") - osCommand := exec.Command(command.GitPath(), "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID) - cmd, err := command.New(ctx, osCommand, nil, ioutil.Discard, nil) + cmd, err := command.Git(ctx, "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID) if err != nil { return false, grpc.Errorf(codes.Internal, err.Error()) } - defer cmd.Close() return cmd.Wait() == nil, nil } diff --git a/internal/service/commit/languages.go b/internal/service/commit/languages.go index 8618f3f46..dddc84235 100644 --- a/internal/service/commit/languages.go +++ b/internal/service/commit/languages.go @@ -82,7 +82,6 @@ func lookupRevision(ctx context.Context, repoPath string, revision string) (stri if err != nil { return "", err } - defer revParse.Close() revParseBytes, err := ioutil.ReadAll(revParse) if err != nil { diff --git a/internal/service/commit/list_files.go b/internal/service/commit/list_files.go index b994f12e2..1c9488385 100644 --- a/internal/service/commit/list_files.go +++ b/internal/service/commit/list_files.go @@ -39,7 +39,6 @@ func (s *server) ListFiles(in *pb.ListFilesRequest, stream pb.CommitService_List if err != nil { return grpc.Errorf(codes.Internal, err.Error()) } - defer cmd.Close() return lines.Send(cmd, listFilesWriter(stream), []byte{'\x00'}) } diff --git a/internal/service/commit/raw_blame.go b/internal/service/commit/raw_blame.go index a0342e5bf..1bfca2647 100644 --- a/internal/service/commit/raw_blame.go +++ b/internal/service/commit/raw_blame.go @@ -33,7 +33,6 @@ func (s *server) RawBlame(in *pb.RawBlameRequest, stream pb.CommitService_RawBla if err != nil { return grpc.Errorf(codes.Internal, "RawBlame: cmd: %v", err) } - defer cmd.Close() sw := streamio.NewWriter(func(p []byte) error { return stream.Send(&pb.RawBlameResponse{Data: p}) diff --git a/internal/service/diff/commit.go b/internal/service/diff/commit.go index 974e528cd..e9455dca4 100644 --- a/internal/service/diff/commit.go +++ b/internal/service/diff/commit.go @@ -219,7 +219,6 @@ func eachDiff(ctx context.Context, rpc string, cmdArgs []string, limits diff.Lim if err != nil { return grpc.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) } - defer cmd.Close() diffParser := diff.NewDiffParser(cmd, limits) @@ -234,7 +233,7 @@ func eachDiff(ctx context.Context, rpc string, cmdArgs []string, limits diff.Lim } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Unavailable, "%s: cmd wait for %v: %v", rpc, cmd.Args, err) + return grpc.Errorf(codes.Unavailable, "%s: %v", rpc, err) } return nil diff --git a/internal/service/diff/patch_test.go b/internal/service/diff/patch_test.go index 52bba4a8c..32d9415d2 100644 --- a/internal/service/diff/patch_test.go +++ b/internal/service/diff/patch_test.go @@ -6,8 +6,6 @@ import ( "github.com/stretchr/testify/assert" - "golang.org/x/net/context" - pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/testhelper" ) @@ -37,30 +35,33 @@ func TestSuccessfulCommitPatchRequest(t *testing.T) { } for _, testCase := range testCases { - t.Log(testCase.desc) - - request := &pb.CommitPatchRequest{ - Repository: testRepo, - Revision: testCase.revision, - } + t.Run(testCase.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() - c, err := client.CommitPatch(context.Background(), request) - if err != nil { - t.Fatal(err) - } + request := &pb.CommitPatchRequest{ + Repository: testRepo, + Revision: testCase.revision, + } - data := []byte{} - for { - r, err := c.Recv() - if err == io.EOF { - break - } else if err != nil { + c, err := client.CommitPatch(ctx, request) + if err != nil { t.Fatal(err) } - data = append(data, r.GetData()...) - } + data := []byte{} + for { + r, err := c.Recv() + if err == io.EOF { + break + } else if err != nil { + t.Fatal(err) + } + + data = append(data, r.GetData()...) + } - assert.Equal(t, testCase.diff, data) + assert.Equal(t, testCase.diff, data) + }) } } diff --git a/internal/service/ref/branches_test.go b/internal/service/ref/branches_test.go index 585f3b467..31620c07a 100644 --- a/internal/service/ref/branches_test.go +++ b/internal/service/ref/branches_test.go @@ -1,6 +1,7 @@ package ref import ( + "context" "os/exec" "testing" @@ -10,22 +11,24 @@ import ( pb "gitlab.com/gitlab-org/gitaly-proto/go" "github.com/stretchr/testify/require" - "golang.org/x/net/context" "google.golang.org/grpc/codes" ) func TestSuccessfulCreateBranchRequest(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + server := runRefServiceServer(t) defer server.Stop() client, conn := newRefClient(t) defer conn.Close() - headCommit, err := log.GetCommit(context.Background(), testRepo, "HEAD", "") + headCommit, err := log.GetCommit(ctx, testRepo, "HEAD", "") require.NoError(t, err) startPoint := "c7fbe50c7c7419d9701eebe64b1fdacc3df5b9dd" - startPointCommit, err := log.GetCommit(context.Background(), testRepo, startPoint, "") + startPointCommit, err := log.GetCommit(ctx, testRepo, startPoint, "") require.NoError(t, err) testCases := []struct { @@ -210,6 +213,9 @@ func TestFailedDeleteBranchRequest(t *testing.T) { } func TestSuccessfulFindBranchRequest(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + server := runRefServiceServer(t) defer server.Stop() @@ -217,7 +223,7 @@ func TestSuccessfulFindBranchRequest(t *testing.T) { defer conn.Close() branchNameInput := "master" - branchTarget, err := log.GetCommit(context.Background(), testRepo, branchNameInput, "") + branchTarget, err := log.GetCommit(ctx, testRepo, branchNameInput, "") require.NoError(t, err) branch := &pb.Branch{ diff --git a/internal/service/ref/refexists.go b/internal/service/ref/refexists.go index 6f88522cb..be5e61ccc 100644 --- a/internal/service/ref/refexists.go +++ b/internal/service/ref/refexists.go @@ -1,8 +1,6 @@ package ref import ( - "io/ioutil" - "os/exec" "strings" log "github.com/Sirupsen/logrus" @@ -41,12 +39,10 @@ func refExists(ctx context.Context, repoPath string, ref string) (bool, error) { return false, grpc.Errorf(codes.InvalidArgument, "invalid refname") } - osCommand := exec.Command(command.GitPath(), "--git-dir", repoPath, "show-ref", "--verify", "--quiet", ref) - cmd, err := command.New(ctx, osCommand, nil, ioutil.Discard, nil) + cmd, err := command.Git(ctx, "--git-dir", repoPath, "show-ref", "--verify", "--quiet", ref) if err != nil { return false, grpc.Errorf(codes.Internal, err.Error()) } - defer cmd.Close() err = cmd.Wait() if err == nil { diff --git a/internal/service/ref/refexists_test.go b/internal/service/ref/refexists_test.go index 2204b39ee..c0275bbe5 100644 --- a/internal/service/ref/refexists_test.go +++ b/internal/service/ref/refexists_test.go @@ -3,12 +3,11 @@ package ref import ( "testing" - "golang.org/x/net/context" - "google.golang.org/grpc" "google.golang.org/grpc/codes" pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" ) func TestRefExists(t *testing.T) { @@ -37,6 +36,9 @@ func TestRefExists(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + server := runRefServiceServer(t) defer server.Stop() @@ -45,7 +47,7 @@ func TestRefExists(t *testing.T) { req := &pb.RefExistsRequest{Repository: tt.repo, Ref: []byte(tt.ref)} - got, err := client.RefExists(context.Background(), req) + got, err := client.RefExists(ctx, req) if grpc.Code(err) != tt.wantErr { t.Errorf("server.RefExists() error = %v, wantErr %v", err, tt.wantErr) diff --git a/internal/service/ref/refname.go b/internal/service/ref/refname.go index 1b16a34c8..06a2d537f 100644 --- a/internal/service/ref/refname.go +++ b/internal/service/ref/refname.go @@ -46,7 +46,6 @@ func findRefName(ctx context.Context, path, commitID, prefix string) (string, er if err != nil { return "", err } - defer cmd.Close() scanner := bufio.NewScanner(cmd) scanner.Scan() diff --git a/internal/service/ref/refs.go b/internal/service/ref/refs.go index 30fefce72..d50b388a4 100644 --- a/internal/service/ref/refs.go +++ b/internal/service/ref/refs.go @@ -57,7 +57,6 @@ func findRefs(ctx context.Context, writer lines.Sender, repo *pb.Repository, pat if err != nil { return err } - defer cmd.Close() if err := lines.Send(cmd, writer, opts.delim); err != nil { return err @@ -110,7 +109,6 @@ func _findBranchNames(ctx context.Context, repoPath string) ([][]byte, error) { if err != nil { return nil, err } - defer cmd.Close() scanner := bufio.NewScanner(cmd) for scanner.Scan() { @@ -134,7 +132,6 @@ func _headReference(ctx context.Context, repoPath string) ([]byte, error) { if err != nil { return nil, err } - defer cmd.Close() scanner := bufio.NewScanner(cmd) scanner.Scan() diff --git a/internal/service/repository/apply_gitattributes_test.go b/internal/service/repository/apply_gitattributes_test.go index fbbf4d6f3..50ea86751 100644 --- a/internal/service/repository/apply_gitattributes_test.go +++ b/internal/service/repository/apply_gitattributes_test.go @@ -1,6 +1,7 @@ package repository import ( + "fmt" "io/ioutil" "os" "path" @@ -11,7 +12,6 @@ import ( "github.com/stretchr/testify/assert" pb "gitlab.com/gitlab-org/gitaly-proto/go" "gitlab.com/gitlab-org/gitaly/internal/testhelper" - "golang.org/x/net/context" ) func TestApplyGitattributesSuccess(t *testing.T) { @@ -108,16 +108,23 @@ func TestApplyGitattributesFailure(t *testing.T) { } for _, test := range tests { - req := &pb.ApplyGitattributesRequest{Repository: test.repo, Revision: test.revision} - _, err := client.ApplyGitattributes(context.Background(), req) - testhelper.AssertGrpcError(t, err, test.code, "") - } + t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + req := &pb.ApplyGitattributesRequest{Repository: test.repo, Revision: test.revision} + _, err := client.ApplyGitattributes(ctx, req) + testhelper.AssertGrpcError(t, err, test.code, "") + }) + } } func assertGitattributesApplied(t *testing.T, client pb.RepositoryServiceClient, attributesPath string, revision, expectedContents []byte) { + ctx, cancel := testhelper.Context() + defer cancel() + req := &pb.ApplyGitattributesRequest{Repository: testRepo, Revision: revision} - c, err := client.ApplyGitattributes(context.Background(), req) + c, err := client.ApplyGitattributes(ctx, req) assert.NoError(t, err) assert.NotNil(t, c) diff --git a/internal/service/repository/fetch_remote.go b/internal/service/repository/fetch_remote.go index f247b5e8d..e5438b4a4 100644 --- a/internal/service/repository/fetch_remote.go +++ b/internal/service/repository/fetch_remote.go @@ -2,8 +2,6 @@ package repository import ( "fmt" - "io" - "io/ioutil" log "github.com/Sirupsen/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" @@ -36,10 +34,6 @@ func (server) FetchRemote(ctx context.Context, in *pb.FetchRemoteRequest) (*pb.F return nil, grpc.Errorf(codes.Internal, err.Error()) } - if _, err = io.Copy(ioutil.Discard, cmd); err != nil { - return nil, grpc.Errorf(codes.Internal, err.Error()) - } - if err = cmd.Wait(); err != nil { return nil, grpc.Errorf(codes.Internal, err.Error()) } diff --git a/internal/service/repository/fetch_remote_test.go b/internal/service/repository/fetch_remote_test.go index 3da43409f..a4776ffe9 100644 --- a/internal/service/repository/fetch_remote_test.go +++ b/internal/service/repository/fetch_remote_test.go @@ -1,7 +1,6 @@ package repository import ( - "context" "io/ioutil" "os" "path" @@ -108,6 +107,9 @@ func TestFetchRemoteArgsBuilder(t *testing.T) { // NOTE: Only tests that `gitlab-shell` is being called, not what it does. func TestFetchRemoteSuccess(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + dir, err := ioutil.TempDir("", "gitlab-shell.") require.NoError(t, err) defer func(dir string) { @@ -135,7 +137,7 @@ func TestFetchRemoteSuccess(t *testing.T) { os.RemoveAll(path) }(cloneRepo) - resp, err := client.FetchRemote(context.Background(), &pb.FetchRemoteRequest{ + resp, err := client.FetchRemote(ctx, &pb.FetchRemoteRequest{ Repository: cloneRepo, Remote: "my-remote", }) @@ -172,6 +174,9 @@ func TestFetchRemoteFailure(t *testing.T) { for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + if len(tc.shellPath) == 0 { defer func(oldPath string) { config.Config.GitlabShell.Dir = oldPath @@ -179,7 +184,7 @@ func TestFetchRemoteFailure(t *testing.T) { config.Config.GitlabShell.Dir = "" } - resp, err := client.FetchRemote(context.Background(), tc.req) + resp, err := client.FetchRemote(ctx, tc.req) testhelper.AssertGrpcError(t, err, tc.code, tc.err) assert.Error(t, err) assert.Nil(t, resp) diff --git a/internal/service/repository/gc.go b/internal/service/repository/gc.go index b7c162621..f8e7a155c 100644 --- a/internal/service/repository/gc.go +++ b/internal/service/repository/gc.go @@ -1,9 +1,6 @@ package repository import ( - "io" - "io/ioutil" - log "github.com/Sirupsen/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "golang.org/x/net/context" @@ -37,10 +34,6 @@ func (server) GarbageCollect(ctx context.Context, in *pb.GarbageCollectRequest) return nil, grpc.Errorf(codes.Internal, err.Error()) } - if _, err := io.Copy(ioutil.Discard, cmd); err != nil { - return nil, grpc.Errorf(codes.Internal, err.Error()) - } - if err := cmd.Wait(); err != nil { return nil, grpc.Errorf(codes.Internal, err.Error()) } diff --git a/internal/service/repository/repack.go b/internal/service/repository/repack.go index 55fe576ab..43201bafc 100644 --- a/internal/service/repository/repack.go +++ b/internal/service/repository/repack.go @@ -1,9 +1,6 @@ package repository import ( - "io" - "io/ioutil" - log "github.com/Sirupsen/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "golang.org/x/net/context" @@ -52,10 +49,6 @@ func repackCommand(ctx context.Context, rpcName string, repo *pb.Repository, bit return grpc.Errorf(codes.Internal, err.Error()) } - if _, err := io.Copy(ioutil.Discard, cmd); err != nil { - return grpc.Errorf(codes.Internal, err.Error()) - } - if err := cmd.Wait(); err != nil { return grpc.Errorf(codes.Internal, err.Error()) } diff --git a/internal/service/smarthttp/inforefs.go b/internal/service/smarthttp/inforefs.go index b01d3911d..ad71839cf 100644 --- a/internal/service/smarthttp/inforefs.go +++ b/internal/service/smarthttp/inforefs.go @@ -44,7 +44,6 @@ func handleInfoRefs(ctx context.Context, service string, repo *pb.Repository, w if err != nil { return grpc.Errorf(codes.Internal, "GetInfoRefs: cmd: %v", err) } - defer cmd.Close() if err := pktLine(w, fmt.Sprintf("# service=git-%s\n", service)); err != nil { return grpc.Errorf(codes.Internal, "GetInfoRefs: pktLine: %v", err) @@ -55,11 +54,11 @@ func handleInfoRefs(ctx context.Context, service string, repo *pb.Repository, w } if _, err := io.Copy(w, cmd); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: copy output of %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err) } if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Internal, "GetInfoRefs: wait for %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Internal, "GetInfoRefs: %v", err) } return nil diff --git a/internal/service/smarthttp/receive_pack.go b/internal/service/smarthttp/receive_pack.go index e0d4fc021..3009ff008 100644 --- a/internal/service/smarthttp/receive_pack.go +++ b/internal/service/smarthttp/receive_pack.go @@ -53,12 +53,11 @@ func (s *server) PostReceivePack(stream pb.SmartHTTPService_PostReceivePackServe cmd, err := command.New(stream.Context(), osCommand, stdin, stdout, nil, env...) if err != nil { - return grpc.Errorf(codes.Unavailable, "PostReceivePack: cmd: %v", err) + return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err) } - defer cmd.Close() if err := cmd.Wait(); err != nil { - return grpc.Errorf(codes.Unavailable, "PostReceivePack: cmd wait for %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Unavailable, "PostReceivePack: %v", err) } return nil diff --git a/internal/service/smarthttp/upload_pack.go b/internal/service/smarthttp/upload_pack.go index 8baf6d851..273aa9abb 100644 --- a/internal/service/smarthttp/upload_pack.go +++ b/internal/service/smarthttp/upload_pack.go @@ -66,7 +66,6 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer) if err != nil { return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd: %v", err) } - defer cmd.Close() if err := cmd.Wait(); err != nil { pw.Close() // ensure scanDeepen returns @@ -77,7 +76,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTPService_PostUploadPackServer) deepenCount.Inc() return nil } - return grpc.Errorf(codes.Unavailable, "PostUploadPack: cmd wait for %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Unavailable, "PostUploadPack: %v", err) } return nil diff --git a/internal/service/ssh/receive_pack.go b/internal/service/ssh/receive_pack.go index 0fa80d6b7..8dd456af4 100644 --- a/internal/service/ssh/receive_pack.go +++ b/internal/service/ssh/receive_pack.go @@ -59,7 +59,6 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error if err != nil { return grpc.Errorf(codes.Unavailable, "SSHReceivePack: cmd: %v", err) } - defer cmd.Close() if err := cmd.Wait(); err != nil { if status, ok := command.ExitStatus(err); ok { @@ -68,7 +67,7 @@ func (s *server) SSHReceivePack(stream pb.SSHService_SSHReceivePackServer) error stream.Send(&pb.SSHReceivePackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}), ) } - return grpc.Errorf(codes.Unavailable, "SSHReceivePack: cmd wait for %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Unavailable, "SSHReceivePack: %v", err) } return nil diff --git a/internal/service/ssh/upload_pack.go b/internal/service/ssh/upload_pack.go index 811d4eb45..1fe56c568 100644 --- a/internal/service/ssh/upload_pack.go +++ b/internal/service/ssh/upload_pack.go @@ -53,7 +53,6 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { if err != nil { return grpc.Errorf(codes.Unavailable, "SSHUploadPack: cmd: %v", err) } - defer cmd.Close() if err := cmd.Wait(); err != nil { if status, ok := command.ExitStatus(err); ok { @@ -62,7 +61,7 @@ func (s *server) SSHUploadPack(stream pb.SSHService_SSHUploadPackServer) error { stream.Send(&pb.SSHUploadPackResponse{ExitStatus: &pb.ExitStatus{Value: int32(status)}}), ) } - return grpc.Errorf(codes.Unavailable, "SSHUploadPack: cmd wait for %v: %v", cmd.Args, err) + return grpc.Errorf(codes.Unavailable, "SSHUploadPack: %v", err) } return nil diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index 69e47ea15..4c4807430 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -2,6 +2,7 @@ package testhelper import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -13,6 +14,7 @@ import ( "strings" "syscall" "testing" + "time" log "github.com/Sirupsen/logrus" @@ -196,8 +198,20 @@ func NewTestGrpcServer(t *testing.T, streamInterceptors []grpc.StreamServerInter } // MustHaveNoChildProcess panics if it finds a running or finished child -// process. +// process. It waits for 2 seconds for processes to be cleaned up by other +// goroutines. func MustHaveNoChildProcess() { + waitDone := make(chan struct{}) + go func() { + command.WaitAllDone() + close(waitDone) + }() + + select { + case <-waitDone: + case <-time.After(2 * time.Second): + } + mustFindNoFinishedChildProcess() mustFindNoRunningChildProcess() } @@ -209,11 +223,7 @@ func mustFindNoFinishedChildProcess() { // rusage. Use WNOHANG to return immediately if there is no child waiting // to be reaped. wpid, err := syscall.Wait4(-1, nil, syscall.WNOHANG, nil) - if err != nil { - return - } - - if wpid > 0 { + if err == nil && wpid > 0 { panic(fmt.Errorf("wait4 found child process %d", wpid)) } } @@ -226,7 +236,7 @@ func mustFindNoRunningChildProcess() { if err == nil { pidsComma := strings.Replace(strings.TrimSpace(string(out)), ",", "\n", -1) psOut, _ := exec.Command("ps", "-o", "pid,args", "-p", pidsComma).Output() - panic(fmt.Sprintf("found running child processes %s:\n%s", pidsComma, psOut)) + panic(fmt.Errorf("found running child processes %s:\n%s", pidsComma, psOut)) } if status, ok := command.ExitStatus(err); ok && status == 1 { @@ -234,5 +244,10 @@ func mustFindNoRunningChildProcess() { return } - panic(fmt.Sprintf("%s: %v", desc, err)) + panic(fmt.Errorf("%s: %v", desc, err)) +} + +// Context returns a cancellable context. +func Context() (context.Context, func()) { + return context.WithCancel(context.Background()) } |