Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSami Hiltunen <shiltunen@gitlab.com>2021-09-09 11:32:29 +0300
committerSami Hiltunen <shiltunen@gitlab.com>2021-09-09 11:32:29 +0300
commit4534c1b11dfc9cb500ca792a058e15bb4e710e52 (patch)
tree322f1304e45ac528bc1143c4912efae6e0c9825b
parent8c4e8fe7b42e1f3d98a8ffb9a7090d171707b99a (diff)
parentfd80520b100eb0d0af17054eae1623129fb966bc (diff)
Merge branch 'qmnguyen0711/1216-create-sidechannel-as-a-sub-protocol-of-backchannel' into 'master'
Create "sidechannel" as a sub-protocol of "backchannel" See merge request gitlab-org/gitaly!3768
-rw-r--r--internal/backchannel/server.go29
-rw-r--r--internal/sidechannel/registry.go121
-rw-r--r--internal/sidechannel/registry_test.go95
-rw-r--r--internal/sidechannel/sidechannel.go121
-rw-r--r--internal/sidechannel/sidechannel_test.go220
5 files changed, 579 insertions, 7 deletions
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go
index b2a1f58da..1d40c0241 100644
--- a/internal/backchannel/server.go
+++ b/internal/backchannel/server.go
@@ -19,11 +19,13 @@ var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection")
// authInfoWrapper is used to pass the peer id through the context to the RPC handlers.
type authInfoWrapper struct {
- id ID
+ id ID
+ session *yamux.Session
credentials.AuthInfo
}
-func (w authInfoWrapper) peerID() ID { return w.id }
+func (w authInfoWrapper) peerID() ID { return w.id }
+func (w authInfoWrapper) yamuxSession() *yamux.Session { return w.session }
// GetPeerID gets the ID of the current peer connection.
func GetPeerID(ctx context.Context) (ID, error) {
@@ -40,10 +42,23 @@ func GetPeerID(ctx context.Context) (ID, error) {
return wrapper.peerID(), nil
}
-// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler.
-// This is exported to facilitate testing.
-func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo {
- return authInfoWrapper{id: id, AuthInfo: authInfo}
+// GetYamuxSession gets the yamux session of the current peer connection.
+func GetYamuxSession(ctx context.Context) (*yamux.Session, error) {
+ peerInfo, ok := peer.FromContext(ctx)
+ if !ok {
+ return nil, errors.New("no peer info in context")
+ }
+
+ wrapper, ok := peerInfo.AuthInfo.(interface{ yamuxSession() *yamux.Session })
+ if !ok {
+ return nil, ErrNonMultiplexedConnection
+ }
+
+ return wrapper.yamuxSession(), nil
+}
+
+func withSessionInfo(authInfo credentials.AuthInfo, id ID, muxSession *yamux.Session) credentials.AuthInfo {
+ return authInfoWrapper{id: id, AuthInfo: authInfo, session: muxSession}
}
// ServerHandshaker implements the server side handshake of the multiplexed connection.
@@ -119,6 +134,6 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf
return nil
},
},
- WithID(authInfo, id),
+ withSessionInfo(authInfo, id, muxSession),
nil
}
diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go
new file mode 100644
index 000000000..50726ab7e
--- /dev/null
+++ b/internal/sidechannel/registry.go
@@ -0,0 +1,121 @@
+package sidechannel
+
+import (
+ "fmt"
+ "net"
+ "sync"
+)
+
+// sidechannelID is the type of ID used to differeniate sidechannel connections
+// in the same registry
+type sidechannelID int64
+
+// Registry manages sidechannel connections. It allows the RPC
+// handlers to wait for the secondary incoming connection made by the client.
+type Registry struct {
+ nextID sidechannelID
+ waiters map[sidechannelID]*Waiter
+ mu sync.Mutex
+}
+
+// Waiter lets the caller waits until a connection with matched id is pushed
+// into the registry, then execute the callback
+type Waiter struct {
+ id sidechannelID
+ registry *Registry
+ errC chan error
+ accept chan net.Conn
+ callback func(net.Conn) error
+}
+
+// NewRegistry returns a new Registry instance
+func NewRegistry() *Registry {
+ return &Registry{
+ waiters: make(map[sidechannelID]*Waiter),
+ }
+}
+
+// Register registers the caller into the waiting list. The caller must provide
+// a callback function. The caller receives a waiter instance. After the
+// connection arrives, the callback function is executed with arrived
+// connection in a new goroutine. The caller receives execution result via
+// waiter.Wait().
+func (s *Registry) Register(callback func(net.Conn) error) (*Waiter, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ waiter := &Waiter{
+ id: s.nextID,
+ registry: s,
+ errC: make(chan error),
+ accept: make(chan net.Conn),
+ callback: callback,
+ }
+ s.nextID++
+
+ go waiter.run()
+ s.waiters[waiter.id] = waiter
+ return waiter, nil
+}
+
+// receive looks into the registry for a waiter with the given ID. If
+// there is an associated ID, the waiter is removed from the registry, and the
+// connection is pushed into the waiter's accept channel. After the callback is done, the
+// connection is closed. When the ID is not found, an error is returned and the
+// connection is closed immediately.
+func (s *Registry) receive(id sidechannelID, conn net.Conn) (err error) {
+ s.mu.Lock()
+ defer func() {
+ s.mu.Unlock()
+ if err != nil {
+ conn.Close()
+ }
+ }()
+
+ waiter, exist := s.waiters[id]
+ if !exist {
+ return fmt.Errorf("sidechannel registry: ID not registered")
+ }
+ delete(s.waiters, waiter.id)
+ waiter.accept <- conn
+
+ return nil
+}
+
+func (s *Registry) removeWaiter(waiter *Waiter) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if _, exist := s.waiters[waiter.id]; exist {
+ delete(s.waiters, waiter.id)
+ close(waiter.accept)
+ }
+}
+
+func (s *Registry) waiting() int {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return len(s.waiters)
+}
+
+func (w *Waiter) run() {
+ defer close(w.errC)
+
+ if conn := <-w.accept; conn != nil {
+ defer conn.Close()
+ w.errC <- w.callback(conn)
+ }
+}
+
+// Close cleans the waiter, removes it from the registry. If the callback is
+// executing, this method is blocked until the callback is done.
+func (w *Waiter) Close() error {
+ w.registry.removeWaiter(w)
+ return <-w.errC
+}
+
+// Wait waits until either the callback is executed, or the waiter is closed
+func (w *Waiter) Wait() error {
+ return <-w.errC
+}
diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go
new file mode 100644
index 000000000..a0a4a8bf7
--- /dev/null
+++ b/internal/sidechannel/registry_test.go
@@ -0,0 +1,95 @@
+package sidechannel
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net"
+ "strconv"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRegistry(t *testing.T) {
+ const N = 10
+ registry := NewRegistry()
+
+ t.Run("waiter removed from the registry right after connection received", func(t *testing.T) {
+ triggerCallback := make(chan struct{})
+ waiter, err := registry.Register(func(conn net.Conn) error {
+ <-triggerCallback
+ return nil
+ })
+ require.NoError(t, err)
+ defer waiter.Close()
+
+ require.Equal(t, 1, registry.waiting())
+
+ client, _ := net.Pipe()
+ require.NoError(t, registry.receive(waiter.id, client))
+ require.Equal(t, 0, registry.waiting())
+
+ close(triggerCallback)
+
+ require.NoError(t, waiter.Wait())
+ requireConnClosed(t, client)
+ })
+
+ t.Run("pull connections successfully", func(t *testing.T) {
+ wg := sync.WaitGroup{}
+ var servers []net.Conn
+
+ for i := 0; i < N; i++ {
+ client, server := net.Pipe()
+ servers = append(servers, server)
+
+ wg.Add(1)
+ go func(i int) {
+ waiter, err := registry.Register(func(conn net.Conn) error {
+ _, err := fmt.Fprintf(conn, "%d", i)
+ return err
+ })
+ require.NoError(t, err)
+ defer waiter.Close()
+
+ require.NoError(t, registry.receive(waiter.id, client))
+ require.NoError(t, waiter.Wait())
+ requireConnClosed(t, client)
+
+ wg.Done()
+ }(i)
+ }
+
+ for i := 0; i < N; i++ {
+ out, err := ioutil.ReadAll(servers[i])
+ require.NoError(t, err)
+ require.Equal(t, strconv.Itoa(i), string(out))
+ }
+
+ wg.Wait()
+ require.Equal(t, 0, registry.waiting())
+ })
+
+ t.Run("push connection to non-existing ID", func(t *testing.T) {
+ client, _ := net.Pipe()
+ err := registry.receive(registry.nextID+1, client)
+ require.EqualError(t, err, "sidechannel registry: ID not registered")
+ requireConnClosed(t, client)
+ })
+
+ t.Run("pre-maturely close the waiter", func(t *testing.T) {
+ waiter, err := registry.Register(func(conn net.Conn) error { panic("never execute") })
+ require.NoError(t, err)
+ require.NoError(t, waiter.Close())
+ require.Equal(t, 0, registry.waiting())
+ })
+}
+
+func requireConnClosed(t *testing.T, conn net.Conn) {
+ one := make([]byte, 1)
+ _, err := conn.Read(one)
+ require.EqualError(t, err, "io: read/write on closed pipe")
+ _, err = conn.Write(one)
+ require.EqualError(t, err, "io: read/write on closed pipe")
+}
diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go
new file mode 100644
index 000000000..eb716b022
--- /dev/null
+++ b/internal/sidechannel/sidechannel.go
@@ -0,0 +1,121 @@
+package sidechannel
+
+import (
+ "context"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "strconv"
+ "time"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/metadata"
+)
+
+var magicBytes = []byte("sidechannel")
+
+// sidechannelTimeout is the timeout for establishing a sidechannel
+// connection. The sidechannel is supposed to be opened on the same wire with
+// incoming grpc request. There won't be real handshaking involved, so it
+// should be fast.
+const (
+ sidechannelTimeout = 5 * time.Second
+ sidechannelMetadataKey = "gitaly-sidechannel-id"
+)
+
+// OpenSidechannel opens a sidechannel connection from the stream opener
+// extracted from the current peer connection.
+func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) {
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return nil, fmt.Errorf("sidechannel: failed to extract incoming metadata")
+ }
+ ids := md.Get(sidechannelMetadataKey)
+ if len(ids) == 0 {
+ return nil, fmt.Errorf("sidechannel: sidechannel-id not found in incoming metadata")
+ }
+ sidechannelID, _ := strconv.ParseInt(ids[len(ids)-1], 10, 64)
+
+ muxSession, err := backchannel.GetYamuxSession(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("sidechannel: fail to extract yamux session: %w", err)
+ }
+
+ stream, err := muxSession.Open()
+ if err != nil {
+ return nil, fmt.Errorf("sidechannel: open stream: %w", err)
+ }
+ defer func() {
+ if err != nil {
+ stream.Close()
+ }
+ }()
+
+ if err := stream.SetDeadline(time.Now().Add(sidechannelTimeout)); err != nil {
+ return nil, err
+ }
+
+ if _, err := stream.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("sidechannel: write magic bytes: %w", err)
+ }
+
+ if err := binary.Write(stream, binary.BigEndian, sidechannelID); err != nil {
+ return nil, fmt.Errorf("sidechannel: write stream id: %w", err)
+ }
+
+ if err := stream.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return stream, nil
+}
+
+// RegisterSidechannel registers the caller into the waiting list of the
+// sidechannel registry and injects the sidechannel ID into outgoing metadata.
+// The caller is expected to establish the request with the returned context. The
+// callback is executed automatically when the sidechannel connection arrives.
+// The result is pushed to the error channel of the returned waiter.
+func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(net.Conn) error) (context.Context, *Waiter, error) {
+ waiter, err := registry.Register(callback)
+ if err != nil {
+ return ctx, nil, err
+ }
+
+ ctxOut := metadata.AppendToOutgoingContext(ctx, sidechannelMetadataKey, fmt.Sprintf("%d", waiter.id))
+ return ctxOut, waiter, nil
+}
+
+// ServerHandshaker implements the server-side sidechannel handshake.
+type ServerHandshaker struct {
+ registry *Registry
+}
+
+// Magic returns the magic bytes for sidechannel
+func (s *ServerHandshaker) Magic() string {
+ return string(magicBytes)
+}
+
+// Handshake implements the handshaking logic for sidechannel so that
+// this handshaker reads the sidechannel ID from the wire, and then delegates
+// the connection to the sidechannel registry
+func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ var sidechannelID sidechannelID
+ if err := binary.Read(conn, binary.BigEndian, &sidechannelID); err != nil {
+ return nil, nil, fmt.Errorf("sidechannel: fail to extract sidechannel ID: %w", err)
+ }
+
+ if err := s.registry.receive(sidechannelID, conn); err != nil {
+ return nil, nil, err
+ }
+
+ // credentials.ErrConnDispatched, indicating that the connection is already
+ // dispatched out of gRPC. gRPC should leave it alone and exit in peace.
+ return nil, nil, credentials.ErrConnDispatched
+}
+
+// NewServerHandshaker creates a new handshaker for sidechannel to
+// embed into listenmux.
+func NewServerHandshaker(registry *Registry) *ServerHandshaker {
+ return &ServerHandshaker{registry: registry}
+}
diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go
new file mode 100644
index 000000000..20fcd1d80
--- /dev/null
+++ b/internal/sidechannel/sidechannel_test.go
@@ -0,0 +1,220 @@
+package sidechannel
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "sync"
+ "testing"
+
+ "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ healthpb "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func TestSidechannel(t *testing.T) {
+ const blobSize = 1024 * 1024
+
+ in := make([]byte, blobSize)
+ _, err := rand.Read(in)
+ require.NoError(t, err)
+
+ var out []byte
+ require.NotEqual(t, in, out)
+
+ addr := startServer(
+ t,
+ func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
+ conn, err := OpenSidechannel(context)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ if _, err = io.CopyN(conn, conn, blobSize); err != nil {
+ return nil, err
+ }
+ return &healthpb.HealthCheckResponse{}, conn.Close()
+ },
+ )
+
+ conn, registry := dial(t, addr)
+ err = call(
+ context.Background(), conn, registry,
+ func(conn net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ out, err = ioutil.ReadAll(conn)
+ errC <- err
+ }()
+
+ _, err = io.Copy(conn, bytes.NewReader(in))
+ require.NoError(t, err)
+ require.NoError(t, <-errC)
+
+ return nil
+ },
+ )
+ require.NoError(t, err)
+ require.Equal(t, in, out, "byte stream works")
+}
+
+// Conduct multiple requests with sidechannel included on the same grpc
+// connection.
+func TestSidechannelConcurrency(t *testing.T) {
+ const concurrency = 10
+ const blobSize = 1024 * 1024
+
+ ins := make([][]byte, concurrency)
+ for i := 0; i < concurrency; i++ {
+ ins[i] = make([]byte, blobSize)
+ _, err := rand.Read(ins[i])
+ require.NoError(t, err)
+ }
+
+ outs := make([][]byte, concurrency)
+ for i := 0; i < concurrency; i++ {
+ require.NotEqual(t, ins[i], outs[i])
+ }
+
+ addr := startServer(
+ t,
+ func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
+ conn, err := OpenSidechannel(context)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ if _, err = io.CopyN(conn, conn, blobSize); err != nil {
+ return nil, err
+ }
+
+ return &healthpb.HealthCheckResponse{}, conn.Close()
+ },
+ )
+
+ conn, registry := dial(t, addr)
+
+ errors := make(chan error, concurrency)
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+
+ err := call(
+ context.Background(), conn, registry,
+ func(conn net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ outs[i], err = ioutil.ReadAll(conn)
+ errC <- err
+ }()
+
+ if _, err := io.Copy(conn, bytes.NewReader(ins[i])); err != nil {
+ return err
+ }
+ if err := <-errC; err != nil {
+ return err
+ }
+
+ return nil
+ },
+ )
+ errors <- err
+ }(i)
+ }
+ wg.Wait()
+
+ for i := 0; i < concurrency; i++ {
+ require.Equal(t, ins[i], outs[i], "byte stream works")
+ require.NoError(t, <-errors)
+ }
+}
+
+func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string {
+ t.Helper()
+
+ logger := logrus.NewEntry(logrus.New())
+
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
+
+ opts = append(opts, grpc.Creds(lm))
+
+ s := grpc.NewServer(opts...)
+ t.Cleanup(func() { s.Stop() })
+
+ handler := &server{testHandler: th}
+ healthpb.RegisterHealthServer(s, handler)
+
+ lis, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+ t.Cleanup(func() { lis.Close() })
+
+ go func() { s.Serve(lis) }()
+
+ return lis.Addr().String()
+}
+
+func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) {
+ registry := NewRegistry()
+ logger := logrus.NewEntry(logrus.New())
+
+ factory := func() backchannel.Server {
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(NewServerHandshaker(registry))
+ return grpc.NewServer(grpc.Creds(lm))
+ }
+
+ clientHandshaker := backchannel.NewClientHandshaker(logger, factory)
+ dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials()))
+
+ conn, err := grpc.Dial(addr, dialOpt)
+ require.NoError(t, err)
+ t.Cleanup(func() { conn.Close() })
+
+ return conn, registry
+}
+
+func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(net.Conn) error) error {
+ client := healthpb.NewHealthClient(conn)
+
+ ctxOut, waiter, err := RegisterSidechannel(ctx, registry, handler)
+ defer waiter.Close()
+ if err != nil {
+ return err
+ }
+
+ if _, err := client.Check(ctxOut, &healthpb.HealthCheckRequest{}); err != nil {
+ return err
+ }
+
+ if err := waiter.Wait(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+type testHandler func(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error)
+
+type server struct {
+ healthpb.UnimplementedHealthServer
+ testHandler
+}
+
+func (s *server) Check(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
+ return s.testHandler(context, request)
+}