diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/reflection/serverreflection.go')
-rw-r--r-- | vendor/google.golang.org/grpc/reflection/serverreflection.go | 223 |
1 files changed, 84 insertions, 139 deletions
diff --git a/vendor/google.golang.org/grpc/reflection/serverreflection.go b/vendor/google.golang.org/grpc/reflection/serverreflection.go index dd22a2da7..1bfbf3e78 100644 --- a/vendor/google.golang.org/grpc/reflection/serverreflection.go +++ b/vendor/google.golang.org/grpc/reflection/serverreflection.go @@ -45,8 +45,7 @@ import ( "io" "io/ioutil" "reflect" - "sort" - "sync" + "strings" "github.com/golang/protobuf/proto" dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" @@ -58,10 +57,8 @@ import ( type serverReflectionServer struct { s *grpc.Server - - initSymbols sync.Once - serviceNames []string - symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files + // TODO add more cache if necessary + serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo() } // Register registers the server reflection service on the given gRPC server. @@ -79,112 +76,6 @@ type protoMessage interface { Descriptor() ([]byte, []int) } -func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { - s.initSymbols.Do(func() { - serviceInfo := s.s.GetServiceInfo() - - s.symbols = map[string]*dpb.FileDescriptorProto{} - s.serviceNames = make([]string, 0, len(serviceInfo)) - processed := map[string]struct{}{} - for svc, info := range serviceInfo { - s.serviceNames = append(s.serviceNames, svc) - fdenc, ok := parseMetadata(info.Metadata) - if !ok { - continue - } - fd, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fd, processed) - } - sort.Strings(s.serviceNames) - }) - - return s.serviceNames, s.symbols -} - -func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { - filename := fd.GetName() - if _, ok := processed[filename]; ok { - return - } - processed[filename] = struct{}{} - - prefix := fd.GetPackage() - - for _, msg := range fd.MessageType { - s.processMessage(fd, prefix, msg) - } - for _, en := range fd.EnumType { - s.processEnum(fd, prefix, en) - } - for _, ext := range fd.Extension { - s.processField(fd, prefix, ext) - } - for _, svc := range fd.Service { - svcName := fqn(prefix, svc.GetName()) - s.symbols[svcName] = fd - for _, meth := range svc.Method { - name := fqn(svcName, meth.GetName()) - s.symbols[name] = fd - } - } - - for _, dep := range fd.Dependency { - fdenc := proto.FileDescriptor(dep) - fdDep, err := decodeFileDesc(fdenc) - if err != nil { - continue - } - s.processFile(fdDep, processed) - } -} - -func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { - msgName := fqn(prefix, msg.GetName()) - s.symbols[msgName] = fd - - for _, nested := range msg.NestedType { - s.processMessage(fd, msgName, nested) - } - for _, en := range msg.EnumType { - s.processEnum(fd, msgName, en) - } - for _, ext := range msg.Extension { - s.processField(fd, msgName, ext) - } - for _, fld := range msg.Field { - s.processField(fd, msgName, fld) - } - for _, oneof := range msg.OneofDecl { - oneofName := fqn(msgName, oneof.GetName()) - s.symbols[oneofName] = fd - } -} - -func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { - enName := fqn(prefix, en.GetName()) - s.symbols[enName] = fd - - for _, val := range en.Value { - valName := fqn(enName, val.GetName()) - s.symbols[valName] = fd - } -} - -func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { - fldName := fqn(prefix, fld.GetName()) - s.symbols[fldName] = fd -} - -func fqn(prefix, name string) string { - if prefix == "" { - return name - } - return prefix + "." + name -} - // fileDescForType gets the file descriptor for the given type. // The given type should be a proto message. func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { @@ -194,12 +85,12 @@ func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDesc } enc, _ := m.Descriptor() - return decodeFileDesc(enc) + return s.decodeFileDesc(enc) } // decodeFileDesc does decompression and unmarshalling on the given // file descriptor byte slice. -func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { +func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { raw, err := decompress(enc) if err != nil { return nil, fmt.Errorf("failed to decompress enc: %v", err) @@ -225,7 +116,7 @@ func decompress(b []byte) ([]byte, error) { return out, nil } -func typeForName(name string) (reflect.Type, error) { +func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) { pt := proto.MessageType(name) if pt == nil { return nil, fmt.Errorf("unknown type: %q", name) @@ -235,7 +126,7 @@ func typeForName(name string) (reflect.Type, error) { return st, nil } -func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { +func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) if !ok { return nil, fmt.Errorf("failed to create message from type: %v", st) @@ -253,7 +144,7 @@ func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescripto return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) } - return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) + return s.decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) } func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { @@ -277,13 +168,53 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte if enc == nil { return nil, fmt.Errorf("unknown file: %v", name) } - fd, err := decodeFileDesc(enc) + fd, err := s.decodeFileDesc(enc) if err != nil { return nil, err } return proto.Marshal(fd) } +// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo. +// name should be a service name or a method name. +func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) { + if s.serviceInfo == nil { + s.serviceInfo = s.s.GetServiceInfo() + } + + // Check if it's a service name. + if info, ok := s.serviceInfo[name]; ok { + return info.Metadata, nil + } + + // Check if it's a method name. + pos := strings.LastIndex(name, ".") + // Not a valid method name. + if pos == -1 { + return nil, fmt.Errorf("unknown symbol: %v", name) + } + + info, ok := s.serviceInfo[name[:pos]] + // Substring before last "." is not a service name. + if !ok { + return nil, fmt.Errorf("unknown symbol: %v", name) + } + + // Search the method name in info.Methods. + var found bool + for _, m := range info.Methods { + if m.Name == name[pos+1:] { + found = true + break + } + } + if found { + return info.Metadata, nil + } + + return nil, fmt.Errorf("unknown symbol: %v", name) +} + // parseMetadata finds the file descriptor bytes specified meta. // For SupportPackageIsVersion4, m is the name of the proto file, we // call proto.FileDescriptor to get the byte slice. @@ -306,21 +237,33 @@ func parseMetadata(meta interface{}) ([]byte, bool) { // does marshalling on it and returns the marshalled result. // The given symbol can be a type, a service or a method. func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { - _, symbols := s.getSymbols() - fd := symbols[name] - if fd == nil { - // Check if it's a type name that was not present in the - // transitive dependencies of the registered services. - if st, err := typeForName(name); err == nil { - fd, err = s.fileDescForType(st) - if err != nil { - return nil, err - } + var ( + fd *dpb.FileDescriptorProto + ) + // Check if it's a type name. + if st, err := s.typeForName(name); err == nil { + fd, err = s.fileDescForType(st) + if err != nil { + return nil, err } - } + } else { // Check if it's a service name or a method name. + meta, err := s.serviceMetadataForSymbol(name) - if fd == nil { - return nil, fmt.Errorf("unknown symbol: %v", name) + // Metadata not found. + if err != nil { + return nil, err + } + + // Metadata not valid. + enc, ok := parseMetadata(meta) + if !ok { + return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name) + } + + fd, err = s.decodeFileDesc(enc) + if err != nil { + return nil, err + } } return proto.Marshal(fd) @@ -329,11 +272,11 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( // fileDescEncodingContainingExtension finds the file descriptor containing given extension, // does marshalling on it and returns the marshalled result. func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { - st, err := typeForName(typeName) + st, err := s.typeForName(typeName) if err != nil { return nil, err } - fd, err := fileDescContainingExtension(st, extNum) + fd, err := s.fileDescContainingExtension(st, extNum) if err != nil { return nil, err } @@ -342,7 +285,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName st // allExtensionNumbersForTypeName returns all extension numbers for the given type. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { - st, err := typeForName(name) + st, err := s.typeForName(name) if err != nil { return nil, err } @@ -431,12 +374,14 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_ListServices: - svcNames, _ := s.getSymbols() - serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) - for i, n := range svcNames { - serviceResponses[i] = &rpb.ServiceResponse{ + if s.serviceInfo == nil { + s.serviceInfo = s.s.GetServiceInfo() + } + serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo)) + for n := range s.serviceInfo { + serviceResponses = append(serviceResponses, &rpb.ServiceResponse{ Name: n, - } + }) } out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ ListServicesResponse: &rpb.ListServiceResponse{ |