diff options
author | Mateusz Nowotyński <maxmati4@gmail.com> | 2019-11-12 01:17:26 +0300 |
---|---|---|
committer | Mateusz Nowotyński <maxmati4@gmail.com> | 2019-11-13 12:23:49 +0300 |
commit | 0d391bfa70d3d123a7454d1ddb8033110bc6c856 (patch) | |
tree | 9d7f58b933497f2f981ef70ea6f5769f1bbc2278 | |
parent | a45b8d9cb4e3ad6ddc7533cce277d4c29bc5ee22 (diff) |
Protobuf registry extract storage name
-rw-r--r-- | internal/praefect/protoregistry/find_oid.go (renamed from internal/praefect/protoregistry/targetrepo.go) | 45 | ||||
-rw-r--r-- | internal/praefect/protoregistry/find_oid_test.go (renamed from internal/praefect/protoregistry/targetrepo_test.go) | 48 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry.go | 153 |
3 files changed, 231 insertions, 15 deletions
diff --git a/internal/praefect/protoregistry/targetrepo.go b/internal/praefect/protoregistry/find_oid.go index ebca49fee..c17324a96 100644 --- a/internal/praefect/protoregistry/targetrepo.go +++ b/internal/praefect/protoregistry/find_oid.go @@ -19,32 +19,51 @@ const ( // ErrTargetRepoMissing indicates that the target repo is missing or not set var ErrTargetRepoMissing = errors.New("target repo is not set") -// reflectFindRepoTarget finds the target repository by using the OID to -// navigate the struct tags -// Warning: this reflection filled function is full of forbidden dark elf magic func reflectFindRepoTarget(pbMsg proto.Message, targetOID []int) (*gitalypb.Repository, error) { - var targetRepo *gitalypb.Repository + msgV, e := reflectFindOID(pbMsg, targetOID) + if e != nil { + return nil, e + } - msgV := reflect.ValueOf(pbMsg) + targetRepo, ok := msgV.Interface().(*gitalypb.Repository) + if !ok { + return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface()) + } + return targetRepo, nil +} + +func reflectFindStorage(pbMsg proto.Message, targetOID []int) (string, error) { + msgV, e := reflectFindOID(pbMsg, targetOID) + if e != nil { + return "", e + } + + targetRepo, ok := msgV.Interface().(string) + if !ok { + return "", fmt.Errorf("repo target OID %v points to non-string type %+v", targetOID, msgV.Interface()) + } + + return targetRepo, nil +} + +// reflectFindOID finds the target repository by using the OID to +// navigate the struct tags +// Warning: this reflection filled function is full of forbidden dark elf magic +func reflectFindOID(pbMsg proto.Message, targetOID []int) (reflect.Value, error) { + msgV := reflect.ValueOf(pbMsg) for _, fieldNo := range targetOID { var err error msgV, err = findProtoField(msgV, fieldNo) if err != nil { - return nil, fmt.Errorf( + return reflect.Value{}, fmt.Errorf( "unable to descend OID %+v into message %s: %s", targetOID, proto.MessageName(pbMsg), err, ) } } - - targetRepo, ok := msgV.Interface().(*gitalypb.Repository) - if !ok { - return nil, fmt.Errorf("repo target OID %v points to non-Repo type %+v", targetOID, msgV.Interface()) - } - - return targetRepo, nil + return msgV, nil } // matches a tag string like "bytes,1,opt,name=repository,proto3" diff --git a/internal/praefect/protoregistry/targetrepo_test.go b/internal/praefect/protoregistry/find_oid_test.go index 1982b5d64..6c6a837eb 100644 --- a/internal/praefect/protoregistry/targetrepo_test.go +++ b/internal/praefect/protoregistry/find_oid_test.go @@ -117,3 +117,51 @@ func TestProtoRegistryTargetRepo(t *testing.T) { }) } } + +func TestProtoRegistryStorage(t *testing.T) { + r := protoregistry.New() + require.NoError(t, r.RegisterFiles(protoregistry.GitalyProtoFileDescriptors...)) + + testcases := []struct { + desc string + svc string + method string + pbMsg proto.Message + expectStorage string + expectErr error + }{ + { + desc: "valid request type single depth", + svc: "NamespaceService", + method: "AddNamespace", + pbMsg: &gitalypb.AddNamespaceRequest{ + StorageName: "some_storage", + }, + expectStorage: "some_storage", + }, + { + desc: "incorrect request type", + svc: "RepositoryService", + method: "RepackIncremental", + pbMsg: &gitalypb.RepackIncrementalResponse{}, + expectErr: errors.New("proto message gitaly.RepackIncrementalResponse does not match expected RPC request message gitaly.RepackIncrementalRequest"), + }, + } + + for _, tc := range testcases { + desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc) + t.Run(desc, func(t *testing.T) { + info, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) + require.NoError(t, err) + + actualStorage, actualErr := info.Storage(tc.pbMsg) + require.Equal(t, tc.expectErr, actualErr) + + // not only do we want the value to be the same, but we actually want the + // exact same instance to be returned + if tc.expectStorage != actualStorage { + t.Fatal("pointers do not match") + } + }) + } +} diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go index d4cf34c69..3e3dc5da0 100644 --- a/internal/praefect/protoregistry/protoregistry.go +++ b/internal/praefect/protoregistry/protoregistry.go @@ -85,6 +85,7 @@ type MethodInfo struct { additionalRepo []int requestName string // protobuf message name for input type requestFactory protoFactory + storage []int } // TargetRepo returns the target repository for a protobuf message if it exists @@ -126,6 +127,18 @@ func (mi MethodInfo) getRepo(msg proto.Message, targetOid []int) (*gitalypb.Repo } } +// Storage returns the storage name for a protobuf message if it exists +func (mi MethodInfo) Storage(msg proto.Message) (string, error) { + if mi.requestName != proto.MessageName(msg) { + return "", fmt.Errorf( + "proto message %s does not match expected RPC request message %s", + proto.MessageName(msg), mi.requestName, + ) + } + + return reflectFindStorage(msg, mi.storage) +} + // UnmarshalRequestProto will unmarshal the bytes into the method's request // message type func (mi MethodInfo) UnmarshalRequestProto(b []byte) (proto.Message, error) { @@ -154,7 +167,7 @@ func (pr *Registry) RegisterFiles(protos ...*descriptor.FileDescriptorProto) err for _, p := range protos { for _, svc := range p.GetService() { for _, method := range svc.GetMethod() { - mi, err := parseMethodInfo(method) + mi, err := parseMethodInfo(p, method) if err != nil { return err } @@ -218,7 +231,7 @@ func methodReqFactory(method *descriptor.MethodDescriptorProto) (protoFactory, e return f, nil } -func parseMethodInfo(methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, error) { +func parseMethodInfo(p *descriptor.FileDescriptorProto, methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, error) { opMsg, err := getOpExtension(methodDesc) if err != nil { return MethodInfo{}, err @@ -270,11 +283,147 @@ func parseMethodInfo(methodDesc *descriptor.MethodDescriptorProto) (MethodInfo, return MethodInfo{}, err } } + } else if scope == ScopeStorage { + topLevelMsgs, err := getTopLevelMsgs(p) + if err != nil { + return MethodInfo{}, err + } + typeName, err := lastName(methodDesc.GetInputType()) + if err != nil { + return MethodInfo{}, err + } + storage, err := findStorageField(topLevelMsgs, topLevelMsgs[typeName]) + if err != nil { + return MethodInfo{}, err + } + if storage == nil { + return MethodInfo{}, fmt.Errorf("unable to find storage for method: %s", requestName) + } + mi.storage = storage } return mi, nil } +func getFileTypes(filename string) ([]*descriptor.DescriptorProto, error) { + sharedFD, err := ExtractFileDescriptor(proto.FileDescriptor(filename)) + if err != nil { + return nil, err + } + + types := sharedFD.GetMessageType() + + for _, dep := range sharedFD.Dependency { + depTypes, err := getFileTypes(dep) + if err != nil { + return nil, err + } + types = append(types, depTypes...) + } + + return types, nil +} + +func getTopLevelMsgs(p *descriptor.FileDescriptorProto) (map[string]*descriptor.DescriptorProto, error) { + topLevelMsgs := map[string]*descriptor.DescriptorProto{} + types, err := getFileTypes(p.GetName()) + if err != nil { + return nil, err + } + for _, msg := range types { + topLevelMsgs[msg.GetName()] = msg + } + return topLevelMsgs, nil +} + +func getStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) { + options := m.GetOptions() + + if !proto.HasExtension(options, gitalypb.E_Storage) { + return false, nil + } + + ext, err := proto.GetExtension(options, gitalypb.E_Storage) + if err != nil { + return false, err + } + + storageMsg, ok := ext.(*bool) + if !ok { + return false, fmt.Errorf("unable to obtain bool from %#v", ext) + } + + if storageMsg == nil { + return false, nil + } + + return *storageMsg, nil +} + +func findStorageField(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto) ([]int, error) { + for _, f := range t.GetField() { + storage, err := getStorageExtension(f) + if err != nil { + return nil, err + } + if storage { + return []int{int(f.GetNumber())}, nil + } + + childMsg, err := findChildMsg(topLevelMsgs, t, f) + if err != nil { + return nil, err + } + + if childMsg != nil { + nestedStorageField, err := findStorageField(topLevelMsgs, childMsg) + if err != nil { + return nil, err + } + if nestedStorageField != nil { + return append([]int{int(f.GetNumber())}, nestedStorageField...), nil + } + } + } + return nil, nil +} + +func findChildMsg(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descriptor.DescriptorProto, f *descriptor.FieldDescriptorProto) (*descriptor.DescriptorProto, error) { + var childType *descriptor.DescriptorProto + const msgPrimitive = "TYPE_MESSAGE" + if primitive := f.GetType().String(); primitive != msgPrimitive { + return nil, nil + } + + msgName, err := lastName(f.GetTypeName()) + if err != nil { + return nil, err + } + + for _, nestedType := range t.GetNestedType() { + if msgName == nestedType.GetName() { + return nestedType, nil + } + } + + if childType = topLevelMsgs[msgName]; childType != nil { + return childType, nil + } + + return nil, fmt.Errorf("could not find message type %q", msgName) +} + +func lastName(inputType string) (string, error) { + tokens := strings.Split(inputType, ".") + + msgName := tokens[len(tokens)-1] + if msgName == "" { + return "", fmt.Errorf("unable to parse method input type: %s", inputType) + } + + return msgName, nil +} + // parses a string like "1.1" and returns a slice of ints func parseOID(rawFieldOID string) ([]int, error) { var fieldNos []int |