diff options
author | Sami Hiltunen <shiltunen@gitlab.com> | 2021-03-09 11:49:39 +0300 |
---|---|---|
committer | Sami Hiltunen <shiltunen@gitlab.com> | 2021-03-09 11:49:39 +0300 |
commit | 059a82773ec2b5afc115442270d663cccc68451c (patch) | |
tree | ec78f8db231ce7b395e26a0387a199305712e038 | |
parent | 154e48d6bb6b9229de600d285990916e32b4bb2f (diff) | |
parent | 914523f5e19d651fb47b2201f604ed8e8c67bc8e (diff) |
Merge branch 'jv-sideband-writer' into 'master'
Add pktline side-band-64 writer
See merge request gitlab-org/gitaly!3215
-rw-r--r-- | changelogs/unreleased/jv-sideband-writer.yml | 5 | ||||
-rw-r--r-- | internal/git/pktline/pkt_line_test.go | 183 | ||||
-rw-r--r-- | internal/git/pktline/pktline.go | 75 |
3 files changed, 258 insertions, 5 deletions
diff --git a/changelogs/unreleased/jv-sideband-writer.yml b/changelogs/unreleased/jv-sideband-writer.yml new file mode 100644 index 000000000..066939f35 --- /dev/null +++ b/changelogs/unreleased/jv-sideband-writer.yml @@ -0,0 +1,5 @@ +--- +title: Add pktline side-band-64 writer +merge_request: 3215 +author: +type: changed diff --git a/internal/git/pktline/pkt_line_test.go b/internal/git/pktline/pkt_line_test.go index 8cbbeed07..32694a7e0 100644 --- a/internal/git/pktline/pkt_line_test.go +++ b/internal/git/pktline/pkt_line_test.go @@ -2,6 +2,10 @@ package pktline import ( "bytes" + "errors" + "io" + "math" + "math/rand" "strings" "testing" @@ -9,11 +13,11 @@ import ( ) var ( - largestString = strings.Repeat("z", 0xffff-4) + largestString = strings.Repeat("z", 65516) ) func TestScanner(t *testing.T) { - largestPacket := "ffff" + largestString + largestPacket := "fff0" + largestString testCases := []struct { desc string in string @@ -125,7 +129,7 @@ func TestWriteString(t *testing.T) { { desc: "largest possible string", in: largestString, - out: "ffff" + largestString, + out: "fff0" + largestString, }, { desc: "string that is too large", @@ -157,3 +161,176 @@ func TestWriteFlush(t *testing.T) { require.NoError(t, WriteFlush(w)) require.Equal(t, "0000", w.String()) } + +func TestSidebandWriter_boundaries(t *testing.T) { + testCases := []struct { + desc string + in string + band byte + out string + }{ + { + desc: "empty", + in: "", + band: 0, + out: "", + }, + { + desc: "1 byte", + in: "x", + band: 1, + out: "0006\x01x", + }, + { + desc: "65514 bytes", + in: strings.Repeat("x", 65514), + band: 255, + out: "ffef\xff" + strings.Repeat("x", 65514), + }, + { + desc: "65515 bytes: max per sideband packets", + in: strings.Repeat("x", 65515), + band: 254, + out: "fff0\xfe" + strings.Repeat("x", 65515), + }, + { + desc: "65516 bytes: split across two packets", + in: strings.Repeat("x", 65516), + band: 253, + out: "fff0\xfd" + strings.Repeat("x", 65515) + "0006\xfdx", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buf := &bytes.Buffer{} + w := NewSidebandWriter(buf).Writer(tc.band) + + n, err := w.Write([]byte(tc.in)) + require.NoError(t, err) + require.Equal(t, n, len(tc.in)) + + require.Equal(t, tc.out, buf.String()) + }) + } +} + +func TestSidebandWriter_concurrency(t *testing.T) { + const N = math.MaxUint8 + 1 + + buf := &bytes.Buffer{} + sw := NewSidebandWriter(buf) + inputs := make([][]byte, N) + writeErrors := make(chan error, N) + start := make(chan struct{}) + + for i := 0; i < N; i++ { + inputs[i] = make([]byte, 1024) + _, _ = rand.Read(inputs[i]) // math/rand.Read never fails + + go func(i int) { + <-start + w := sw.Writer(byte(i)) + writeErrors <- func() error { + data := inputs[i] + for j := 0; j < len(data); j++ { + n, err := w.Write(data[j : j+1]) + if err != nil { + return err + } + if n != 1 { + return io.ErrShortWrite + } + } + + return nil + }() + }(i) + } + + close(start) + for i := 0; i < N; i++ { + require.NoError(t, <-writeErrors) + } + + outputs := make([][]byte, N) + scanner := NewScanner(buf) + for scanner.Scan() { + data := Data(scanner.Bytes()) + require.NotEmpty(t, data) + band := data[0] + outputs[band] = append(outputs[band], data[1:]...) + } + + require.NoError(t, scanner.Err()) + + require.Equal(t, inputs, outputs) +} + +func TestEachSidebandPacket(t *testing.T) { + callbackError := errors.New("callback failed") + + testCases := []struct { + desc string + in string + out map[byte]string + err error + callback func(byte, []byte) error + }{ + { + desc: "empty", + out: map[byte]string{}, + }, + { + desc: "empty with failing callback: callback does not run", + out: map[byte]string{}, + callback: func(byte, []byte) error { panic("oh no") }, + }, + { + desc: "valid stream", + in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz", + out: map[byte]string{0: "foo", 1: "bar", 254: "qux", 255: "baz"}, + }, + { + desc: "valid stream, failing callback", + in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz", + callback: func(byte, []byte) error { return callbackError }, + err: callbackError, + }, + { + desc: "interrupted stream", + in: "ffff\x10hello world!!", + err: io.ErrUnexpectedEOF, + }, + { + desc: "stream without band", + in: "0004", + err: &errNotSideband{pkt: "0004"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out := make(map[byte]string) + callback := tc.callback + if callback == nil { + callback = func(band byte, data []byte) error { + out[band] += string(data) + return nil + } + } + + err := EachSidebandPacket(strings.NewReader(tc.in), callback) + if tc.err != nil { + require.Equal(t, tc.err, err) + return + } + + require.NoError(t, err) + + if tc.callback == nil { + require.Equal(t, tc.out, out) + } + }) + } +} diff --git a/internal/git/pktline/pktline.go b/internal/git/pktline/pktline.go index d012c8c7d..665d6cbde 100644 --- a/internal/git/pktline/pktline.go +++ b/internal/git/pktline/pktline.go @@ -9,10 +9,11 @@ import ( "fmt" "io" "strconv" + "sync" ) const ( - maxPktSize = 0xffff + maxPktSize = 65520 // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216 pktDelim = "0001" ) @@ -100,7 +101,7 @@ func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err er // data contains incomplete packet if atEOF { - return 0, nil, fmt.Errorf("pktLineSplitter: less than %d bytes in input %q", pktLength, data) + return 0, nil, io.ErrUnexpectedEOF } return 0, nil, nil // want more data @@ -108,3 +109,73 @@ func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err er return pktLength, data[:pktLength], nil } + +// SidebandWriter multiplexes byte streams into a single side-band-64k stream. +type SidebandWriter struct { + w io.Writer + m sync.Mutex +} + +// NewSidebandWriter instantiates a new SidebandWriter. +func NewSidebandWriter(w io.Writer) *SidebandWriter { return &SidebandWriter{w: w} } + +func (sw *SidebandWriter) writeBand(band byte, data []byte) (int, error) { + sw.m.Lock() + defer sw.m.Unlock() + + n := 0 + for len(data) > 0 { + chunkSize := len(data) + const headerSize = 5 + if max := maxPktSize - headerSize; chunkSize > max { + chunkSize = max + } + + if _, err := fmt.Fprintf(sw.w, "%04x%s", chunkSize+headerSize, []byte{band}); err != nil { + return n, err + } + + if _, err := sw.w.Write(data[:chunkSize]); err != nil { + return n, err + } + data = data[chunkSize:] + n += chunkSize + } + + return n, nil +} + +// Writer returns an io.Writer that writes into the multiplexed stream. +// Writers for different bands can be used concurrently. +func (sw *SidebandWriter) Writer(band byte) io.Writer { + return writerFunc(func(p []byte) (int, error) { + return sw.writeBand(band, p) + }) +} + +type writerFunc func([]byte) (int, error) + +func (wf writerFunc) Write(p []byte) (int, error) { return wf(p) } + +type errNotSideband struct{ pkt string } + +func (err *errNotSideband) Error() string { return fmt.Sprintf("invalid sideband packet: %q", err.pkt) } + +// EachSidebandPacket iterates over a side-band-64k pktline stream. For +// each packet, it will call fn with the band ID and the packet. Fn must +// not retain the packet. +func EachSidebandPacket(r io.Reader, fn func(byte, []byte) error) error { + scanner := NewScanner(r) + + for scanner.Scan() { + data := Data(scanner.Bytes()) + if len(data) == 0 { + return &errNotSideband{scanner.Text()} + } + if err := fn(data[0], data[1:]); err != nil { + return err + } + } + + return scanner.Err() +} |