diff options
author | James Fargher <proglottis@gmail.com> | 2023-05-01 00:40:48 +0300 |
---|---|---|
committer | James Fargher <proglottis@gmail.com> | 2023-05-01 00:40:48 +0300 |
commit | 62e36a86900bdce323716c34dd877ad424309bc8 (patch) | |
tree | 618ee4a71661b2c09ea4ff72a44810ecafe09ef2 | |
parent | 512b7d09c44261f4018bd0cb89470315207f474c (diff) | |
parent | e12acb508bdfb85f8a15c3e9c714848d1d214993 (diff) |
Merge branch 'pks-protoregistry-method-info-simplifications' into 'master'
protoregistry: Refactor code to retrieve Protobuf fields tagged with extensions
See merge request https://gitlab.com/gitlab-org/gitaly/-/merge_requests/5683
Merged-by: James Fargher <proglottis@gmail.com>
Approved-by: James Fargher <proglottis@gmail.com>
Approved-by: Will Chandler <wchandler@gitlab.com>
Reviewed-by: James Fargher <proglottis@gmail.com>
Reviewed-by: Patrick Steinhardt <psteinhardt@gitlab.com>
Reviewed-by: Will Chandler <wchandler@gitlab.com>
Co-authored-by: Patrick Steinhardt <psteinhardt@gitlab.com>
-rw-r--r-- | internal/praefect/coordinator.go | 22 | ||||
-rw-r--r-- | internal/praefect/coordinator_test.go | 2 | ||||
-rw-r--r-- | internal/praefect/protoregistry/find_oid.go | 162 | ||||
-rw-r--r-- | internal/praefect/protoregistry/find_oid_test.go | 198 | ||||
-rw-r--r-- | internal/praefect/protoregistry/method_info.go | 309 | ||||
-rw-r--r-- | internal/praefect/protoregistry/method_info_test.go | 434 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry.go | 368 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry_test.go | 250 |
8 files changed, 869 insertions, 876 deletions
diff --git a/internal/praefect/coordinator.go b/internal/praefect/coordinator.go index 5b88eaf87..10773f5e0 100644 --- a/internal/praefect/coordinator.go +++ b/internal/praefect/coordinator.go @@ -373,9 +373,15 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall } var additionalRepoRelativePath string - if additionalRepo, ok, err := call.methodInfo.AdditionalRepo(call.msg); err != nil { + if additionalRepo, err := call.methodInfo.AdditionalRepo(call.msg); errors.Is(err, protoregistry.ErrTargetRepoMissing) { + // We can land here in two cases: either the message doesn't have an additional + // repository, or the repository wasn't set. The former case is obviously fine, but + // the latter case is fine, too, given that the additional repository may be an + // optional field for some RPC calls. The Gitaly-side RPC handlers should know to + // handle this case anyway, so we just leave the field unset in that case. + } else if err != nil { return nil, structerr.NewInvalidArgument("%w", err) - } else if ok { + } else { additionalRepoRelativePath = additionalRepo.GetRelativePath() } @@ -700,7 +706,7 @@ func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string, func (c *Coordinator) directStorageScopedMessage(ctx context.Context, mi protoregistry.MethodInfo, msg proto.Message) (*proxy.StreamParameters, error) { virtualStorage, err := mi.Storage(msg) if err != nil { - return nil, structerr.NewInvalidArgument("%w", err) + return nil, structerr.NewInvalidArgument("storage scoped: %w", err) } if virtualStorage == "" { @@ -794,12 +800,12 @@ func rewrittenRepositoryMessage(mi protoregistry.MethodInfo, m proto.Message, st targetRepo.StorageName = storage targetRepo.RelativePath = relativePath - additionalRepo, ok, err := mi.AdditionalRepo(m) - if err != nil { + if additionalRepo, err := mi.AdditionalRepo(m); errors.Is(err, protoregistry.ErrTargetRepoMissing) { + // Nothing to rewrite in case the additional repository either doesn't exist in the + // message or wasn't set by the caller. + } else if err != nil { return nil, structerr.NewInvalidArgument("%w", err) - } - - if ok { + } else { additionalRepo.StorageName = storage additionalRepo.RelativePath = additionalRelativePath } diff --git a/internal/praefect/coordinator_test.go b/internal/praefect/coordinator_test.go index c3e2328f6..82f698f51 100644 --- a/internal/praefect/coordinator_test.go +++ b/internal/praefect/coordinator_test.go @@ -1726,7 +1726,7 @@ func TestStreamDirectorStorageScopeError(t *testing.T) { result, ok := status.FromError(err) require.True(t, ok) require.Equal(t, codes.InvalidArgument, result.Code()) - require.Equal(t, "storage scoped: target storage is invalid", result.Message()) + require.Equal(t, "storage scoped: target storage field not found", result.Message()) }) t.Run("unknown storage provided", func(t *testing.T) { diff --git a/internal/praefect/protoregistry/find_oid.go b/internal/praefect/protoregistry/find_oid.go deleted file mode 100644 index 8ce91c0dd..000000000 --- a/internal/praefect/protoregistry/find_oid.go +++ /dev/null @@ -1,162 +0,0 @@ -package protoregistry - -import ( - "errors" - "fmt" - "reflect" - "regexp" - "strconv" - - "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" - "google.golang.org/protobuf/proto" -) - -const ( - protobufTag = "protobuf" - protobufOneOfTag = "protobuf_oneof" -) - -// ErrTargetRepoMissing indicates that the target repo is missing or not set -var ErrTargetRepoMissing = errors.New("empty Repository") - -func reflectFindRepoTarget(pbMsg proto.Message, targetOID []int) (*gitalypb.Repository, error) { - msgV, e := reflectFindOID(pbMsg, targetOID) - if e != nil { - if e == ErrProtoFieldEmpty { - return nil, ErrTargetRepoMissing - } - return nil, e - } - - targetRepo, ok := msgV.Interface().(*gitalypb.Repository) - if !ok { - return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface()) - } - - return targetRepo, nil -} - -func reflectFindStorage(pbMsg proto.Message, targetOID []int) (string, error) { - msgV, e := reflectFindOID(pbMsg, targetOID) - if e != nil { - return "", e - } - - targetRepo, ok := msgV.Interface().(string) - if !ok { - return "", fmt.Errorf("repo target OID %v points to non-string type %+v", targetOID, msgV.Interface()) - } - - return targetRepo, nil -} - -func reflectSetStorage(pbMsg proto.Message, targetOID []int, storage string) error { - msgV, err := reflectFindOID(pbMsg, targetOID) - if err != nil { - return err - } - - msgV.Set(reflect.ValueOf(storage)) - return nil -} - -// ErrProtoFieldEmpty indicates the protobuf field is empty -var ErrProtoFieldEmpty = errors.New("proto field is empty") - -// reflectFindOID finds the target repository by using the OID to -// navigate the struct tags -// Warning: this reflection filled function is full of forbidden dark elf magic -func reflectFindOID(pbMsg proto.Message, targetOID []int) (reflect.Value, error) { - msgV := reflect.ValueOf(pbMsg) - for _, fieldNo := range targetOID { - var err error - - msgV, err = findProtoField(msgV, fieldNo) - if err != nil { - return reflect.Value{}, fmt.Errorf( - "unable to descend OID %+v into message %s: %w", - targetOID, proto.MessageName(pbMsg), err, - ) - } - } - return msgV, nil -} - -// matches a tag string like "bytes,1,opt,name=repository,proto3" -var protobufTagRegex = regexp.MustCompile(`^(.*?),(\d+),(.*?),name=(.*?),proto3(\,oneof)?$`) - -const ( - protobufTagRegexGroups = 6 - protobufTagRegexFieldGroup = 2 -) - -func findProtoField(msgV reflect.Value, protoField int) (reflect.Value, error) { - if msgV.IsZero() { - return reflect.Value{}, ErrProtoFieldEmpty - } - - msgV = reflect.Indirect(msgV) - for i := 0; i < msgV.NumField(); i++ { - field := msgV.Type().Field(i) - - ok, err := tryNumberedField(field, protoField) - if err != nil { - return reflect.Value{}, err - } - if ok { - return msgV.FieldByName(field.Name), nil - } - - oneofField, ok := tryOneOfField(msgV, field, protoField) - if !ok { - continue - } - return oneofField, nil - } - - err := fmt.Errorf( - "unable to find protobuf field %d in message %s", - protoField, msgV.Type().Name(), - ) - return reflect.Value{}, err -} - -func tryNumberedField(field reflect.StructField, protoField int) (bool, error) { - tag := field.Tag.Get(protobufTag) - matches := protobufTagRegex.FindStringSubmatch(tag) - if len(matches) == protobufTagRegexGroups { - fieldStr := matches[protobufTagRegexFieldGroup] - if fieldStr == strconv.Itoa(protoField) { - return true, nil - } - } - - return false, nil -} - -func tryOneOfField(msgV reflect.Value, field reflect.StructField, protoField int) (reflect.Value, bool) { - if msgV.IsZero() { - return reflect.Value{}, false - } - - oneOfTag := field.Tag.Get(protobufOneOfTag) - if oneOfTag == "" { - return reflect.Value{}, false // empty tag means this is not a oneOf field - } - - // try all of the oneOf fields until a match is found - msgV = msgV.FieldByName(field.Name).Elem().Elem() - for i := 0; i < msgV.NumField(); i++ { - field = msgV.Type().Field(i) - - ok, err := tryNumberedField(field, protoField) - if err != nil { - return reflect.Value{}, false - } - if ok { - return msgV.FieldByName(field.Name), true - } - } - - return reflect.Value{}, false -} diff --git a/internal/praefect/protoregistry/find_oid_test.go b/internal/praefect/protoregistry/find_oid_test.go deleted file mode 100644 index 1aaccd5e9..000000000 --- a/internal/praefect/protoregistry/find_oid_test.go +++ /dev/null @@ -1,198 +0,0 @@ -package protoregistry_test - -import ( - "errors" - "fmt" - "testing" - - "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitaly/v15/internal/praefect/protoregistry" - "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" - "google.golang.org/protobuf/proto" -) - -func TestProtoRegistryTargetRepo(t *testing.T) { - testRepos := []*gitalypb.Repository{ - { - GitAlternateObjectDirectories: []string{"a", "b", "c"}, - GitObjectDirectory: "d", - GlProjectPath: "e", - GlRepository: "f", - RelativePath: "g", - StorageName: "h", - }, - { - GitAlternateObjectDirectories: []string{"1", "2", "3"}, - GitObjectDirectory: "4", - GlProjectPath: "5", - GlRepository: "6", - RelativePath: "7", - StorageName: "8", - }, - } - - testcases := []struct { - desc string - svc string - method string - pbMsg proto.Message - expectRepo *gitalypb.Repository - expectAdditionalRepo *gitalypb.Repository - expectErr error - }{ - { - desc: "valid request type single depth", - svc: "RepositoryService", - method: "OptimizeRepository", - pbMsg: &gitalypb.OptimizeRepositoryRequest{ - Repository: testRepos[0], - }, - expectRepo: testRepos[0], - }, - { - desc: "target nested in oneOf", - svc: "OperationService", - method: "UserCommitFiles", - pbMsg: &gitalypb.UserCommitFilesRequest{ - UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{ - Header: &gitalypb.UserCommitFilesRequestHeader{ - Repository: testRepos[1], - }, - }, - }, - expectRepo: testRepos[1], - }, - { - desc: "target nested, includes additional repository", - svc: "ObjectPoolService", - method: "FetchIntoObjectPool", - pbMsg: &gitalypb.FetchIntoObjectPoolRequest{ - Origin: testRepos[0], - ObjectPool: &gitalypb.ObjectPool{Repository: testRepos[1]}, - }, - expectRepo: testRepos[1], - expectAdditionalRepo: testRepos[0], - }, - { - desc: "target repo is nil", - svc: "RepositoryService", - method: "OptimizeRepository", - pbMsg: &gitalypb.OptimizeRepositoryRequest{Repository: nil}, - expectErr: protoregistry.ErrTargetRepoMissing, - }, - } - - for _, tc := range testcases { - desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc) - t.Run(desc, func(t *testing.T) { - info, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) - require.NoError(t, err) - - actualTarget, actualErr := info.TargetRepo(tc.pbMsg) - require.Equal(t, tc.expectErr, actualErr) - - // not only do we want the value to be the same, but we actually want the - // exact same instance to be returned - if tc.expectRepo != actualTarget { - t.Fatal("pointers do not match") - } - - if tc.expectAdditionalRepo != nil { - additionalRepo, ok, err := info.AdditionalRepo(tc.pbMsg) - require.True(t, ok) - require.NoError(t, err) - require.Equal(t, tc.expectAdditionalRepo, additionalRepo) - } - }) - } -} - -func TestProtoRegistryStorage(t *testing.T) { - testcases := []struct { - desc string - svc string - method string - pbMsg proto.Message - expectStorage string - expectErr error - }{ - { - desc: "valid request type single depth", - svc: "NamespaceService", - method: "AddNamespace", - pbMsg: &gitalypb.AddNamespaceRequest{ - StorageName: "some_storage", - }, - expectStorage: "some_storage", - }, - { - desc: "incorrect request type", - svc: "RepositoryService", - method: "OptimizeRepository", - pbMsg: &gitalypb.OptimizeRepositoryResponse{}, - expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"), - }, - } - - for _, tc := range testcases { - desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc) - t.Run(desc, func(t *testing.T) { - info, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) - require.NoError(t, err) - - actualStorage, actualErr := info.Storage(tc.pbMsg) - require.Equal(t, tc.expectErr, actualErr) - - // not only do we want the value to be the same, but we actually want the - // exact same instance to be returned - if tc.expectStorage != actualStorage { - t.Fatal("pointers do not match") - } - }) - } -} - -func TestMethodInfo_SetStorage(t *testing.T) { - testCases := []struct { - desc string - service string - method string - pbMsg proto.Message - storage string - expectErr error - }{ - { - desc: "valid request type", - service: "NamespaceService", - method: "AddNamespace", - pbMsg: &gitalypb.AddNamespaceRequest{ - StorageName: "old_storage", - }, - storage: "new_storage", - }, - { - desc: "incorrect request type", - service: "RepositoryService", - method: "OptimizeRepository", - pbMsg: &gitalypb.OptimizeRepositoryResponse{}, - expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"), - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - info, err := protoregistry.GitalyProtoPreregistered.LookupMethod("/gitaly." + tc.service + "/" + tc.method) - require.NoError(t, err) - - err = info.SetStorage(tc.pbMsg, tc.storage) - if tc.expectErr == nil { - require.NoError(t, err) - changed, err := info.Storage(tc.pbMsg) - require.NoError(t, err) - require.Equal(t, tc.storage, changed) - } else { - require.Equal(t, tc.expectErr, err) - } - }) - } -} diff --git a/internal/praefect/protoregistry/method_info.go b/internal/praefect/protoregistry/method_info.go new file mode 100644 index 000000000..91dab9261 --- /dev/null +++ b/internal/praefect/protoregistry/method_info.go @@ -0,0 +1,309 @@ +package protoregistry + +import ( + "errors" + "fmt" + "strings" + + "gitlab.com/gitlab-org/gitaly/v15/internal/protoutil" + "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protopath" + "google.golang.org/protobuf/reflect/protorange" + "google.golang.org/protobuf/reflect/protoreflect" + protoreg "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/descriptorpb" +) + +// OpType represents the operation type for a RPC method +type OpType int + +const ( + // OpUnknown = unknown operation type + OpUnknown OpType = iota + // OpAccessor = accessor operation type (ready only) + OpAccessor + // OpMutator = mutator operation type (modifies a repository) + OpMutator + // OpMaintenance is an operation which performs maintenance-tasks on the repository. It + // shouldn't ever result in a user-visible change in behaviour, except that it may repair + // corrupt data. + OpMaintenance +) + +// Scope represents the intended scope of an RPC method +type Scope int + +const ( + // ScopeUnknown is the default scope until determined otherwise + ScopeUnknown Scope = iota + // ScopeRepository indicates an RPC's scope is limited to a repository + ScopeRepository + // ScopeStorage indicates an RPC is scoped to an entire storage location + ScopeStorage +) + +func (s Scope) String() string { + switch s { + case ScopeStorage: + return "storage" + case ScopeRepository: + return "repository" + default: + return fmt.Sprintf("N/A: %d", s) + } +} + +var protoScope = map[gitalypb.OperationMsg_Scope]Scope{ + gitalypb.OperationMsg_REPOSITORY: ScopeRepository, + gitalypb.OperationMsg_STORAGE: ScopeStorage, +} + +// MethodInfo contains metadata about the RPC method. Refer to documentation +// for message type "OperationMsg" shared.proto in ./proto for +// more documentation. +type MethodInfo struct { + Operation OpType + Scope Scope + requestName string // protobuf message name for input type + requestFactory protoFactory + fullMethodName string +} + +// TargetRepo returns the target repository for a protobuf message if it exists +func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error) { + return mi.getRepo(msg, gitalypb.E_TargetRepository) +} + +// AdditionalRepo returns the additional repository for a Protobuf message that needs a storage +// rewritten if it exists. +func (mi MethodInfo) AdditionalRepo(msg proto.Message) (*gitalypb.Repository, error) { + return mi.getRepo(msg, gitalypb.E_AdditionalRepository) +} + +//nolint:revive // This is unintentionally missing documentation. +func (mi MethodInfo) FullMethodName() string { + return mi.fullMethodName +} + +// ErrTargetRepoMissing indicates that the target repo is missing or not set +var ErrTargetRepoMissing = errors.New("empty Repository") + +func (mi MethodInfo) getRepo(msg proto.Message, extensionType protoreflect.ExtensionType) (*gitalypb.Repository, error) { + if mi.requestName != string(proto.MessageName(msg)) { + return nil, fmt.Errorf( + "proto message %s does not match expected RPC request message %s", + proto.MessageName(msg), mi.requestName, + ) + } + + field, err := findFieldByExtension(msg, extensionType) + if err != nil { + if errors.Is(err, errFieldNotFound) { + return nil, ErrTargetRepoMissing + } + + return nil, err + } + + if field.desc.Kind() != protoreflect.MessageKind { + return nil, fmt.Errorf("expected repository message, got %s", field.desc.Kind().String()) + } + + switch fieldMsg := field.value.Message().Interface().(type) { + case *gitalypb.Repository: + return fieldMsg, nil + case *gitalypb.ObjectPool: + repo := fieldMsg.GetRepository() + if repo == nil { + return nil, ErrTargetRepoMissing + } + + return repo, nil + default: + return nil, fmt.Errorf("repository message has unexpected type %T", fieldMsg) + } +} + +// Storage returns the storage name for a protobuf message if it exists +func (mi MethodInfo) Storage(msg proto.Message) (string, error) { + field, err := mi.getStorageField(msg) + if err != nil { + return "", err + } + + return field.value.String(), nil +} + +// SetStorage sets the storage name for a protobuf message +func (mi MethodInfo) SetStorage(msg proto.Message, storage string) error { + field, err := mi.getStorageField(msg) + if err != nil { + return err + } + + msg.ProtoReflect().Set(field.desc, protoreflect.ValueOfString(storage)) + + return nil +} + +func (mi MethodInfo) getStorageField(msg proto.Message) (valueField, error) { + if mi.requestName != string(proto.MessageName(msg)) { + return valueField{}, fmt.Errorf( + "proto message %s does not match expected RPC request message %s", + proto.MessageName(msg), mi.requestName, + ) + } + + field, err := findFieldByExtension(msg, gitalypb.E_Storage) + if err != nil { + if errors.Is(err, errFieldNotFound) { + return valueField{}, fmt.Errorf("target storage field not found") + } + return valueField{}, err + } + + if field.desc.Kind() != protoreflect.StringKind { + return valueField{}, fmt.Errorf("expected string, got %s", field.desc.Kind().String()) + } + + return field, nil +} + +// UnmarshalRequestProto will unmarshal the bytes into the method's request +// message type +func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) { + return mi.requestFactory(b) +} + +type protoFactory func([]byte) (proto.Message, error) + +func methodReqFactory(method *descriptorpb.MethodDescriptorProto) (protoFactory, error) { + // for some reason, the descriptor prepends a dot not expected in Go + inputTypeName := strings.TrimPrefix(method.GetInputType(), ".") + + inputType, err := protoreg.GlobalTypes.FindMessageByName(protoreflect.FullName(inputTypeName)) + if err != nil { + return nil, fmt.Errorf("no message type found for %w", err) + } + + f := func(buf []byte) (proto.Message, error) { + pb := inputType.New().Interface() + if err := proto.Unmarshal(buf, pb); err != nil { + return nil, err + } + + return pb, nil + } + + return f, nil +} + +func parseMethodInfo( + p *descriptorpb.FileDescriptorProto, + methodDesc *descriptorpb.MethodDescriptorProto, + fullMethodName string, +) (MethodInfo, error) { + opMsg, err := protoutil.GetOpExtension(methodDesc) + if err != nil { + return MethodInfo{}, err + } + + var opCode OpType + + switch opMsg.GetOp() { + case gitalypb.OperationMsg_ACCESSOR: + opCode = OpAccessor + case gitalypb.OperationMsg_MUTATOR: + opCode = OpMutator + case gitalypb.OperationMsg_MAINTENANCE: + opCode = OpMaintenance + default: + opCode = OpUnknown + } + + // for some reason, the protobuf descriptor contains an extra dot in front + // of the request name that the generated code does not. This trimming keeps + // the two copies consistent for comparisons. + requestName := strings.TrimLeft(methodDesc.GetInputType(), ".") + + reqFactory, err := methodReqFactory(methodDesc) + if err != nil { + return MethodInfo{}, err + } + + scope, ok := protoScope[opMsg.GetScopeLevel()] + if !ok { + return MethodInfo{}, fmt.Errorf("encountered unknown method scope %d", opMsg.GetScopeLevel()) + } + + mi := MethodInfo{ + Operation: opCode, + Scope: scope, + requestName: requestName, + requestFactory: reqFactory, + fullMethodName: fullMethodName, + } + + return mi, nil +} + +type valueField struct { + desc protoreflect.FieldDescriptor + value protoreflect.Value +} + +// findFieldsByExtension will search through all populated fields and returns all of those which +// have the given extension type set. +func findFieldsByExtension(msg proto.Message, extensionType protoreflect.ExtensionType) ([]valueField, error) { + var valueFields []valueField + + if err := (protorange.Options{Stable: true}).Range(msg.ProtoReflect(), func(values protopath.Values) error { + value := values.Index(-1) + + fieldDescriptor := value.Step.FieldDescriptor() + if fieldDescriptor == nil { + return nil + } + + opts := fieldDescriptor.Options().(*descriptorpb.FieldOptions) + if !proto.HasExtension(opts, extensionType) { + return nil + } + + valueFields = append(valueFields, valueField{ + desc: fieldDescriptor, + value: value.Value, + }) + + return nil + }, nil); err != nil { + return nil, fmt.Errorf("ranging over message: %w", err) + } + + return valueFields, nil +} + +var ( + errFieldNotFound = errors.New("field not found") + errFieldAmbiguous = errors.New("field is ambiguous") +) + +// findFieldByExtension is a wrapper around findFieldsByExtension that returns a single field +// descriptor, only. Returns a errFieldNotFound error in case the field wasn't found, and a +// errFieldAmbiguous error in case there are multiple fields with the same extension. +func findFieldByExtension(msg proto.Message, extensionType protoreflect.ExtensionType) (valueField, error) { + fields, err := findFieldsByExtension(msg, extensionType) + if err != nil { + return valueField{}, err + } + + switch len(fields) { + case 1: + return fields[0], nil + case 0: + return valueField{}, errFieldNotFound + default: + return valueField{}, errFieldAmbiguous + } +} diff --git a/internal/praefect/protoregistry/method_info_test.go b/internal/praefect/protoregistry/method_info_test.go new file mode 100644 index 000000000..1532a4d99 --- /dev/null +++ b/internal/praefect/protoregistry/method_info_test.go @@ -0,0 +1,434 @@ +package protoregistry + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper" + "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func TestMethodInfo_getRepo(t *testing.T) { + t.Parallel() + + testRepos := []*gitalypb.Repository{ + { + GitAlternateObjectDirectories: []string{"a", "b", "c"}, + GitObjectDirectory: "d", + GlProjectPath: "e", + GlRepository: "f", + RelativePath: "g", + StorageName: "h", + }, + { + GitAlternateObjectDirectories: []string{"1", "2", "3"}, + GitObjectDirectory: "4", + GlProjectPath: "5", + GlRepository: "6", + RelativePath: "7", + StorageName: "8", + }, + } + + testcases := []struct { + desc string + svc string + method string + pbMsg proto.Message + expectRepo *gitalypb.Repository + expectErr error + expectAdditionalRepo *gitalypb.Repository + expectAdditionalErr error + }{ + { + desc: "valid request type single depth", + svc: "RepositoryService", + method: "OptimizeRepository", + pbMsg: &gitalypb.OptimizeRepositoryRequest{ + Repository: testRepos[0], + }, + expectRepo: testRepos[0], + expectAdditionalErr: ErrTargetRepoMissing, + }, + { + desc: "unset oneof", + svc: "OperationService", + method: "UserCommitFiles", + pbMsg: &gitalypb.UserCommitFilesRequest{}, + expectErr: ErrTargetRepoMissing, + expectAdditionalErr: ErrTargetRepoMissing, + }, + { + desc: "unset value in oneof", + svc: "OperationService", + method: "UserCommitFiles", + pbMsg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{}, + }, + expectErr: ErrTargetRepoMissing, + expectAdditionalErr: ErrTargetRepoMissing, + }, + { + desc: "unset repository in oneof", + svc: "OperationService", + method: "UserCommitFiles", + pbMsg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{ + Header: &gitalypb.UserCommitFilesRequestHeader{}, + }, + }, + expectErr: ErrTargetRepoMissing, + expectAdditionalErr: ErrTargetRepoMissing, + }, + { + desc: "target nested in oneOf", + svc: "OperationService", + method: "UserCommitFiles", + pbMsg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{ + Header: &gitalypb.UserCommitFilesRequestHeader{ + Repository: testRepos[1], + }, + }, + }, + expectRepo: testRepos[1], + expectAdditionalErr: ErrTargetRepoMissing, + }, + { + desc: "target nested, includes additional repository", + svc: "ObjectPoolService", + method: "FetchIntoObjectPool", + pbMsg: &gitalypb.FetchIntoObjectPoolRequest{ + Origin: testRepos[0], + ObjectPool: &gitalypb.ObjectPool{Repository: testRepos[1]}, + }, + expectRepo: testRepos[1], + expectAdditionalRepo: testRepos[0], + }, + { + desc: "target repo is nil", + svc: "RepositoryService", + method: "OptimizeRepository", + pbMsg: &gitalypb.OptimizeRepositoryRequest{Repository: nil}, + expectErr: ErrTargetRepoMissing, + expectAdditionalErr: ErrTargetRepoMissing, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + info, err := GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) + require.NoError(t, err) + + t.Run("TargetRepo", func(t *testing.T) { + repo, err := info.TargetRepo(tc.pbMsg) + require.Equal(t, tc.expectErr, err) + require.Same(t, tc.expectRepo, repo) + }) + + t.Run("AdditionalRepo", func(t *testing.T) { + additionalRepo, err := info.AdditionalRepo(tc.pbMsg) + require.Equal(t, tc.expectAdditionalErr, err) + require.Same(t, tc.expectAdditionalRepo, additionalRepo) + }) + }) + } +} + +func TestMethodInfo_Storage(t *testing.T) { + t.Parallel() + + testcases := []struct { + desc string + svc string + method string + pbMsg proto.Message + expectStorage string + expectErr error + }{ + { + desc: "valid request type single depth", + svc: "NamespaceService", + method: "AddNamespace", + pbMsg: &gitalypb.AddNamespaceRequest{ + StorageName: "some_storage", + }, + expectStorage: "some_storage", + }, + { + desc: "incorrect request type", + svc: "RepositoryService", + method: "OptimizeRepository", + pbMsg: &gitalypb.OptimizeRepositoryResponse{}, + expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"), + }, + } + + for _, tc := range testcases { + desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc) + t.Run(desc, func(t *testing.T) { + info, err := GitalyProtoPreregistered.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) + require.NoError(t, err) + + actualStorage, actualErr := info.Storage(tc.pbMsg) + require.Equal(t, tc.expectErr, actualErr) + + // not only do we want the value to be the same, but we actually want the + // exact same instance to be returned + if tc.expectStorage != actualStorage { + t.Fatal("pointers do not match") + } + }) + } +} + +func TestMethodInfo_SetStorage(t *testing.T) { + t.Parallel() + + testCases := []struct { + desc string + service string + method string + pbMsg proto.Message + storage string + expectErr error + }{ + { + desc: "valid request type", + service: "NamespaceService", + method: "AddNamespace", + pbMsg: &gitalypb.AddNamespaceRequest{ + StorageName: "old_storage", + }, + storage: "new_storage", + }, + { + desc: "incorrect request type", + service: "RepositoryService", + method: "OptimizeRepository", + pbMsg: &gitalypb.OptimizeRepositoryResponse{}, + expectErr: errors.New("proto message gitaly.OptimizeRepositoryResponse does not match expected RPC request message gitaly.OptimizeRepositoryRequest"), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + info, err := GitalyProtoPreregistered.LookupMethod("/gitaly." + tc.service + "/" + tc.method) + require.NoError(t, err) + + err = info.SetStorage(tc.pbMsg, tc.storage) + if tc.expectErr == nil { + require.NoError(t, err) + changed, err := info.Storage(tc.pbMsg) + require.NoError(t, err) + require.Equal(t, tc.storage, changed) + } else { + require.Equal(t, tc.expectErr, err) + } + }) + } +} + +func TestMethodInfo_RequestFactory(t *testing.T) { + t.Parallel() + + mInfo, err := GitalyProtoPreregistered.LookupMethod("/gitaly.RepositoryService/RepositoryExists") + require.NoError(t, err) + + pb, err := mInfo.UnmarshalRequestProto([]byte{}) + require.NoError(t, err) + + testhelper.ProtoEqual(t, &gitalypb.RepositoryExistsRequest{}, pb) +} + +func TestMethodInfoScope(t *testing.T) { + for _, tt := range []struct { + method string + scope Scope + }{ + { + method: "/gitaly.RepositoryService/RepositoryExists", + scope: ScopeRepository, + }, + } { + t.Run(tt.method, func(t *testing.T) { + mInfo, err := GitalyProtoPreregistered.LookupMethod(tt.method) + require.NoError(t, err) + + require.Exactly(t, tt.scope, mInfo.Scope) + }) + } +} + +func TestFindFieldsByExtension(t *testing.T) { + t.Parallel() + + repo := &gitalypb.Repository{StorageName: "storage", RelativePath: "relative-path"} + + for _, tc := range []struct { + desc string + msg proto.Message + extension protoreflect.ExtensionType + expectedFields []string + }{ + { + desc: "unset field does not match", + msg: &gitalypb.OptimizeRepositoryRequest{ + Repository: nil, + }, + extension: gitalypb.E_TargetRepository, + }, + { + desc: "matching field", + msg: &gitalypb.OptimizeRepositoryRequest{ + Repository: repo, + }, + extension: gitalypb.E_TargetRepository, + expectedFields: []string{ + "gitaly.OptimizeRepositoryRequest.repository", + }, + }, + { + desc: "no matching field", + msg: &gitalypb.OptimizeRepositoryRequest{ + Repository: repo, + }, + extension: gitalypb.E_AdditionalRepository, + }, + { + desc: "multiple fields with distinct extensions", + msg: &gitalypb.FetchIntoObjectPoolRequest{ + Origin: repo, + ObjectPool: &gitalypb.ObjectPool{Repository: repo}, + }, + extension: gitalypb.E_AdditionalRepository, + expectedFields: []string{ + "gitaly.FetchIntoObjectPoolRequest.origin", + }, + }, + { + desc: "matching field in unset oneOf", + msg: &gitalypb.UserCommitFilesRequest{}, + extension: gitalypb.E_TargetRepository, + }, + { + desc: "matching field in empty oneOf", + msg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{}, + }, + extension: gitalypb.E_TargetRepository, + }, + { + desc: "matching field with unset oneOf repository", + msg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{ + Header: &gitalypb.UserCommitFilesRequestHeader{}, + }, + }, + extension: gitalypb.E_TargetRepository, + }, + { + desc: "matching field in oneOf", + msg: &gitalypb.UserCommitFilesRequest{ + UserCommitFilesRequestPayload: &gitalypb.UserCommitFilesRequest_Header{ + Header: &gitalypb.UserCommitFilesRequestHeader{ + Repository: repo, + }, + }, + }, + extension: gitalypb.E_TargetRepository, + expectedFields: []string{ + "gitaly.UserCommitFilesRequestHeader.repository", + }, + }, + } { + tc := tc + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + t.Run("findFieldsByExtension", func(t *testing.T) { + fields, err := findFieldsByExtension(tc.msg, tc.extension) + require.NoError(t, err) + + var fieldNames []string + for _, field := range fields { + fieldNames = append(fieldNames, string(field.desc.FullName())) + } + require.Equal(t, tc.expectedFields, fieldNames) + }) + + t.Run("findFieldByExtension", func(t *testing.T) { + field, err := findFieldByExtension(tc.msg, tc.extension) + + switch len(tc.expectedFields) { + case 0: + require.Equal(t, valueField{}, field) + require.Equal(t, errFieldNotFound, err) + case 1: + require.NoError(t, err) + require.Equal(t, tc.expectedFields[0], string(field.desc.FullName())) + default: + require.Equal(t, valueField{}, field) + require.Equal(t, errFieldAmbiguous, err) + } + }) + }) + } +} + +func BenchmarkMethodInfo(b *testing.B) { + for _, bc := range []struct { + desc string + method string + request proto.Message + expectedErr error + }{ + { + desc: "unset target repository", + method: "/gitaly.RepositoryService/OptimizeRepository", + request: &gitalypb.OptimizeRepositoryRequest{ + Repository: nil, + }, + expectedErr: ErrTargetRepoMissing, + }, + { + desc: "target repository", + method: "/gitaly.RepositoryService/OptimizeRepository", + request: &gitalypb.OptimizeRepositoryRequest{ + Repository: &gitalypb.Repository{ + StorageName: "something", + RelativePath: "something", + }, + }, + }, + { + desc: "target object pool", + method: "/gitaly.ObjectPoolService/FetchIntoObjectPool", + request: &gitalypb.FetchIntoObjectPoolRequest{ + ObjectPool: &gitalypb.ObjectPool{ + Repository: &gitalypb.Repository{ + StorageName: "something", + RelativePath: "something", + }, + }, + }, + }, + } { + b.Run(bc.desc, func(b *testing.B) { + mi, err := GitalyProtoPreregistered.LookupMethod(bc.method) + require.NoError(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := mi.TargetRepo(bc.request) + require.Equal(b, bc.expectedErr, err) + } + }) + } +} diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go index 0399bf885..7fed1831f 100644 --- a/internal/praefect/protoregistry/protoregistry.go +++ b/internal/praefect/protoregistry/protoregistry.go @@ -2,13 +2,10 @@ package protoregistry import ( "fmt" - "strings" "gitlab.com/gitlab-org/gitaly/v15/internal/protoutil" "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protodesc" - "google.golang.org/protobuf/reflect/protoreflect" protoreg "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" ) @@ -25,138 +22,6 @@ func init() { } } -// OpType represents the operation type for a RPC method -type OpType int - -const ( - // OpUnknown = unknown operation type - OpUnknown OpType = iota - // OpAccessor = accessor operation type (ready only) - OpAccessor - // OpMutator = mutator operation type (modifies a repository) - OpMutator - // OpMaintenance is an operation which performs maintenance-tasks on the repository. It - // shouldn't ever result in a user-visible change in behaviour, except that it may repair - // corrupt data. - OpMaintenance -) - -// Scope represents the intended scope of an RPC method -type Scope int - -const ( - // ScopeUnknown is the default scope until determined otherwise - ScopeUnknown Scope = iota - // ScopeRepository indicates an RPC's scope is limited to a repository - ScopeRepository - // ScopeStorage indicates an RPC is scoped to an entire storage location - ScopeStorage -) - -func (s Scope) String() string { - switch s { - case ScopeStorage: - return "storage" - case ScopeRepository: - return "repository" - default: - return fmt.Sprintf("N/A: %d", s) - } -} - -var protoScope = map[gitalypb.OperationMsg_Scope]Scope{ - gitalypb.OperationMsg_REPOSITORY: ScopeRepository, - gitalypb.OperationMsg_STORAGE: ScopeStorage, -} - -// MethodInfo contains metadata about the RPC method. Refer to documentation -// for message type "OperationMsg" shared.proto in ./proto for -// more documentation. -type MethodInfo struct { - Operation OpType - Scope Scope - targetRepo []int - additionalRepo []int - requestName string // protobuf message name for input type - requestFactory protoFactory - storage []int - fullMethodName string -} - -// TargetRepo returns the target repository for a protobuf message if it exists -func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error) { - return mi.getRepo(msg, mi.targetRepo) -} - -// AdditionalRepo returns the additional repository for a protobuf message that needs a storage rewritten -// if it exists -func (mi MethodInfo) AdditionalRepo(msg proto.Message) (*gitalypb.Repository, bool, error) { - if mi.additionalRepo == nil { - return nil, false, nil - } - - repo, err := mi.getRepo(msg, mi.additionalRepo) - - return repo, true, err -} - -//nolint:revive // This is unintentionally missing documentation. -func (mi MethodInfo) FullMethodName() string { - return mi.fullMethodName -} - -func (mi MethodInfo) getRepo(msg proto.Message, targetOid []int) (*gitalypb.Repository, error) { - if mi.requestName != string(proto.MessageName(msg)) { - return nil, fmt.Errorf( - "proto message %s does not match expected RPC request message %s", - proto.MessageName(msg), mi.requestName, - ) - } - - repo, err := reflectFindRepoTarget(msg, targetOid) - switch { - case err != nil: - return nil, err - case repo == nil: - // it is possible for the target repo to not be set (especially in our unit - // tests designed to fail and this should return an error to prevent nil - // pointer dereferencing - return nil, ErrTargetRepoMissing - default: - return repo, nil - } -} - -// Storage returns the storage name for a protobuf message if it exists -func (mi MethodInfo) Storage(msg proto.Message) (string, error) { - if mi.requestName != string(proto.MessageName(msg)) { - return "", fmt.Errorf( - "proto message %s does not match expected RPC request message %s", - proto.MessageName(msg), mi.requestName, - ) - } - - return reflectFindStorage(msg, mi.storage) -} - -// SetStorage sets the storage name for a protobuf message -func (mi MethodInfo) SetStorage(msg proto.Message, storage string) error { - if mi.requestName != string(proto.MessageName(msg)) { - return fmt.Errorf( - "proto message %s does not match expected RPC request message %s", - proto.MessageName(msg), mi.requestName, - ) - } - - return reflectSetStorage(msg, mi.storage, storage) -} - -// UnmarshalRequestProto will unmarshal the bytes into the method's request -// message type -func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) { - return mi.requestFactory(b) -} - // Registry contains info about RPC methods type Registry struct { protos map[string]MethodInfo @@ -214,239 +79,6 @@ func NewFromPaths(paths ...string) (*Registry, error) { return New(fds...) } -type protoFactory func([]byte) (proto.Message, error) - -func methodReqFactory(method *descriptorpb.MethodDescriptorProto) (protoFactory, error) { - // for some reason, the descriptor prepends a dot not expected in Go - inputTypeName := strings.TrimPrefix(method.GetInputType(), ".") - - inputType, err := protoreg.GlobalTypes.FindMessageByName(protoreflect.FullName(inputTypeName)) - if err != nil { - return nil, fmt.Errorf("no message type found for %w", err) - } - - f := func(buf []byte) (proto.Message, error) { - pb := inputType.New().Interface() - if err := proto.Unmarshal(buf, pb); err != nil { - return nil, err - } - - return pb, nil - } - - return f, nil -} - -func parseMethodInfo( - p *descriptorpb.FileDescriptorProto, - methodDesc *descriptorpb.MethodDescriptorProto, - fullMethodName string, -) (MethodInfo, error) { - opMsg, err := protoutil.GetOpExtension(methodDesc) - if err != nil { - return MethodInfo{}, err - } - - var opCode OpType - - switch opMsg.GetOp() { - case gitalypb.OperationMsg_ACCESSOR: - opCode = OpAccessor - case gitalypb.OperationMsg_MUTATOR: - opCode = OpMutator - case gitalypb.OperationMsg_MAINTENANCE: - opCode = OpMaintenance - default: - opCode = OpUnknown - } - - // for some reason, the protobuf descriptor contains an extra dot in front - // of the request name that the generated code does not. This trimming keeps - // the two copies consistent for comparisons. - requestName := strings.TrimLeft(methodDesc.GetInputType(), ".") - - reqFactory, err := methodReqFactory(methodDesc) - if err != nil { - return MethodInfo{}, err - } - - scope, ok := protoScope[opMsg.GetScopeLevel()] - if !ok { - return MethodInfo{}, fmt.Errorf("encountered unknown method scope %d", opMsg.GetScopeLevel()) - } - - mi := MethodInfo{ - Operation: opCode, - Scope: scope, - requestName: requestName, - requestFactory: reqFactory, - fullMethodName: fullMethodName, - } - - topLevelMsgs, err := getTopLevelMsgs(p) - if err != nil { - return MethodInfo{}, err - } - - typeName, err := lastName(methodDesc.GetInputType()) - if err != nil { - return MethodInfo{}, err - } - - if scope == ScopeRepository { - m := matcher{ - match: protoutil.GetTargetRepositoryExtension, - subMatch: protoutil.GetRepositoryExtension, - expectedType: ".gitaly.Repository", - topLevelMsgs: topLevelMsgs, - } - - targetRepo, err := m.findField(topLevelMsgs[typeName]) - if err != nil { - return MethodInfo{}, err - } - if targetRepo == nil { - return MethodInfo{}, fmt.Errorf("unable to find target repository for method: %s", requestName) - } - mi.targetRepo = targetRepo - - m.match = protoutil.GetAdditionalRepositoryExtension - additionalRepo, err := m.findField(topLevelMsgs[typeName]) - if err != nil { - return MethodInfo{}, err - } - mi.additionalRepo = additionalRepo - } else if scope == ScopeStorage { - m := matcher{ - match: protoutil.GetStorageExtension, - topLevelMsgs: topLevelMsgs, - } - storage, err := m.findField(topLevelMsgs[typeName]) - if err != nil { - return MethodInfo{}, err - } - if storage == nil { - return MethodInfo{}, fmt.Errorf("unable to find storage for method: %s", requestName) - } - mi.storage = storage - } - - return mi, nil -} - -func getFileTypes(filename string) ([]*descriptorpb.DescriptorProto, error) { - fd, err := protoreg.GlobalFiles.FindFileByPath(filename) - if err != nil { - return nil, err - } - sharedFD := protodesc.ToFileDescriptorProto(fd) - - types := sharedFD.GetMessageType() - - for _, dep := range sharedFD.Dependency { - depTypes, err := getFileTypes(dep) - if err != nil { - return nil, err - } - types = append(types, depTypes...) - } - - return types, nil -} - -func getTopLevelMsgs(p *descriptorpb.FileDescriptorProto) (map[string]*descriptorpb.DescriptorProto, error) { - topLevelMsgs := map[string]*descriptorpb.DescriptorProto{} - types, err := getFileTypes(p.GetName()) - if err != nil { - return nil, err - } - for _, msg := range types { - topLevelMsgs[msg.GetName()] = msg - } - return topLevelMsgs, nil -} - -// Matcher helps find field matching credentials. At first match method is used to check fields -// recursively. Then if field matches but type don't match expectedType subMatch method is used -// from this point. This matcher assumes that only one field in the message matches the credentials. -type matcher struct { - match func(*descriptorpb.FieldDescriptorProto) (bool, error) - subMatch func(*descriptorpb.FieldDescriptorProto) (bool, error) - expectedType string // fully qualified name of expected type e.g. ".gitaly.Repository" - topLevelMsgs map[string]*descriptorpb.DescriptorProto // Map of all top level messages in given file and it dependencies. Result of getTopLevelMsgs should be used. -} - -func (m matcher) findField(t *descriptorpb.DescriptorProto) ([]int, error) { - for _, f := range t.GetField() { - match, err := m.match(f) - if err != nil { - return nil, err - } - if match { - if f.GetTypeName() == m.expectedType { - return []int{int(f.GetNumber())}, nil - } else if m.subMatch != nil { - m.match = m.subMatch - m.subMatch = nil - } else { - return nil, fmt.Errorf("found wrong type, expected: %s, got: %s", m.expectedType, f.GetTypeName()) - } - } - - childMsg, err := findChildMsg(m.topLevelMsgs, t, f) - if err != nil { - return nil, err - } - - if childMsg != nil { - nestedField, err := m.findField(childMsg) - if err != nil { - return nil, err - } - if nestedField != nil { - return append([]int{int(f.GetNumber())}, nestedField...), nil - } - } - } - return nil, nil -} - -func findChildMsg(topLevelMsgs map[string]*descriptorpb.DescriptorProto, t *descriptorpb.DescriptorProto, f *descriptorpb.FieldDescriptorProto) (*descriptorpb.DescriptorProto, error) { - var childType *descriptorpb.DescriptorProto - const msgPrimitive = "TYPE_MESSAGE" - if primitive := f.GetType().String(); primitive != msgPrimitive { - return nil, nil - } - - msgName, err := lastName(f.GetTypeName()) - if err != nil { - return nil, err - } - - for _, nestedType := range t.GetNestedType() { - if msgName == nestedType.GetName() { - return nestedType, nil - } - } - - if childType = topLevelMsgs[msgName]; childType != nil { - return childType, nil - } - - return nil, fmt.Errorf("could not find message type %q", msgName) -} - -func lastName(inputType string) (string, error) { - tokens := strings.Split(inputType, ".") - - msgName := tokens[len(tokens)-1] - if msgName == "" { - return "", fmt.Errorf("unable to parse method input type: %s", inputType) - } - - return msgName, nil -} - // LookupMethod looks up an MethodInfo by service and method name func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) { methodInfo, ok := pr.protos[fullMethodName] diff --git a/internal/praefect/protoregistry/protoregistry_test.go b/internal/praefect/protoregistry/protoregistry_test.go index d6aabf834..9e448c443 100644 --- a/internal/praefect/protoregistry/protoregistry_test.go +++ b/internal/praefect/protoregistry/protoregistry_test.go @@ -1,143 +1,142 @@ -package protoregistry_test +package protoregistry import ( "fmt" "testing" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitaly/v15/internal/praefect/protoregistry" - "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper" - "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" ) func TestNewProtoRegistry(t *testing.T) { - expectedResults := map[string]map[string]protoregistry.OpType{ + t.Parallel() + + expectedResults := map[string]map[string]OpType{ "BlobService": { - "GetBlob": protoregistry.OpAccessor, - "GetBlobs": protoregistry.OpAccessor, - "GetLFSPointers": protoregistry.OpAccessor, + "GetBlob": OpAccessor, + "GetBlobs": OpAccessor, + "GetLFSPointers": OpAccessor, }, "CleanupService": { - "ApplyBfgObjectMapStream": protoregistry.OpMutator, + "ApplyBfgObjectMapStream": OpMutator, }, "CommitService": { - "CommitIsAncestor": protoregistry.OpAccessor, - "CommitLanguages": protoregistry.OpAccessor, - "CommitStats": protoregistry.OpAccessor, - "CommitsByMessage": protoregistry.OpAccessor, - "CountCommits": protoregistry.OpAccessor, - "CountDivergingCommits": protoregistry.OpAccessor, - "FilterShasWithSignatures": protoregistry.OpAccessor, - "FindAllCommits": protoregistry.OpAccessor, - "FindCommit": protoregistry.OpAccessor, - "FindCommits": protoregistry.OpAccessor, - "GetTreeEntries": protoregistry.OpAccessor, - "LastCommitForPath": protoregistry.OpAccessor, - "ListCommitsByOid": protoregistry.OpAccessor, - "ListFiles": protoregistry.OpAccessor, - "ListLastCommitsForTree": protoregistry.OpAccessor, - "RawBlame": protoregistry.OpAccessor, - "TreeEntry": protoregistry.OpAccessor, + "CommitIsAncestor": OpAccessor, + "CommitLanguages": OpAccessor, + "CommitStats": OpAccessor, + "CommitsByMessage": OpAccessor, + "CountCommits": OpAccessor, + "CountDivergingCommits": OpAccessor, + "FilterShasWithSignatures": OpAccessor, + "FindAllCommits": OpAccessor, + "FindCommit": OpAccessor, + "FindCommits": OpAccessor, + "GetTreeEntries": OpAccessor, + "LastCommitForPath": OpAccessor, + "ListCommitsByOid": OpAccessor, + "ListFiles": OpAccessor, + "ListLastCommitsForTree": OpAccessor, + "RawBlame": OpAccessor, + "TreeEntry": OpAccessor, }, "ConflictsService": { - "ListConflictFiles": protoregistry.OpAccessor, - "ResolveConflicts": protoregistry.OpMutator, + "ListConflictFiles": OpAccessor, + "ResolveConflicts": OpMutator, }, "DiffService": { - "CommitDelta": protoregistry.OpAccessor, - "CommitDiff": protoregistry.OpAccessor, - "DiffStats": protoregistry.OpAccessor, - "RawDiff": protoregistry.OpAccessor, - "RawPatch": protoregistry.OpAccessor, + "CommitDelta": OpAccessor, + "CommitDiff": OpAccessor, + "DiffStats": OpAccessor, + "RawDiff": OpAccessor, + "RawPatch": OpAccessor, }, "NamespaceService": { - "AddNamespace": protoregistry.OpMutator, - "NamespaceExists": protoregistry.OpAccessor, - "RemoveNamespace": protoregistry.OpMutator, - "RenameNamespace": protoregistry.OpMutator, + "AddNamespace": OpMutator, + "NamespaceExists": OpAccessor, + "RemoveNamespace": OpMutator, + "RenameNamespace": OpMutator, }, "ObjectPoolService": { - "CreateObjectPool": protoregistry.OpMutator, - "DeleteObjectPool": protoregistry.OpMutator, - "DisconnectGitAlternates": protoregistry.OpMutator, - "LinkRepositoryToObjectPool": protoregistry.OpMutator, + "CreateObjectPool": OpMutator, + "DeleteObjectPool": OpMutator, + "DisconnectGitAlternates": OpMutator, + "LinkRepositoryToObjectPool": OpMutator, }, "OperationService": { - "UserApplyPatch": protoregistry.OpMutator, - "UserCherryPick": protoregistry.OpMutator, - "UserCommitFiles": protoregistry.OpMutator, - "UserCreateBranch": protoregistry.OpMutator, - "UserCreateTag": protoregistry.OpMutator, - "UserDeleteBranch": protoregistry.OpMutator, - "UserDeleteTag": protoregistry.OpMutator, - "UserFFBranch": protoregistry.OpMutator, - "UserMergeBranch": protoregistry.OpMutator, - "UserMergeToRef": protoregistry.OpMutator, - "UserRevert": protoregistry.OpMutator, - "UserSquash": protoregistry.OpMutator, - "UserUpdateBranch": protoregistry.OpMutator, - "UserUpdateSubmodule": protoregistry.OpMutator, + "UserApplyPatch": OpMutator, + "UserCherryPick": OpMutator, + "UserCommitFiles": OpMutator, + "UserCreateBranch": OpMutator, + "UserCreateTag": OpMutator, + "UserDeleteBranch": OpMutator, + "UserDeleteTag": OpMutator, + "UserFFBranch": OpMutator, + "UserMergeBranch": OpMutator, + "UserMergeToRef": OpMutator, + "UserRevert": OpMutator, + "UserSquash": OpMutator, + "UserUpdateBranch": OpMutator, + "UserUpdateSubmodule": OpMutator, }, "RefService": { - "DeleteRefs": protoregistry.OpMutator, - "FindAllBranchNames": protoregistry.OpAccessor, - "FindAllBranches": protoregistry.OpAccessor, - "FindAllRemoteBranches": protoregistry.OpAccessor, - "FindAllTagNames": protoregistry.OpAccessor, - "FindAllTags": protoregistry.OpAccessor, - "FindBranch": protoregistry.OpAccessor, - "FindDefaultBranchName": protoregistry.OpAccessor, - "FindLocalBranches": protoregistry.OpAccessor, - "GetTagMessages": protoregistry.OpAccessor, - "ListBranchNamesContainingCommit": protoregistry.OpAccessor, - "ListTagNamesContainingCommit": protoregistry.OpAccessor, - "RefExists": protoregistry.OpAccessor, + "DeleteRefs": OpMutator, + "FindAllBranchNames": OpAccessor, + "FindAllBranches": OpAccessor, + "FindAllRemoteBranches": OpAccessor, + "FindAllTagNames": OpAccessor, + "FindAllTags": OpAccessor, + "FindBranch": OpAccessor, + "FindDefaultBranchName": OpAccessor, + "FindLocalBranches": OpAccessor, + "GetTagMessages": OpAccessor, + "ListBranchNamesContainingCommit": OpAccessor, + "ListTagNamesContainingCommit": OpAccessor, + "RefExists": OpAccessor, }, "RemoteService": { - "FindRemoteRepository": protoregistry.OpAccessor, - "FindRemoteRootRef": protoregistry.OpAccessor, - "UpdateRemoteMirror": protoregistry.OpAccessor, + "FindRemoteRepository": OpAccessor, + "FindRemoteRootRef": OpAccessor, + "UpdateRemoteMirror": OpAccessor, }, "RepositoryService": { - "ApplyGitattributes": protoregistry.OpMutator, - "BackupCustomHooks": protoregistry.OpAccessor, - "CalculateChecksum": protoregistry.OpAccessor, - "CreateBundle": protoregistry.OpAccessor, - "CreateFork": protoregistry.OpMutator, - "CreateRepository": protoregistry.OpMutator, - "CreateRepositoryFromBundle": protoregistry.OpMutator, - "CreateRepositoryFromSnapshot": protoregistry.OpMutator, - "CreateRepositoryFromURL": protoregistry.OpMutator, - "FetchBundle": protoregistry.OpMutator, - "FetchRemote": protoregistry.OpMutator, - "FetchSourceBranch": protoregistry.OpMutator, - "FindLicense": protoregistry.OpAccessor, - "FindMergeBase": protoregistry.OpAccessor, - "Fsck": protoregistry.OpAccessor, - "GetArchive": protoregistry.OpAccessor, - "GetInfoAttributes": protoregistry.OpAccessor, - "GetRawChanges": protoregistry.OpAccessor, - "GetSnapshot": protoregistry.OpAccessor, - "HasLocalBranches": protoregistry.OpAccessor, - "OptimizeRepository": protoregistry.OpMaintenance, - "PruneUnreachableObjects": protoregistry.OpMaintenance, - "RepositoryExists": protoregistry.OpAccessor, - "RepositorySize": protoregistry.OpAccessor, - "RestoreCustomHooks": protoregistry.OpMutator, - "SearchFilesByContent": protoregistry.OpAccessor, - "SearchFilesByName": protoregistry.OpAccessor, - "WriteRef": protoregistry.OpMutator, + "ApplyGitattributes": OpMutator, + "BackupCustomHooks": OpAccessor, + "CalculateChecksum": OpAccessor, + "CreateBundle": OpAccessor, + "CreateFork": OpMutator, + "CreateRepository": OpMutator, + "CreateRepositoryFromBundle": OpMutator, + "CreateRepositoryFromSnapshot": OpMutator, + "CreateRepositoryFromURL": OpMutator, + "FetchBundle": OpMutator, + "FetchRemote": OpMutator, + "FetchSourceBranch": OpMutator, + "FindLicense": OpAccessor, + "FindMergeBase": OpAccessor, + "Fsck": OpAccessor, + "GetArchive": OpAccessor, + "GetInfoAttributes": OpAccessor, + "GetRawChanges": OpAccessor, + "GetSnapshot": OpAccessor, + "HasLocalBranches": OpAccessor, + "OptimizeRepository": OpMaintenance, + "PruneUnreachableObjects": OpMaintenance, + "RepositoryExists": OpAccessor, + "RepositorySize": OpAccessor, + "RestoreCustomHooks": OpMutator, + "SearchFilesByContent": OpAccessor, + "SearchFilesByName": OpAccessor, + "WriteRef": OpMutator, }, "SmartHTTPService": { - "InfoRefsReceivePack": protoregistry.OpAccessor, - "InfoRefsUploadPack": protoregistry.OpAccessor, - "PostReceivePack": protoregistry.OpMutator, - "PostUploadPackWithSidechannel": protoregistry.OpAccessor, + "InfoRefsReceivePack": OpAccessor, + "InfoRefsUploadPack": OpAccessor, + "PostReceivePack": OpMutator, + "PostUploadPackWithSidechannel": OpAccessor, }, "SSHService": { - "SSHReceivePack": protoregistry.OpMutator, - "SSHUploadArchive": protoregistry.OpAccessor, - "SSHUploadPack": protoregistry.OpAccessor, + "SSHReceivePack": OpMutator, + "SSHUploadArchive": OpAccessor, + "SSHUploadPack": OpAccessor, }, } @@ -145,17 +144,19 @@ func TestNewProtoRegistry(t *testing.T) { for methodName, opType := range methods { method := fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName) - methodInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(method) + methodInfo, err := GitalyProtoPreregistered.LookupMethod(method) require.NoError(t, err) require.Equalf(t, opType, methodInfo.Operation, "expect %s:%s to have the correct op type", serviceName, methodName) require.Equal(t, method, methodInfo.FullMethodName()) - require.False(t, protoregistry.GitalyProtoPreregistered.IsInterceptedMethod(method), method) + require.False(t, GitalyProtoPreregistered.IsInterceptedMethod(method), method) } } } func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) { + t.Parallel() + for service, methods := range map[string][]string{ "ServerService": { "ServerInfo", @@ -175,8 +176,8 @@ func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) { for _, method := range methods { t.Run(method, func(t *testing.T) { fullMethodName := fmt.Sprintf("/gitaly.%s/%s", service, method) - require.True(t, protoregistry.GitalyProtoPreregistered.IsInterceptedMethod(fullMethodName)) - methodInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fullMethodName) + require.True(t, GitalyProtoPreregistered.IsInterceptedMethod(fullMethodName)) + methodInfo, err := GitalyProtoPreregistered.LookupMethod(fullMethodName) require.Empty(t, methodInfo) require.Error(t, err, "full method name not found:") }) @@ -184,32 +185,3 @@ func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) { }) } } - -func TestRequestFactory(t *testing.T) { - mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod("/gitaly.RepositoryService/RepositoryExists") - require.NoError(t, err) - - pb, err := mInfo.UnmarshalRequestProto([]byte{}) - require.NoError(t, err) - - testhelper.ProtoEqual(t, &gitalypb.RepositoryExistsRequest{}, pb) -} - -func TestMethodInfoScope(t *testing.T) { - for _, tt := range []struct { - method string - scope protoregistry.Scope - }{ - { - method: "/gitaly.RepositoryService/RepositoryExists", - scope: protoregistry.ScopeRepository, - }, - } { - t.Run(tt.method, func(t *testing.T) { - mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(tt.method) - require.NoError(t, err) - - require.Exactly(t, tt.scope, mInfo.Scope) - }) - } -} |