Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2017-07-10 18:31:38 +0300
committerJacob Vosmaer <jacob@gitlab.com>2017-07-13 13:20:53 +0300
commit7adbb3678dbfc72dd712405f1b7e3436d29b0052 (patch)
treebaf7afb3224aee7726b3a0a4bcc1a108ca326270 /streamio
parent2a122b1b90b14b2c518b9e2a4cefbeea0d71bc18 (diff)
Streamio optimizations
Diffstat (limited to 'streamio')
-rw-r--r--streamio/stream.go81
-rw-r--r--streamio/stream_test.go88
2 files changed, 163 insertions, 6 deletions
diff --git a/streamio/stream.go b/streamio/stream.go
index dafade8fb..c1c71db66 100644
--- a/streamio/stream.go
+++ b/streamio/stream.go
@@ -29,16 +29,91 @@ func (rr *receiveReader) Read(p []byte) (int, error) {
return n, nil
}
-// NewWriter turns sender into an io.Writer. The number of 'bytes
-// written' reported back is always len(p).
+// WriteTo implements io.WriterTo.
+func (rr *receiveReader) WriteTo(w io.Writer) (int64, error) {
+ var written int64
+
+ // Deal with left-over state in rr.data and rr.err, if any
+ if len(rr.data) > 0 {
+ n, err := w.Write(rr.data)
+ written += int64(n)
+ if err != nil {
+ return written, err
+ }
+ }
+ if rr.err != nil {
+ return written, rr.err
+ }
+
+ // Consume the response stream
+ var errRead, errWrite error
+ var n int
+ var buf []byte
+ for errWrite == nil && errRead != io.EOF {
+ buf, errRead = rr.receiver()
+ if errRead != nil && errRead != io.EOF {
+ return written, errRead
+ }
+
+ if len(buf) > 0 {
+ n, errWrite = w.Write(buf)
+ written += int64(n)
+ }
+ }
+
+ return written, errWrite
+}
+
+// NewWriter turns sender into an io.Writer.
func NewWriter(sender func(p []byte) error) io.Writer {
return &sendWriter{sender: sender}
}
+var writeBufferSize = 128 * 1024
+
type sendWriter struct {
sender func([]byte) error
}
func (sw *sendWriter) Write(p []byte) (int, error) {
- return len(p), sw.sender(p)
+ var sent int
+
+ for len(p) > 0 {
+ chunkSize := len(p)
+ if chunkSize > writeBufferSize {
+ chunkSize = writeBufferSize
+ }
+
+ if err := sw.sender(p[:chunkSize]); err != nil {
+ return sent, err
+ }
+
+ sent += chunkSize
+ p = p[chunkSize:]
+ }
+
+ return sent, nil
+}
+
+// ReadFrom implements io.ReaderFrom.
+func (sw *sendWriter) ReadFrom(r io.Reader) (int64, error) {
+ var nRead int64
+ buf := make([]byte, writeBufferSize)
+
+ var errRead, errSend error
+ for errSend == nil && errRead != io.EOF {
+ var n int
+
+ n, errRead = r.Read(buf)
+ nRead += int64(n)
+ if errRead != nil && errRead != io.EOF {
+ return nRead, errRead
+ }
+
+ if n > 0 {
+ errSend = sw.sender(buf[:n])
+ }
+ }
+
+ return nRead, errSend
}
diff --git a/streamio/stream_test.go b/streamio/stream_test.go
index 456d17e80..d263a68ab 100644
--- a/streamio/stream_test.go
+++ b/streamio/stream_test.go
@@ -25,7 +25,7 @@ func TestReceiveSources(t *testing.T) {
}
for _, tc := range testCases {
- data, err := ioutil.ReadAll(NewReader(receiverFromReader(tc.r)))
+ data, err := ioutil.ReadAll(&opaqueReader{NewReader(receiverFromReader(tc.r))})
require.NoError(t, err, tc.desc)
require.Equal(t, testData, string(data), tc.desc)
}
@@ -35,15 +35,38 @@ func TestReadSizes(t *testing.T) {
testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
for n := 1; n < 100; n *= 3 {
desc := fmt.Sprintf("reads of size %d", n)
- buffer := make([]byte, n)
result := &bytes.Buffer{}
reader := &opaqueReader{NewReader(receiverFromReader(strings.NewReader(testData)))}
- _, err := io.CopyBuffer(&opaqueWriter{result}, reader, buffer)
+ _, err := io.CopyBuffer(&opaqueWriter{result}, reader, make([]byte, n))
+
require.NoError(t, err, desc)
require.Equal(t, testData, result.String())
}
}
+func TestWriterTo(t *testing.T) {
+ testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
+ testCases := []struct {
+ desc string
+ r io.Reader
+ }{
+ {desc: "base", r: strings.NewReader(testData)},
+ {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))},
+ {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))},
+ {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))},
+ }
+
+ for _, tc := range testCases {
+ result := &bytes.Buffer{}
+ reader := NewReader(receiverFromReader(tc.r))
+ n, err := reader.(io.WriterTo).WriteTo(result)
+
+ require.NoError(t, err, tc.desc)
+ require.Equal(t, int64(len(testData)), n, tc.desc)
+ require.Equal(t, testData, result.String(), tc.desc)
+ }
+}
+
func receiverFromReader(r io.Reader) func() ([]byte, error) {
return func() ([]byte, error) {
data := make([]byte, 10)
@@ -61,3 +84,62 @@ type opaqueReader struct {
type opaqueWriter struct {
io.Writer
}
+
+func TestWriterChunking(t *testing.T) {
+ defer func(oldBufferSize int) {
+ writeBufferSize = oldBufferSize
+ }(writeBufferSize)
+ writeBufferSize = 5
+
+ testData := "Hello this is some test data"
+ ts := &testSender{}
+ w := NewWriter(ts.send)
+ _, err := io.CopyBuffer(&opaqueWriter{w}, strings.NewReader(testData), make([]byte, 10))
+
+ require.NoError(t, err)
+ require.Equal(t, testData, string(bytes.Join(ts.sends, nil)))
+ for _, send := range ts.sends {
+ require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize")
+ }
+}
+
+type testSender struct {
+ sends [][]byte
+}
+
+func (ts *testSender) send(p []byte) error {
+ buf := make([]byte, len(p))
+ copy(buf, p)
+ ts.sends = append(ts.sends, buf)
+ return nil
+}
+
+func TestReadFrom(t *testing.T) {
+ defer func(oldBufferSize int) {
+ writeBufferSize = oldBufferSize
+ }(writeBufferSize)
+ writeBufferSize = 5
+
+ testData := "Hello this is the test data that will be received. It goes on for a while bla bla bla."
+ testCases := []struct {
+ desc string
+ r io.Reader
+ }{
+ {desc: "base", r: strings.NewReader(testData)},
+ {desc: "dataerr", r: iotest.DataErrReader(strings.NewReader(testData))},
+ {desc: "onebyte", r: iotest.OneByteReader(strings.NewReader(testData))},
+ {desc: "dataerr(onebyte)", r: iotest.DataErrReader(iotest.OneByteReader(strings.NewReader(testData)))},
+ }
+
+ for _, tc := range testCases {
+ ts := &testSender{}
+ n, err := NewWriter(ts.send).(io.ReaderFrom).ReadFrom(tc.r)
+
+ require.NoError(t, err, tc.desc)
+ require.Equal(t, int64(len(testData)), n, tc.desc)
+ require.Equal(t, testData, string(bytes.Join(ts.sends, nil)), tc.desc)
+ for _, send := range ts.sends {
+ require.True(t, len(send) <= writeBufferSize, "send calls may not exceed writeBufferSize")
+ }
+ }
+}