diff options
author | Sami Hiltunen <shiltunen@gitlab.com> | 2020-09-15 17:56:52 +0300 |
---|---|---|
committer | Sami Hiltunen <shiltunen@gitlab.com> | 2020-09-17 20:10:03 +0300 |
commit | 769b6ffbc54a7d9faf0dd25f8e20434e5972e7ca (patch) | |
tree | c53ee84b57ae95d95fff38836d8f1f175001725f /internal/praefect/protoregistry | |
parent | 6206616d93ece66349075c83d553a0a9af3e9cad (diff) |
add intercepted option to mark a service handled by praefect
Adds an option for marking a gRPC service as handled by Praefect.
Services handled by Praefect do not need operation or scope
annotations as they are only used for proxying logic.
Diffstat (limited to 'internal/praefect/protoregistry')
-rw-r--r-- | internal/praefect/protoregistry/protoregistry.go | 100 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry_test.go | 38 |
2 files changed, 61 insertions, 77 deletions
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go index 7acc94d03..62fc906f9 100644 --- a/internal/praefect/protoregistry/protoregistry.go +++ b/internal/praefect/protoregistry/protoregistry.go @@ -10,6 +10,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" + "gitlab.com/gitlab-org/gitaly/internal/protoutil" "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" ) @@ -171,49 +172,44 @@ func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) { // Registry contains info about RPC methods type Registry struct { protos map[string]MethodInfo + // interceptedMethods contains the set of methods which are intercepted + // by Praefect instead of proxying. + interceptedMethods map[string]struct{} } // New creates a new ProtoRegistry with info from one or more descriptor.FileDescriptorProto func New(protos ...*descriptor.FileDescriptorProto) (*Registry, error) { methods := make(map[string]MethodInfo) + interceptedMethods := make(map[string]struct{}) for _, p := range protos { for _, svc := range p.GetService() { for _, method := range svc.GetMethod() { + fullMethodName := fmt.Sprintf("/%s.%s/%s", + p.GetPackage(), svc.GetName(), method.GetName(), + ) + + if intercepted, err := protoutil.IsInterceptedService(svc); err != nil { + return nil, fmt.Errorf("is intercepted: %w", err) + } else if intercepted { + interceptedMethods[fullMethodName] = struct{}{} + continue + } + mi, err := parseMethodInfo(p, method) if err != nil { return nil, err } - fullMethodName := fmt.Sprintf( - "/%s.%s/%s", - p.GetPackage(), svc.GetName(), method.GetName(), - ) methods[fullMethodName] = mi } } } - return &Registry{protos: methods}, nil -} - -func getOpExtension(m *descriptor.MethodDescriptorProto) (*gitalypb.OperationMsg, error) { - options := m.GetOptions() - - if !proto.HasExtension(options, gitalypb.E_OpType) { - return nil, fmt.Errorf("method %s missing op_type option", m.GetName()) - } - - ext, err := proto.GetExtension(options, gitalypb.E_OpType) - if err != nil { - return nil, fmt.Errorf("unable to get Gitaly custom OpType extension: %s", err) - } - - opMsg, ok := ext.(*gitalypb.OperationMsg) - if !ok { - return nil, fmt.Errorf("unable to obtain OperationMsg from %#v", ext) - } - return opMsg, nil + return &Registry{ + protos: methods, + interceptedMethods: interceptedMethods, + }, nil } type protoFactory func([]byte) (proto.Message, error) @@ -245,7 +241,7 @@ func methodReqFactory(method *descriptor.MethodDescriptorProto) (protoFactory, e } func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, error) { - opMsg, err := getOpExtension(methodDesc) + opMsg, err := protoutil.GetOpExtension(methodDesc) if err != nil { return MethodInfo{}, err } @@ -295,8 +291,8 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M if scope == ScopeRepository { m := matcher{ - match: getTargetRepositoryExtension, - subMatch: getRepositoryExtension, + match: protoutil.GetTargetRepositoryExtension, + subMatch: protoutil.GetRepositoryExtension, expectedType: ".gitaly.Repository", topLevelMsgs: topLevelMsgs, } @@ -310,7 +306,7 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M } mi.targetRepo = targetRepo - m.match = getAdditionalRepositoryExtension + m.match = protoutil.GetAdditionalRepositoryExtension additionalRepo, err := m.findField(topLevelMsgs[typeName]) if err != nil { return MethodInfo{}, err @@ -318,7 +314,7 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M mi.additionalRepo = additionalRepo } else if scope == ScopeStorage { m := matcher{ - match: getStorageExtension, + match: protoutil.GetStorageExtension, topLevelMsgs: topLevelMsgs, } storage, err := m.findField(topLevelMsgs[typeName]) @@ -365,46 +361,6 @@ func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor. return topLevelMsgs, nil } -func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) { - return getBoolExtension(m, gitalypb.E_Storage) -} - -func getTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { - return getBoolExtension(m, gitalypb.E_TargetRepository) -} - -func getAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { - return getBoolExtension(m, gitalypb.E_AdditionalRepository) -} - -func getRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { - return getBoolExtension(m, gitalypb.E_Repository) -} - -func getBoolExtension(m *descriptor.FieldDescriptorProto, extension *proto.ExtensionDesc) (bool, error) { - options := m.GetOptions() - - if !proto.HasExtension(options, extension) { - return false, nil - } - - ext, err := proto.GetExtension(options, extension) - if err != nil { - return false, err - } - - storageMsg, ok := ext.(*bool) - if !ok { - return false, fmt.Errorf("unable to obtain bool from %#v", ext) - } - - if storageMsg == nil { - return false, nil - } - - return *storageMsg, nil -} - // Matcher helps find field matching credentials. At first match method is used to check fields // recursively. Then if field matches but type don't match expectedType subMatch method is used // from this point. This matcher assumes that only one field in the message matches the credentials. @@ -495,6 +451,12 @@ func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) { return methodInfo, nil } +// IsInterceptedMethod returns whether Praefect intercepts the method call instead of proxying it. +func (pr *Registry) IsInterceptedMethod(fullMethodName string) bool { + _, ok := pr.interceptedMethods[fullMethodName] + return ok +} + // ExtractFileDescriptor extracts a FileDescriptorProto from a gzip'd buffer. // https://github.com/golang/protobuf/blob/9eb2c01ac278a5d89ce4b2be68fe4500955d8179/descriptor/descriptor.go#L50 func ExtractFileDescriptor(gz []byte) (*descriptor.FileDescriptorProto, error) { diff --git a/internal/praefect/protoregistry/protoregistry_test.go b/internal/praefect/protoregistry/protoregistry_test.go index 030c971c1..df7874404 100644 --- a/internal/praefect/protoregistry/protoregistry_test.go +++ b/internal/praefect/protoregistry/protoregistry_test.go @@ -151,9 +151,6 @@ func TestNewProtoRegistry(t *testing.T) { "BackupCustomHooks": protoregistry.OpAccessor, "FetchHTTPRemote": protoregistry.OpMutator, }, - "ServerService": map[string]protoregistry.OpType{ - "ServerInfo": protoregistry.OpAccessor, - }, "SmartHTTPService": map[string]protoregistry.OpType{ "InfoRefsUploadPack": protoregistry.OpAccessor, "InfoRefsReceivePack": protoregistry.OpAccessor, @@ -179,13 +176,42 @@ func TestNewProtoRegistry(t *testing.T) { for serviceName, methods := range expectedResults { for methodName, opType := range methods { - methodInfo, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName)) + method := fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName) + methodInfo, err := r.LookupMethod(method) require.NoError(t, err) assert.Equalf(t, opType, methodInfo.Operation, "expect %s:%s to have the correct op type", serviceName, methodName) + require.False(t, r.IsInterceptedMethod(method), method) } } } +func TestNewProtoRegistry_IsInterceptedMethod(t *testing.T) { + for service, methods := range map[string][]string{ + "ServerService": { + "ServerInfo", + "DiskStatistics", + }, + "PraefectInfoService": { + "RepositoryReplicas", + "ConsistencyCheck", + "DatalossCheck", + "SetAuthoritativeStorage", + }, + } { + t.Run(service, func(t *testing.T) { + for _, method := range methods { + t.Run(method, func(t *testing.T) { + fullMethodName := fmt.Sprintf("/gitaly.%s/%s", service, method) + require.True(t, protoregistry.GitalyProtoPreregistered.IsInterceptedMethod(fullMethodName)) + methodInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(fullMethodName) + require.Empty(t, methodInfo) + require.Error(t, err, "full method name not found:") + }) + } + }) + } +} + func TestRequestFactory(t *testing.T) { mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod("/gitaly.RepositoryService/RepositoryExists") require.NoError(t, err) @@ -205,10 +231,6 @@ func TestMethodInfoScope(t *testing.T) { method: "/gitaly.RepositoryService/RepositoryExists", scope: protoregistry.ScopeRepository, }, - { - method: "/gitaly.ServerService/ServerInfo", - scope: protoregistry.ScopeServer, - }, } { t.Run(tt.method, func(t *testing.T) { mInfo, err := protoregistry.GitalyProtoPreregistered.LookupMethod(tt.method) |