diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2019-11-29 13:25:10 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2019-11-29 13:25:10 +0300 |
commit | 5060907241e2960a18f4437f84e429a50e14d879 (patch) | |
tree | 865bd620488d8b8429c59f0c00276031ca65fb7b | |
parent | 364a9bae81784b06371513ab2070b142d6d16504 (diff) | |
parent | 9dc660fe4ccd652737b63de217e5a9ae0f3c6b93 (diff) |
Merge branch 'jc-fix-race-test' into 'master'
Fix upload pack request racy test
Closes #2217
See merge request gitlab-org/gitaly!1661
-rw-r--r-- | internal/service/ssh/server.go | 43 | ||||
-rw-r--r-- | internal/service/ssh/testhelper_test.go | 4 | ||||
-rw-r--r-- | internal/service/ssh/upload_archive.go | 11 | ||||
-rw-r--r-- | internal/service/ssh/upload_archive_test.go | 9 | ||||
-rw-r--r-- | internal/service/ssh/upload_pack.go | 11 | ||||
-rw-r--r-- | internal/service/ssh/upload_pack_test.go | 8 |
6 files changed, 50 insertions, 36 deletions
diff --git a/internal/service/ssh/server.go b/internal/service/ssh/server.go index 095812072..6933b76b6 100644 --- a/internal/service/ssh/server.go +++ b/internal/service/ssh/server.go @@ -1,12 +1,49 @@ package ssh -import "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" +import ( + "time" + + "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" +) + +var ( + defaultUploadPackRequestTimeout = 10 * time.Minute + defaultUploadArchiveRequestTimeout = time.Minute +) type server struct { + uploadPackRequestTimeout time.Duration + uploadArchiveRequestTimeout time.Duration gitalypb.UnimplementedSSHServiceServer } // NewServer creates a new instance of a grpc SSHServer -func NewServer() gitalypb.SSHServiceServer { - return &server{} +func NewServer(serverOpts ...ServerOpt) gitalypb.SSHServiceServer { + s := &server{ + uploadPackRequestTimeout: defaultUploadPackRequestTimeout, + uploadArchiveRequestTimeout: defaultUploadArchiveRequestTimeout, + } + + for _, serverOpt := range serverOpts { + serverOpt(s) + } + + return s +} + +// ServerOpt is a self referential option for server +type ServerOpt func(s *server) + +// WithUploadPackRequestTimeout sets the upload pack request timeout +func WithUploadPackRequestTimeout(d time.Duration) ServerOpt { + return func(s *server) { + s.uploadPackRequestTimeout = d + } +} + +// WithArchiveRequestTimeout sets the upload pack request timeout +func WithArchiveRequestTimeout(d time.Duration) ServerOpt { + return func(s *server) { + s.uploadArchiveRequestTimeout = d + } } diff --git a/internal/service/ssh/testhelper_test.go b/internal/service/ssh/testhelper_test.go index aa47bcdc2..171430590 100644 --- a/internal/service/ssh/testhelper_test.go +++ b/internal/service/ssh/testhelper_test.go @@ -57,7 +57,7 @@ func mustGetCwd() string { return wd } -func runSSHServer(t *testing.T) (*grpc.Server, string) { +func runSSHServer(t *testing.T, serverOpts ...ServerOpt) (*grpc.Server, string) { server := testhelper.NewTestGrpcServer(t, nil, nil) serverSocketPath := testhelper.GetTemporaryGitalySocketFileName() @@ -66,7 +66,7 @@ func runSSHServer(t *testing.T) (*grpc.Server, string) { t.Fatal(err) } - gitalypb.RegisterSSHServiceServer(server, NewServer()) + gitalypb.RegisterSSHServiceServer(server, NewServer(serverOpts...)) reflection.Register(server) go server.Serve(listener) diff --git a/internal/service/ssh/upload_archive.go b/internal/service/ssh/upload_archive.go index ae8cdcc3c..ec61468a1 100644 --- a/internal/service/ssh/upload_archive.go +++ b/internal/service/ssh/upload_archive.go @@ -3,7 +3,6 @@ package ssh import ( "context" "fmt" - "time" "gitlab.com/gitlab-org/gitaly/internal/command" "gitlab.com/gitlab-org/gitaly/internal/git" @@ -13,10 +12,6 @@ import ( "gitlab.com/gitlab-org/gitaly/streamio" ) -var ( - uploadArchiveRequestTimeout = time.Minute -) - func (s *server) SSHUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer) error { req, err := stream.Recv() // First request contains Repository only if err != nil { @@ -26,14 +21,14 @@ func (s *server) SSHUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveSer return helper.ErrInvalidArgument(err) } - if err = sshUploadArchive(stream, req); err != nil { + if err = s.sshUploadArchive(stream, req); err != nil { return helper.ErrInternal(err) } return nil } -func sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gitalypb.SSHUploadArchiveRequest) error { +func (s *server) sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gitalypb.SSHUploadArchiveRequest) error { ctx, cancelCtx := context.WithCancel(stream.Context()) defer cancelCtx() @@ -66,7 +61,7 @@ func sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gi // // Place a timeout on receiving the flush packet to mitigate use-after-check // attacks - go monitor.Monitor(pktline.PktFlush(), uploadArchiveRequestTimeout, cancelCtx) + go monitor.Monitor(pktline.PktFlush(), s.uploadArchiveRequestTimeout, cancelCtx) if err := cmd.Wait(); err != nil { if status, ok := command.ExitStatus(err); ok { diff --git a/internal/service/ssh/upload_archive_test.go b/internal/service/ssh/upload_archive_test.go index f58013a5c..c1420cc15 100644 --- a/internal/service/ssh/upload_archive_test.go +++ b/internal/service/ssh/upload_archive_test.go @@ -16,15 +16,8 @@ import ( "google.golang.org/grpc/codes" ) -var ( - originalUploadArchiveRequestTimeout = uploadArchiveRequestTimeout -) - func TestFailedUploadArchiveRequestDueToTimeout(t *testing.T) { - uploadArchiveRequestTimeout = time.Millisecond - defer func() { uploadArchiveRequestTimeout = originalUploadArchiveRequestTimeout }() - - server, serverSocketPath := runSSHServer(t) + server, serverSocketPath := runSSHServer(t, WithArchiveRequestTimeout(100*time.Microsecond)) defer server.Stop() client, conn := newSSHClient(t, serverSocketPath) diff --git a/internal/service/ssh/upload_pack.go b/internal/service/ssh/upload_pack.go index 094afab2c..3973a0955 100644 --- a/internal/service/ssh/upload_pack.go +++ b/internal/service/ssh/upload_pack.go @@ -3,7 +3,6 @@ package ssh import ( "context" "fmt" - "time" "gitlab.com/gitlab-org/gitaly/internal/command" "gitlab.com/gitlab-org/gitaly/internal/git" @@ -13,10 +12,6 @@ import ( "gitlab.com/gitlab-org/gitaly/streamio" ) -var ( - uploadPackRequestTimeout = 10 * time.Minute -) - func (s *server) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error { req, err := stream.Recv() // First request contains Repository only if err != nil { @@ -27,14 +22,14 @@ func (s *server) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) e return helper.ErrInvalidArgument(err) } - if err = sshUploadPack(stream, req); err != nil { + if err = s.sshUploadPack(stream, req); err != nil { return helper.ErrInternal(err) } return nil } -func sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb.SSHUploadPackRequest) error { +func (s *server) sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb.SSHUploadPackRequest) error { ctx, cancelCtx := context.WithCancel(stream.Context()) defer cancelCtx() @@ -77,7 +72,7 @@ func sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb // "flush" tells the server it can terminate, while "done" tells it to start // generating a packfile. Add a timeout to the second case to mitigate // use-after-check attacks. - go monitor.Monitor(pktline.PktDone(), uploadPackRequestTimeout, cancelCtx) + go monitor.Monitor(pktline.PktDone(), s.uploadPackRequestTimeout, cancelCtx) if err := cmd.Wait(); err != nil { if status, ok := command.ExitStatus(err); ok { diff --git a/internal/service/ssh/upload_pack_test.go b/internal/service/ssh/upload_pack_test.go index d72b35b6f..6d6d3ebdf 100644 --- a/internal/service/ssh/upload_pack_test.go +++ b/internal/service/ssh/upload_pack_test.go @@ -20,15 +20,9 @@ import ( "google.golang.org/grpc/codes" ) -var ( - originalUploadPackRequestTimeout = uploadPackRequestTimeout -) - func TestFailedUploadPackRequestDueToTimeout(t *testing.T) { - uploadPackRequestTimeout = time.Millisecond - defer func() { uploadPackRequestTimeout = originalUploadPackRequestTimeout }() + server, serverSocketPath := runSSHServer(t, WithUploadPackRequestTimeout(10*time.Microsecond)) - server, serverSocketPath := runSSHServer(t) defer server.Stop() client, conn := newSSHClient(t, serverSocketPath) |