From 7adbb3678dbfc72dd712405f1b7e3436d29b0052 Mon Sep 17 00:00:00 2001 From: Jacob Vosmaer Date: Mon, 10 Jul 2017 17:31:38 +0200 Subject: Streamio optimizations --- streamio/stream.go | 81 +++++++++++++++++++++++++++++++++++++++++++-- streamio/stream_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 163 insertions(+), 6 deletions(-) (limited to 'streamio') diff --git a/streamio/stream.go b/streamio/stream.go index dafade8fb..c1c71db66 100644 --- a/streamio/stream.go +++ b/streamio/stream.go @@ -29,16 +29,91 @@ func (rr *receiveReader) Read(p []byte) (int, error) { return n, nil } -// NewWriter turns sender into an io.Writer. The number of 'bytes -// written' reported back is always len(p). +// WriteTo implements io.WriterTo. +func (rr *receiveReader) WriteTo(w io.Writer) (int64, error) { + var written int64 + + // Deal with left-over state in rr.data and rr.err, if any + if len(rr.data) > 0 { + n, err := w.Write(rr.data) + written += int64(n) + if err != nil { + return written, err + } + } + if rr.err != nil { + return written, rr.err + } + + // Consume the response stream + var errRead, errWrite error + var n int + var buf []byte + for errWrite == nil && errRead != io.EOF { + buf, errRead = rr.receiver() + if errRead != nil && errRead != io.EOF { + return written, errRead + } + + if len(buf) > 0 { + n, errWrite = w.Write(buf) + written += int64(n) + } + } + + return written, errWrite +} + +// NewWriter turns sender into an io.Writer. func NewWriter(sender func(p []byte) error) io.Writer { return &sendWriter{sender: sender} } +var writeBufferSize = 128 * 1024 + type sendWriter struct { sender func([]byte) error } func (sw *sendWriter) Write(p []byte) (int, error) { - return len(p), sw.sender(p) + var sent int + + for len(p) > 0 { + chunkSize := len(p) + if chunkSize > writeBufferSize { + chunkSize = writeBufferSize + } + + if err := sw.sender(p[:chunkSize]); err != nil { + return sent, err + } + + sent += chunkSize + p = p[chunkSize:] + } + + return sent, nil +} + +// ReadFrom implements io.ReaderFrom. +func (sw *sendWriter) ReadFrom(r io.Reader) (int64, error) { + var nRead int64 + buf := make([]byte, writeBufferSize) + + var errRead, errSend error + for errSend == nil && errRead != io.EOF { + var n int + + n, errRead = r.Read(buf) + nRead += int64(n) + if errRead != nil && errRead != io.EOF { + return nRead, errRead + } + + if n > 0 { + errSend = sw.sender(buf[:n]) + } + } + + return nRead, errSend } diff --git a/streamio/stream_test.go b/streamio/stream_test.go index 456d17e80..d263a68ab 100644 --- a/streamio/stream_test.go +++ b/streamio/stream_test.go @@ -25,7 +25,7 @@ func TestReceiveSources(t *testing.T) { } for _, tc := range testCases { - data, err := ioutil.ReadAll(NewReader(receiverFromReader(tc.r))) + data, err := ioutil.ReadAll(&opaqueReader{NewReader(receiverFromReader(tc.r))}) require.NoError(t, err, tc.desc) require.Equal(t, testData, string(data), tc.desc) } @@ -35,15 +35,38 @@ func TestReadSizes(t *testing.T) { testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla." for n := 1; n < 100; n *= 3 { desc := fmt.Sprintf("reads of size %d", n) - buffer := make([]byte, n) result := &bytes.Buffer{} reader := &opaqueReader{NewReader(receiverFromReader(strings.NewReader(testData)))} - _, err := io.CopyBuffer(&opaqueWriter{result}, reader, buffer) + _, err := io.CopyBuffer(&opaqueWriter{result}, reader, make([]byte, n)) + require.NoError(t, err, desc) require.Equal(t, testData, result.String()) } } +func TestWriterTo(t *testing.T) { + testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla." + testCases := []struct { + desc string + r io.Reader + }{ + {desc: "base", r: strings.NewReader(testData)}, + {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))}, + {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))}, + {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))}, + } + + for _, tc := range testCases { + result := &bytes.Buffer{} + reader := NewReader(receiverFromReader(tc.r)) + n, err := reader.(io.WriterTo).WriteTo(result) + + require.NoError(t, err, tc.desc) + require.Equal(t, int64(len(testData)), n, tc.desc) + require.Equal(t, testData, result.String(), tc.desc) + } +} + func receiverFromReader(r io.Reader) func() ([]byte, error) { return func() ([]byte, error) { data := make([]byte, 10) @@ -61,3 +84,62 @@ type opaqueReader struct { type opaqueWriter struct { io.Writer } + +func TestWriterChunking(t *testing.T) { + defer func(oldBufferSize int) { + writeBufferSize = oldBufferSize + }(writeBufferSize) + writeBufferSize = 5 + + testData := "Hello this is some test data" + ts := &testSender{} + w := NewWriter(ts.send) + _, err := io.CopyBuffer(&opaqueWriter{w}, strings.NewReader(testData), make([]byte, 10)) + + require.NoError(t, err) + require.Equal(t, testData, string(bytes.Join(ts.sends, nil))) + for _, send := range ts.sends { + require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize") + } +} + +type testSender struct { + sends [][]byte +} + +func (ts *testSender) send(p []byte) error { + buf := make([]byte, len(p)) + copy(buf, p) + ts.sends = append(ts.sends, buf) + return nil +} + +func TestReadFrom(t *testing.T) { + defer func(oldBufferSize int) { + writeBufferSize = oldBufferSize + }(writeBufferSize) + writeBufferSize = 5 + + testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla." + testCases := []struct { + desc string + r io.Reader + }{ + {desc: "base", r: strings.NewReader(testData)}, + {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))}, + {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))}, + {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))}, + } + + for _, tc := range testCases { + ts := &testSender{} + n, err := NewWriter(ts.send).(io.ReaderFrom).ReadFrom(tc.r) + + require.NoError(t, err, tc.desc) + require.Equal(t, int64(len(testData)), n, tc.desc) + require.Equal(t, testData, string(bytes.Join(ts.sends, nil)), tc.desc) + for _, send := range ts.sends { + require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize") + } + } +} -- cgit v1.2.3