Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-09-14 09:13:19 +0300
committerQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-09-22 11:19:49 +0300
commit8c48ef49e6b97de2850b94571822aeae32fdb2e4 (patch)
tree4d2d09b7727f819983887c42af7f475c6ded892a
parent858ab5adcc2996f16e958160f3f31fe519300bc8 (diff)
Add half-close capability to Gitaly sidechannel
Issue: https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1278 Changelog: added
-rw-r--r--internal/git/pktline/pktline.go13
-rw-r--r--internal/sidechannel/conn.go181
-rw-r--r--internal/sidechannel/conn_test.go179
-rw-r--r--internal/sidechannel/registry.go6
-rw-r--r--internal/sidechannel/registry_test.go45
-rw-r--r--internal/sidechannel/sidechannel.go6
-rw-r--r--internal/sidechannel/sidechannel_test.go6
-rw-r--r--proto/go/gitalypb/protolist.go1
8 files changed, 411 insertions, 26 deletions
diff --git a/internal/git/pktline/pktline.go b/internal/git/pktline/pktline.go
index 6a37f9651..1841a94de 100644
--- a/internal/git/pktline/pktline.go
+++ b/internal/git/pktline/pktline.go
@@ -15,16 +15,19 @@ import (
const (
// MaxSidebandData is the maximum number of bytes that fits into one Git
// pktline side-band-64k packet.
- MaxSidebandData = maxPktSize - 5
+ MaxSidebandData = MaxPktSize - 5
- maxPktSize = 65520 // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216
+ // MaxPktSize is the maximum size of content of a Git pktline side-band-64k
+ // packet, excluding size of length and band number
+ // https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216
+ MaxPktSize = 65520
pktDelim = "0001"
)
// NewScanner returns a bufio.Scanner that splits on Git pktline boundaries
func NewScanner(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
- scanner.Buffer(make([]byte, maxPktSize), maxPktSize)
+ scanner.Buffer(make([]byte, MaxPktSize), MaxPktSize)
scanner.Split(pktLineSplitter)
return scanner
}
@@ -44,7 +47,7 @@ func IsFlush(pkt []byte) bool {
// WriteString writes a string with pkt-line framing
func WriteString(w io.Writer, str string) (int, error) {
pktLen := len(str) + 4
- if pktLen > maxPktSize {
+ if pktLen > MaxPktSize {
return 0, fmt.Errorf("string too large: %d bytes", len(str))
}
@@ -131,7 +134,7 @@ func (sw *SidebandWriter) writeBand(band byte, data []byte) (int, error) {
for len(data) > 0 {
chunkSize := len(data)
const headerSize = 5
- if max := maxPktSize - headerSize; chunkSize > max {
+ if max := MaxPktSize - headerSize; chunkSize > max {
chunkSize = max
}
diff --git a/internal/sidechannel/conn.go b/internal/sidechannel/conn.go
new file mode 100644
index 000000000..5af660fab
--- /dev/null
+++ b/internal/sidechannel/conn.go
@@ -0,0 +1,181 @@
+package sidechannel
+
+import (
+ "fmt"
+ "io"
+ "net"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/git/pktline"
+ "gitlab.com/gitlab-org/gitaly/v14/streamio"
+)
+
+// ServerConn and ClientConn implement an asymmetric framing protocol to
+// exchange data between clients and servers in Sidechannel. A typical flow
+// looks like following:
+// - The client writes data into the connecton.
+// - The client half-closes the connection. The server is aware of this event
+// when reading operations return EOF.
+// - The server writes the data back to the client, then close the connection.
+// - The client read the data until EOF
+//
+// Half-close ability is important to signal the server that the client
+// finishes data transformation. As sidechannel is built on top of Yamux
+// stream, half-close ability is not supported. Therefore, we apply a
+// length-prefix framing protocol, simiarly to Git pktline protocol, except we
+// omit the band number. The close or half-close event are signaled by sending
+// a flush packet.
+//
+// This is an example of the data written into the wire:
+//
+// | 4-byte length, including size of length itself.
+// v
+// 0009Hello0000
+// ^
+// | Flush packet signaling a half-close event
+//
+// Many methods in battle-tested pktline package are re-used to save us some
+// times. At the moment, we don't need server-client half-closed ability. And
+// it may affect the performance when wrapping huge data sent from the server.
+
+const (
+ // maxChunkSize is the maximum chunk size of data. The chunk size must include 4-byte
+ // length prefix. This constant is different from MaxSidebandData because
+ // we don't include the sideband number.
+ maxChunkSize = pktline.MaxPktSize - 4
+)
+
+// ServerConn is a wrapper around net.Conn with the support of half-closed
+// capacity for sidechannel. This struct is expected to be used by
+// sidechannel's server only.
+type ServerConn struct {
+ conn net.Conn
+ r io.Reader
+}
+
+func newServerConn(c net.Conn) *ServerConn {
+ scanner := pktline.NewScanner(c)
+ reader := streamio.NewReader(func() ([]byte, error) {
+ if !scanner.Scan() {
+ if err := scanner.Err(); err != nil {
+ return nil, err
+ }
+ // If there is any error while scanning, scanner.Err() returns a
+ // non-nil error. If scanner.Err() returns nil, the connection
+ // reaches end-of-file. However, the effect of returning io.EOF is
+ // that we allow two kinds of streams: "000fhello world0000" (with
+ // trailing 0000) and "000fhello world" (without trialing 0000).
+ // Having optional behaviors like this is a source of complexity.
+ // We should not allow "000fhello world" without the trailing 0000.
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ if pktline.IsFlush(scanner.Bytes()) {
+ return nil, io.EOF
+ }
+
+ data := scanner.Bytes()
+ if len(data) < 4 {
+ return nil, fmt.Errorf("sidechannel: invalid packet %q", data)
+ }
+
+ // pktline treats 0001, 0002, or 0003 as magic empty packets
+ // They are irrelevant to sidechannel, hence should be rejected
+ if len(data) == 4 {
+ if s := string(data); s == "0001" || s == "0002" || s == "0003" {
+ return nil, fmt.Errorf("sidechannel: invalid header %s", string(data[3]))
+ }
+ }
+
+ return data[4:], nil
+ })
+
+ return &ServerConn{conn: c, r: reader}
+}
+
+// Read reads up to len(p) bytes into p. It returns the number of bytes read or
+// any error encountered. This struct overrides Read() to extract the data
+// wrapped in a frame generated by ClientConn.Write().
+func (cc *ServerConn) Read(p []byte) (n int, err error) {
+ return cc.r.Read(p)
+}
+
+// Write writes data to the connection. This method fallbacks to underlying
+// connection without any modificiation.
+func (cc *ServerConn) Write(b []byte) (n int, err error) {
+ return cc.conn.Write(b)
+}
+
+// Close closes the connection. This method fallbacks to underlying
+// connection without any modificiation.
+func (cc *ServerConn) Close() error {
+ return cc.conn.Close()
+}
+
+// ClientConn is a wrapper around net.Conn with the support of half-closed
+// capacity for sidechannel. This struct is expected to use by sidechannel's
+// client only.
+type ClientConn struct {
+ conn net.Conn
+ writeClosed bool
+}
+
+func newClientConn(c net.Conn) *ClientConn {
+ return &ClientConn{conn: c}
+}
+
+// Read reads data from the connection. This method fallbacks to underlying
+// connection without any modificiation.
+func (cc *ClientConn) Read(b []byte) (n int, err error) {
+ return cc.conn.Read(b)
+}
+
+// Write writes len(p) bytes from p to the underlying data stream. It returns
+// the number of bytes written from p and any error encountered that caused the
+// write to stop early. This method overrides Write() to wrap the writing data
+// into a frame. The frame is then extracted and read by ServerConn.Read().
+func (cc *ClientConn) Write(p []byte) (int, error) {
+ if cc.writeClosed {
+ return 0, fmt.Errorf("sidechannel: write into a half-closed connection")
+ }
+
+ var n int
+
+ for len(p) > 0 {
+ chunk := maxChunkSize
+ if len(p) < chunk {
+ chunk = len(p)
+ }
+
+ if _, err := fmt.Fprintf(cc.conn, "%04x", chunk+4); err != nil {
+ return n, err
+ }
+ if _, err := cc.conn.Write(p[:chunk]); err != nil {
+ return n, err
+ }
+ n += chunk
+ p = p[chunk:]
+ }
+
+ return n, nil
+}
+
+func (cc *ClientConn) close() error {
+ return cc.conn.Close()
+}
+
+// CloseWrite shuts down the writing side of the connection. After this call,
+// any read operations from the server return EOF. The reading side is still
+// functional so that the server is still able to write back to the client. Any
+// attempt to write into a half-closed connection returns an error.
+func (cc *ClientConn) CloseWrite() error {
+ if cc.writeClosed {
+ return nil
+ }
+
+ cc.writeClosed = true
+ if err := pktline.WriteFlush(cc.conn); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/sidechannel/conn_test.go b/internal/sidechannel/conn_test.go
new file mode 100644
index 000000000..257fd4273
--- /dev/null
+++ b/internal/sidechannel/conn_test.go
@@ -0,0 +1,179 @@
+package sidechannel
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestClientWrite(t *testing.T) {
+ largeString := strings.Repeat("a", maxChunkSize)
+
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ err error
+ }{
+ {desc: "empty", out: ""},
+ {desc: "1-byte string", in: "h", out: "0005h"},
+ {desc: "short string", in: "hello", out: "0009hello"},
+ {desc: "short string 2", in: "hello this world", out: "0014hello this world"},
+ {
+ desc: "large string",
+ in: largeString,
+ out: "fff0" + largeString,
+ },
+ {
+ desc: "very large string",
+ in: largeString + "b",
+ out: "fff0" + largeString + "0005b",
+ },
+ }
+
+ type result struct {
+ data string
+ err error
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, server := net.Pipe()
+ ch := make(chan result, 1)
+ go func() {
+ out, err := io.ReadAll(server)
+ ch <- result{data: string(out), err: err}
+ }()
+
+ cc := newClientConn(client)
+ written, err := cc.Write([]byte(tc.in))
+ if tc.err != nil {
+ require.Equal(t, tc.err, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NoError(t, cc.close())
+ require.Equal(t, written, len(tc.in))
+
+ res := <-ch
+ require.NoError(t, res.err)
+ require.Equal(t, tc.out, res.data)
+ })
+ }
+}
+
+func TestServerRead(t *testing.T) {
+ largeString := strings.Repeat("a", maxChunkSize)
+
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ err error
+ }{
+ {desc: "empty", in: "0000", out: ""},
+ {desc: "empty 2", in: "00040000", out: ""},
+ {desc: "1-byte string", in: "0005h0000", out: "h"},
+ {desc: "short string", in: "0009hello0000", out: "hello"},
+ {desc: "short string 2", in: "0014hello this world0000", out: "hello this world"},
+ {desc: "invalid header 1", in: "0001", err: fmt.Errorf("sidechannel: invalid header 1")},
+ {desc: "invalid header 2", in: "0002", err: fmt.Errorf("sidechannel: invalid header 2")},
+ {desc: "invalid header 3", in: "0003", err: fmt.Errorf("sidechannel: invalid header 3")},
+ {desc: "multiple short strings", in: "0009hello0008this0009world0000", out: "hellothisworld"},
+ {
+ desc: "large string",
+ in: "fff0" + largeString + "0000",
+ out: largeString,
+ },
+ {
+ desc: "very large string",
+ in: "fff0" + largeString + "0005b0000",
+ out: largeString + "b",
+ },
+ {desc: "flush packet", in: "0009hello0000trashtrash", out: "hello"},
+ {desc: "unexpected closed without trailing 0000", in: "0009hello", err: io.ErrUnexpectedEOF},
+ }
+
+ type result struct {
+ data string
+ err error
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, server := net.Pipe()
+ ch := make(chan result, 1)
+
+ go func() {
+ cc := newServerConn(server)
+ out, err := io.ReadAll(cc)
+ ch <- result{data: string(out), err: err}
+ }()
+
+ written, err := client.Write([]byte(tc.in))
+ require.Equal(t, written, len(tc.in))
+ require.NoError(t, err)
+
+ client.Close()
+
+ res := <-ch
+ if tc.err != nil {
+ require.Equal(t, tc.err, res.err)
+ return
+ }
+
+ require.NoError(t, res.err)
+ require.Equal(t, tc.out, res.data)
+ })
+ }
+}
+
+func TestHalfClose(t *testing.T) {
+ type result struct {
+ data string
+ err error
+ }
+
+ client, server := net.Pipe()
+ clientCc := newClientConn(client)
+ serverCc := newServerConn(server)
+
+ ch := make(chan result, 1)
+
+ go func() {
+ out, err := io.ReadAll(serverCc)
+ ch <- result{data: string(out), err: err}
+ }()
+
+ written, err := clientCc.Write([]byte("Ping"))
+ require.Equal(t, written, 4)
+ require.NoError(t, err)
+
+ require.NoError(t, clientCc.CloseWrite())
+
+ res := <-ch
+ require.NoError(t, res.err)
+ require.Equal(t, "Ping", res.data)
+
+ _, err = clientCc.Write([]byte("Should not be sent"))
+ require.EqualError(t, err, "sidechannel: write into a half-closed connection")
+
+ go func() {
+ out, err := io.ReadAll(clientCc)
+ ch <- result{data: string(out), err: err}
+ }()
+
+ _, err = serverCc.Write([]byte("Pong"))
+ require.NoError(t, err)
+
+ serverCc.Close()
+
+ res = <-ch
+ require.NoError(t, res.err)
+ require.Equal(t, "Pong", res.data)
+}
diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go
index 32985a9a1..1b96e8c88 100644
--- a/internal/sidechannel/registry.go
+++ b/internal/sidechannel/registry.go
@@ -25,7 +25,7 @@ type Waiter struct {
registry *Registry
errC chan error
accept chan net.Conn
- callback func(net.Conn) error
+ callback func(*ClientConn) error
}
// NewRegistry returns a new Registry instance
@@ -40,7 +40,7 @@ func NewRegistry() *Registry {
// connection arrives, the callback function is executed with arrived
// connection in a new goroutine. The caller receives execution result via
// waiter.Wait().
-func (s *Registry) Register(callback func(net.Conn) error) *Waiter {
+func (s *Registry) Register(callback func(*ClientConn) error) *Waiter {
s.mu.Lock()
defer s.mu.Unlock()
@@ -104,7 +104,7 @@ func (w *Waiter) run() {
if conn := <-w.accept; conn != nil {
defer conn.Close()
- w.errC <- w.callback(conn)
+ w.errC <- w.callback(newClientConn(conn))
}
}
diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go
index 07d2f85c1..4db4e6ac5 100644
--- a/internal/sidechannel/registry_test.go
+++ b/internal/sidechannel/registry_test.go
@@ -4,8 +4,10 @@ import (
"fmt"
"io"
"net"
+ "os"
"strconv"
"sync"
+ "syscall"
"testing"
"github.com/stretchr/testify/require"
@@ -17,7 +19,7 @@ func TestRegistry(t *testing.T) {
t.Run("waiter removed from the registry right after connection received", func(t *testing.T) {
triggerCallback := make(chan struct{})
- waiter := registry.Register(func(conn net.Conn) error {
+ waiter := registry.Register(func(conn *ClientConn) error {
<-triggerCallback
return nil
})
@@ -25,7 +27,7 @@ func TestRegistry(t *testing.T) {
require.Equal(t, 1, registry.waiting())
- client, _ := net.Pipe()
+ client, _ := socketPair(t)
require.NoError(t, registry.receive(waiter.id, client))
require.Equal(t, 0, registry.waiting())
@@ -37,17 +39,20 @@ func TestRegistry(t *testing.T) {
t.Run("pull connections successfully", func(t *testing.T) {
wg := sync.WaitGroup{}
- var servers []net.Conn
+ var servers []*ServerConn
for i := 0; i < N; i++ {
- client, server := net.Pipe()
- servers = append(servers, server)
+ client, server := socketPair(t)
+ servers = append(servers, newServerConn(server))
wg.Add(1)
go func(i int) {
- waiter := registry.Register(func(conn net.Conn) error {
- _, err := fmt.Fprintf(conn, "%d", i)
- return err
+ waiter := registry.Register(func(conn *ClientConn) error {
+ if _, err := fmt.Fprintf(conn, "%d", i); err != nil {
+ return err
+ }
+
+ return conn.CloseWrite()
})
defer waiter.Close()
@@ -70,14 +75,14 @@ func TestRegistry(t *testing.T) {
})
t.Run("push connection to non-existing ID", func(t *testing.T) {
- client, _ := net.Pipe()
+ client, _ := socketPair(t)
err := registry.receive(registry.nextID+1, client)
require.EqualError(t, err, "sidechannel registry: ID not registered")
requireConnClosed(t, client)
})
t.Run("pre-maturely close the waiter", func(t *testing.T) {
- waiter := registry.Register(func(conn net.Conn) error { panic("never execute") })
+ waiter := registry.Register(func(conn *ClientConn) error { panic("never execute") })
require.NoError(t, waiter.Close())
require.Equal(t, 0, registry.waiting())
})
@@ -86,7 +91,23 @@ func TestRegistry(t *testing.T) {
func requireConnClosed(t *testing.T, conn net.Conn) {
one := make([]byte, 1)
_, err := conn.Read(one)
- require.EqualError(t, err, "io: read/write on closed pipe")
+ require.Errorf(t, err, "use of closed network connection")
_, err = conn.Write(one)
- require.EqualError(t, err, "io: read/write on closed pipe")
+ require.Errorf(t, err, "use of closed network connection")
+}
+
+func socketPair(t *testing.T) (net.Conn, net.Conn) {
+ conns := make([]net.Conn, 2)
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ require.NoError(t, err)
+
+ for i, fd := range fds[:] {
+ f := os.NewFile(uintptr(fd), "socket pair")
+ c, err := net.FileConn(f)
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+ t.Cleanup(func() { c.Close() })
+ conns[i] = c
+ }
+ return conns[0], conns[1]
}
diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go
index c133fcc6e..a2083fb04 100644
--- a/internal/sidechannel/sidechannel.go
+++ b/internal/sidechannel/sidechannel.go
@@ -26,7 +26,7 @@ const (
// OpenSidechannel opens a sidechannel connection from the stream opener
// extracted from the current peer connection.
-func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) {
+func OpenSidechannel(ctx context.Context) (_ *ServerConn, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, fmt.Errorf("sidechannel: failed to extract incoming metadata")
@@ -68,7 +68,7 @@ func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) {
return nil, err
}
- return stream, nil
+ return newServerConn(stream), nil
}
// RegisterSidechannel registers the caller into the waiting list of the
@@ -76,7 +76,7 @@ func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) {
// The caller is expected to establish the request with the returned context. The
// callback is executed automatically when the sidechannel connection arrives.
// The result is pushed to the error channel of the returned waiter.
-func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(net.Conn) error) (context.Context, *Waiter) {
+func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(*ClientConn) error) (context.Context, *Waiter) {
waiter := registry.Register(callback)
ctxOut := metadata.AppendToOutgoingContext(ctx, sidechannelMetadataKey, fmt.Sprintf("%d", waiter.id))
return ctxOut, waiter
diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go
index c85517952..f1d625504 100644
--- a/internal/sidechannel/sidechannel_test.go
+++ b/internal/sidechannel/sidechannel_test.go
@@ -47,7 +47,7 @@ func TestSidechannel(t *testing.T) {
conn, registry := dial(t, addr)
err = call(
context.Background(), conn, registry,
- func(conn net.Conn) error {
+ func(conn *ClientConn) error {
errC := make(chan error, 1)
go func() {
var err error
@@ -113,7 +113,7 @@ func TestSidechannelConcurrency(t *testing.T) {
err := call(
context.Background(), conn, registry,
- func(conn net.Conn) error {
+ func(conn *ClientConn) error {
errC := make(chan error, 1)
go func() {
var err error
@@ -187,7 +187,7 @@ func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) {
return conn, registry
}
-func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(net.Conn) error) error {
+func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(*ClientConn) error) error {
client := healthpb.NewHealthClient(conn)
ctxOut, waiter := RegisterSidechannel(ctx, registry, handler)
diff --git a/proto/go/gitalypb/protolist.go b/proto/go/gitalypb/protolist.go
index a15916f70..afb0510ae 100644
--- a/proto/go/gitalypb/protolist.go
+++ b/proto/go/gitalypb/protolist.go
@@ -9,6 +9,7 @@ var GitalyProtos = []string{
"commit.proto",
"conflicts.proto",
"diff.proto",
+ "errors.proto",
"hook.proto",
"internal.proto",
"lint.proto",