diff options
Diffstat (limited to 'internal/service/commit/tree_entry.go')
-rw-r--r-- | internal/service/commit/tree_entry.go | 120 |
1 files changed, 104 insertions, 16 deletions
diff --git a/internal/service/commit/tree_entry.go b/internal/service/commit/tree_entry.go index 8498b8436..074cf92fa 100644 --- a/internal/service/commit/tree_entry.go +++ b/internal/service/commit/tree_entry.go @@ -1,6 +1,7 @@ package commit import ( + "errors" "fmt" "io" "strings" @@ -14,14 +15,59 @@ import ( "google.golang.org/grpc/status" ) -func sendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Batch, revision, path string, limit int64) error { +func findAndSendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Batch, revision, path string, limit int64) error { treeEntry, err := NewTreeEntryFinder(c).FindByRevisionAndPath(revision, path) if err != nil { return err } + response, blobReader, dataLength, err := getTreeEntry(c, treeEntry, limit) + if err != nil { + return err + } + + if blobReader == nil { + return helper.DecorateError(codes.Unavailable, stream.Send(response)) + } + + sw := findTreeEntryStreamWriter(stream, response, blobReader) + + if _, err = io.CopyN(sw, blobReader, dataLength); err != nil { + return err + } + + return nil +} + +func findAndSendTreeEntries(stream gitalypb.CommitService_FindTreeEntriesServer, c *catfile.Batch, revision string, paths [][]byte, limit int64) error { + for _, path := range paths { + treeEntry, err := NewTreeEntryFinder(c).FindByRevisionAndPath(revision, string(path)) + if err != nil { + return err + } + + response, blobReader, dataLength, err := getTreeEntry(c, treeEntry, limit) + if err != nil { + return err + } + + if blobReader == nil { + helper.DecorateError(codes.Unavailable, stream.Send(response)) + continue + } + + sw := findTreeEntriesStreamWriter(stream, response, blobReader) + + if _, err = io.CopyN(sw, blobReader, dataLength); err != nil { + return err + } + } + return nil +} + +func getTreeEntry(c *catfile.Batch, treeEntry *gitalypb.TreeEntry, limit int64) (*gitalypb.TreeEntryResponse, io.Reader, int64, error) { if treeEntry == nil || len(treeEntry.Oid) == 0 { - return helper.DecorateError(codes.Unavailable, stream.Send(&gitalypb.TreeEntryResponse{})) + return nil, nil, 0, helper.DecorateError(codes.Unavailable, errors.New("tree entry not found")) } if treeEntry.Type == gitalypb.TreeEntry_COMMIT { @@ -30,17 +76,14 @@ func sendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Bat Mode: treeEntry.Mode, Oid: treeEntry.Oid, } - if err := stream.Send(response); err != nil { - return status.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) - } - return nil + return response, nil, 0, nil } if treeEntry.Type == gitalypb.TreeEntry_TREE { treeInfo, err := c.Info(treeEntry.Oid) if err != nil { - return err + return nil, nil, 0, err } response := &gitalypb.TreeEntryResponse{ @@ -49,16 +92,16 @@ func sendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Bat Size: treeInfo.Size, Mode: treeEntry.Mode, } - return helper.DecorateError(codes.Unavailable, stream.Send(response)) + return response, nil, 0, nil } objectInfo, err := c.Info(treeEntry.Oid) if err != nil { - return status.Errorf(codes.Internal, "TreeEntry: %v", err) + return nil, nil, 0, status.Errorf(codes.Internal, "TreeEntry: %v", err) } if strings.ToLower(treeEntry.Type.String()) != objectInfo.Type { - return status.Errorf( + return nil, nil, 0, status.Errorf( codes.Internal, "TreeEntry: mismatched object type: tree-oid=%s object-oid=%s entry-type=%s object-type=%s", treeEntry.Oid, objectInfo.Oid, treeEntry.Type.String(), objectInfo.Type, @@ -77,15 +120,19 @@ func sendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Bat Mode: treeEntry.Mode, } if dataLength == 0 { - return helper.DecorateError(codes.Unavailable, stream.Send(response)) + return response, nil, 0, nil } blobReader, err := c.Blob(objectInfo.Oid) if err != nil { - return err + return nil, nil, 0, err } - sw := streamio.NewWriter(func(p []byte) error { + return response, blobReader, dataLength, nil +} + +func findTreeEntryStreamWriter(stream gitalypb.CommitService_FindTreeEntriesServer, response *gitalypb.TreeEntryResponse, blobReader io.Reader) io.Writer { + return streamio.NewWriter(func(p []byte) error { response.Data = p if err := stream.Send(response); err != nil { @@ -97,9 +144,21 @@ func sendTreeEntry(stream gitalypb.CommitService_TreeEntryServer, c *catfile.Bat return nil }) +} + +func findTreeEntriesStreamWriter(stream gitalypb.CommitService_FindTreeEntriesServer, response *gitalypb.TreeEntryResponse, blobReader io.Reader) io.Writer { + return streamio.NewWriter(func(p []byte) error { + response.Data = p + + if err := stream.Send(response); err != nil { + return status.Errorf(codes.Unavailable, "TreeEntry: send: %v", err) + } - _, err = io.CopyN(sw, blobReader, dataLength) - return err + // Use a new response so we don't send other fields (Size, ...) over and over + response = &gitalypb.TreeEntryResponse{} + + return nil + }) } func (s *server) TreeEntry(in *gitalypb.TreeEntryRequest, stream gitalypb.CommitService_TreeEntryServer) error { @@ -119,7 +178,7 @@ func (s *server) TreeEntry(in *gitalypb.TreeEntryRequest, stream gitalypb.Commit return err } - return sendTreeEntry(stream, c, string(in.GetRevision()), requestPath, in.GetLimit()) + return findAndSendTreeEntry(stream, c, string(in.GetRevision()), requestPath, in.GetLimit()) } func validateRequest(in *gitalypb.TreeEntryRequest) error { @@ -133,3 +192,32 @@ func validateRequest(in *gitalypb.TreeEntryRequest) error { return nil } + +func (s *server) FindTreeEntries(in *gitalypb.FindTreeEntriesRequest, stream gitalypb.CommitService_FindTreeEntriesServer) error { + if err := validateFindTreeEntriesRequest(in); err != nil { + return helper.ErrInvalidArgument(err) + } + + c, err := catfile.New(stream.Context(), in.GetRepository()) + if err != nil { + return helper.ErrInternal(err) + } + + if err := findAndSendTreeEntries(stream, c, string(in.GetRevision()), in.GetPaths(), in.GetLimit()); err != nil { + return helper.ErrInternal(err) + } + + return nil +} + +func validateFindTreeEntriesRequest(in *gitalypb.FindTreeEntriesRequest) error { + if err := git.ValidateRevision(in.GetRevision()); err != nil { + return err + } + + if len(in.GetPaths()) == 0 { + return fmt.Errorf("empty Path") + } + + return nil +} |