diff options
Diffstat (limited to 'streamio/stream.go')
-rw-r--r-- | streamio/stream.go | 81 |
1 files changed, 78 insertions, 3 deletions
diff --git a/streamio/stream.go b/streamio/stream.go index dafade8fb..c1c71db66 100644 --- a/streamio/stream.go +++ b/streamio/stream.go @@ -29,16 +29,91 @@ func (rr *receiveReader) Read(p []byte) (int, error) { return n, nil } -// NewWriter turns sender into an io.Writer. The number of 'bytes -// written' reported back is always len(p). +// WriteTo implements io.WriterTo. +func (rr *receiveReader) WriteTo(w io.Writer) (int64, error) { + var written int64 + + // Deal with left-over state in rr.data and rr.err, if any + if len(rr.data) > 0 { + n, err := w.Write(rr.data) + written += int64(n) + if err != nil { + return written, err + } + } + if rr.err != nil { + return written, rr.err + } + + // Consume the response stream + var errRead, errWrite error + var n int + var buf []byte + for errWrite == nil && errRead != io.EOF { + buf, errRead = rr.receiver() + if errRead != nil && errRead != io.EOF { + return written, errRead + } + + if len(buf) > 0 { + n, errWrite = w.Write(buf) + written += int64(n) + } + } + + return written, errWrite +} + +// NewWriter turns sender into an io.Writer. func NewWriter(sender func(p []byte) error) io.Writer { return &sendWriter{sender: sender} } +var writeBufferSize = 128 * 1024 + type sendWriter struct { sender func([]byte) error } func (sw *sendWriter) Write(p []byte) (int, error) { - return len(p), sw.sender(p) + var sent int + + for len(p) > 0 { + chunkSize := len(p) + if chunkSize > writeBufferSize { + chunkSize = writeBufferSize + } + + if err := sw.sender(p[:chunkSize]); err != nil { + return sent, err + } + + sent += chunkSize + p = p[chunkSize:] + } + + return sent, nil +} + +// ReadFrom implements io.ReaderFrom. +func (sw *sendWriter) ReadFrom(r io.Reader) (int64, error) { + var nRead int64 + buf := make([]byte, writeBufferSize) + + var errRead, errSend error + for errSend == nil && errRead != io.EOF { + var n int + + n, errRead = r.Read(buf) + nRead += int64(n) + if errRead != nil && errRead != io.EOF { + return nRead, errRead + } + + if n > 0 { + errSend = sw.sender(buf[:n]) + } + } + + return nRead, errSend } |