Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'streamio/stream_test.go')
-rw-r--r--streamio/stream_test.go88
1 files changed, 85 insertions, 3 deletions
diff --git a/streamio/stream_test.go b/streamio/stream_test.go
index 456d17e80..d263a68ab 100644
--- a/streamio/stream_test.go
+++ b/streamio/stream_test.go
@@ -25,7 +25,7 @@ func TestReceiveSources(t *testing.T) {
}
for _, tc := range testCases {
- data, err := ioutil.ReadAll(NewReader(receiverFromReader(tc.r)))
+ data, err := ioutil.ReadAll(&opaqueReader{NewReader(receiverFromReader(tc.r))})
require.NoError(t, err, tc.desc)
require.Equal(t, testData, string(data), tc.desc)
}
@@ -35,15 +35,38 @@ func TestReadSizes(t *testing.T) {
testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
for n := 1; n < 100; n *= 3 {
desc := fmt.Sprintf("reads of size %d", n)
- buffer := make([]byte, n)
result := &bytes.Buffer{}
reader := &opaqueReader{NewReader(receiverFromReader(strings.NewReader(testData)))}
- _, err := io.CopyBuffer(&opaqueWriter{result}, reader, buffer)
+ _, err := io.CopyBuffer(&opaqueWriter{result}, reader, make([]byte, n))
+
require.NoError(t, err, desc)
require.Equal(t, testData, result.String())
}
}
+func TestWriterTo(t *testing.T) {
+ testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
+ testCases := []struct {
+ desc string
+ r io.Reader
+ }{
+ {desc: "base", r: strings.NewReader(testData)},
+ {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))},
+ {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))},
+ {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))},
+ }
+
+ for _, tc := range testCases {
+ result := &bytes.Buffer{}
+ reader := NewReader(receiverFromReader(tc.r))
+ n, err := reader.(io.WriterTo).WriteTo(result)
+
+ require.NoError(t, err, tc.desc)
+ require.Equal(t, int64(len(testData)), n, tc.desc)
+ require.Equal(t, testData, result.String(), tc.desc)
+ }
+}
+
func receiverFromReader(r io.Reader) func() ([]byte, error) {
return func() ([]byte, error) {
data := make([]byte, 10)
@@ -61,3 +84,62 @@ type opaqueReader struct {
type opaqueWriter struct {
io.Writer
}
+
+func TestWriterChunking(t *testing.T) {
+ defer func(oldBufferSize int) {
+ writeBufferSize = oldBufferSize
+ }(writeBufferSize)
+ writeBufferSize = 5
+
+ testData := "Hello this is some test data"
+ ts := &testSender{}
+ w := NewWriter(ts.send)
+ _, err := io.CopyBuffer(&opaqueWriter{w}, strings.NewReader(testData), make([]byte, 10))
+
+ require.NoError(t, err)
+ require.Equal(t, testData, string(bytes.Join(ts.sends, nil)))
+ for _, send := range ts.sends {
+ require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize")
+ }
+}
+
+type testSender struct {
+ sends [][]byte
+}
+
+func (ts *testSender) send(p []byte) error {
+ buf := make([]byte, len(p))
+ copy(buf, p)
+ ts.sends = append(ts.sends, buf)
+ return nil
+}
+
+func TestReadFrom(t *testing.T) {
+ defer func(oldBufferSize int) {
+ writeBufferSize = oldBufferSize
+ }(writeBufferSize)
+ writeBufferSize = 5
+
+ testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
+ testCases := []struct {
+ desc string
+ r io.Reader
+ }{
+ {desc: "base", r: strings.NewReader(testData)},
+ {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))},
+ {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))},
+ {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))},
+ }
+
+ for _, tc := range testCases {
+ ts := &testSender{}
+ n, err := NewWriter(ts.send).(io.ReaderFrom).ReadFrom(tc.r)
+
+ require.NoError(t, err, tc.desc)
+ require.Equal(t, int64(len(testData)), n, tc.desc)
+ require.Equal(t, testData, string(bytes.Join(ts.sends, nil)), tc.desc)
+ for _, send := range ts.sends {
+ require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize")
+ }
+ }
+}