Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjramsay <jcai@gitlab.com>2019-11-28 10:25:36 +0300
committerjramsay <jcai@gitlab.com>2019-11-28 22:54:06 +0300
commit9dc660fe4ccd652737b63de217e5a9ae0f3c6b93 (patch)
tree209b0e231d8be9d739cca228bbc016fdf2d4dc77
parentbee4674b85e80337da69937e5de8687299452cc3 (diff)
Fix upload pack request racy test
-rw-r--r--internal/service/ssh/server.go43
-rw-r--r--internal/service/ssh/testhelper_test.go4
-rw-r--r--internal/service/ssh/upload_archive.go11
-rw-r--r--internal/service/ssh/upload_archive_test.go9
-rw-r--r--internal/service/ssh/upload_pack.go11
-rw-r--r--internal/service/ssh/upload_pack_test.go8
6 files changed, 50 insertions, 36 deletions
diff --git a/internal/service/ssh/server.go b/internal/service/ssh/server.go
index 095812072..6933b76b6 100644
--- a/internal/service/ssh/server.go
+++ b/internal/service/ssh/server.go
@@ -1,12 +1,49 @@
package ssh
-import "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+import (
+ "time"
+
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+)
+
+var (
+ defaultUploadPackRequestTimeout = 10 * time.Minute
+ defaultUploadArchiveRequestTimeout = time.Minute
+)
type server struct {
+ uploadPackRequestTimeout time.Duration
+ uploadArchiveRequestTimeout time.Duration
gitalypb.UnimplementedSSHServiceServer
}
// NewServer creates a new instance of a grpc SSHServer
-func NewServer() gitalypb.SSHServiceServer {
- return &server{}
+func NewServer(serverOpts ...ServerOpt) gitalypb.SSHServiceServer {
+ s := &server{
+ uploadPackRequestTimeout: defaultUploadPackRequestTimeout,
+ uploadArchiveRequestTimeout: defaultUploadArchiveRequestTimeout,
+ }
+
+ for _, serverOpt := range serverOpts {
+ serverOpt(s)
+ }
+
+ return s
+}
+
+// ServerOpt is a self referential option for server
+type ServerOpt func(s *server)
+
+// WithUploadPackRequestTimeout sets the upload pack request timeout
+func WithUploadPackRequestTimeout(d time.Duration) ServerOpt {
+ return func(s *server) {
+ s.uploadPackRequestTimeout = d
+ }
+}
+
+// WithArchiveRequestTimeout sets the upload pack request timeout
+func WithArchiveRequestTimeout(d time.Duration) ServerOpt {
+ return func(s *server) {
+ s.uploadArchiveRequestTimeout = d
+ }
}
diff --git a/internal/service/ssh/testhelper_test.go b/internal/service/ssh/testhelper_test.go
index aa47bcdc2..171430590 100644
--- a/internal/service/ssh/testhelper_test.go
+++ b/internal/service/ssh/testhelper_test.go
@@ -57,7 +57,7 @@ func mustGetCwd() string {
return wd
}
-func runSSHServer(t *testing.T) (*grpc.Server, string) {
+func runSSHServer(t *testing.T, serverOpts ...ServerOpt) (*grpc.Server, string) {
server := testhelper.NewTestGrpcServer(t, nil, nil)
serverSocketPath := testhelper.GetTemporaryGitalySocketFileName()
@@ -66,7 +66,7 @@ func runSSHServer(t *testing.T) (*grpc.Server, string) {
t.Fatal(err)
}
- gitalypb.RegisterSSHServiceServer(server, NewServer())
+ gitalypb.RegisterSSHServiceServer(server, NewServer(serverOpts...))
reflection.Register(server)
go server.Serve(listener)
diff --git a/internal/service/ssh/upload_archive.go b/internal/service/ssh/upload_archive.go
index ae8cdcc3c..ec61468a1 100644
--- a/internal/service/ssh/upload_archive.go
+++ b/internal/service/ssh/upload_archive.go
@@ -3,7 +3,6 @@ package ssh
import (
"context"
"fmt"
- "time"
"gitlab.com/gitlab-org/gitaly/internal/command"
"gitlab.com/gitlab-org/gitaly/internal/git"
@@ -13,10 +12,6 @@ import (
"gitlab.com/gitlab-org/gitaly/streamio"
)
-var (
- uploadArchiveRequestTimeout = time.Minute
-)
-
func (s *server) SSHUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer) error {
req, err := stream.Recv() // First request contains Repository only
if err != nil {
@@ -26,14 +21,14 @@ func (s *server) SSHUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveSer
return helper.ErrInvalidArgument(err)
}
- if err = sshUploadArchive(stream, req); err != nil {
+ if err = s.sshUploadArchive(stream, req); err != nil {
return helper.ErrInternal(err)
}
return nil
}
-func sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gitalypb.SSHUploadArchiveRequest) error {
+func (s *server) sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gitalypb.SSHUploadArchiveRequest) error {
ctx, cancelCtx := context.WithCancel(stream.Context())
defer cancelCtx()
@@ -66,7 +61,7 @@ func sshUploadArchive(stream gitalypb.SSHService_SSHUploadArchiveServer, req *gi
//
// Place a timeout on receiving the flush packet to mitigate use-after-check
// attacks
- go monitor.Monitor(pktline.PktFlush(), uploadArchiveRequestTimeout, cancelCtx)
+ go monitor.Monitor(pktline.PktFlush(), s.uploadArchiveRequestTimeout, cancelCtx)
if err := cmd.Wait(); err != nil {
if status, ok := command.ExitStatus(err); ok {
diff --git a/internal/service/ssh/upload_archive_test.go b/internal/service/ssh/upload_archive_test.go
index f58013a5c..c1420cc15 100644
--- a/internal/service/ssh/upload_archive_test.go
+++ b/internal/service/ssh/upload_archive_test.go
@@ -16,15 +16,8 @@ import (
"google.golang.org/grpc/codes"
)
-var (
- originalUploadArchiveRequestTimeout = uploadArchiveRequestTimeout
-)
-
func TestFailedUploadArchiveRequestDueToTimeout(t *testing.T) {
- uploadArchiveRequestTimeout = time.Millisecond
- defer func() { uploadArchiveRequestTimeout = originalUploadArchiveRequestTimeout }()
-
- server, serverSocketPath := runSSHServer(t)
+ server, serverSocketPath := runSSHServer(t, WithArchiveRequestTimeout(100*time.Microsecond))
defer server.Stop()
client, conn := newSSHClient(t, serverSocketPath)
diff --git a/internal/service/ssh/upload_pack.go b/internal/service/ssh/upload_pack.go
index 094afab2c..3973a0955 100644
--- a/internal/service/ssh/upload_pack.go
+++ b/internal/service/ssh/upload_pack.go
@@ -3,7 +3,6 @@ package ssh
import (
"context"
"fmt"
- "time"
"gitlab.com/gitlab-org/gitaly/internal/command"
"gitlab.com/gitlab-org/gitaly/internal/git"
@@ -13,10 +12,6 @@ import (
"gitlab.com/gitlab-org/gitaly/streamio"
)
-var (
- uploadPackRequestTimeout = 10 * time.Minute
-)
-
func (s *server) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error {
req, err := stream.Recv() // First request contains Repository only
if err != nil {
@@ -27,14 +22,14 @@ func (s *server) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) e
return helper.ErrInvalidArgument(err)
}
- if err = sshUploadPack(stream, req); err != nil {
+ if err = s.sshUploadPack(stream, req); err != nil {
return helper.ErrInternal(err)
}
return nil
}
-func sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb.SSHUploadPackRequest) error {
+func (s *server) sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb.SSHUploadPackRequest) error {
ctx, cancelCtx := context.WithCancel(stream.Context())
defer cancelCtx()
@@ -77,7 +72,7 @@ func sshUploadPack(stream gitalypb.SSHService_SSHUploadPackServer, req *gitalypb
// "flush" tells the server it can terminate, while "done" tells it to start
// generating a packfile. Add a timeout to the second case to mitigate
// use-after-check attacks.
- go monitor.Monitor(pktline.PktDone(), uploadPackRequestTimeout, cancelCtx)
+ go monitor.Monitor(pktline.PktDone(), s.uploadPackRequestTimeout, cancelCtx)
if err := cmd.Wait(); err != nil {
if status, ok := command.ExitStatus(err); ok {
diff --git a/internal/service/ssh/upload_pack_test.go b/internal/service/ssh/upload_pack_test.go
index d72b35b6f..6d6d3ebdf 100644
--- a/internal/service/ssh/upload_pack_test.go
+++ b/internal/service/ssh/upload_pack_test.go
@@ -20,15 +20,9 @@ import (
"google.golang.org/grpc/codes"
)
-var (
- originalUploadPackRequestTimeout = uploadPackRequestTimeout
-)
-
func TestFailedUploadPackRequestDueToTimeout(t *testing.T) {
- uploadPackRequestTimeout = time.Millisecond
- defer func() { uploadPackRequestTimeout = originalUploadPackRequestTimeout }()
+ server, serverSocketPath := runSSHServer(t, WithUploadPackRequestTimeout(10*time.Microsecond))
- server, serverSocketPath := runSSHServer(t)
defer server.Stop()
client, conn := newSSHClient(t, serverSocketPath)