Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer (GitLab) <jacob@gitlab.com>2018-01-12 01:05:17 +0300
committerAlejandro Rodríguez <alejorro70@gmail.com>2018-01-12 01:05:17 +0300
commita761485aebeb2f4ebd189950595a693e86619e5a (patch)
treed83032d5de4df885f6c52af0b5bd35174ed6c9dd /internal/rubyserver
parent66e77b3d89e4b07ed257e6305518e5f39ae68163 (diff)
Check repo existence before passing to gitaly-ruby
Diffstat (limited to 'internal/rubyserver')
-rw-r--r--internal/rubyserver/proxy.go19
-rw-r--r--internal/rubyserver/rubyserver_test.go57
2 files changed, 63 insertions, 13 deletions
diff --git a/internal/rubyserver/proxy.go b/internal/rubyserver/proxy.go
index 7364e9e5f..29aa97814 100644
--- a/internal/rubyserver/proxy.go
+++ b/internal/rubyserver/proxy.go
@@ -23,14 +23,31 @@ const (
repoAltDirsHeader = "gitaly-repo-alt-dirs"
)
+// SetHeadersWithoutRepoCheck adds headers that tell gitaly-ruby the full
+// path to the repository. It is not an error if the repository does not
+// yet exist. This can be used on RPC calls that will create a
+// repository.
+func SetHeadersWithoutRepoCheck(ctx context.Context, repo *pb.Repository) (context.Context, error) {
+ return setHeaders(ctx, repo, false)
+}
+
// SetHeaders adds headers that tell gitaly-ruby the full path to the repository.
func SetHeaders(ctx context.Context, repo *pb.Repository) (context.Context, error) {
+ return setHeaders(ctx, repo, true)
+}
+
+func setHeaders(ctx context.Context, repo *pb.Repository, mustExist bool) (context.Context, error) {
storagePath, err := helper.GetStorageByName(repo.GetStorageName())
if err != nil {
return nil, err
}
- repoPath, err := helper.GetPath(repo)
+ var repoPath string
+ if mustExist {
+ repoPath, err = helper.GetRepoPath(repo)
+ } else {
+ repoPath, err = helper.GetPath(repo)
+ }
if err != nil {
return nil, err
}
diff --git a/internal/rubyserver/rubyserver_test.go b/internal/rubyserver/rubyserver_test.go
index 77943731e..5233d720c 100644
--- a/internal/rubyserver/rubyserver_test.go
+++ b/internal/rubyserver/rubyserver_test.go
@@ -1,6 +1,7 @@
package rubyserver
import (
+ "context"
"testing"
"github.com/stretchr/testify/assert"
@@ -23,32 +24,64 @@ func TestStopSafe(t *testing.T) {
}
func TestSetHeaders(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
testCases := []struct {
+ desc string
repo *pb.Repository
errType codes.Code
+ setter func(context.Context, *pb.Repository) (context.Context, error)
}{
{
+ desc: "SetHeaders invalid storage",
repo: &pb.Repository{StorageName: "foo", RelativePath: "bar.git"},
errType: codes.InvalidArgument,
+ setter: SetHeaders,
+ },
+ {
+ desc: "SetHeaders invalid rel path",
+ repo: &pb.Repository{StorageName: testRepo.StorageName, RelativePath: "bar.git"},
+ errType: codes.NotFound,
+ setter: SetHeaders,
},
{
+ desc: "SetHeaders OK",
repo: testRepo,
errType: codes.OK,
+ setter: SetHeaders,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck invalid storage",
+ repo: &pb.Repository{StorageName: "foo", RelativePath: "bar.git"},
+ errType: codes.InvalidArgument,
+ setter: SetHeadersWithoutRepoCheck,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck invalid relative path",
+ repo: &pb.Repository{StorageName: testRepo.StorageName, RelativePath: "bar.git"},
+ errType: codes.OK,
+ setter: SetHeadersWithoutRepoCheck,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck OK",
+ repo: testRepo,
+ errType: codes.OK,
+ setter: SetHeadersWithoutRepoCheck,
},
}
for _, tc := range testCases {
- ctx, cancel := testhelper.Context()
- defer cancel()
-
- clientCtx, err := SetHeaders(ctx, tc.repo)
-
- if tc.errType != codes.OK {
- testhelper.AssertGrpcError(t, err, tc.errType, "")
- assert.Nil(t, clientCtx)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, clientCtx)
- }
+ t.Run(tc.desc, func(t *testing.T) {
+ clientCtx, err := tc.setter(ctx, tc.repo)
+
+ if tc.errType != codes.OK {
+ testhelper.AssertGrpcError(t, err, tc.errType, "")
+ assert.Nil(t, clientCtx)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, clientCtx)
+ }
+ })
}
}