diff options
author | Mateusz Nowotyński <maxmati4@gmail.com> | 2019-11-29 02:24:46 +0300 |
---|---|---|
committer | jramsay <maxmati4@gmail.com> | 2020-02-06 22:29:01 +0300 |
commit | eef10e2217463c7c3582604ebdebae36679e43f4 (patch) | |
tree | 8781b89a8c0320449d17871fc0112364171e25f7 /internal/praefect/protoregistry | |
parent | fc5321467ae7f2a9d39e81f3f700495292fd785e (diff) |
Use field annotation for target and additional repository
Instead of setting OID in the RPC method use annotation in the field
(`target_repository` and `additional_repository`). Having only this 2
annotations created a problem with messages that can be either target
or additional repository (for example `ObjectPool`). Those are marked
with `repository` annotation and `target_repository` and
`additional_repository` are used in the parent messages.
Signed-off-by: Mateusz Nowotyński <maxmati4@gmail.com>
Signed-off-by: jramsay <maxmati4@gmail.com>
Diffstat (limited to 'internal/praefect/protoregistry')
-rw-r--r-- | internal/praefect/protoregistry/protoregistry.go | 132 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry_internal_test.go | 42 |
2 files changed, 72 insertions, 102 deletions
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go index 3e3dc5da0..3fd3d64c2 100644 --- a/internal/praefect/protoregistry/protoregistry.go +++ b/internal/praefect/protoregistry/protoregistry.go @@ -3,11 +3,9 @@ package protoregistry import ( "bytes" "compress/gzip" - "errors" "fmt" "io/ioutil" "reflect" - "strconv" "strings" "sync" @@ -270,29 +268,45 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M requestFactory: reqFactory, } + topLevelMsgs, err := getTopLevelMsgs(p) + if err != nil { + return MethodInfo{}, err + } + + typeName, err := lastName(methodDesc.GetInputType()) + if err != nil { + return MethodInfo{}, err + } + if scope == ScopeRepository { - targetRepo, err := parseOID(opMsg.GetTargetRepositoryField()) + m := matcher{ + match: getTargetRepositoryExtension, + subMatch: getRepositoryExtension, + expectedType: ".gitaly.Repository", + topLevelMsgs: topLevelMsgs, + } + + targetRepo, err := m.findField(topLevelMsgs[typeName]) if err != nil { return MethodInfo{}, err } + if targetRepo == nil { + return MethodInfo{}, fmt.Errorf("unable to find target repository for method: %s", requestName) + } mi.targetRepo = targetRepo - if opMsg.GetAdditionalRepositoryField() != "" { - mi.additionalRepo, err = parseOID(opMsg.GetAdditionalRepositoryField()) - if err != nil { - return MethodInfo{}, err - } - } - } else if scope == ScopeStorage { - topLevelMsgs, err := getTopLevelMsgs(p) + m.match = getAdditionalRepositoryExtension + additionalRepo, err := m.findField(topLevelMsgs[typeName]) if err != nil { return MethodInfo{}, err } - typeName, err := lastName(methodDesc.GetInputType()) - if err != nil { - return MethodInfo{}, err + mi.additionalRepo = additionalRepo + } else if scope == ScopeStorage { + m := matcher{ + match: getStorageExtension, + topLevelMsgs: topLevelMsgs, } - storage, err := findStorageField(topLevelMsgs, topLevelMsgs[typeName]) + storage, err := m.findField(topLevelMsgs[typeName]) if err != nil { return MethodInfo{}, err } @@ -337,13 +351,29 @@ func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor. } func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) { + return getBoolExtension(m, gitalypb.E_Storage) +} + +func getTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { + return getBoolExtension(m, gitalypb.E_TargetRepository) +} + +func getAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { + return getBoolExtension(m, gitalypb.E_AdditionalRepository) +} + +func getRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { + return getBoolExtension(m, gitalypb.E_Repository) +} + +func getBoolExtension(m *descriptor.FieldDescriptorProto, extension *proto.ExtensionDesc) (bool, error) { options := m.GetOptions() - if !proto.HasExtension(options, gitalypb.E_Storage) { + if !proto.HasExtension(options, extension) { return false, nil } - ext, err := proto.GetExtension(options, gitalypb.E_Storage) + ext, err := proto.GetExtension(options, extension) if err != nil { return false, err } @@ -360,28 +390,45 @@ func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) { return *storageMsg, nil } -func findStorageField(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto) ([]int, error) { +// Matcher helps find field matching credentials. At first match method is used to check fields +// recursively. Then if field matches but type don't match expectedType subMatch method is used +// from this point. This matcher assumes that only one field in the message matches the credentials. +type matcher struct { + match func(*descriptor.FieldDescriptorProto) (bool, error) + subMatch func(*descriptor.FieldDescriptorProto) (bool, error) + expectedType string // fully qualified name of expected type e.g. ".gitaly.Repository" + topLevelMsgs map[string]*descriptor.DescriptorProto // Map of all top level messages in given file and it dependencies. Result of getTopLevelMsgs should be used. +} + +func (m matcher) findField(t *descriptor.DescriptorProto) ([]int, error) { for _, f := range t.GetField() { - storage, err := getStorageExtension(f) + match, err := m.match(f) if err != nil { return nil, err } - if storage { - return []int{int(f.GetNumber())}, nil + if match { + if f.GetTypeName() == m.expectedType { + return []int{int(f.GetNumber())}, nil + } else if m.subMatch != nil { + m.match = m.subMatch + m.subMatch = nil + } else { + return nil, fmt.Errorf("found wrong type, expected: %s, got: %s", m.expectedType, f.GetTypeName()) + } } - childMsg, err := findChildMsg(topLevelMsgs, t, f) + childMsg, err := findChildMsg(m.topLevelMsgs, t, f) if err != nil { return nil, err } if childMsg != nil { - nestedStorageField, err := findStorageField(topLevelMsgs, childMsg) + nestedField, err := m.findField(childMsg) if err != nil { return nil, err } - if nestedStorageField != nil { - return append([]int{int(f.GetNumber())}, nestedStorageField...), nil + if nestedField != nil { + return append([]int{int(f.GetNumber())}, nestedField...), nil } } } @@ -424,41 +471,6 @@ func lastName(inputType string) (string, error) { return msgName, nil } -// parses a string like "1.1" and returns a slice of ints -func parseOID(rawFieldOID string) ([]int, error) { - var fieldNos []int - - if rawFieldOID == "" { - return fieldNos, nil - } - - fieldNoStrs := strings.Split(rawFieldOID, ".") - - if len(fieldNoStrs) < 1 { - return nil, - fmt.Errorf("OID string contains no field numbers: %s", fieldNoStrs) - } - - fieldNos = make([]int, len(fieldNoStrs)) - - for i, fieldNoStr := range fieldNoStrs { - fieldNo, err := strconv.Atoi(fieldNoStr) - if err != nil { - return nil, - fmt.Errorf( - "unable to parse target field OID %s: %s", - rawFieldOID, err, - ) - } - if fieldNo == 0 { - return nil, errors.New("zero is an invalid field number") - } - fieldNos[i] = fieldNo - } - - return fieldNos, nil -} - // LookupMethod looks up an MethodInfo by service and method name func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) { pr.RLock() diff --git a/internal/praefect/protoregistry/protoregistry_internal_test.go b/internal/praefect/protoregistry/protoregistry_internal_test.go deleted file mode 100644 index b04b864bf..000000000 --- a/internal/praefect/protoregistry/protoregistry_internal_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package protoregistry - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseOID(t *testing.T) { - for _, tt := range []struct { - raw string - expectOID []int - expectErr error - }{ - { - raw: "", - }, - { - raw: "1", - expectOID: []int{1}, - }, - { - raw: "1.1", - expectOID: []int{1, 1}, - }, - { - raw: "1.2.1", - expectOID: []int{1, 2, 1}, - }, - { - raw: "a.b.c", - expectErr: errors.New("unable to parse target field OID a.b.c: strconv.Atoi: parsing \"a\": invalid syntax"), - }, - } { - t.Run(tt.raw, func(t *testing.T) { - actualOID, actualErr := parseOID(tt.raw) - require.Equal(t, tt.expectOID, actualOID) - require.Equal(t, tt.expectErr, actualErr) - }) - } -} |