diff options
author | Sami Hiltunen <shiltunen@gitlab.com> | 2021-09-09 11:32:29 +0300 |
---|---|---|
committer | Sami Hiltunen <shiltunen@gitlab.com> | 2021-09-09 11:32:29 +0300 |
commit | 4534c1b11dfc9cb500ca792a058e15bb4e710e52 (patch) | |
tree | 322f1304e45ac528bc1143c4912efae6e0c9825b | |
parent | 8c4e8fe7b42e1f3d98a8ffb9a7090d171707b99a (diff) | |
parent | fd80520b100eb0d0af17054eae1623129fb966bc (diff) |
Merge branch 'qmnguyen0711/1216-create-sidechannel-as-a-sub-protocol-of-backchannel' into 'master'
Create "sidechannel" as a sub-protocol of "backchannel"
See merge request gitlab-org/gitaly!3768
-rw-r--r-- | internal/backchannel/server.go | 29 | ||||
-rw-r--r-- | internal/sidechannel/registry.go | 121 | ||||
-rw-r--r-- | internal/sidechannel/registry_test.go | 95 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel.go | 121 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel_test.go | 220 |
5 files changed, 579 insertions, 7 deletions
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go index b2a1f58da..1d40c0241 100644 --- a/internal/backchannel/server.go +++ b/internal/backchannel/server.go @@ -19,11 +19,13 @@ var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection") // authInfoWrapper is used to pass the peer id through the context to the RPC handlers. type authInfoWrapper struct { - id ID + id ID + session *yamux.Session credentials.AuthInfo } -func (w authInfoWrapper) peerID() ID { return w.id } +func (w authInfoWrapper) peerID() ID { return w.id } +func (w authInfoWrapper) yamuxSession() *yamux.Session { return w.session } // GetPeerID gets the ID of the current peer connection. func GetPeerID(ctx context.Context) (ID, error) { @@ -40,10 +42,23 @@ func GetPeerID(ctx context.Context) (ID, error) { return wrapper.peerID(), nil } -// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler. -// This is exported to facilitate testing. -func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo { - return authInfoWrapper{id: id, AuthInfo: authInfo} +// GetYamuxSession gets the yamux session of the current peer connection. +func GetYamuxSession(ctx context.Context) (*yamux.Session, error) { + peerInfo, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("no peer info in context") + } + + wrapper, ok := peerInfo.AuthInfo.(interface{ yamuxSession() *yamux.Session }) + if !ok { + return nil, ErrNonMultiplexedConnection + } + + return wrapper.yamuxSession(), nil +} + +func withSessionInfo(authInfo credentials.AuthInfo, id ID, muxSession *yamux.Session) credentials.AuthInfo { + return authInfoWrapper{id: id, AuthInfo: authInfo, session: muxSession} } // ServerHandshaker implements the server side handshake of the multiplexed connection. @@ -119,6 +134,6 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf return nil }, }, - WithID(authInfo, id), + withSessionInfo(authInfo, id, muxSession), nil } diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go new file mode 100644 index 000000000..50726ab7e --- /dev/null +++ b/internal/sidechannel/registry.go @@ -0,0 +1,121 @@ +package sidechannel + +import ( + "fmt" + "net" + "sync" +) + +// sidechannelID is the type of ID used to differeniate sidechannel connections +// in the same registry +type sidechannelID int64 + +// Registry manages sidechannel connections. It allows the RPC +// handlers to wait for the secondary incoming connection made by the client. +type Registry struct { + nextID sidechannelID + waiters map[sidechannelID]*Waiter + mu sync.Mutex +} + +// Waiter lets the caller waits until a connection with matched id is pushed +// into the registry, then execute the callback +type Waiter struct { + id sidechannelID + registry *Registry + errC chan error + accept chan net.Conn + callback func(net.Conn) error +} + +// NewRegistry returns a new Registry instance +func NewRegistry() *Registry { + return &Registry{ + waiters: make(map[sidechannelID]*Waiter), + } +} + +// Register registers the caller into the waiting list. The caller must provide +// a callback function. The caller receives a waiter instance. After the +// connection arrives, the callback function is executed with arrived +// connection in a new goroutine. The caller receives execution result via +// waiter.Wait(). +func (s *Registry) Register(callback func(net.Conn) error) (*Waiter, error) { + s.mu.Lock() + defer s.mu.Unlock() + + waiter := &Waiter{ + id: s.nextID, + registry: s, + errC: make(chan error), + accept: make(chan net.Conn), + callback: callback, + } + s.nextID++ + + go waiter.run() + s.waiters[waiter.id] = waiter + return waiter, nil +} + +// receive looks into the registry for a waiter with the given ID. If +// there is an associated ID, the waiter is removed from the registry, and the +// connection is pushed into the waiter's accept channel. After the callback is done, the +// connection is closed. When the ID is not found, an error is returned and the +// connection is closed immediately. +func (s *Registry) receive(id sidechannelID, conn net.Conn) (err error) { + s.mu.Lock() + defer func() { + s.mu.Unlock() + if err != nil { + conn.Close() + } + }() + + waiter, exist := s.waiters[id] + if !exist { + return fmt.Errorf("sidechannel registry: ID not registered") + } + delete(s.waiters, waiter.id) + waiter.accept <- conn + + return nil +} + +func (s *Registry) removeWaiter(waiter *Waiter) { + s.mu.Lock() + defer s.mu.Unlock() + + if _, exist := s.waiters[waiter.id]; exist { + delete(s.waiters, waiter.id) + close(waiter.accept) + } +} + +func (s *Registry) waiting() int { + s.mu.Lock() + defer s.mu.Unlock() + + return len(s.waiters) +} + +func (w *Waiter) run() { + defer close(w.errC) + + if conn := <-w.accept; conn != nil { + defer conn.Close() + w.errC <- w.callback(conn) + } +} + +// Close cleans the waiter, removes it from the registry. If the callback is +// executing, this method is blocked until the callback is done. +func (w *Waiter) Close() error { + w.registry.removeWaiter(w) + return <-w.errC +} + +// Wait waits until either the callback is executed, or the waiter is closed +func (w *Waiter) Wait() error { + return <-w.errC +} diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go new file mode 100644 index 000000000..a0a4a8bf7 --- /dev/null +++ b/internal/sidechannel/registry_test.go @@ -0,0 +1,95 @@ +package sidechannel + +import ( + "fmt" + "io/ioutil" + "net" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRegistry(t *testing.T) { + const N = 10 + registry := NewRegistry() + + t.Run("waiter removed from the registry right after connection received", func(t *testing.T) { + triggerCallback := make(chan struct{}) + waiter, err := registry.Register(func(conn net.Conn) error { + <-triggerCallback + return nil + }) + require.NoError(t, err) + defer waiter.Close() + + require.Equal(t, 1, registry.waiting()) + + client, _ := net.Pipe() + require.NoError(t, registry.receive(waiter.id, client)) + require.Equal(t, 0, registry.waiting()) + + close(triggerCallback) + + require.NoError(t, waiter.Wait()) + requireConnClosed(t, client) + }) + + t.Run("pull connections successfully", func(t *testing.T) { + wg := sync.WaitGroup{} + var servers []net.Conn + + for i := 0; i < N; i++ { + client, server := net.Pipe() + servers = append(servers, server) + + wg.Add(1) + go func(i int) { + waiter, err := registry.Register(func(conn net.Conn) error { + _, err := fmt.Fprintf(conn, "%d", i) + return err + }) + require.NoError(t, err) + defer waiter.Close() + + require.NoError(t, registry.receive(waiter.id, client)) + require.NoError(t, waiter.Wait()) + requireConnClosed(t, client) + + wg.Done() + }(i) + } + + for i := 0; i < N; i++ { + out, err := ioutil.ReadAll(servers[i]) + require.NoError(t, err) + require.Equal(t, strconv.Itoa(i), string(out)) + } + + wg.Wait() + require.Equal(t, 0, registry.waiting()) + }) + + t.Run("push connection to non-existing ID", func(t *testing.T) { + client, _ := net.Pipe() + err := registry.receive(registry.nextID+1, client) + require.EqualError(t, err, "sidechannel registry: ID not registered") + requireConnClosed(t, client) + }) + + t.Run("pre-maturely close the waiter", func(t *testing.T) { + waiter, err := registry.Register(func(conn net.Conn) error { panic("never execute") }) + require.NoError(t, err) + require.NoError(t, waiter.Close()) + require.Equal(t, 0, registry.waiting()) + }) +} + +func requireConnClosed(t *testing.T, conn net.Conn) { + one := make([]byte, 1) + _, err := conn.Read(one) + require.EqualError(t, err, "io: read/write on closed pipe") + _, err = conn.Write(one) + require.EqualError(t, err, "io: read/write on closed pipe") +} diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go new file mode 100644 index 000000000..eb716b022 --- /dev/null +++ b/internal/sidechannel/sidechannel.go @@ -0,0 +1,121 @@ +package sidechannel + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "time" + + "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +var magicBytes = []byte("sidechannel") + +// sidechannelTimeout is the timeout for establishing a sidechannel +// connection. The sidechannel is supposed to be opened on the same wire with +// incoming grpc request. There won't be real handshaking involved, so it +// should be fast. +const ( + sidechannelTimeout = 5 * time.Second + sidechannelMetadataKey = "gitaly-sidechannel-id" +) + +// OpenSidechannel opens a sidechannel connection from the stream opener +// extracted from the current peer connection. +func OpenSidechannel(ctx context.Context) (_ net.Conn, err error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, fmt.Errorf("sidechannel: failed to extract incoming metadata") + } + ids := md.Get(sidechannelMetadataKey) + if len(ids) == 0 { + return nil, fmt.Errorf("sidechannel: sidechannel-id not found in incoming metadata") + } + sidechannelID, _ := strconv.ParseInt(ids[len(ids)-1], 10, 64) + + muxSession, err := backchannel.GetYamuxSession(ctx) + if err != nil { + return nil, fmt.Errorf("sidechannel: fail to extract yamux session: %w", err) + } + + stream, err := muxSession.Open() + if err != nil { + return nil, fmt.Errorf("sidechannel: open stream: %w", err) + } + defer func() { + if err != nil { + stream.Close() + } + }() + + if err := stream.SetDeadline(time.Now().Add(sidechannelTimeout)); err != nil { + return nil, err + } + + if _, err := stream.Write(magicBytes); err != nil { + return nil, fmt.Errorf("sidechannel: write magic bytes: %w", err) + } + + if err := binary.Write(stream, binary.BigEndian, sidechannelID); err != nil { + return nil, fmt.Errorf("sidechannel: write stream id: %w", err) + } + + if err := stream.SetDeadline(time.Time{}); err != nil { + return nil, err + } + + return stream, nil +} + +// RegisterSidechannel registers the caller into the waiting list of the +// sidechannel registry and injects the sidechannel ID into outgoing metadata. +// The caller is expected to establish the request with the returned context. The +// callback is executed automatically when the sidechannel connection arrives. +// The result is pushed to the error channel of the returned waiter. +func RegisterSidechannel(ctx context.Context, registry *Registry, callback func(net.Conn) error) (context.Context, *Waiter, error) { + waiter, err := registry.Register(callback) + if err != nil { + return ctx, nil, err + } + + ctxOut := metadata.AppendToOutgoingContext(ctx, sidechannelMetadataKey, fmt.Sprintf("%d", waiter.id)) + return ctxOut, waiter, nil +} + +// ServerHandshaker implements the server-side sidechannel handshake. +type ServerHandshaker struct { + registry *Registry +} + +// Magic returns the magic bytes for sidechannel +func (s *ServerHandshaker) Magic() string { + return string(magicBytes) +} + +// Handshake implements the handshaking logic for sidechannel so that +// this handshaker reads the sidechannel ID from the wire, and then delegates +// the connection to the sidechannel registry +func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + var sidechannelID sidechannelID + if err := binary.Read(conn, binary.BigEndian, &sidechannelID); err != nil { + return nil, nil, fmt.Errorf("sidechannel: fail to extract sidechannel ID: %w", err) + } + + if err := s.registry.receive(sidechannelID, conn); err != nil { + return nil, nil, err + } + + // credentials.ErrConnDispatched, indicating that the connection is already + // dispatched out of gRPC. gRPC should leave it alone and exit in peace. + return nil, nil, credentials.ErrConnDispatched +} + +// NewServerHandshaker creates a new handshaker for sidechannel to +// embed into listenmux. +func NewServerHandshaker(registry *Registry) *ServerHandshaker { + return &ServerHandshaker{registry: registry} +} diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go new file mode 100644 index 000000000..20fcd1d80 --- /dev/null +++ b/internal/sidechannel/sidechannel_test.go @@ -0,0 +1,220 @@ +package sidechannel + +import ( + "bytes" + "context" + "io" + "io/ioutil" + "math/rand" + "net" + "sync" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +func TestSidechannel(t *testing.T) { + const blobSize = 1024 * 1024 + + in := make([]byte, blobSize) + _, err := rand.Read(in) + require.NoError(t, err) + + var out []byte + require.NotEqual(t, in, out) + + addr := startServer( + t, + func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + conn, err := OpenSidechannel(context) + if err != nil { + return nil, err + } + defer conn.Close() + + if _, err = io.CopyN(conn, conn, blobSize); err != nil { + return nil, err + } + return &healthpb.HealthCheckResponse{}, conn.Close() + }, + ) + + conn, registry := dial(t, addr) + err = call( + context.Background(), conn, registry, + func(conn net.Conn) error { + errC := make(chan error, 1) + go func() { + var err error + out, err = ioutil.ReadAll(conn) + errC <- err + }() + + _, err = io.Copy(conn, bytes.NewReader(in)) + require.NoError(t, err) + require.NoError(t, <-errC) + + return nil + }, + ) + require.NoError(t, err) + require.Equal(t, in, out, "byte stream works") +} + +// Conduct multiple requests with sidechannel included on the same grpc +// connection. +func TestSidechannelConcurrency(t *testing.T) { + const concurrency = 10 + const blobSize = 1024 * 1024 + + ins := make([][]byte, concurrency) + for i := 0; i < concurrency; i++ { + ins[i] = make([]byte, blobSize) + _, err := rand.Read(ins[i]) + require.NoError(t, err) + } + + outs := make([][]byte, concurrency) + for i := 0; i < concurrency; i++ { + require.NotEqual(t, ins[i], outs[i]) + } + + addr := startServer( + t, + func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + conn, err := OpenSidechannel(context) + if err != nil { + return nil, err + } + defer conn.Close() + + if _, err = io.CopyN(conn, conn, blobSize); err != nil { + return nil, err + } + + return &healthpb.HealthCheckResponse{}, conn.Close() + }, + ) + + conn, registry := dial(t, addr) + + errors := make(chan error, concurrency) + + wg := sync.WaitGroup{} + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + err := call( + context.Background(), conn, registry, + func(conn net.Conn) error { + errC := make(chan error, 1) + go func() { + var err error + outs[i], err = ioutil.ReadAll(conn) + errC <- err + }() + + if _, err := io.Copy(conn, bytes.NewReader(ins[i])); err != nil { + return err + } + if err := <-errC; err != nil { + return err + } + + return nil + }, + ) + errors <- err + }(i) + } + wg.Wait() + + for i := 0; i < concurrency; i++ { + require.Equal(t, ins[i], outs[i], "byte stream works") + require.NoError(t, <-errors) + } +} + +func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string { + t.Helper() + + logger := logrus.NewEntry(logrus.New()) + + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + + opts = append(opts, grpc.Creds(lm)) + + s := grpc.NewServer(opts...) + t.Cleanup(func() { s.Stop() }) + + handler := &server{testHandler: th} + healthpb.RegisterHealthServer(s, handler) + + lis, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { lis.Close() }) + + go func() { s.Serve(lis) }() + + return lis.Addr().String() +} + +func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) { + registry := NewRegistry() + logger := logrus.NewEntry(logrus.New()) + + factory := func() backchannel.Server { + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(NewServerHandshaker(registry)) + return grpc.NewServer(grpc.Creds(lm)) + } + + clientHandshaker := backchannel.NewClientHandshaker(logger, factory) + dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())) + + conn, err := grpc.Dial(addr, dialOpt) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + return conn, registry +} + +func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handler func(net.Conn) error) error { + client := healthpb.NewHealthClient(conn) + + ctxOut, waiter, err := RegisterSidechannel(ctx, registry, handler) + defer waiter.Close() + if err != nil { + return err + } + + if _, err := client.Check(ctxOut, &healthpb.HealthCheckRequest{}); err != nil { + return err + } + + if err := waiter.Wait(); err != nil { + return err + } + + return nil +} + +type testHandler func(context.Context, *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) + +type server struct { + healthpb.UnimplementedHealthServer + testHandler +} + +func (s *server) Check(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return s.testHandler(context, request) +} |