Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Nowotyński <maxmati4@gmail.com>2019-11-12 01:17:26 +0300
committerMateusz Nowotyński <maxmati4@gmail.com>2019-11-13 12:23:49 +0300
commit0d391bfa70d3d123a7454d1ddb8033110bc6c856 (patch)
tree9d7f58b933497f2f981ef70ea6f5769f1bbc2278
parenta45b8d9cb4e3ad6ddc7533cce277d4c29bc5ee22 (diff)
Protobuf registry extract storage name
-rw-r--r--internal/praefect/protoregistry/find_oid.go (renamed from internal/praefect/protoregistry/targetrepo.go)45
-rw-r--r--internal/praefect/protoregistry/find_oid_test.go (renamed from internal/praefect/protoregistry/targetrepo_test.go)48
-rw-r--r--internal/praefect/protoregistry/protoregistry.go153
3 files changed, 231 insertions, 15 deletions
diff --git a/internal/praefect/protoregistry/targetrepo.go b/internal/praefect/protoregistry/find_oid.go
index ebca49fee..c17324a96 100644
--- a/internal/praefect/protoregistry/targetrepo.go
+++ b/internal/praefect/protoregistry/find_oid.go
@@ -19,32 +19,51 @@ const (
// ErrTargetRepoMissing indicates that the target repo is missing or not set
var ErrTargetRepoMissing = errors.New("target repo is not set")
-// reflectFindRepoTarget finds the target repository by using the OID to
-// navigate the struct tags
-// Warning: this reflection filled function is full of forbidden dark elf magic
func reflectFindRepoTarget(pbMsg proto.Message, targetOID []int) (*gitalypb.Repository, error) {
- var targetRepo *gitalypb.Repository
+ msgV, e := reflectFindOID(pbMsg, targetOID)
+ if e != nil {
+ return nil, e
+ }
- msgV := reflect.ValueOf(pbMsg)
+ targetRepo, ok := msgV.Interface().(*gitalypb.Repository)
+ if !ok {
+ return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface())
+ }
+ return targetRepo, nil
+}
+
+func reflectFindStorage(pbMsg proto.Message, targetOID []int) (string, error) {
+ msgV, e := reflectFindOID(pbMsg, targetOID)
+ if e != nil {
+ return "", e
+ }
+
+ targetRepo, ok := msgV.Interface().(string)
+ if !ok {
+ return "", fmt.Errorf("repo target OID %v points to non-string type %+v", targetOID, msgV.Interface())
+ }
+
+ return targetRepo, nil
+}
+
+// reflectFindOID finds the target repository by using the OID to
+// navigate the struct tags
+// Warning: this reflection filled function is full of forbidden dark elf magic
+func reflectFindOID(pbMsg proto.Message, targetOID []int) (reflect.Value, error) {
+ msgV := reflect.ValueOf(pbMsg)
for _, fieldNo := range targetOID {
var err error
msgV, err = findProtoField(msgV, fieldNo)
if err != nil {
- return nil, fmt.Errorf(
+ return reflect.Value{}, fmt.Errorf(
"unable to descend OID %+v into message %s: %s",
targetOID, proto.MessageName(pbMsg), err,
)
}
}
-
- targetRepo, ok := msgV.Interface().(*gitalypb.Repository)
- if !ok {
- return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface())
- }
-
- return targetRepo, nil
+ return msgV, nil
}
// matches a tag string like "bytes,1,opt,name=repository,proto3"
diff --git a/internal/praefect/protoregistry/targetrepo_test.go b/internal/praefect/protoregistry/find_oid_test.go
index 1982b5d64..6c6a837eb 100644
--- a/internal/praefect/protoregistry/targetrepo_test.go
+++ b/internal/praefect/protoregistry/find_oid_test.go
@@ -117,3 +117,51 @@ func TestProtoRegistryTargetRepo(t *testing.T) {
})
}
}
+
+func TestProtoRegistryStorage(t *testing.T) {
+ r := protoregistry.New()
+ require.NoError(t, r.RegisterFiles(protoregistry.GitalyProtoFileDescriptors...))
+
+ testcases := []struct {
+ desc string
+ svc string
+ method string
+ pbMsg proto.Message
+ expectStorage string
+ expectErr error
+ }{
+ {
+ desc: "valid request type single depth",
+ svc: "NamespaceService",
+ method: "AddNamespace",
+ pbMsg: &gitalypb.AddNamespaceRequest{
+ StorageName: "some_storage",
+ },
+ expectStorage: "some_storage",
+ },
+ {
+ desc: "incorrect request type",
+ svc: "RepositoryService",
+ method: "RepackIncremental",
+ pbMsg: &gitalypb.RepackIncrementalResponse{},
+ expectErr: errors.New("proto message gitaly.RepackIncrementalResponse does not match expected RPC request message gitaly.RepackIncrementalRequest"),
+ },
+ }
+
+ for _, tc := range testcases {
+ desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc)
+ t.Run(desc, func(t *testing.T) {
+ info, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
+ require.NoError(t, err)
+
+ actualStorage, actualErr := info.Storage(tc.pbMsg)
+ require.Equal(t, tc.expectErr, actualErr)
+
+ // not only do we want the value to be the same, but we actually want the
+ // exact same instance to be returned
+ if tc.expectStorage != actualStorage {
+ t.Fatal("pointers do not match")
+ }
+ })
+ }
+}
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go
index d4cf34c69..3e3dc5da0 100644
--- a/internal/praefect/protoregistry/protoregistry.go
+++ b/internal/praefect/protoregistry/protoregistry.go
@@ -85,6 +85,7 @@ type MethodInfo struct {
additionalRepo []int
requestName string // protobuf message name for input type
requestFactory protoFactory
+ storage []int
}
// TargetRepo returns the target repository for a protobuf message if it exists
@@ -126,6 +127,18 @@ func (mi MethodInfo) getRepo(msg proto.Message, targetOid []int) (*gitalypb.Repo
}
}
+// Storage returns the storage name for a protobuf message if it exists
+func (mi MethodInfo) Storage(msg proto.Message) (string, error) {
+ if mi.requestName != proto.MessageName(msg) {
+ return "", fmt.Errorf(
+ "proto message %s does not match expected RPC request message %s",
+ proto.MessageName(msg), mi.requestName,
+ )
+ }
+
+ return reflectFindStorage(msg, mi.storage)
+}
+
// UnmarshalRequestProto will unmarshal the bytes into the method's request
// message type
func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) {
@@ -154,7 +167,7 @@ func (pr *Registry) RegisterFiles(protos ...*descriptor.FileDescriptorProto) err
for _, p := range protos {
for _, svc := range p.GetService() {
for _, method := range svc.GetMethod() {
- mi, err := parseMethodInfo(method)
+ mi, err := parseMethodInfo(p, method)
if err != nil {
return err
}
@@ -218,7 +231,7 @@ func methodReqFactory(method *descriptor.MethodDescriptorProto) (protoFactory, e
return f, nil
}
-func parseMethodInfo(methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, error) {
+func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, error) {
opMsg, err := getOpExtension(methodDesc)
if err != nil {
return MethodInfo{}, err
@@ -270,11 +283,147 @@ func parseMethodInfo(methodDesc *descriptor.MethodDescriptorProto) (MethodInfo,
return MethodInfo{}, err
}
}
+ } else if scope == ScopeStorage {
+ topLevelMsgs, err := getTopLevelMsgs(p)
+ if err != nil {
+ return MethodInfo{}, err
+ }
+ typeName, err := lastName(methodDesc.GetInputType())
+ if err != nil {
+ return MethodInfo{}, err
+ }
+ storage, err := findStorageField(topLevelMsgs, topLevelMsgs[typeName])
+ if err != nil {
+ return MethodInfo{}, err
+ }
+ if storage == nil {
+ return MethodInfo{}, fmt.Errorf("unable to find storage for method: %s", requestName)
+ }
+ mi.storage = storage
}
return mi, nil
}
+func getFileTypes(filename string) ([]*descriptor.DescriptorProto, error) {
+ sharedFD, err := ExtractFileDescriptor(proto.FileDescriptor(filename))
+ if err != nil {
+ return nil, err
+ }
+
+ types := sharedFD.GetMessageType()
+
+ for _, dep := range sharedFD.Dependency {
+ depTypes, err := getFileTypes(dep)
+ if err != nil {
+ return nil, err
+ }
+ types = append(types, depTypes...)
+ }
+
+ return types, nil
+}
+
+func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor.DescriptorProto, error) {
+ topLevelMsgs := map[string]*descriptor.DescriptorProto{}
+ types, err := getFileTypes(p.GetName())
+ if err != nil {
+ return nil, err
+ }
+ for _, msg := range types {
+ topLevelMsgs[msg.GetName()] = msg
+ }
+ return topLevelMsgs, nil
+}
+
+func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
+ options := m.GetOptions()
+
+ if !proto.HasExtension(options, gitalypb.E_Storage) {
+ return false, nil
+ }
+
+ ext, err := proto.GetExtension(options, gitalypb.E_Storage)
+ if err != nil {
+ return false, err
+ }
+
+ storageMsg, ok := ext.(*bool)
+ if !ok {
+ return false, fmt.Errorf("unable to obtain bool from %#v", ext)
+ }
+
+ if storageMsg == nil {
+ return false, nil
+ }
+
+ return *storageMsg, nil
+}
+
+func findStorageField(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto) ([]int, error) {
+ for _, f := range t.GetField() {
+ storage, err := getStorageExtension(f)
+ if err != nil {
+ return nil, err
+ }
+ if storage {
+ return []int{int(f.GetNumber())}, nil
+ }
+
+ childMsg, err := findChildMsg(topLevelMsgs, t, f)
+ if err != nil {
+ return nil, err
+ }
+
+ if childMsg != nil {
+ nestedStorageField, err := findStorageField(topLevelMsgs, childMsg)
+ if err != nil {
+ return nil, err
+ }
+ if nestedStorageField != nil {
+ return append([]int{int(f.GetNumber())}, nestedStorageField...), nil
+ }
+ }
+ }
+ return nil, nil
+}
+
+func findChildMsg(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto, f *descriptor.FieldDescriptorProto) (*descriptor.DescriptorProto, error) {
+ var childType *descriptor.DescriptorProto
+ const msgPrimitive = "TYPE_MESSAGE"
+ if primitive := f.GetType().String(); primitive != msgPrimitive {
+ return nil, nil
+ }
+
+ msgName, err := lastName(f.GetTypeName())
+ if err != nil {
+ return nil, err
+ }
+
+ for _, nestedType := range t.GetNestedType() {
+ if msgName == nestedType.GetName() {
+ return nestedType, nil
+ }
+ }
+
+ if childType = topLevelMsgs[msgName]; childType != nil {
+ return childType, nil
+ }
+
+ return nil, fmt.Errorf("could not find message type %q", msgName)
+}
+
+func lastName(inputType string) (string, error) {
+ tokens := strings.Split(inputType, ".")
+
+ msgName := tokens[len(tokens)-1]
+ if msgName == "" {
+ return "", fmt.Errorf("unable to parse method input type: %s", inputType)
+ }
+
+ return msgName, nil
+}
+
// parses a string like "1.1" and returns a slice of ints
func parseOID(rawFieldOID string) ([]int, error) {
var fieldNos []int