diff options
author | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-09-14 09:13:19 +0300 |
---|---|---|
committer | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-09-22 11:19:49 +0300 |
commit | 8c48ef49e6b97de2850b94571822aeae32fdb2e4 (patch) | |
tree | 4d2d09b7727f819983887c42af7f475c6ded892a /internal | |
parent | 858ab5adcc2996f16e958160f3f31fe519300bc8 (diff) |
Add half-close capability to Gitaly sidechannel
Issue: https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1278
Changelog: added
Diffstat (limited to 'internal')
-rw-r--r-- | internal/git/pktline/pktline.go | 13 | ||||
-rw-r--r-- | internal/sidechannel/conn.go | 181 | ||||
-rw-r--r-- | internal/sidechannel/conn_test.go | 179 | ||||
-rw-r--r-- | internal/sidechannel/registry.go | 6 | ||||
-rw-r--r-- | internal/sidechannel/registry_test.go | 45 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel.go | 6 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel_test.go | 6 |
7 files changed, 410 insertions, 26 deletions
diff --git a/internal/git/pktline/pktline.go b/internal/git/pktline/pktline.go index 6a37f9651..1841a94de 100644 --- a/internal/git/pktline/pktline.go +++ b/internal/git/pktline/pktline.go @@ -15,16 +15,19 @@ import ( const ( // MaxSidebandData is the maximum number of bytes that fits into one Git // pktline side-band-64k packet. - MaxSidebandData = maxPktSize - 5 + MaxSidebandData = MaxPktSize - 5 - maxPktSize = 65520 // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216 + // MaxPktSize is the maximum size of content of a Git pktline side-band-64k + // packet, excluding size of length and band number + // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216 + MaxPktSize = 65520 pktDelim = "0001" ) // NewScanner returns a bufio.Scanner that splits on Git pktline boundaries func NewScanner(r io.Reader) *bufio.Scanner { scanner := bufio.NewScanner(r) - scanner.Buffer(make([]byte, maxPktSize), maxPktSize) + scanner.Buffer(make([]byte, MaxPktSize), MaxPktSize) scanner.Split(pktLineSplitter) return scanner } @@ -44,7 +47,7 @@ func IsFlush(pkt []byte) bool { // WriteString writes a string with pkt-line framing func WriteString(w io.Writer, str string) (int, error) { pktLen := len(str) + 4 - if pktLen > maxPktSize { + if pktLen > MaxPktSize { return 0, fmt.Errorf("string too large: %d bytes", len(str)) } @@ -131,7 +134,7 @@ func (sw *SidebandWriter) writeBand(band byte, data []byte) (int, error) { for len(data) > 0 { chunkSize := len(data) const headerSize = 5 - if max := maxPktSize - headerSize; chunkSize > max { + if max := MaxPktSize - headerSize; chunkSize > max { chunkSize = max } diff --git a/internal/sidechannel/conn.go b/internal/sidechannel/conn.go new file mode 100644 index 000000000..5af660fab --- /dev/null +++ b/internal/sidechannel/conn.go @@ -0,0 +1,181 @@ +package sidechannel + +import ( + "fmt" + "io" + "net" + + "gitlab.com/gitlab-org/gitaly/v14/internal/git/pktline" + "gitlab.com/gitlab-org/gitaly/v14/streamio" +) + +// ServerConn and ClientConn implement an asymmetric framing protocol to +// exchange data between clients and servers in Sidechannel. A typical flow +// looks like following: +// - The client writes data into the connecton. +// - The client half-closes the connection. The server is aware of this event +// when reading operations return EOF. +// - The server writes the data back to the client, then close the connection. +// - The client read the data until EOF +// +// Half-close ability is important to signal the server that the client +// finishes data transformation. As sidechannel is built on top of Yamux +// stream, half-close ability is not supported. Therefore, we apply a +// length-prefix framing protocol, simiarly to Git pktline protocol, except we +// omit the band number. The close or half-close event are signaled by sending +// a flush packet. +// +// This is an example of the data written into the wire: +// +// | 4-byte length, including size of length itself. +// v +// 0009Hello0000 +// ^ +// | Flush packet signaling a half-close event +// +// Many methods in battle-tested pktline package are re-used to save us some +// times. At the moment, we don't need server-client half-closed ability. And +// it may affect the performance when wrapping huge data sent from the server. + +const ( + // maxChunkSize is the maximum chunk size of data. The chunk size must include 4-byte + // length prefix. This constant is different from MaxSidebandData because + // we don't include the sideband number. + maxChunkSize = pktline.MaxPktSize - 4 +) + +// ServerConn is a wrapper around net.Conn with the support of half-closed +// capacity for sidechannel. This struct is expected to be used by +// sidechannel's server only. +type ServerConn struct { + conn net.Conn + r io.Reader +} + +func newServerConn(c net.Conn) *ServerConn { + scanner := pktline.NewScanner(c) + reader := streamio.NewReader(func() ([]byte, error) { + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + // If there is any error while scanning, scanner.Err() returns a + // non-nil error. If scanner.Err() returns nil, the connection + // reaches end-of-file. However, the effect of returning io.EOF is + // that we allow two kinds of streams: "000fhello world0000" (with + // trailing 0000) and "000fhello world" (without trialing 0000). + // Having optional behaviors like this is a source of complexity. + // We should not allow "000fhello world" without the trailing 0000. + return nil, io.ErrUnexpectedEOF + } + + if pktline.IsFlush(scanner.Bytes()) { + return nil, io.EOF + } + + data := scanner.Bytes() + if len(data) < 4 { + return nil, fmt.Errorf("sidechannel: invalid packet %q", data) + } + + // pktline treats 0001, 0002, or 0003 as magic empty packets + // They are irrelevant to sidechannel, hence should be rejected + if len(data) == 4 { + if s := string(data); s == "0001" || s == "0002" || s == "0003" { + return nil, fmt.Errorf("sidechannel: invalid header %s", string(data[3])) + } + } + + return data[4:], nil + }) + + return &ServerConn{conn: c, r: reader} +} + +// Read reads up to len(p) bytes into p. It returns the number of bytes read or +// any error encountered. This struct overrides Read() to extract the data +// wrapped in a frame generated by ClientConn.Write(). +func (cc *ServerConn) Read(p []byte) (n int, err error) { + return cc.r.Read(p) +} + +// Write writes data to the connection. This method fallbacks to underlying +// connection without any modificiation. +func (cc *ServerConn) Write(b []byte) (n int, err error) { + return cc.conn.Write(b) +} + +// Close closes the connection. This method fallbacks to underlying +// connection without any modificiation. +func (cc *ServerConn) Close() error { + return cc.conn.Close() +} + +// ClientConn is a wrapper around net.Conn with the support of half-closed +// capacity for sidechannel. This struct is expected to use by sidechannel's +// client only. +type ClientConn struct { + conn net.Conn + writeClosed bool +} + +func newClientConn(c net.Conn) *ClientConn { + return &ClientConn{conn: c} +} + +// Read reads data from the connection. This method fallbacks to underlying +// connection without any modificiation. +func (cc *ClientConn) Read(b []byte) (n int, err error) { + return cc.conn.Read(b) +} + +// Write writes len(p) bytes from p to the underlying data stream. It returns +// the number of bytes written from p and any error encountered that caused the +// write to stop early. This method overrides Write() to wrap the writing data +// into a frame. The frame is then extracted and read by ServerConn.Read(). +func (cc *ClientConn) Write(p []byte) (int, error) { + if cc.writeClosed { + return 0, fmt.Errorf("sidechannel: write into a half-closed connection") + } + + var n int + + for len(p) > 0 { + chunk := maxChunkSize + if len(p) < chunk { + chunk = len(p) + } + + if _, err := fmt.Fprintf(cc.conn, "%04x", chunk+4); err != nil { + return n, err + } + if _, err := cc.conn.Write(p[:chunk]); err != nil { + return n, err + } + n += chunk + p = p[chunk:] + } + + return n, nil +} + +func (cc *ClientConn) close() error { + return cc.conn.Close() +} + +// CloseWrite shuts down the writing side of the connection. After this call, +// any read operations from the server return EOF. The reading side is still +// functional so that the server is still able to write back to the client. Any +// attempt to write into a half-closed connection returns an error. +func (cc *ClientConn) CloseWrite() error { + if cc.writeClosed { + return nil + } + + cc.writeClosed = true + if err := pktline.WriteFlush(cc.conn); err != nil { + return err + } + + return nil +} diff --git a/internal/sidechannel/conn_test.go b/internal/sidechannel/conn_test.go new file mode 100644 index 000000000..257fd4273 --- /dev/null +++ b/internal/sidechannel/conn_test.go @@ -0,0 +1,179 @@ +package sidechannel + +import ( + "fmt" + "io" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClientWrite(t *testing.T) { + largeString := strings.Repeat("a", maxChunkSize) + + testCases := []struct { + desc string + in string + out string + err error + }{ + {desc: "empty", out: ""}, + {desc: "1-byte string", in: "h", out: "0005h"}, + {desc: "short string", in: "hello", out: "0009hello"}, + {desc: "short string 2", in: "hello this world", out: "0014hello this world"}, + { + desc: "large string", + in: largeString, + out: "fff0" + largeString, + }, + { + desc: "very large string", + in: largeString + "b", + out: "fff0" + largeString + "0005b", + }, + } + + type result struct { + data string + err error + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, server := net.Pipe() + ch := make(chan result, 1) + go func() { + out, err := io.ReadAll(server) + ch <- result{data: string(out), err: err} + }() + + cc := newClientConn(client) + written, err := cc.Write([]byte(tc.in)) + if tc.err != nil { + require.Equal(t, tc.err, err) + return + } + + require.NoError(t, err) + require.NoError(t, cc.close()) + require.Equal(t, written, len(tc.in)) + + res := <-ch + require.NoError(t, res.err) + require.Equal(t, tc.out, res.data) + }) + } +} + +func TestServerRead(t *testing.T) { + largeString := strings.Repeat("a", maxChunkSize) + + testCases := []struct { + desc string + in string + out string + err error + }{ + {desc: "empty", in: "0000", out: ""}, + {desc: "empty 2", in: "00040000", out: ""}, + {desc: "1-byte string", in: "0005h0000", out: "h"}, + {desc: "short string", in: "0009hello0000", out: "hello"}, + {desc: "short string 2", in: "0014hello this world0000", out: "hello this world"}, + {desc: "invalid header 1", in: "0001", err: fmt.Errorf("sidechannel: invalid header 1")}, + {desc: "invalid header 2", in: "0002", err: fmt.Errorf("sidechannel: invalid header 2")}, + {desc: "invalid header 3", in: "0003", err: fmt.Errorf("sidechannel: invalid header 3")}, + {desc: "multiple short strings", in: "0009hello0008this0009world0000", out: "hellothisworld"}, + { + desc: "large string", + in: "fff0" + largeString + "0000", + out: largeString, + }, + { + desc: "very large string", + in: "fff0" + largeString + "0005b0000", + out: largeString + "b", + }, + {desc: "flush packet", in: "0009hello0000trashtrash", out: "hello"}, + {desc: "unexpected closed without trailing 0000", in: "0009hello", err: io.ErrUnexpectedEOF}, + } + + type result struct { + data string + err error + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, server := net.Pipe() + ch := make(chan result, 1) + + go func() { + cc := newServerConn(server) + out, err := io.ReadAll(cc) + ch <- result{data: string(out), err: err} + }() + + written, err := client.Write([]byte(tc.in)) + require.Equal(t, written, len(tc.in)) + require.NoError(t, err) + + client.Close() + + res := <-ch + if tc.err != nil { + require.Equal(t, tc.err, res.err) + return + } + + require.NoError(t, res.err) + require.Equal(t, tc.out, res.data) + }) + } +} + +func TestHalfClose(t *testing.T) { + type result struct { + data string + err error + } + + client, server := net.Pipe() + clientCc := newClientConn(client) + serverCc := newServerConn(server) + + ch := make(chan result, 1) + + go func() { + out, err := io.ReadAll(serverCc) + ch <- result{data: string(out), err: err} + }() + + written, err := clientCc.Write([]byte("Ping")) + require.Equal(t, written, 4) + require.NoError(t, err) + + require.NoError(t, clientCc.CloseWrite()) + + res := <-ch + require.NoError(t, res.err) + require.Equal(t, "Ping", res.data) + + _, err = clientCc.Write([]byte("Should not be sent")) + require.EqualError(t, err, "sidechannel: write into a half-closed connection") + + go func() { + out, err := io.ReadAll(clientCc) + ch <- result{data: string(out), err: err} + }() + + _, err = serverCc.Write([]byte("Pong")) + require.NoError(t, err) + + serverCc.Close() + + res = <-ch + require.NoError(t, res.err) + require.Equal(t, "Pong", res.data) +} diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go index 32985a9a1..1b96e8c88 100644 --- a/internal/sidechannel/registry.go +++ b/internal/sidechannel/registry.go @@ -25,7 +25,7 @@ type Waiter struct { registry *Registry errC chan error accept chan net.Conn - callback func(net.Conn) error + callback func(*ClientConn) error } // NewRegistry returns a new Registry instance @@ -40,7 +40,7 @@ func NewRegistry() *Registry { // connection arrives, the callback function is executed with arrived // connection in a new goroutine. The caller receives execution result via // waiter.Wait(). -func (s *Registry) Register(callback func(net.Conn) error) *Waiter { +func (s *Registry) Register(callback func(*ClientConn) error) *Waiter { s.mu.Lock() defer s.mu.Unlock() @@ -104,7 +104,7 @@ func (w *Waiter) run() { if conn := <-w.accept; conn != nil { defer conn.Close() - w.errC <- w.callback(conn) + w.errC <- w.callback(newClientConn(conn)) } } diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go index 07d2f85c1..4db4e6ac5 100644 --- a/internal/sidechannel/registry_test.go +++ b/internal/sidechannel/registry_test.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "net" + "os" "strconv" "sync" + "syscall" "testing" "github.com/stretchr/testify/require" @@ -17,7 +19,7 @@ func TestRegistry(t *testing.T) { t.Run("waiter removed from the registry right after connection received", func(t *testing.T) { triggerCallback := make(chan struct{}) - waiter := registry.Register(func(conn net.Conn) error { + waiter := registry.Register(func(conn *ClientConn) error { <-triggerCallback return nil }) @@ -25,7 +27,7 @@ func TestRegistry(t *testing.T) { require.Equal(t, 1, registry.waiting()) - client, _ := net.Pipe() + client, _ := socketPair(t) require.NoError(t, registry.receive(waiter.id, client)) require.Equal(t, 0, registry.waiting()) @@ -37,17 +39,20 @@ func TestRegistry(t *testing.T) { t.Run("pull connections successfully", func(t *testing.T) { wg := sync.WaitGroup{} - var servers []net.Conn + var servers []*ServerConn for i := 0; i < N; i++ { - client, server := net.Pipe() - servers = append(servers, server) + client, server := socketPair(t) + servers = append(servers, newServerConn(server)) wg.Add(1) go func(i int) { - waiter := registry.Register(func(conn net.Conn) error { - _, err := fmt.Fprintf(conn, "%d", i) - return err + waiter := registry.Register(func(conn *ClientConn) error { + if _, err := fmt.Fprintf(conn, "%d", i); err != nil { + return err + } + + return conn.CloseWrite() }) defer waiter.Close() @@ -70,14 +75,14 @@ func TestRegistry(t *testing.T) { }) t.Run("push connection to non-existing ID", func(t *testing.T) { - client, _ := net.Pipe() + client, _ := socketPair(t) err := registry.receive(registry.nextID+1, client) require.EqualError(t, err, "sidechannel registry: ID not registered") requireConnClosed(t, client) }) t.Run("pre-maturely close the waiter", func(t *testing.T) { - waiter := registry.Register(func(conn net.Conn) error { panic("never execute") }) + waiter := registry.Register(func(conn *ClientConn) error { panic("never execute") }) require.NoError(t, waiter.Close()) require.Equal(t, 0, registry.waiting()) }) @@ -86,7 +91,23 @@ func TestRegistry(t *testing.T) { func requireConnClosed(t *testing.T, conn net.Conn) { one := make([]byte, 1) _, err := conn.Read(one) - require.EqualError(t, err, "io: read/write on closed pipe") + require.Errorf(t, err, "use of closed network connection") _, err = conn.Write(one) - require.EqualError(t, err, "io: read/write on closed pipe") + require.Errorf(t, err, "use of closed network connection") +} + +func socketPair(t *testing.T) (net.Conn, net.Conn) { + conns := make([]net.Conn, 2) + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + require.NoError(t, err) + + for i, fd := range fds[:] { + f := os.NewFile(uintptr(fd), "socket pair") + c, err := net.FileConn(f) + require.NoError(t, err) + require.NoError(t, f.Close()) + t.Cleanup(func() { c.Close() }) + conns[i] = c + } + return conns[0], conns[1] } diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go index c133fcc6e..a2083fb04 100644 --- a/internal/sidechannel/sidechannel.go +++ b/internal/sidechannel/sidechannel.go @@ -26,7 +26,7 @@ const ( // OpenSidechannel opens a sidechannel connection from the stream opener // extracted from the current peer connection. -func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) { +func OpenSidechannel(ctx context.Context) (_ *ServerConn, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, fmt.Errorf("sidechannel: failed to extract incoming metadata") @@ -68,7 +68,7 @@ func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) { return nil, err } - return stream, nil + return newServerConn(stream), nil } // RegisterSidechannel registers the caller into the waiting list of the @@ -76,7 +76,7 @@ func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) { // The caller is expected to establish the request with the returned context. The // callback is executed automatically when the sidechannel connection arrives. // The result is pushed to the error channel of the returned waiter. -func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(net.Conn) error) (context.Context, *Waiter) { +func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(*ClientConn) error) (context.Context, *Waiter) { waiter := registry.Register(callback) ctxOut := metadata.AppendToOutgoingContext(ctx, sidechannelMetadataKey, fmt.Sprintf("%d", waiter.id)) return ctxOut, waiter diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index c85517952..f1d625504 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -47,7 +47,7 @@ func TestSidechannel(t *testing.T) { conn, registry := dial(t, addr) err = call( context.Background(), conn, registry, - func(conn net.Conn) error { + func(conn *ClientConn) error { errC := make(chan error, 1) go func() { var err error @@ -113,7 +113,7 @@ func TestSidechannelConcurrency(t *testing.T) { err := call( context.Background(), conn, registry, - func(conn net.Conn) error { + func(conn *ClientConn) error { errC := make(chan error, 1) go func() { var err error @@ -187,7 +187,7 @@ func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) { return conn, registry } -func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(net.Conn) error) error { +func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(*ClientConn) error) error { client := healthpb.NewHealthClient(conn) ctxOut, waiter := RegisterSidechannel(ctx, registry, handler) |