diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2023-08-11 14:11:22 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2023-08-14 14:14:35 +0300 |
commit | 7a9a9bdfb9321c5cb60804e63e08e2336d22ffd6 (patch) | |
tree | 3991ff9c2eb740a3809d0cb7e634bf8feb5feac3 | |
parent | da843b37d7177a7b1d11c9e4eb1396fbaf84ff1d (diff) |
testhelper: Introduce function to easily receive from streaming RPCs
The logic required to read a streaming RPC until we hit its end is
comparatively complex and not exactly pretty. Many of our tests get it
wrong, which leads to cases where we don't properly check for error
conditions.
Introduce two new test helpers `Receive()` and `ReceiveAndFold()` that
allow the caller to receive all results as well as any potential error
code and optionally fold the results in order to convert them. Convert
our tests to use these helpers.
12 files changed, 83 insertions, 195 deletions
diff --git a/internal/gitaly/service/commit/commit_signatures_test.go b/internal/gitaly/service/commit/commit_signatures_test.go index b9d32f8a1..1a5e6671a 100644 --- a/internal/gitaly/service/commit/commit_signatures_test.go +++ b/internal/gitaly/service/commit/commit_signatures_test.go @@ -3,9 +3,7 @@ package commit import ( "bytes" "context" - "errors" "fmt" - "io" "strings" "testing" @@ -352,31 +350,24 @@ aC1lZDI1NTE5AAAAQKgC1TFLVZOqvVs2AqCp2lhkRAUtZsDa89RgHOOsYAC3T1kB stream, err := client.GetCommitSignatures(ctx, setup.request) require.NoError(t, err) - var actualResponses []*gitalypb.GetCommitSignaturesResponse - for { - var response *gitalypb.GetCommitSignaturesResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - + actualResponses, err := testhelper.ReceiveAndFold(stream.Recv, func( + result []*gitalypb.GetCommitSignaturesResponse, + response *gitalypb.GetCommitSignaturesResponse, + ) []*gitalypb.GetCommitSignaturesResponse { // We don't need to do any fiddling when we have a commit ID, which would signify // another returned commit. if response.CommitId != "" { - actualResponses = append(actualResponses, response) - continue + return append(result, response) } // But when we don't have a commit ID we append both the signature and signed text so // that it becomes easier to test for these values, as they might otherwise be split. - currentResponse := actualResponses[len(actualResponses)-1] + currentResponse := result[len(result)-1] currentResponse.Signature = append(currentResponse.Signature, response.Signature...) currentResponse.SignedText = append(currentResponse.SignedText, response.SignedText...) - } + + return result + }) testhelper.RequireGrpcError(t, setup.expectedErr, err) require.Len(t, actualResponses, len(setup.expectedResponses)) diff --git a/internal/gitaly/service/commit/filter_shas_with_signatures_test.go b/internal/gitaly/service/commit/filter_shas_with_signatures_test.go index 8d6cf9990..f198fd344 100644 --- a/internal/gitaly/service/commit/filter_shas_with_signatures_test.go +++ b/internal/gitaly/service/commit/filter_shas_with_signatures_test.go @@ -2,9 +2,7 @@ package commit import ( "bytes" - "errors" "fmt" - "io" "strings" "testing" @@ -150,20 +148,7 @@ gpgsig -----BEGIN PGP SIGNATURE----- require.NoError(t, stream.Send(tc.request)) require.NoError(t, stream.CloseSend()) - var responses []*gitalypb.FilterShasWithSignaturesResponse - for { - var response *gitalypb.FilterShasWithSignaturesResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - - responses = append(responses, response) - } + responses, err := testhelper.Receive(stream.Recv) testhelper.RequireGrpcError(t, tc.expectedErr, err) testhelper.ProtoEqual(t, tc.expectedResponses, responses) }) diff --git a/internal/gitaly/service/commit/find_all_commits_test.go b/internal/gitaly/service/commit/find_all_commits_test.go index d302b684f..350c2626e 100644 --- a/internal/gitaly/service/commit/find_all_commits_test.go +++ b/internal/gitaly/service/commit/find_all_commits_test.go @@ -1,8 +1,6 @@ package commit import ( - "errors" - "io" "testing" "time" @@ -161,21 +159,9 @@ func TestFindAllCommits(t *testing.T) { stream, err := client.FindAllCommits(ctx, tc.request) require.NoError(t, err) - var actualCommits []*gitalypb.GitCommit - for { - var response *gitalypb.FindAllCommitsResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - - actualCommits = append(actualCommits, response.GetCommits()...) - } - + actualCommits, err := testhelper.ReceiveAndFold(stream.Recv, func(result []*gitalypb.GitCommit, response *gitalypb.FindAllCommitsResponse) []*gitalypb.GitCommit { + return append(result, response.GetCommits()...) + }) testhelper.RequireGrpcError(t, tc.expectedErr, err) testhelper.ProtoEqual(t, tc.expectedCommits, actualCommits) }) diff --git a/internal/gitaly/service/commit/list_all_commits_test.go b/internal/gitaly/service/commit/list_all_commits_test.go index 1c5139f08..11f4b26ac 100644 --- a/internal/gitaly/service/commit/list_all_commits_test.go +++ b/internal/gitaly/service/commit/list_all_commits_test.go @@ -1,7 +1,6 @@ package commit import ( - "errors" "io" "os" "path/filepath" @@ -164,20 +163,9 @@ func TestListAllCommits(t *testing.T) { stream, err := client.ListAllCommits(ctx, setup.request) require.NoError(t, err) - var actualCommits []*gitalypb.GitCommit - for { - var response *gitalypb.ListAllCommitsResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - - actualCommits = append(actualCommits, response.Commits...) - } + actualCommits, err := testhelper.ReceiveAndFold(stream.Recv, func(result []*gitalypb.GitCommit, response *gitalypb.ListAllCommitsResponse) []*gitalypb.GitCommit { + return append(result, response.Commits...) + }) testhelper.RequireGrpcError(t, setup.expectedErr, err) if setup.skipCommitValidation { diff --git a/internal/gitaly/service/commit/list_commits_test.go b/internal/gitaly/service/commit/list_commits_test.go index 284be5fc2..b69a69958 100644 --- a/internal/gitaly/service/commit/list_commits_test.go +++ b/internal/gitaly/service/commit/list_commits_test.go @@ -1,8 +1,6 @@ package commit import ( - "errors" - "io" "testing" "time" @@ -263,20 +261,10 @@ func TestListCommits(t *testing.T) { stream, err := client.ListCommits(ctx, tc.request) require.NoError(t, err) - var commits []*gitalypb.GitCommit - for { - response, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - - require.Equal(t, tc.expectedErr, err) - } - - commits = append(commits, response.Commits...) - } - + commits, err := testhelper.ReceiveAndFold(stream.Recv, func(result []*gitalypb.GitCommit, response *gitalypb.ListCommitsResponse) []*gitalypb.GitCommit { + return append(result, response.GetCommits()...) + }) + require.NoError(t, err) testhelper.ProtoEqual(t, tc.expectedCommits, commits) }) } diff --git a/internal/gitaly/service/commit/list_files_test.go b/internal/gitaly/service/commit/list_files_test.go index 95c066e31..4204e1b2a 100644 --- a/internal/gitaly/service/commit/list_files_test.go +++ b/internal/gitaly/service/commit/list_files_test.go @@ -1,8 +1,6 @@ package commit import ( - "errors" - "io" "testing" "github.com/stretchr/testify/require" @@ -174,22 +172,11 @@ func TestListFiles(t *testing.T) { stream, err := client.ListFiles(ctx, tc.request) require.NoError(t, err) - var files [][]byte - for { - resp, err := stream.Recv() - if err != nil { - if !errors.Is(err, io.EOF) { - testhelper.RequireGrpcError(t, tc.expectedErr, err) - } - - break - } - require.NoError(t, err) - - files = append(files, resp.GetPaths()...) - } - - require.ElementsMatch(t, files, tc.expectedPaths) + paths, err := testhelper.ReceiveAndFold(stream.Recv, func(result [][]byte, response *gitalypb.ListFilesResponse) [][]byte { + return append(result, response.GetPaths()...) + }) + testhelper.RequireGrpcError(t, tc.expectedErr, err) + require.ElementsMatch(t, paths, tc.expectedPaths) }) } } diff --git a/internal/gitaly/service/commit/list_last_commits_for_tree_test.go b/internal/gitaly/service/commit/list_last_commits_for_tree_test.go index 6ba36cc60..e1f4544f2 100644 --- a/internal/gitaly/service/commit/list_last_commits_for_tree_test.go +++ b/internal/gitaly/service/commit/list_last_commits_for_tree_test.go @@ -1,8 +1,6 @@ package commit import ( - "errors" - "io" "testing" "unicode/utf8" @@ -339,21 +337,12 @@ func TestListLastCommitsForTree(t *testing.T) { stream, err := client.ListLastCommitsForTree(ctx, setup.request) require.NoError(t, err) - var commits []*gitalypb.ListLastCommitsForTreeResponse_CommitForTree - for { - var response *gitalypb.ListLastCommitsForTreeResponse - - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - break - } - - commits = append(commits, response.Commits...) - } - + commits, err := testhelper.ReceiveAndFold(stream.Recv, func( + result []*gitalypb.ListLastCommitsForTreeResponse_CommitForTree, + response *gitalypb.ListLastCommitsForTreeResponse, + ) []*gitalypb.ListLastCommitsForTreeResponse_CommitForTree { + return append(result, response.Commits...) + }) testhelper.RequireGrpcError(t, setup.expectedErr, err) testhelper.ProtoEqual(t, setup.expectedCommits, commits) }) diff --git a/internal/gitaly/service/commit/tree_entries_test.go b/internal/gitaly/service/commit/tree_entries_test.go index 03a2c989e..82c88c283 100644 --- a/internal/gitaly/service/commit/tree_entries_test.go +++ b/internal/gitaly/service/commit/tree_entries_test.go @@ -2,7 +2,6 @@ package commit import ( "context" - "errors" "fmt" "io" "strconv" @@ -1485,18 +1484,10 @@ func BenchmarkGetTreeEntries(b *testing.B) { stream, err := client.GetTreeEntries(ctx, tc.request) require.NoError(b, err) - entriesReceived := 0 - for { - response, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - require.NoError(b, err) - } - - entriesReceived += len(response.Entries) - } + entriesReceived, err := testhelper.ReceiveAndFold(stream.Recv, func(result int, response *gitalypb.GetTreeEntriesResponse) int { + return result + len(response.Entries) + }) + require.NoError(b, err) require.Equal(b, tc.expectedEntries, entriesReceived) } }) diff --git a/internal/gitaly/service/commit/tree_entry_test.go b/internal/gitaly/service/commit/tree_entry_test.go index 357a1e0c1..9823fa1c3 100644 --- a/internal/gitaly/service/commit/tree_entry_test.go +++ b/internal/gitaly/service/commit/tree_entry_test.go @@ -1,8 +1,6 @@ package commit import ( - "errors" - "io" "testing" "github.com/stretchr/testify/require" @@ -317,22 +315,7 @@ func TestTreeEntry(t *testing.T) { stream, err := client.TreeEntry(ctx, tc.request) require.NoError(t, err) - var responses []*gitalypb.TreeEntryResponse - for { - var response *gitalypb.TreeEntryResponse - - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - - responses = append(responses, response) - } - + responses, err := testhelper.Receive(stream.Recv) testhelper.RequireGrpcError(t, tc.expectedErr, err) testhelper.ProtoEqual(t, tc.expectedResponses, responses) }) diff --git a/internal/gitaly/service/ref/refnames_containing_test.go b/internal/gitaly/service/ref/refnames_containing_test.go index d84a825ad..b4e4cc053 100644 --- a/internal/gitaly/service/ref/refnames_containing_test.go +++ b/internal/gitaly/service/ref/refnames_containing_test.go @@ -1,9 +1,7 @@ package ref import ( - "errors" "fmt" - "io" "testing" "github.com/stretchr/testify/require" @@ -107,23 +105,12 @@ func TestListTagNamesContainingCommit(t *testing.T) { stream, err := client.ListTagNamesContainingCommit(ctx, tc.request) require.NoError(t, err) - var tagNames []string - for { - var response *gitalypb.ListTagNamesContainingCommitResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - + tagNames, err := testhelper.ReceiveAndFold(stream.Recv, func(result []string, response *gitalypb.ListTagNamesContainingCommitResponse) []string { for _, tagName := range response.GetTagNames() { - tagNames = append(tagNames, string(tagName)) + result = append(result, string(tagName)) } - } - + return result + }) testhelper.RequireGrpcError(t, tc.expectedErr, err) require.ElementsMatch(t, tc.expectedTags, tagNames) }) @@ -225,23 +212,12 @@ func TestListBranchNamesContainingCommit(t *testing.T) { stream, err := client.ListBranchNamesContainingCommit(ctx, tc.request) require.NoError(t, err) - var branchNames []string - for { - var response *gitalypb.ListBranchNamesContainingCommitResponse - response, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - - break - } - + branchNames, err := testhelper.ReceiveAndFold(stream.Recv, func(result []string, response *gitalypb.ListBranchNamesContainingCommitResponse) []string { for _, branchName := range response.GetBranchNames() { - branchNames = append(branchNames, string(branchName)) + result = append(result, string(branchName)) } - } - + return result + }) testhelper.RequireGrpcError(t, tc.expectedErr, err) require.ElementsMatch(t, tc.expectedBranches, branchNames) }) diff --git a/internal/gitaly/service/ref/tag_signatures_test.go b/internal/gitaly/service/ref/tag_signatures_test.go index a4a314586..6605165dc 100644 --- a/internal/gitaly/service/ref/tag_signatures_test.go +++ b/internal/gitaly/service/ref/tag_signatures_test.go @@ -3,9 +3,7 @@ package ref import ( - "errors" "fmt" - "io" "strings" "testing" @@ -174,19 +172,13 @@ func TestGetTagSignatures(t *testing.T) { }) require.NoError(t, err) - var signatures []*gitalypb.GetTagSignaturesResponse_TagSignature - for { - resp, err := stream.Recv() - if err != nil { - if !errors.Is(err, io.EOF) { - testhelper.RequireGrpcError(t, tc.expectedErr, err) - } - break - } - - signatures = append(signatures, resp.Signatures...) - } - + signatures, err := testhelper.ReceiveAndFold(stream.Recv, func( + result []*gitalypb.GetTagSignaturesResponse_TagSignature, + response *gitalypb.GetTagSignaturesResponse, + ) []*gitalypb.GetTagSignaturesResponse_TagSignature { + return append(result, response.GetSignatures()...) + }) + testhelper.RequireGrpcError(t, tc.expectedErr, err) testhelper.ProtoEqual(t, tc.expectedSignatures, signatures) }) } diff --git a/internal/testhelper/grpc.go b/internal/testhelper/grpc.go index f60110566..89a462486 100644 --- a/internal/testhelper/grpc.go +++ b/internal/testhelper/grpc.go @@ -2,7 +2,9 @@ package testhelper import ( "context" + "errors" "fmt" + "io" "testing" "github.com/google/go-cmp/cmp" @@ -163,3 +165,33 @@ func WithInterceptedMetadataItems(err structerr.Error, items ...structerr.Metada func ToInterceptedMetadata(err structerr.Error) structerr.Error { return WithInterceptedMetadataItems(err, err.MetadataItems()...) } + +// Receive receives all responses from the receiver until an error is encountered or `io.EOF` is received. +func Receive[Response any](receiver func() (Response, error)) ([]Response, error) { + var responses []Response + for { + response, err := receiver() + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + } + + return responses, err + } + + responses = append(responses, response) + } +} + +// ReceiveAndFold receives all responses from the receiver and then folds the results with the given folder. The folder +// will be called even in the case where the receiver returns an error. +func ReceiveAndFold[Response, Result any](receiver func() (Response, error), folder func(initial Result, response Response) Result) (Result, error) { + responses, err := Receive(receiver) + + var result Result + for _, response := range responses { + result = folder(result, response) + } + + return result, err +} |