Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Nowotyński <maxmati4@gmail.com>2019-11-29 02:24:46 +0300
committerjramsay <maxmati4@gmail.com>2020-02-06 22:29:01 +0300
commiteef10e2217463c7c3582604ebdebae36679e43f4 (patch)
tree8781b89a8c0320449d17871fc0112364171e25f7 /internal/praefect/protoregistry
parentfc5321467ae7f2a9d39e81f3f700495292fd785e (diff)
Use field annotation for target and additional repository
Instead of setting OID in the RPC method use annotation in the field (`target_repository` and `additional_repository`). Having only this 2 annotations created a problem with messages that can be either target or additional repository (for example `ObjectPool`). Those are marked with `repository` annotation and `target_repository` and `additional_repository` are used in the parent messages. Signed-off-by: Mateusz Nowotyński <maxmati4@gmail.com> Signed-off-by: jramsay <maxmati4@gmail.com>
Diffstat (limited to 'internal/praefect/protoregistry')
-rw-r--r--internal/praefect/protoregistry/protoregistry.go132
-rw-r--r--internal/praefect/protoregistry/protoregistry_internal_test.go42
2 files changed, 72 insertions, 102 deletions
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go
index 3e3dc5da0..3fd3d64c2 100644
--- a/internal/praefect/protoregistry/protoregistry.go
+++ b/internal/praefect/protoregistry/protoregistry.go
@@ -3,11 +3,9 @@ package protoregistry
import (
"bytes"
"compress/gzip"
- "errors"
"fmt"
"io/ioutil"
"reflect"
- "strconv"
"strings"
"sync"
@@ -270,29 +268,45 @@ func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.M
requestFactory: reqFactory,
}
+ topLevelMsgs, err := getTopLevelMsgs(p)
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
+ typeName, err := lastName(methodDesc.GetInputType())
+ if err != nil {
+ return MethodInfo{}, err
+ }
+
if scope == ScopeRepository {
- targetRepo, err := parseOID(opMsg.GetTargetRepositoryField())
+ m := matcher{
+ match: getTargetRepositoryExtension,
+ subMatch: getRepositoryExtension,
+ expectedType: ".gitaly.Repository",
+ topLevelMsgs: topLevelMsgs,
+ }
+
+ targetRepo, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
+ if targetRepo == nil {
+ return MethodInfo{}, fmt.Errorf("unable to find target repository for method: %s", requestName)
+ }
mi.targetRepo = targetRepo
- if opMsg.GetAdditionalRepositoryField() != "" {
- mi.additionalRepo, err = parseOID(opMsg.GetAdditionalRepositoryField())
- if err != nil {
- return MethodInfo{}, err
- }
- }
- } else if scope == ScopeStorage {
- topLevelMsgs, err := getTopLevelMsgs(p)
+ m.match = getAdditionalRepositoryExtension
+ additionalRepo, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
- typeName, err := lastName(methodDesc.GetInputType())
- if err != nil {
- return MethodInfo{}, err
+ mi.additionalRepo = additionalRepo
+ } else if scope == ScopeStorage {
+ m := matcher{
+ match: getStorageExtension,
+ topLevelMsgs: topLevelMsgs,
}
- storage, err := findStorageField(topLevelMsgs, topLevelMsgs[typeName])
+ storage, err := m.findField(topLevelMsgs[typeName])
if err != nil {
return MethodInfo{}, err
}
@@ -337,13 +351,29 @@ func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor.
}
func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_Storage)
+}
+
+func getTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_TargetRepository)
+}
+
+func getAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_AdditionalRepository)
+}
+
+func getRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ return getBoolExtension(m, gitalypb.E_Repository)
+}
+
+func getBoolExtension(m *descriptor.FieldDescriptorProto, extension *proto.ExtensionDesc) (bool, error) {
options := m.GetOptions()
- if !proto.HasExtension(options, gitalypb.E_Storage) {
+ if !proto.HasExtension(options, extension) {
return false, nil
}
- ext, err := proto.GetExtension(options, gitalypb.E_Storage)
+ ext, err := proto.GetExtension(options, extension)
if err != nil {
return false, err
}
@@ -360,28 +390,45 @@ func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
return *storageMsg, nil
}
-func findStorageField(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto) ([]int, error) {
+// Matcher helps find field matching credentials. At first match method is used to check fields
+// recursively. Then if field matches but type don't match expectedType subMatch method is used
+// from this point. This matcher assumes that only one field in the message matches the credentials.
+type matcher struct {
+ match func(*descriptor.FieldDescriptorProto) (bool, error)
+ subMatch func(*descriptor.FieldDescriptorProto) (bool, error)
+ expectedType string // fully qualified name of expected type e.g. ".gitaly.Repository"
+ topLevelMsgs map[string]*descriptor.DescriptorProto // Map of all top level messages in given file and it dependencies. Result of getTopLevelMsgs should be used.
+}
+
+func (m matcher) findField(t *descriptor.DescriptorProto) ([]int, error) {
for _, f := range t.GetField() {
- storage, err := getStorageExtension(f)
+ match, err := m.match(f)
if err != nil {
return nil, err
}
- if storage {
- return []int{int(f.GetNumber())}, nil
+ if match {
+ if f.GetTypeName() == m.expectedType {
+ return []int{int(f.GetNumber())}, nil
+ } else if m.subMatch != nil {
+ m.match = m.subMatch
+ m.subMatch = nil
+ } else {
+ return nil, fmt.Errorf("found wrong type, expected: %s, got: %s", m.expectedType, f.GetTypeName())
+ }
}
- childMsg, err := findChildMsg(topLevelMsgs, t, f)
+ childMsg, err := findChildMsg(m.topLevelMsgs, t, f)
if err != nil {
return nil, err
}
if childMsg != nil {
- nestedStorageField, err := findStorageField(topLevelMsgs, childMsg)
+ nestedField, err := m.findField(childMsg)
if err != nil {
return nil, err
}
- if nestedStorageField != nil {
- return append([]int{int(f.GetNumber())}, nestedStorageField...), nil
+ if nestedField != nil {
+ return append([]int{int(f.GetNumber())}, nestedField...), nil
}
}
}
@@ -424,41 +471,6 @@ func lastName(inputType string) (string, error) {
return msgName, nil
}
-// parses a string like "1.1" and returns a slice of ints
-func parseOID(rawFieldOID string) ([]int, error) {
- var fieldNos []int
-
- if rawFieldOID == "" {
- return fieldNos, nil
- }
-
- fieldNoStrs := strings.Split(rawFieldOID, ".")
-
- if len(fieldNoStrs) < 1 {
- return nil,
- fmt.Errorf("OID string contains no field numbers: %s", fieldNoStrs)
- }
-
- fieldNos = make([]int, len(fieldNoStrs))
-
- for i, fieldNoStr := range fieldNoStrs {
- fieldNo, err := strconv.Atoi(fieldNoStr)
- if err != nil {
- return nil,
- fmt.Errorf(
- "unable to parse target field OID %s: %s",
- rawFieldOID, err,
- )
- }
- if fieldNo == 0 {
- return nil, errors.New("zero is an invalid field number")
- }
- fieldNos[i] = fieldNo
- }
-
- return fieldNos, nil
-}
-
// LookupMethod looks up an MethodInfo by service and method name
func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) {
pr.RLock()
diff --git a/internal/praefect/protoregistry/protoregistry_internal_test.go b/internal/praefect/protoregistry/protoregistry_internal_test.go
deleted file mode 100644
index b04b864bf..000000000
--- a/internal/praefect/protoregistry/protoregistry_internal_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package protoregistry
-
-import (
- "errors"
- "testing"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestParseOID(t *testing.T) {
- for _, tt := range []struct {
- raw string
- expectOID []int
- expectErr error
- }{
- {
- raw: "",
- },
- {
- raw: "1",
- expectOID: []int{1},
- },
- {
- raw: "1.1",
- expectOID: []int{1, 1},
- },
- {
- raw: "1.2.1",
- expectOID: []int{1, 2, 1},
- },
- {
- raw: "a.b.c",
- expectErr: errors.New("unable to parse target field OID a.b.c: strconv.Atoi: parsing \"a\": invalid syntax"),
- },
- } {
- t.Run(tt.raw, func(t *testing.T) {
- actualOID, actualErr := parseOID(tt.raw)
- require.Equal(t, tt.expectOID, actualOID)
- require.Equal(t, tt.expectErr, actualErr)
- })
- }
-}