Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'internal/praefect/protoregistry/protoregistry.go')
-rw-r--r--internal/praefect/protoregistry/protoregistry.go132
1 files changed, 72 insertions, 60 deletions
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go
index 3e3dc5da0..3fd3d64c2 100644
--- a/internal/praefect/protoregistry/protoregistry.go
+++ b/internal/praefect/protoregistry/protoregistry.go
@@ -3,11 +3,9 @@ package protoregistry
import (
"bytes"
"compress/gzip"
- "errors"
"fmt"
"io/ioutil"
"reflect"
- "strconv"
"strings"
"sync"
@@ -270,29 +268,45 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M
requestFactory: reqFactory,
}
+ topLevelMsgs, err := getTopLevelMsgs(p)
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
+ typeName, err := lastName(methodDesc.GetInputType())
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
if scope == ScopeRepository {
- targetRepo, err := parseOID(opMsg.GetTargetRepositoryField())
+ m := matcher{
+ match: getTargetRepositoryExtension,
+ subMatch: getRepositoryExtension,
+ expectedType: ".gitaly.Repository",
+ topLevelMsgs: topLevelMsgs,
+ }
+
+ targetRepo, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
+ if targetRepo == nil {
+ return MethodInfo{}, fmt.Errorf("unable to find target repository for method: %s", requestName)
+ }
mi.targetRepo = targetRepo
- if opMsg.GetAdditionalRepositoryField() != "" {
- mi.additionalRepo, err = parseOID(opMsg.GetAdditionalRepositoryField())
- if err != nil {
- return MethodInfo{}, err
- }
- }
- } else if scope == ScopeStorage {
- topLevelMsgs, err := getTopLevelMsgs(p)
+ m.match = getAdditionalRepositoryExtension
+ additionalRepo, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
- typeName, err := lastName(methodDesc.GetInputType())
- if err != nil {
- return MethodInfo{}, err
+ mi.additionalRepo = additionalRepo
+ } else if scope == ScopeStorage {
+ m := matcher{
+ match: getStorageExtension,
+ topLevelMsgs: topLevelMsgs,
}
- storage, err := findStorageField(topLevelMsgs, topLevelMsgs[typeName])
+ storage, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
@@ -337,13 +351,29 @@ func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor.
}
func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_Storage)
+}
+
+func getTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_TargetRepository)
+}
+
+func getAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_AdditionalRepository)
+}
+
+func getRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_Repository)
+}
+
+func getBoolExtension(m *descriptor.FieldDescriptorProto, extension *proto.ExtensionDesc) (bool, error) {
options := m.GetOptions()
- if !proto.HasExtension(options, gitalypb.E_Storage) {
+ if !proto.HasExtension(options, extension) {
return false, nil
}
- ext, err := proto.GetExtension(options, gitalypb.E_Storage)
+ ext, err := proto.GetExtension(options, extension)
if err != nil {
return false, err
}
@@ -360,28 +390,45 @@ func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
return *storageMsg, nil
}
-func findStorageField(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto) ([]int, error) {
+// Matcher helps find field matching credentials. At first match method is used to check fields
+// recursively. Then if field matches but type don't match expectedType subMatch method is used
+// from this point. This matcher assumes that only one field in the message matches the credentials.
+type matcher struct {
+ match func(*descriptor.FieldDescriptorProto) (bool, error)
+ subMatch func(*descriptor.FieldDescriptorProto) (bool, error)
+ expectedType string // fully qualified name of expected type e.g. ".gitaly.Repository"
+ topLevelMsgs map[string]*descriptor.DescriptorProto // Map of all top level messages in given file and it dependencies. Result of getTopLevelMsgs should be used.
+}
+
+func (m matcher) findField(t *descriptor.DescriptorProto) ([]int, error) {
for _, f := range t.GetField() {
- storage, err := getStorageExtension(f)
+ match, err := m.match(f)
if err != nil {
return nil, err
}
- if storage {
- return []int{int(f.GetNumber())}, nil
+ if match {
+ if f.GetTypeName() == m.expectedType {
+ return []int{int(f.GetNumber())}, nil
+ } else if m.subMatch != nil {
+ m.match = m.subMatch
+ m.subMatch = nil
+ } else {
+ return nil, fmt.Errorf("found wrong type, expected: %s, got: %s", m.expectedType, f.GetTypeName())
+ }
}
- childMsg, err := findChildMsg(topLevelMsgs, t, f)
+ childMsg, err := findChildMsg(m.topLevelMsgs, t, f)
if err != nil {
return nil, err
}
if childMsg != nil {
- nestedStorageField, err := findStorageField(topLevelMsgs, childMsg)
+ nestedField, err := m.findField(childMsg)
if err != nil {
return nil, err
}
- if nestedStorageField != nil {
- return append([]int{int(f.GetNumber())}, nestedStorageField...), nil
+ if nestedField != nil {
+ return append([]int{int(f.GetNumber())}, nestedField...), nil
}
}
}
@@ -424,41 +471,6 @@ func lastName(inputType string) (string, error) {
return msgName, nil
}
-// parses a string like "1.1" and returns a slice of ints
-func parseOID(rawFieldOID string) ([]int, error) {
- var fieldNos []int
-
- if rawFieldOID == "" {
- return fieldNos, nil
- }
-
- fieldNoStrs := strings.Split(rawFieldOID, ".")
-
- if len(fieldNoStrs) < 1 {
- return nil,
- fmt.Errorf("OID string contains no field numbers: %s", fieldNoStrs)
- }
-
- fieldNos = make([]int, len(fieldNoStrs))
-
- for i, fieldNoStr := range fieldNoStrs {
- fieldNo, err := strconv.Atoi(fieldNoStr)
- if err != nil {
- return nil,
- fmt.Errorf(
- "unable to parse target field OID %s: %s",
- rawFieldOID, err,
- )
- }
- if fieldNo == 0 {
- return nil, errors.New("zero is an invalid field number")
- }
- fieldNos[i] = fieldNo
- }
-
- return fieldNos, nil
-}
-
// LookupMethod looks up an MethodInfo by service and method name
func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) {
pr.RLock()