Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Fargher <proglottis@gmail.com>2023-05-01 00:40:48 +0300
committerJames Fargher <proglottis@gmail.com>2023-05-01 00:40:48 +0300
commit62e36a86900bdce323716c34dd877ad424309bc8 (patch)
tree618ee4a71661b2c09ea4ff72a44810ecafe09ef2
parent512b7d09c44261f4018bd0cb89470315207f474c (diff)
parente12acb508bdfb85f8a15c3e9c714848d1d214993 (diff)
Merge branch 'pks-protoregistry-method-info-simplifications' into 'master'
protoregistry: Refactor code to retrieve Protobuf fields tagged with extensions See merge request https://gitlab.com/gitlab-org/gitaly/-/merge_requests/5683 Merged-by: James Fargher <proglottis@gmail.com> Approved-by: James Fargher <proglottis@gmail.com> Approved-by: Will Chandler <wchandler@gitlab.com> Reviewed-by: James Fargher <proglottis@gmail.com> Reviewed-by: Patrick Steinhardt <psteinhardt@gitlab.com> Reviewed-by: Will Chandler <wchandler@gitlab.com> Co-authored-by: Patrick Steinhardt <psteinhardt@gitlab.com>
-rw-r--r--internal/praefect/coordinator.go22
-rw-r--r--internal/praefect/coordinator_test.go2
-rw-r--r--internal/praefect/protoregistry/find_oid.go162
-rw-r--r--internal/praefect/protoregistry/find_oid_test.go198
-rw-r--r--internal/praefect/protoregistry/method_info.go309
-rw-r--r--internal/praefect/protoregistry/method_info_test.go434
-rw-r--r--internal/praefect/protoregistry/protoregistry.go368
-rw-r--r--internal/praefect/protoregistry/protoregistry_test.go250
8 files changed, 869 insertions, 876 deletions
diff --git a/internal/praefect/coordinator.go b/internal/praefect/coordinator.go
index 5b88eaf87..10773f5e0 100644
--- a/internal/praefect/coordinator.go
+++ b/internal/praefect/coordinator.go
@@ -373,9 +373,15 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall
}
var additionalRepoRelativePath string
- if additionalRepo, ok, err := call.methodInfo.AdditionalRepo(call.msg); err != nil {
+ if additionalRepo, err := call.methodInfo.AdditionalRepo(call.msg); errors.Is(err, protoregistry.ErrTargetRepoMissing) {
+ // We can land here in two cases: either the message doesn't have an additional
+ // repository, or the repository wasn't set. The former case is obviously fine, but
+ // the latter case is fine, too, given that the additional repository may be an
+ // optional field for some RPC calls. The Gitaly-side RPC handlers should know to
+ // handle this case anyway, so we just leave the field unset in that case.
+ } else if err != nil {
return nil, structerr.NewInvalidArgument("%w", err)
- } else if ok {
+ } else {
additionalRepoRelativePath = additionalRepo.GetRelativePath()
}
@@ -700,7 +706,7 @@ func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string,
func (c *Coordinator) directStorageScopedMessage(ctx context.Context, mi protoregistry.MethodInfo, msg proto.Message) (*proxy.StreamParameters, error) {
virtualStorage, err := mi.Storage(msg)
if err != nil {
- return nil, structerr.NewInvalidArgument("%w", err)
+ return nil, structerr.NewInvalidArgument("storage scoped: %w", err)
}
if virtualStorage == "" {
@@ -794,12 +800,12 @@ func rewrittenRepositoryMessage(mi protoregistry.MethodInfo, m proto.Message, st
targetRepo.StorageName = storage
targetRepo.RelativePath = relativePath
- additionalRepo, ok, err := mi.AdditionalRepo(m)
- if err != nil {
+ if additionalRepo, err := mi.AdditionalRepo(m); errors.Is(err, protoregistry.ErrTargetRepoMissing) {
+ // Nothing to rewrite in case the additional repository either doesn't exist in the
+ // message or wasn't set by the caller.
+ } else if err != nil {
return nil, structerr.NewInvalidArgument("%w", err)
- }
-
- if ok {
+ } else {
additionalRepo.StorageName = storage
additionalRepo.RelativePath = additionalRelativePath
}
diff --git a/internal/praefect/coordinator_test.go b/internal/praefect/coordinator_test.go
index c3e2328f6..82f698f51 100644
--- a/internal/praefect/coordinator_test.go
+++ b/internal/praefect/coordinator_test.go
@@ -1726,7 +1726,7 @@ func TestStreamDirectorStorageScopeError(t *testing.T) {
result, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.InvalidArgument, result.Code())
- require.Equal(t, "storage scoped: target storage is invalid", result.Message())
+ require.Equal(t, "storage scoped: target storage field not found", result.Message())
})
t.Run("unknown storage provided", func(t *testing.T) {
diff --git a/internal/praefect/protoregistry/find_oid.go b/internal/praefect/protoregistry/find_oid.go
deleted file mode 100644
index 8ce91c0dd..000000000
--- a/internal/praefect/protoregistry/find_oid.go
+++ /dev/null
@@ -1,162 +0,0 @@
-package protoregistry
-
-import (
- "errors"
- "fmt"
- "reflect"
- "regexp"
- "strconv"
-
- "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
- "google.golang.org/protobuf/proto"
-)
-
-const (
- protobufTag = "protobuf"
- protobufOneOfTag = "protobuf_oneof"
-)
-
-// ErrTargetRepoMissing indicates that the target repo is missing or not set
-var ErrTargetRepoMissing = errors.New("empty Repository")
-
-func reflectFindRepoTarget(pbMsg proto.Message, targetOID []int) (*gitalypb.Repository, error) {
- msgV, e := reflectFindOID(pbMsg, targetOID)
- if e != nil {
- if e == ErrProtoFieldEmpty {
- return nil, ErrTargetRepoMissing
- }
- return nil, e
- }
-
- targetRepo, ok := msgV.Interface().(*gitalypb.Repository)
- if !ok {
- return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface())
- }
-
- return targetRepo, nil
-}
-
-func reflectFindStorage(pbMsg proto.Message, targetOID []int) (string, error) {
- msgV, e := reflectFindOID(pbMsg, targetOID)
- if e != nil {
- return "", e
- }
-
- targetRepo, ok := msgV.Interface().(string)
- if !ok {
- return "", fmt.Errorf("repo target OID %v points to non-string type %+v", targetOID, msgV.Interface())
- }
-
- return targetRepo, nil
-}
-
-func reflectSetStorage(pbMsg proto.Message, targetOID []int, storage string) error {
- msgV, err := reflectFindOID(pbMsg, targetOID)
- if err != nil {
- return err
- }
-
- msgV.Set(reflect.ValueOf(storage))
- return nil
-}
-
-// ErrProtoFieldEmpty indicates the protobuf field is empty
-var ErrProtoFieldEmpty = errors.New("proto field is empty")
-
-// reflectFindOID finds the target repository by using the OID to
-// navigate the struct tags
-// Warning: this reflection filled function is full of forbidden dark elf magic
-func reflectFindOID(pbMsg proto.Message, targetOID []int) (reflect.Value, error) {
- msgV := reflect.ValueOf(pbMsg)
- for _, fieldNo := range targetOID {
- var err error
-
- msgV, err = findProtoField(msgV, fieldNo)
- if err != nil {
- return reflect.Value{}, fmt.Errorf(
- "unable to descend OID %+v into message %s: %w",
- targetOID, proto.MessageName(pbMsg), err,
- )
- }
- }
- return msgV, nil
-}
-
-// matches a tag string like "bytes,1,opt,name=repository,proto3"
-var protobufTagRegex = regexp.MustCompile(`^(.*?),(\d+),(.*?),name=(.*?),proto3(\,oneof)?$`)
-
-const (
- protobufTagRegexGroups = 6
- protobufTagRegexFieldGroup = 2
-)
-
-func findProtoField(msgV reflect.Value, protoField int) (reflect.Value, error) {
- if msgV.IsZero() {
- return reflect.Value{}, ErrProtoFieldEmpty
- }
-
- msgV = reflect.Indirect(msgV)
- for i := 0; i < msgV.NumField(); i++ {
- field := msgV.Type().Field(i)
-
- ok, err := tryNumberedField(field, protoField)
- if err != nil {
- return reflect.Value{}, err
- }
- if ok {
- return msgV.FieldByName(field.Name), nil
- }
-
- oneofField, ok := tryOneOfField(msgV, field, protoField)
- if !ok {
- continue
- }
- return oneofField, nil
- }
-
- err := fmt.Errorf(
- "unable to find protobuf field %d in message %s",
- protoField, msgV.Type().Name(),
- )
- return reflect.Value{}, err
-}
-
-func tryNumberedField(field reflect.StructField, protoField int) (bool, error) {
- tag := field.Tag.Get(protobufTag)
- matches := protobufTagRegex.FindStringSubmatch(tag)
- if len(matches) == protobufTagRegexGroups {
- fieldStr := matches[protobufTagRegexFieldGroup]
- if fieldStr == strconv.Itoa(protoField) {
- return true, nil
- }
- }
-
- return false, nil
-}
-
-func tryOneOfField(msgV reflect.Value, field reflect.StructField, protoField int) (reflect.Value, bool) {
- if msgV.IsZero() {
- return reflect.Value{}, false
- }
-
- oneOfTag := field.Tag.Get(protobufOneOfTag)
- if oneOfTag == "" {
- return reflect.Value{}, false // empty tag means this is not a oneOf field
- }
-
- // try all of the oneOf fields until a match is found
- msgV = msgV.FieldByName(field.Name).Elem().Elem()
- for i := 0; i < msgV.NumField(); i++ {
- field = msgV.Type().Field(i)
-
- ok, err := tryNumberedField(field, protoField)
- if err != nil {
- return reflect.Value{}, false
- }
- if ok {
- return msgV.FieldByName(field.Name), true
- }
- }
-
- return reflect.Value{}, false
-}
diff --git a/internal/praefect/protoregistry/find_oid_test.go b/internal/praefect/protoregistry/find_oid_test.go
deleted file mode 100644
index 1aaccd5e9..000000000
--- a/internal/praefect/protoregistry/find_oid_test.go
+++ /dev/null
@@ -1,198 +0,0 @@
-package protoregistry_test
-
-import (
- "errors"
- "fmt"
- "testing"
-
- "github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitaly/v15/internal/praefect/protoregistry"
- "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
- "google.golang.org/protobuf/proto"
-)
-
-func TestProtoRegistryTargetRepo(t *testing.T) {
- testRepos := []*gitalypb.Repository{
- {
- GitAlternateObjectDirectories: []string{"a", "b", "c"},
- GitObjectDirectory: "d",
- GlProjectPath: "e",
- GlRepository: "f",
- RelativePath: "g",
- StorageName: "h",
- },
- {
- GitAlternateObjectDirectories: []string{"1", "2", "3"},
- GitObjectDirectory: "4",
- GlProjectPath: "5",
- GlRepository: "6",
- RelativePath: "7",
- StorageName: "8",
- },
- }
-
- testcases := []struct {
- desc string
- svc string
- method string
- pbMsg proto.Message
- expectRepo *gitalypb.Repository
- expectAdditionalRepo *gitalypb.Repository
- expectErr error
- }{
- {
- desc: "valid request type single depth",
- svc: "RepositoryService",
- method: "OptimizeRepository",
- pbMsg: &gitalypb.OptimizeRepositoryRequest{
- Repository: testRepos[0],
- },
- expectRepo: testRepos[0],
- },
- {
- desc: "target nested in oneOf",
- svc: "OperationService",
- method: "UserCommitFiles",
- pbMsg: &gitalypb.UserCommitFilesRequest{
- UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{
- Header: &gitalypb.UserCommitFilesRequestHeader{
- Repository: testRepos[1],
- },
- },
- },
- expectRepo: testRepos[1],
- },
- {
- desc: "target nested, includes additional repository",
- svc: "ObjectPoolService",
- method: "FetchIntoObjectPool",
- pbMsg: &gitalypb.FetchIntoObjectPoolRequest{
- Origin: testRepos[0],
- ObjectPool: &gitalypb.ObjectPool{Repository: testRepos[1]},
- },
- expectRepo: testRepos[1],
- expectAdditionalRepo: testRepos[0],
- },
- {
- desc: "target repo is nil",
- svc: "RepositoryService",
- method: "OptimizeRepository",
- pbMsg: &gitalypb.OptimizeRepositoryRequest{Repository: nil},
- expectErr: protoregistry.ErrTargetRepoMissing,
- },
- }
-
- for _, tc := range testcases {
- desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc)
- t.Run(desc, func(t *testing.T) {
- info, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
- require.NoError(t, err)
-
- actualTarget, actualErr := info.TargetRepo(tc.pbMsg)
- require.Equal(t, tc.expectErr, actualErr)
-
- // not only do we want the value to be the same, but we actually want the
- // exact same instance to be returned
- if tc.expectRepo != actualTarget {
- t.Fatal("pointers do not match")
- }
-
- if tc.expectAdditionalRepo != nil {
- additionalRepo, ok, err := info.AdditionalRepo(tc.pbMsg)
- require.True(t, ok)
- require.NoError(t, err)
- require.Equal(t, tc.expectAdditionalRepo, additionalRepo)
- }
- })
- }
-}
-
-func TestProtoRegistryStorage(t *testing.T) {
- testcases := []struct {
- desc string
- svc string
- method string
- pbMsg proto.Message
- expectStorage string
- expectErr error
- }{
- {
- desc: "valid request type single depth",
- svc: "NamespaceService",
- method: "AddNamespace",
- pbMsg: &gitalypb.AddNamespaceRequest{
- StorageName: "some_storage",
- },
- expectStorage: "some_storage",
- },
- {
- desc: "incorrect request type",
- svc: "RepositoryService",
- method: "OptimizeRepository",
- pbMsg: &gitalypb.OptimizeRepositoryResponse{},
- expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"),
- },
- }
-
- for _, tc := range testcases {
- desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc)
- t.Run(desc, func(t *testing.T) {
- info, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
- require.NoError(t, err)
-
- actualStorage, actualErr := info.Storage(tc.pbMsg)
- require.Equal(t, tc.expectErr, actualErr)
-
- // not only do we want the value to be the same, but we actually want the
- // exact same instance to be returned
- if tc.expectStorage != actualStorage {
- t.Fatal("pointers do not match")
- }
- })
- }
-}
-
-func TestMethodInfo_SetStorage(t *testing.T) {
- testCases := []struct {
- desc string
- service string
- method string
- pbMsg proto.Message
- storage string
- expectErr error
- }{
- {
- desc: "valid request type",
- service: "NamespaceService",
- method: "AddNamespace",
- pbMsg: &gitalypb.AddNamespaceRequest{
- StorageName: "old_storage",
- },
- storage: "new_storage",
- },
- {
- desc: "incorrect request type",
- service: "RepositoryService",
- method: "OptimizeRepository",
- pbMsg: &gitalypb.OptimizeRepositoryResponse{},
- expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"),
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.desc, func(t *testing.T) {
- info, err := protoregistry.GitalyProtoPreregistered.LookupMethod("/gitaly." + tc.service + "/" + tc.method)
- require.NoError(t, err)
-
- err = info.SetStorage(tc.pbMsg, tc.storage)
- if tc.expectErr == nil {
- require.NoError(t, err)
- changed, err := info.Storage(tc.pbMsg)
- require.NoError(t, err)
- require.Equal(t, tc.storage, changed)
- } else {
- require.Equal(t, tc.expectErr, err)
- }
- })
- }
-}
diff --git a/internal/praefect/protoregistry/method_info.go b/internal/praefect/protoregistry/method_info.go
new file mode 100644
index 000000000..91dab9261
--- /dev/null
+++ b/internal/praefect/protoregistry/method_info.go
@@ -0,0 +1,309 @@
+package protoregistry
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitaly/v15/internal/protoutil"
+ "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protopath"
+ "google.golang.org/protobuf/reflect/protorange"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ protoreg "google.golang.org/protobuf/reflect/protoregistry"
+ "google.golang.org/protobuf/types/descriptorpb"
+)
+
+// OpType represents the operation type for a RPC method
+type OpType int
+
+const (
+ // OpUnknown = unknown operation type
+ OpUnknown OpType = iota
+ // OpAccessor = accessor operation type (ready only)
+ OpAccessor
+ // OpMutator = mutator operation type (modifies a repository)
+ OpMutator
+ // OpMaintenance is an operation which performs maintenance-tasks on the repository. It
+ // shouldn't ever result in a user-visible change in behaviour, except that it may repair
+ // corrupt data.
+ OpMaintenance
+)
+
+// Scope represents the intended scope of an RPC method
+type Scope int
+
+const (
+ // ScopeUnknown is the default scope until determined otherwise
+ ScopeUnknown Scope = iota
+ // ScopeRepository indicates an RPC's scope is limited to a repository
+ ScopeRepository
+ // ScopeStorage indicates an RPC is scoped to an entire storage location
+ ScopeStorage
+)
+
+func (s Scope) String() string {
+ switch s {
+ case ScopeStorage:
+ return "storage"
+ case ScopeRepository:
+ return "repository"
+ default:
+ return fmt.Sprintf("N/A: %d", s)
+ }
+}
+
+var protoScope = map[gitalypb.OperationMsg_Scope]Scope{
+ gitalypb.OperationMsg_REPOSITORY: ScopeRepository,
+ gitalypb.OperationMsg_STORAGE: ScopeStorage,
+}
+
+// MethodInfo contains metadata about the RPC method. Refer to documentation
+// for message type "OperationMsg" shared.proto in ./proto for
+// more documentation.
+type MethodInfo struct {
+ Operation OpType
+ Scope Scope
+ requestName string // protobuf message name for input type
+ requestFactory protoFactory
+ fullMethodName string
+}
+
+// TargetRepo returns the target repository for a protobuf message if it exists
+func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error) {
+ return mi.getRepo(msg, gitalypb.E_TargetRepository)
+}
+
+// AdditionalRepo returns the additional repository for a Protobuf message that needs a storage
+// rewritten if it exists.
+func (mi MethodInfo) AdditionalRepo(msg proto.Message) (*gitalypb.Repository, error) {
+ return mi.getRepo(msg, gitalypb.E_AdditionalRepository)
+}
+
+//nolint:revive // This is unintentionally missing documentation.
+func (mi MethodInfo) FullMethodName() string {
+ return mi.fullMethodName
+}
+
+// ErrTargetRepoMissing indicates that the target repo is missing or not set
+var ErrTargetRepoMissing = errors.New("empty Repository")
+
+func (mi MethodInfo) getRepo(msg proto.Message, extensionType protoreflect.ExtensionType) (*gitalypb.Repository, error) {
+ if mi.requestName != string(proto.MessageName(msg)) {
+ return nil, fmt.Errorf(
+ "proto message %s does not match expected RPC request message %s",
+ proto.MessageName(msg), mi.requestName,
+ )
+ }
+
+ field, err := findFieldByExtension(msg, extensionType)
+ if err != nil {
+ if errors.Is(err, errFieldNotFound) {
+ return nil, ErrTargetRepoMissing
+ }
+
+ return nil, err
+ }
+
+ if field.desc.Kind() != protoreflect.MessageKind {
+ return nil, fmt.Errorf("expected repository message, got %s", field.desc.Kind().String())
+ }
+
+ switch fieldMsg := field.value.Message().Interface().(type) {
+ case *gitalypb.Repository:
+ return fieldMsg, nil
+ case *gitalypb.ObjectPool:
+ repo := fieldMsg.GetRepository()
+ if repo == nil {
+ return nil, ErrTargetRepoMissing
+ }
+
+ return repo, nil
+ default:
+ return nil, fmt.Errorf("repository message has unexpected type %T", fieldMsg)
+ }
+}
+
+// Storage returns the storage name for a protobuf message if it exists
+func (mi MethodInfo) Storage(msg proto.Message) (string, error) {
+ field, err := mi.getStorageField(msg)
+ if err != nil {
+ return "", err
+ }
+
+ return field.value.String(), nil
+}
+
+// SetStorage sets the storage name for a protobuf message
+func (mi MethodInfo) SetStorage(msg proto.Message, storage string) error {
+ field, err := mi.getStorageField(msg)
+ if err != nil {
+ return err
+ }
+
+ msg.ProtoReflect().Set(field.desc, protoreflect.ValueOfString(storage))
+
+ return nil
+}
+
+func (mi MethodInfo) getStorageField(msg proto.Message) (valueField, error) {
+ if mi.requestName != string(proto.MessageName(msg)) {
+ return valueField{}, fmt.Errorf(
+ "proto message %s does not match expected RPC request message %s",
+ proto.MessageName(msg), mi.requestName,
+ )
+ }
+
+ field, err := findFieldByExtension(msg, gitalypb.E_Storage)
+ if err != nil {
+ if errors.Is(err, errFieldNotFound) {
+ return valueField{}, fmt.Errorf("target storage field not found")
+ }
+ return valueField{}, err
+ }
+
+ if field.desc.Kind() != protoreflect.StringKind {
+ return valueField{}, fmt.Errorf("expected string, got %s", field.desc.Kind().String())
+ }
+
+ return field, nil
+}
+
+// UnmarshalRequestProto will unmarshal the bytes into the method's request
+// message type
+func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) {
+ return mi.requestFactory(b)
+}
+
+type protoFactory func([]byte) (proto.Message, error)
+
+func methodReqFactory(method *descriptorpb.MethodDescriptorProto) (protoFactory, error) {
+ // for some reason, the descriptor prepends a dot not expected in Go
+ inputTypeName := strings.TrimPrefix(method.GetInputType(), ".")
+
+ inputType, err := protoreg.GlobalTypes.FindMessageByName(protoreflect.FullName(inputTypeName))
+ if err != nil {
+ return nil, fmt.Errorf("no message type found for %w", err)
+ }
+
+ f := func(buf []byte) (proto.Message, error) {
+ pb := inputType.New().Interface()
+ if err := proto.Unmarshal(buf, pb); err != nil {
+ return nil, err
+ }
+
+ return pb, nil
+ }
+
+ return f, nil
+}
+
+func parseMethodInfo(
+ p *descriptorpb.FileDescriptorProto,
+ methodDesc *descriptorpb.MethodDescriptorProto,
+ fullMethodName string,
+) (MethodInfo, error) {
+ opMsg, err := protoutil.GetOpExtension(methodDesc)
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
+ var opCode OpType
+
+ switch opMsg.GetOp() {
+ case gitalypb.OperationMsg_ACCESSOR:
+ opCode = OpAccessor
+ case gitalypb.OperationMsg_MUTATOR:
+ opCode = OpMutator
+ case gitalypb.OperationMsg_MAINTENANCE:
+ opCode = OpMaintenance
+ default:
+ opCode = OpUnknown
+ }
+
+ // for some reason, the protobuf descriptor contains an extra dot in front
+ // of the request name that the generated code does not. This trimming keeps
+ // the two copies consistent for comparisons.
+ requestName := strings.TrimLeft(methodDesc.GetInputType(), ".")
+
+ reqFactory, err := methodReqFactory(methodDesc)
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
+ scope, ok := protoScope[opMsg.GetScopeLevel()]
+ if !ok {
+ return MethodInfo{}, fmt.Errorf("encountered unknown method scope %d", opMsg.GetScopeLevel())
+ }
+
+ mi := MethodInfo{
+ Operation: opCode,
+ Scope: scope,
+ requestName: requestName,
+ requestFactory: reqFactory,
+ fullMethodName: fullMethodName,
+ }
+
+ return mi, nil
+}
+
+type valueField struct {
+ desc protoreflect.FieldDescriptor
+ value protoreflect.Value
+}
+
+// findFieldsByExtension will search through all populated fields and returns all of those which
+// have the given extension type set.
+func findFieldsByExtension(msg proto.Message, extensionType protoreflect.ExtensionType) ([]valueField, error) {
+ var valueFields []valueField
+
+ if err := (protorange.Options{Stable: true}).Range(msg.ProtoReflect(), func(values protopath.Values) error {
+ value := values.Index(-1)
+
+ fieldDescriptor := value.Step.FieldDescriptor()
+ if fieldDescriptor == nil {
+ return nil
+ }
+
+ opts := fieldDescriptor.Options().(*descriptorpb.FieldOptions)
+ if !proto.HasExtension(opts, extensionType) {
+ return nil
+ }
+
+ valueFields = append(valueFields, valueField{
+ desc: fieldDescriptor,
+ value: value.Value,
+ })
+
+ return nil
+ }, nil); err != nil {
+ return nil, fmt.Errorf("ranging over message: %w", err)
+ }
+
+ return valueFields, nil
+}
+
+var (
+ errFieldNotFound = errors.New("field not found")
+ errFieldAmbiguous = errors.New("field is ambiguous")
+)
+
+// findFieldByExtension is a wrapper around findFieldsByExtension that returns a single field
+// descriptor, only. Returns a errFieldNotFound error in case the field wasn't found, and a
+// errFieldAmbiguous error in case there are multiple fields with the same extension.
+func findFieldByExtension(msg proto.Message, extensionType protoreflect.ExtensionType) (valueField, error) {
+ fields, err := findFieldsByExtension(msg, extensionType)
+ if err != nil {
+ return valueField{}, err
+ }
+
+ switch len(fields) {
+ case 1:
+ return fields[0], nil
+ case 0:
+ return valueField{}, errFieldNotFound
+ default:
+ return valueField{}, errFieldAmbiguous
+ }
+}
diff --git a/internal/praefect/protoregistry/method_info_test.go b/internal/praefect/protoregistry/method_info_test.go
new file mode 100644
index 000000000..1532a4d99
--- /dev/null
+++ b/internal/praefect/protoregistry/method_info_test.go
@@ -0,0 +1,434 @@
+package protoregistry
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+func TestMethodInfo_getRepo(t *testing.T) {
+ t.Parallel()
+
+ testRepos := []*gitalypb.Repository{
+ {
+ GitAlternateObjectDirectories: []string{"a", "b", "c"},
+ GitObjectDirectory: "d",
+ GlProjectPath: "e",
+ GlRepository: "f",
+ RelativePath: "g",
+ StorageName: "h",
+ },
+ {
+ GitAlternateObjectDirectories: []string{"1", "2", "3"},
+ GitObjectDirectory: "4",
+ GlProjectPath: "5",
+ GlRepository: "6",
+ RelativePath: "7",
+ StorageName: "8",
+ },
+ }
+
+ testcases := []struct {
+ desc string
+ svc string
+ method string
+ pbMsg proto.Message
+ expectRepo *gitalypb.Repository
+ expectErr error
+ expectAdditionalRepo *gitalypb.Repository
+ expectAdditionalErr error
+ }{
+ {
+ desc: "valid request type single depth",
+ svc: "RepositoryService",
+ method: "OptimizeRepository",
+ pbMsg: &gitalypb.OptimizeRepositoryRequest{
+ Repository: testRepos[0],
+ },
+ expectRepo: testRepos[0],
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "unset oneof",
+ svc: "OperationService",
+ method: "UserCommitFiles",
+ pbMsg: &gitalypb.UserCommitFilesRequest{},
+ expectErr: ErrTargetRepoMissing,
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "unset value in oneof",
+ svc: "OperationService",
+ method: "UserCommitFiles",
+ pbMsg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{},
+ },
+ expectErr: ErrTargetRepoMissing,
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "unset repository in oneof",
+ svc: "OperationService",
+ method: "UserCommitFiles",
+ pbMsg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{
+ Header: &gitalypb.UserCommitFilesRequestHeader{},
+ },
+ },
+ expectErr: ErrTargetRepoMissing,
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "target nested in oneOf",
+ svc: "OperationService",
+ method: "UserCommitFiles",
+ pbMsg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{
+ Header: &gitalypb.UserCommitFilesRequestHeader{
+ Repository: testRepos[1],
+ },
+ },
+ },
+ expectRepo: testRepos[1],
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "target nested, includes additional repository",
+ svc: "ObjectPoolService",
+ method: "FetchIntoObjectPool",
+ pbMsg: &gitalypb.FetchIntoObjectPoolRequest{
+ Origin: testRepos[0],
+ ObjectPool: &gitalypb.ObjectPool{Repository: testRepos[1]},
+ },
+ expectRepo: testRepos[1],
+ expectAdditionalRepo: testRepos[0],
+ },
+ {
+ desc: "target repo is nil",
+ svc: "RepositoryService",
+ method: "OptimizeRepository",
+ pbMsg: &gitalypb.OptimizeRepositoryRequest{Repository: nil},
+ expectErr: ErrTargetRepoMissing,
+ expectAdditionalErr: ErrTargetRepoMissing,
+ },
+ }
+
+ for _, tc := range testcases {
+ t.Run(tc.desc, func(t *testing.T) {
+ info, err := GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
+ require.NoError(t, err)
+
+ t.Run("TargetRepo", func(t *testing.T) {
+ repo, err := info.TargetRepo(tc.pbMsg)
+ require.Equal(t, tc.expectErr, err)
+ require.Same(t, tc.expectRepo, repo)
+ })
+
+ t.Run("AdditionalRepo", func(t *testing.T) {
+ additionalRepo, err := info.AdditionalRepo(tc.pbMsg)
+ require.Equal(t, tc.expectAdditionalErr, err)
+ require.Same(t, tc.expectAdditionalRepo, additionalRepo)
+ })
+ })
+ }
+}
+
+func TestMethodInfo_Storage(t *testing.T) {
+ t.Parallel()
+
+ testcases := []struct {
+ desc string
+ svc string
+ method string
+ pbMsg proto.Message
+ expectStorage string
+ expectErr error
+ }{
+ {
+ desc: "valid request type single depth",
+ svc: "NamespaceService",
+ method: "AddNamespace",
+ pbMsg: &gitalypb.AddNamespaceRequest{
+ StorageName: "some_storage",
+ },
+ expectStorage: "some_storage",
+ },
+ {
+ desc: "incorrect request type",
+ svc: "RepositoryService",
+ method: "OptimizeRepository",
+ pbMsg: &gitalypb.OptimizeRepositoryResponse{},
+ expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"),
+ },
+ }
+
+ for _, tc := range testcases {
+ desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc)
+ t.Run(desc, func(t *testing.T) {
+ info, err := GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
+ require.NoError(t, err)
+
+ actualStorage, actualErr := info.Storage(tc.pbMsg)
+ require.Equal(t, tc.expectErr, actualErr)
+
+ // not only do we want the value to be the same, but we actually want the
+ // exact same instance to be returned
+ if tc.expectStorage != actualStorage {
+ t.Fatal("pointers do not match")
+ }
+ })
+ }
+}
+
+func TestMethodInfo_SetStorage(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ desc string
+ service string
+ method string
+ pbMsg proto.Message
+ storage string
+ expectErr error
+ }{
+ {
+ desc: "valid request type",
+ service: "NamespaceService",
+ method: "AddNamespace",
+ pbMsg: &gitalypb.AddNamespaceRequest{
+ StorageName: "old_storage",
+ },
+ storage: "new_storage",
+ },
+ {
+ desc: "incorrect request type",
+ service: "RepositoryService",
+ method: "OptimizeRepository",
+ pbMsg: &gitalypb.OptimizeRepositoryResponse{},
+ expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ info, err := GitalyProtoPreregistered.LookupMethod("/gitaly." + tc.service + "/" + tc.method)
+ require.NoError(t, err)
+
+ err = info.SetStorage(tc.pbMsg, tc.storage)
+ if tc.expectErr == nil {
+ require.NoError(t, err)
+ changed, err := info.Storage(tc.pbMsg)
+ require.NoError(t, err)
+ require.Equal(t, tc.storage, changed)
+ } else {
+ require.Equal(t, tc.expectErr, err)
+ }
+ })
+ }
+}
+
+func TestMethodInfo_RequestFactory(t *testing.T) {
+ t.Parallel()
+
+ mInfo, err := GitalyProtoPreregistered.LookupMethod("/gitaly.RepositoryService/RepositoryExists")
+ require.NoError(t, err)
+
+ pb, err := mInfo.UnmarshalRequestProto([]byte{})
+ require.NoError(t, err)
+
+ testhelper.ProtoEqual(t, &gitalypb.RepositoryExistsRequest{}, pb)
+}
+
+func TestMethodInfoScope(t *testing.T) {
+ for _, tt := range []struct {
+ method string
+ scope Scope
+ }{
+ {
+ method: "/gitaly.RepositoryService/RepositoryExists",
+ scope: ScopeRepository,
+ },
+ } {
+ t.Run(tt.method, func(t *testing.T) {
+ mInfo, err := GitalyProtoPreregistered.LookupMethod(tt.method)
+ require.NoError(t, err)
+
+ require.Exactly(t, tt.scope, mInfo.Scope)
+ })
+ }
+}
+
+func TestFindFieldsByExtension(t *testing.T) {
+ t.Parallel()
+
+ repo := &gitalypb.Repository{StorageName: "storage", RelativePath: "relative-path"}
+
+ for _, tc := range []struct {
+ desc string
+ msg proto.Message
+ extension protoreflect.ExtensionType
+ expectedFields []string
+ }{
+ {
+ desc: "unset field does not match",
+ msg: &gitalypb.OptimizeRepositoryRequest{
+ Repository: nil,
+ },
+ extension: gitalypb.E_TargetRepository,
+ },
+ {
+ desc: "matching field",
+ msg: &gitalypb.OptimizeRepositoryRequest{
+ Repository: repo,
+ },
+ extension: gitalypb.E_TargetRepository,
+ expectedFields: []string{
+ "gitaly.OptimizeRepositoryRequest.repository",
+ },
+ },
+ {
+ desc: "no matching field",
+ msg: &gitalypb.OptimizeRepositoryRequest{
+ Repository: repo,
+ },
+ extension: gitalypb.E_AdditionalRepository,
+ },
+ {
+ desc: "multiple fields with distinct extensions",
+ msg: &gitalypb.FetchIntoObjectPoolRequest{
+ Origin: repo,
+ ObjectPool: &gitalypb.ObjectPool{Repository: repo},
+ },
+ extension: gitalypb.E_AdditionalRepository,
+ expectedFields: []string{
+ "gitaly.FetchIntoObjectPoolRequest.origin",
+ },
+ },
+ {
+ desc: "matching field in unset oneOf",
+ msg: &gitalypb.UserCommitFilesRequest{},
+ extension: gitalypb.E_TargetRepository,
+ },
+ {
+ desc: "matching field in empty oneOf",
+ msg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{},
+ },
+ extension: gitalypb.E_TargetRepository,
+ },
+ {
+ desc: "matching field with unset oneOf repository",
+ msg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{
+ Header: &gitalypb.UserCommitFilesRequestHeader{},
+ },
+ },
+ extension: gitalypb.E_TargetRepository,
+ },
+ {
+ desc: "matching field in oneOf",
+ msg: &gitalypb.UserCommitFilesRequest{
+ UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{
+ Header: &gitalypb.UserCommitFilesRequestHeader{
+ Repository: repo,
+ },
+ },
+ },
+ extension: gitalypb.E_TargetRepository,
+ expectedFields: []string{
+ "gitaly.UserCommitFilesRequestHeader.repository",
+ },
+ },
+ } {
+ tc := tc
+
+ t.Run(tc.desc, func(t *testing.T) {
+ t.Parallel()
+
+ t.Run("findFieldsByExtension", func(t *testing.T) {
+ fields, err := findFieldsByExtension(tc.msg, tc.extension)
+ require.NoError(t, err)
+
+ var fieldNames []string
+ for _, field := range fields {
+ fieldNames = append(fieldNames, string(field.desc.FullName()))
+ }
+ require.Equal(t, tc.expectedFields, fieldNames)
+ })
+
+ t.Run("findFieldByExtension", func(t *testing.T) {
+ field, err := findFieldByExtension(tc.msg, tc.extension)
+
+ switch len(tc.expectedFields) {
+ case 0:
+ require.Equal(t, valueField{}, field)
+ require.Equal(t, errFieldNotFound, err)
+ case 1:
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedFields[0], string(field.desc.FullName()))
+ default:
+ require.Equal(t, valueField{}, field)
+ require.Equal(t, errFieldAmbiguous, err)
+ }
+ })
+ })
+ }
+}
+
+func BenchmarkMethodInfo(b *testing.B) {
+ for _, bc := range []struct {
+ desc string
+ method string
+ request proto.Message
+ expectedErr error
+ }{
+ {
+ desc: "unset target repository",
+ method: "/gitaly.RepositoryService/OptimizeRepository",
+ request: &gitalypb.OptimizeRepositoryRequest{
+ Repository: nil,
+ },
+ expectedErr: ErrTargetRepoMissing,
+ },
+ {
+ desc: "target repository",
+ method: "/gitaly.RepositoryService/OptimizeRepository",
+ request: &gitalypb.OptimizeRepositoryRequest{
+ Repository: &gitalypb.Repository{
+ StorageName: "something",
+ RelativePath: "something",
+ },
+ },
+ },
+ {
+ desc: "target object pool",
+ method: "/gitaly.ObjectPoolService/FetchIntoObjectPool",
+ request: &gitalypb.FetchIntoObjectPoolRequest{
+ ObjectPool: &gitalypb.ObjectPool{
+ Repository: &gitalypb.Repository{
+ StorageName: "something",
+ RelativePath: "something",
+ },
+ },
+ },
+ },
+ } {
+ b.Run(bc.desc, func(b *testing.B) {
+ mi, err := GitalyProtoPreregistered.LookupMethod(bc.method)
+ require.NoError(b, err)
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ _, err := mi.TargetRepo(bc.request)
+ require.Equal(b, bc.expectedErr, err)
+ }
+ })
+ }
+}
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go
index 0399bf885..7fed1831f 100644
--- a/internal/praefect/protoregistry/protoregistry.go
+++ b/internal/praefect/protoregistry/protoregistry.go
@@ -2,13 +2,10 @@ package protoregistry
import (
"fmt"
- "strings"
"gitlab.com/gitlab-org/gitaly/v15/internal/protoutil"
"gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
- "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
- "google.golang.org/protobuf/reflect/protoreflect"
protoreg "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
)
@@ -25,138 +22,6 @@ func init() {
}
}
-// OpType represents the operation type for a RPC method
-type OpType int
-
-const (
- // OpUnknown = unknown operation type
- OpUnknown OpType = iota
- // OpAccessor = accessor operation type (ready only)
- OpAccessor
- // OpMutator = mutator operation type (modifies a repository)
- OpMutator
- // OpMaintenance is an operation which performs maintenance-tasks on the repository. It
- // shouldn't ever result in a user-visible change in behaviour, except that it may repair
- // corrupt data.
- OpMaintenance
-)
-
-// Scope represents the intended scope of an RPC method
-type Scope int
-
-const (
- // ScopeUnknown is the default scope until determined otherwise
- ScopeUnknown Scope = iota
- // ScopeRepository indicates an RPC's scope is limited to a repository
- ScopeRepository
- // ScopeStorage indicates an RPC is scoped to an entire storage location
- ScopeStorage
-)
-
-func (s Scope) String() string {
- switch s {
- case ScopeStorage:
- return "storage"
- case ScopeRepository:
- return "repository"
- default:
- return fmt.Sprintf("N/A: %d", s)
- }
-}
-
-var protoScope = map[gitalypb.OperationMsg_Scope]Scope{
- gitalypb.OperationMsg_REPOSITORY: ScopeRepository,
- gitalypb.OperationMsg_STORAGE: ScopeStorage,
-}
-
-// MethodInfo contains metadata about the RPC method. Refer to documentation
-// for message type "OperationMsg" shared.proto in ./proto for
-// more documentation.
-type MethodInfo struct {
- Operation OpType
- Scope Scope
- targetRepo []int
- additionalRepo []int
- requestName string // protobuf message name for input type
- requestFactory protoFactory
- storage []int
- fullMethodName string
-}
-
-// TargetRepo returns the target repository for a protobuf message if it exists
-func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error) {
- return mi.getRepo(msg, mi.targetRepo)
-}
-
-// AdditionalRepo returns the additional repository for a protobuf message that needs a storage rewritten
-// if it exists
-func (mi MethodInfo) AdditionalRepo(msg proto.Message) (*gitalypb.Repository, bool, error) {
- if mi.additionalRepo == nil {
- return nil, false, nil
- }
-
- repo, err := mi.getRepo(msg, mi.additionalRepo)
-
- return repo, true, err
-}
-
-//nolint:revive // This is unintentionally missing documentation.
-func (mi MethodInfo) FullMethodName() string {
- return mi.fullMethodName
-}
-
-func (mi MethodInfo) getRepo(msg proto.Message, targetOid []int) (*gitalypb.Repository, error) {
- if mi.requestName != string(proto.MessageName(msg)) {
- return nil, fmt.Errorf(
- "proto message %s does not match expected RPC request message %s",
- proto.MessageName(msg), mi.requestName,
- )
- }
-
- repo, err := reflectFindRepoTarget(msg, targetOid)
- switch {
- case err != nil:
- return nil, err
- case repo == nil:
- // it is possible for the target repo to not be set (especially in our unit
- // tests designed to fail and this should return an error to prevent nil
- // pointer dereferencing
- return nil, ErrTargetRepoMissing
- default:
- return repo, nil
- }
-}
-
-// Storage returns the storage name for a protobuf message if it exists
-func (mi MethodInfo) Storage(msg proto.Message) (string, error) {
- if mi.requestName != string(proto.MessageName(msg)) {
- return "", fmt.Errorf(
- "proto message %s does not match expected RPC request message %s",
- proto.MessageName(msg), mi.requestName,
- )
- }
-
- return reflectFindStorage(msg, mi.storage)
-}
-
-// SetStorage sets the storage name for a protobuf message
-func (mi MethodInfo) SetStorage(msg proto.Message, storage string) error {
- if mi.requestName != string(proto.MessageName(msg)) {
- return fmt.Errorf(
- "proto message %s does not match expected RPC request message %s",
- proto.MessageName(msg), mi.requestName,
- )
- }
-
- return reflectSetStorage(msg, mi.storage, storage)
-}
-
-// UnmarshalRequestProto will unmarshal the bytes into the method's request
-// message type
-func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) {
- return mi.requestFactory(b)
-}
-
// Registry contains info about RPC methods
type Registry struct {
protos map[string]MethodInfo
@@ -214,239 +79,6 @@ func NewFromPaths(paths ...string) (*Registry, error) {
return New(fds...)
}
-type protoFactory func([]byte) (proto.Message, error)
-
-func methodReqFactory(method *descriptorpb.MethodDescriptorProto) (protoFactory, error) {
- // for some reason, the descriptor prepends a dot not expected in Go
- inputTypeName := strings.TrimPrefix(method.GetInputType(), ".")
-
- inputType, err := protoreg.GlobalTypes.FindMessageByName(protoreflect.FullName(inputTypeName))
- if err != nil {
- return nil, fmt.Errorf("no message type found for %w", err)
- }
-
- f := func(buf []byte) (proto.Message, error) {
- pb := inputType.New().Interface()
- if err := proto.Unmarshal(buf, pb); err != nil {
- return nil, err
- }
-
- return pb, nil
- }
-
- return f, nil
-}
-
-func parseMethodInfo(
- p *descriptorpb.FileDescriptorProto,
- methodDesc *descriptorpb.MethodDescriptorProto,
- fullMethodName string,
-) (MethodInfo, error) {
- opMsg, err := protoutil.GetOpExtension(methodDesc)
- if err != nil {
- return MethodInfo{}, err
- }
-
- var opCode OpType
-
- switch opMsg.GetOp() {
- case gitalypb.OperationMsg_ACCESSOR:
- opCode = OpAccessor
- case gitalypb.OperationMsg_MUTATOR:
- opCode = OpMutator
- case gitalypb.OperationMsg_MAINTENANCE:
- opCode = OpMaintenance
- default:
- opCode = OpUnknown
- }
-
- // for some reason, the protobuf descriptor contains an extra dot in front
- // of the request name that the generated code does not. This trimming keeps
- // the two copies consistent for comparisons.
- requestName := strings.TrimLeft(methodDesc.GetInputType(), ".")
-
- reqFactory, err := methodReqFactory(methodDesc)
- if err != nil {
- return MethodInfo{}, err
- }
-
- scope, ok := protoScope[opMsg.GetScopeLevel()]
- if !ok {
- return MethodInfo{}, fmt.Errorf("encountered unknown method scope %d", opMsg.GetScopeLevel())
- }
-
- mi := MethodInfo{
- Operation: opCode,
- Scope: scope,
- requestName: requestName,
- requestFactory: reqFactory,
- fullMethodName: fullMethodName,
- }
-
- topLevelMsgs, err := getTopLevelMsgs(p)
- if err != nil {
- return MethodInfo{}, err
- }
-
- typeName, err := lastName(methodDesc.GetInputType())
- if err != nil {
- return MethodInfo{}, err
- }
-
- if scope == ScopeRepository {
- m := matcher{
- match: protoutil.GetTargetRepositoryExtension,
- subMatch: protoutil.GetRepositoryExtension,
- expectedType: ".gitaly.Repository",
- topLevelMsgs: topLevelMsgs,
- }
-
- targetRepo, err := m.findField(topLevelMsgs[typeName])
- if err != nil {
- return MethodInfo{}, err
- }
- if targetRepo == nil {
- return MethodInfo{}, fmt.Errorf("unable to find target repository for method: %s", requestName)
- }
- mi.targetRepo = targetRepo
-
- m.match = protoutil.GetAdditionalRepositoryExtension
- additionalRepo, err := m.findField(topLevelMsgs[typeName])
- if err != nil {
- return MethodInfo{}, err
- }
- mi.additionalRepo = additionalRepo
- } else if scope == ScopeStorage {
- m := matcher{
- match: protoutil.GetStorageExtension,
- topLevelMsgs: topLevelMsgs,
- }
- storage, err := m.findField(topLevelMsgs[typeName])
- if err != nil {
- return MethodInfo{}, err
- }
- if storage == nil {
- return MethodInfo{}, fmt.Errorf("unable to find storage for method: %s", requestName)
- }
- mi.storage = storage
- }
-
- return mi, nil
-}
-
-func getFileTypes(filename string) ([]*descriptorpb.DescriptorProto, error) {
- fd, err := protoreg.GlobalFiles.FindFileByPath(filename)
- if err != nil {
- return nil, err
- }
- sharedFD := protodesc.ToFileDescriptorProto(fd)
-
- types := sharedFD.GetMessageType()
-
- for _, dep := range sharedFD.Dependency {
- depTypes, err := getFileTypes(dep)
- if err != nil {
- return nil, err
- }
- types = append(types, depTypes...)
- }
-
- return types, nil
-}
-
-func getTopLevelMsgs(p *descriptorpb.FileDescriptorProto) (map[string]*descriptorpb.DescriptorProto, error) {
- topLevelMsgs := map[string]*descriptorpb.DescriptorProto{}
- types, err := getFileTypes(p.GetName())
- if err != nil {
- return nil, err
- }
- for _, msg := range types {
- topLevelMsgs[msg.GetName()] = msg
- }
- return topLevelMsgs, nil
-}
-
-// Matcher helps find field matching credentials. At first match method is used to check fields
-// recursively. Then if field matches but type don't match expectedType subMatch method is used
-// from this point. This matcher assumes that only one field in the message matches the credentials.
-type matcher struct {
- match func(*descriptorpb.FieldDescriptorProto) (bool, error)
- subMatch func(*descriptorpb.FieldDescriptorProto) (bool, error)
- expectedType string // fully qualified name of expected type e.g. ".gitaly.Repository"
- topLevelMsgs map[string]*descriptorpb.DescriptorProto // Map of all top level messages in given file and it dependencies. Result of getTopLevelMsgs should be used.
-}
-
-func (m matcher) findField(t *descriptorpb.DescriptorProto) ([]int, error) {
- for _, f := range t.GetField() {
- match, err := m.match(f)
- if err != nil {
- return nil, err
- }
- if match {
- if f.GetTypeName() == m.expectedType {
- return []int{int(f.GetNumber())}, nil
- } else if m.subMatch != nil {
- m.match = m.subMatch
- m.subMatch = nil
- } else {
- return nil, fmt.Errorf("found wrong type, expected: %s, got: %s", m.expectedType, f.GetTypeName())
- }
- }
-
- childMsg, err := findChildMsg(m.topLevelMsgs, t, f)
- if err != nil {
- return nil, err
- }
-
- if childMsg != nil {
- nestedField, err := m.findField(childMsg)
- if err != nil {
- return nil, err
- }
- if nestedField != nil {
- return append([]int{int(f.GetNumber())}, nestedField...), nil
- }
- }
- }
- return nil, nil
-}
-
-func findChildMsg(topLevelMsgs map[string]*descriptorpb.DescriptorProto, t *descriptorpb.DescriptorProto, f *descriptorpb.FieldDescriptorProto) (*descriptorpb.DescriptorProto, error) {
- var childType *descriptorpb.DescriptorProto
- const msgPrimitive = "TYPE_MESSAGE"
- if primitive := f.GetType().String(); primitive != msgPrimitive {
- return nil, nil
- }
-
- msgName, err := lastName(f.GetTypeName())
- if err != nil {
- return nil, err
- }
-
- for _, nestedType := range t.GetNestedType() {
- if msgName == nestedType.GetName() {
- return nestedType, nil
- }
- }
-
- if childType = topLevelMsgs[msgName]; childType != nil {
- return childType, nil
- }
-
- return nil, fmt.Errorf("could not find message type %q", msgName)
-}
-
-func lastName(inputType string) (string, error) {
- tokens := strings.Split(inputType, ".")
-
- msgName := tokens[len(tokens)-1]
- if msgName == "" {
- return "", fmt.Errorf("unable to parse method input type: %s", inputType)
- }
-
- return msgName, nil
-}
-
// LookupMethod looks up an MethodInfo by service and method name
func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) {
methodInfo, ok := pr.protos[fullMethodName]
diff --git a/internal/praefect/protoregistry/protoregistry_test.go b/internal/praefect/protoregistry/protoregistry_test.go
index d6aabf834..9e448c443 100644
--- a/internal/praefect/protoregistry/protoregistry_test.go
+++ b/internal/praefect/protoregistry/protoregistry_test.go
@@ -1,143 +1,142 @@
-package protoregistry_test
+package protoregistry
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitaly/v15/internal/praefect/protoregistry"
- "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper"
- "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
)
func TestNewProtoRegistry(t *testing.T) {
- expectedResults := map[string]map[string]protoregistry.OpType{
+ t.Parallel()
+
+ expectedResults := map[string]map[string]OpType{
"BlobService": {
- "GetBlob": protoregistry.OpAccessor,
- "GetBlobs": protoregistry.OpAccessor,
- "GetLFSPointers": protoregistry.OpAccessor,
+ "GetBlob": OpAccessor,
+ "GetBlobs": OpAccessor,
+ "GetLFSPointers": OpAccessor,
},
"CleanupService": {
- "ApplyBfgObjectMapStream": protoregistry.OpMutator,
+ "ApplyBfgObjectMapStream": OpMutator,
},
"CommitService": {
- "CommitIsAncestor": protoregistry.OpAccessor,
- "CommitLanguages": protoregistry.OpAccessor,
- "CommitStats": protoregistry.OpAccessor,
- "CommitsByMessage": protoregistry.OpAccessor,
- "CountCommits": protoregistry.OpAccessor,
- "CountDivergingCommits": protoregistry.OpAccessor,
- "FilterShasWithSignatures": protoregistry.OpAccessor,
- "FindAllCommits": protoregistry.OpAccessor,
- "FindCommit": protoregistry.OpAccessor,
- "FindCommits": protoregistry.OpAccessor,
- "GetTreeEntries": protoregistry.OpAccessor,
- "LastCommitForPath": protoregistry.OpAccessor,
- "ListCommitsByOid": protoregistry.OpAccessor,
- "ListFiles": protoregistry.OpAccessor,
- "ListLastCommitsForTree": protoregistry.OpAccessor,
- "RawBlame": protoregistry.OpAccessor,
- "TreeEntry": protoregistry.OpAccessor,
+ "CommitIsAncestor": OpAccessor,
+ "CommitLanguages": OpAccessor,
+ "CommitStats": OpAccessor,
+ "CommitsByMessage": OpAccessor,
+ "CountCommits": OpAccessor,
+ "CountDivergingCommits": OpAccessor,
+ "FilterShasWithSignatures": OpAccessor,
+ "FindAllCommits": OpAccessor,
+ "FindCommit": OpAccessor,
+ "FindCommits": OpAccessor,
+ "GetTreeEntries": OpAccessor,
+ "LastCommitForPath": OpAccessor,
+ "ListCommitsByOid": OpAccessor,
+ "ListFiles": OpAccessor,
+ "ListLastCommitsForTree": OpAccessor,
+ "RawBlame": OpAccessor,
+ "TreeEntry": OpAccessor,
},
"ConflictsService": {
- "ListConflictFiles": protoregistry.OpAccessor,
- "ResolveConflicts": protoregistry.OpMutator,
+ "ListConflictFiles": OpAccessor,
+ "ResolveConflicts": OpMutator,
},
"DiffService": {
- "CommitDelta": protoregistry.OpAccessor,
- "CommitDiff": protoregistry.OpAccessor,
- "DiffStats": protoregistry.OpAccessor,
- "RawDiff": protoregistry.OpAccessor,
- "RawPatch": protoregistry.OpAccessor,
+ "CommitDelta": OpAccessor,
+ "CommitDiff": OpAccessor,
+ "DiffStats": OpAccessor,
+ "RawDiff": OpAccessor,
+ "RawPatch": OpAccessor,
},
"NamespaceService": {
- "AddNamespace": protoregistry.OpMutator,
- "NamespaceExists": protoregistry.OpAccessor,
- "RemoveNamespace": protoregistry.OpMutator,
- "RenameNamespace": protoregistry.OpMutator,
+ "AddNamespace": OpMutator,
+ "NamespaceExists": OpAccessor,
+ "RemoveNamespace": OpMutator,
+ "RenameNamespace": OpMutator,
},
"ObjectPoolService": {
- "CreateObjectPool": protoregistry.OpMutator,
- "DeleteObjectPool": protoregistry.OpMutator,
- "DisconnectGitAlternates": protoregistry.OpMutator,
- "LinkRepositoryToObjectPool": protoregistry.OpMutator,
+ "CreateObjectPool": OpMutator,
+ "DeleteObjectPool": OpMutator,
+ "DisconnectGitAlternates": OpMutator,
+ "LinkRepositoryToObjectPool": OpMutator,
},
"OperationService": {
- "UserApplyPatch": protoregistry.OpMutator,
- "UserCherryPick": protoregistry.OpMutator,
- "UserCommitFiles": protoregistry.OpMutator,
- "UserCreateBranch": protoregistry.OpMutator,
- "UserCreateTag": protoregistry.OpMutator,
- "UserDeleteBranch": protoregistry.OpMutator,
- "UserDeleteTag": protoregistry.OpMutator,
- "UserFFBranch": protoregistry.OpMutator,
- "UserMergeBranch": protoregistry.OpMutator,
- "UserMergeToRef": protoregistry.OpMutator,
- "UserRevert": protoregistry.OpMutator,
- "UserSquash": protoregistry.OpMutator,
- "UserUpdateBranch": protoregistry.OpMutator,
- "UserUpdateSubmodule": protoregistry.OpMutator,
+ "UserApplyPatch": OpMutator,
+ "UserCherryPick": OpMutator,
+ "UserCommitFiles": OpMutator,
+ "UserCreateBranch": OpMutator,
+ "UserCreateTag": OpMutator,
+ "UserDeleteBranch": OpMutator,
+ "UserDeleteTag": OpMutator,
+ "UserFFBranch": OpMutator,
+ "UserMergeBranch": OpMutator,
+ "UserMergeToRef": OpMutator,
+ "UserRevert": OpMutator,
+ "UserSquash": OpMutator,
+ "UserUpdateBranch": OpMutator,
+ "UserUpdateSubmodule": OpMutator,
},
"RefService": {
- "DeleteRefs": protoregistry.OpMutator,
- "FindAllBranchNames": protoregistry.OpAccessor,
- "FindAllBranches": protoregistry.OpAccessor,
- "FindAllRemoteBranches": protoregistry.OpAccessor,
- "FindAllTagNames": protoregistry.OpAccessor,
- "FindAllTags": protoregistry.OpAccessor,
- "FindBranch": protoregistry.OpAccessor,
- "FindDefaultBranchName": protoregistry.OpAccessor,
- "FindLocalBranches": protoregistry.OpAccessor,
- "GetTagMessages": protoregistry.OpAccessor,
- "ListBranchNamesContainingCommit": protoregistry.OpAccessor,
- "ListTagNamesContainingCommit": protoregistry.OpAccessor,
- "RefExists": protoregistry.OpAccessor,
+ "DeleteRefs": OpMutator,
+ "FindAllBranchNames": OpAccessor,
+ "FindAllBranches": OpAccessor,
+ "FindAllRemoteBranches": OpAccessor,
+ "FindAllTagNames": OpAccessor,
+ "FindAllTags": OpAccessor,
+ "FindBranch": OpAccessor,
+ "FindDefaultBranchName": OpAccessor,
+ "FindLocalBranches": OpAccessor,
+ "GetTagMessages": OpAccessor,
+ "ListBranchNamesContainingCommit": OpAccessor,
+ "ListTagNamesContainingCommit": OpAccessor,
+ "RefExists": OpAccessor,
},
"RemoteService": {
- "FindRemoteRepository": protoregistry.OpAccessor,
- "FindRemoteRootRef": protoregistry.OpAccessor,
- "UpdateRemoteMirror": protoregistry.OpAccessor,
+ "FindRemoteRepository": OpAccessor,
+ "FindRemoteRootRef": OpAccessor,
+ "UpdateRemoteMirror": OpAccessor,
},
"RepositoryService": {
- "ApplyGitattributes": protoregistry.OpMutator,
- "BackupCustomHooks": protoregistry.OpAccessor,
- "CalculateChecksum": protoregistry.OpAccessor,
- "CreateBundle": protoregistry.OpAccessor,
- "CreateFork": protoregistry.OpMutator,
- "CreateRepository": protoregistry.OpMutator,
- "CreateRepositoryFromBundle": protoregistry.OpMutator,
- "CreateRepositoryFromSnapshot": protoregistry.OpMutator,
- "CreateRepositoryFromURL": protoregistry.OpMutator,
- "FetchBundle": protoregistry.OpMutator,
- "FetchRemote": protoregistry.OpMutator,
- "FetchSourceBranch": protoregistry.OpMutator,
- "FindLicense": protoregistry.OpAccessor,
- "FindMergeBase": protoregistry.OpAccessor,
- "Fsck": protoregistry.OpAccessor,
- "GetArchive": protoregistry.OpAccessor,
- "GetInfoAttributes": protoregistry.OpAccessor,
- "GetRawChanges": protoregistry.OpAccessor,
- "GetSnapshot": protoregistry.OpAccessor,
- "HasLocalBranches": protoregistry.OpAccessor,
- "OptimizeRepository": protoregistry.OpMaintenance,
- "PruneUnreachableObjects": protoregistry.OpMaintenance,
- "RepositoryExists": protoregistry.OpAccessor,
- "RepositorySize": protoregistry.OpAccessor,
- "RestoreCustomHooks": protoregistry.OpMutator,
- "SearchFilesByContent": protoregistry.OpAccessor,
- "SearchFilesByName": protoregistry.OpAccessor,
- "WriteRef": protoregistry.OpMutator,
+ "ApplyGitattributes": OpMutator,
+ "BackupCustomHooks": OpAccessor,
+ "CalculateChecksum": OpAccessor,
+ "CreateBundle": OpAccessor,
+ "CreateFork": OpMutator,
+ "CreateRepository": OpMutator,
+ "CreateRepositoryFromBundle": OpMutator,
+ "CreateRepositoryFromSnapshot": OpMutator,
+ "CreateRepositoryFromURL": OpMutator,
+ "FetchBundle": OpMutator,
+ "FetchRemote": OpMutator,
+ "FetchSourceBranch": OpMutator,
+ "FindLicense": OpAccessor,
+ "FindMergeBase": OpAccessor,
+ "Fsck": OpAccessor,
+ "GetArchive": OpAccessor,
+ "GetInfoAttributes": OpAccessor,
+ "GetRawChanges": OpAccessor,
+ "GetSnapshot": OpAccessor,
+ "HasLocalBranches": OpAccessor,
+ "OptimizeRepository": OpMaintenance,
+ "PruneUnreachableObjects": OpMaintenance,
+ "RepositoryExists": OpAccessor,
+ "RepositorySize": OpAccessor,
+ "RestoreCustomHooks": OpMutator,
+ "SearchFilesByContent": OpAccessor,
+ "SearchFilesByName": OpAccessor,
+ "WriteRef": OpMutator,
},
"SmartHTTPService": {
- "InfoRefsReceivePack": protoregistry.OpAccessor,
- "InfoRefsUploadPack": protoregistry.OpAccessor,
- "PostReceivePack": protoregistry.OpMutator,
- "PostUploadPackWithSidechannel": protoregistry.OpAccessor,
+ "InfoRefsReceivePack": OpAccessor,
+ "InfoRefsUploadPack": OpAccessor,
+ "PostReceivePack": OpMutator,
+ "PostUploadPackWithSidechannel": OpAccessor,
},
"SSHService": {
- "SSHReceivePack": protoregistry.OpMutator,
- "SSHUploadArchive": protoregistry.OpAccessor,
- "SSHUploadPack": protoregistry.OpAccessor,
+ "SSHReceivePack": OpMutator,
+ "SSHUploadArchive": OpAccessor,
+ "SSHUploadPack": OpAccessor,
},
}
@@ -145,17 +144,19 @@ func TestNewProtoRegistry(t *testing.T) {
for methodName, opType := range methods {
method := fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName)
- methodInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(method)
+ methodInfo, err := GitalyProtoPreregistered.LookupMethod(method)
require.NoError(t, err)
require.Equalf(t, opType, methodInfo.Operation, "expect %s:%s to have the correct op type", serviceName, methodName)
require.Equal(t, method, methodInfo.FullMethodName())
- require.False(t, protoregistry.GitalyProtoPreregistered.IsInterceptedMethod(method), method)
+ require.False(t, GitalyProtoPreregistered.IsInterceptedMethod(method), method)
}
}
}
func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) {
+ t.Parallel()
+
for service, methods := range map[string][]string{
"ServerService": {
"ServerInfo",
@@ -175,8 +176,8 @@ func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) {
for _, method := range methods {
t.Run(method, func(t *testing.T) {
fullMethodName := fmt.Sprintf("/gitaly.%s/%s", service, method)
- require.True(t, protoregistry.GitalyProtoPreregistered.IsInterceptedMethod(fullMethodName))
- methodInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fullMethodName)
+ require.True(t, GitalyProtoPreregistered.IsInterceptedMethod(fullMethodName))
+ methodInfo, err := GitalyProtoPreregistered.LookupMethod(fullMethodName)
require.Empty(t, methodInfo)
require.Error(t, err, "full method name not found:")
})
@@ -184,32 +185,3 @@ func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) {
})
}
}
-
-func TestRequestFactory(t *testing.T) {
- mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod("/gitaly.RepositoryService/RepositoryExists")
- require.NoError(t, err)
-
- pb, err := mInfo.UnmarshalRequestProto([]byte{})
- require.NoError(t, err)
-
- testhelper.ProtoEqual(t, &gitalypb.RepositoryExistsRequest{}, pb)
-}
-
-func TestMethodInfoScope(t *testing.T) {
- for _, tt := range []struct {
- method string
- scope protoregistry.Scope
- }{
- {
- method: "/gitaly.RepositoryService/RepositoryExists",
- scope: protoregistry.ScopeRepository,
- },
- } {
- t.Run(tt.method, func(t *testing.T) {
- mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(tt.method)
- require.NoError(t, err)
-
- require.Exactly(t, tt.scope, mInfo.Scope)
- })
- }
-}