diff options
Diffstat (limited to 'proto/go/internal/linter/method.go')
-rw-r--r-- | proto/go/internal/linter/method.go | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/proto/go/internal/linter/method.go b/proto/go/internal/linter/method.go index a93d4cd85..0c695e0cf 100644 --- a/proto/go/internal/linter/method.go +++ b/proto/go/internal/linter/method.go @@ -1,16 +1,18 @@ package linter import ( + "errors" "fmt" "strings" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" + plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb" "gitlab.com/gitlab-org/gitaly/proto/go/internal" ) type methodLinter struct { + req *plugin.CodeGeneratorRequest fileDesc *descriptor.FileDescriptorProto methodDesc *descriptor.MethodDescriptorProto opMsg *gitalypb.OperationMsg @@ -138,7 +140,8 @@ func (ml methodLinter) ensureValidTargetRepository(expected int) error { func (ml methodLinter) getTopLevelMsgs() (map[string]*descriptor.DescriptorProto, error) { topLevelMsgs := map[string]*descriptor.DescriptorProto{} - types, err := getFileTypes(ml.fileDesc.GetName()) + + types, err := getFileTypes(ml.fileDesc.GetName(), ml.req) if err != nil { return nil, err } @@ -220,16 +223,23 @@ func findChildMsg(topLevelMsgs map[string]*descriptor.DescriptorProto, t *descri return nil, fmt.Errorf("could not find message type %q", msgName) } -func getFileTypes(filename string) ([]*descriptor.DescriptorProto, error) { - sharedFD, err := internal.ExtractFile(proto.FileDescriptor(filename)) - if err != nil { - return nil, err +func getFileTypes(filename string, req *plugin.CodeGeneratorRequest) ([]*descriptor.DescriptorProto, error) { + var types []*descriptor.DescriptorProto + var protoFile *descriptor.FileDescriptorProto + for _, pf := range req.ProtoFile { + if pf.Name != nil && *pf.Name == filename { + types = pf.GetMessageType() + protoFile = pf + break + } } - types := sharedFD.GetMessageType() + if protoFile == nil { + return nil, errors.New("proto file could not be found: " + filename) + } - for _, dep := range sharedFD.Dependency { - depTypes, err := getFileTypes(dep) + for _, dep := range protoFile.Dependency { + depTypes, err := getFileTypes(dep, req) if err != nil { return nil, err } |