diff options
author | Sami Hiltunen <shiltunen@gitlab.com> | 2021-07-30 17:35:48 +0300 |
---|---|---|
committer | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-08-18 07:40:18 +0300 |
commit | d6b0cc37a83023e4dff276fee49c395f5ff5ee95 (patch) | |
tree | 36bf46e70919e27f64b69f865456dc4ac1ca7720 | |
parent | a8520a1568f0c0515eef6931c01b3fa8e55e7985 (diff) |
dirty multiplexing sidechannel implementationsmh-muxed-stream-proto
-rw-r--r-- | cmd/streamrpc/main.go | 270 | ||||
-rw-r--r-- | internal/backchannel/backchannel_example_test.go | 2 | ||||
-rw-r--r-- | internal/backchannel/backchannel_test.go | 4 | ||||
-rw-r--r-- | internal/backchannel/grpc.go | 74 | ||||
-rw-r--r-- | internal/backchannel/server.go | 60 | ||||
-rw-r--r-- | internal/gitaly/client/dial_test.go | 2 | ||||
-rw-r--r-- | internal/gitaly/server/server.go | 2 | ||||
-rw-r--r-- | internal/praefect/server_test.go | 2 |
8 files changed, 355 insertions, 61 deletions
diff --git a/cmd/streamrpc/main.go b/cmd/streamrpc/main.go new file mode 100644 index 000000000..b8878cfd0 --- /dev/null +++ b/cmd/streamrpc/main.go @@ -0,0 +1,270 @@ +package main + +import ( + "context" + "encoding/binary" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "strconv" + "time" + + "github.com/hashicorp/yamux" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +func main() { + server := flag.Bool("server", false, "take the server role") + flag.Parse() + + if *server { + gitaly() + } else { + client() + } +} + +// Registry is the glue between RPC invocations waiting to receive sidechannel +// data and the listener accepting new sidechannel connections. +type registry struct { + nextID int64 + waiters map[int64]func(io.Reader) +} + +// Await registers a new waiter that should be invoked when the sidechannel +// data is available. It returns a sidechannel id the client sents to the server +// which the server back when opening the sidechannel. The sidechannel id allows +// for correlating the accepted sidechannel streams to the callbacks waiting for them. +func (r *registry) Await(callback func(r io.Reader)) (int64, func()) { + id := r.nextID + r.nextID++ + r.waiters[id] = callback + return id, func() { delete(r.waiters, id) } +} + +// Listen waits accepts sidechannel connections from the server. The server sends +// the sidechannel's id as the first thing on the stream which allows the client to +// match the sidechannel data that the server is about to send with one of the +// registered callbacks. +func (r *registry) Listen(ln *yamux.Session) error { + for { + sidechannel, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %w", err) + } + + var id int64 + if err := binary.Read(sidechannel, binary.BigEndian, &id); err != nil { + return fmt.Errorf("read sidechannel id: %w", err) + } + + callback, ok := r.waiters[id] + if !ok { + return fmt.Errorf("invalid sidechannel id: %v", id) + } + + log.Printf("received id: %d", id) + callback(sidechannel) + } +} + +type clientHandshake struct { + credentials.TransportCredentials + *registry +} + +// ClientHandshake opens one stream for the gRPC session and starts listening for +// incoming sidechannel connections. +func (ch clientHandshake) ClientHandshake(ctx context.Context, serverName string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, authInfo, err := ch.TransportCredentials.ClientHandshake(ctx, serverName, conn) + if err != nil { + log.Fatal("original client handshake: %w", err) + } + + _, err = conn.Write([]byte("mux00000000")) + if err != nil { + log.Fatal("write grpc stream magic: %w", err) + } + + session, err := yamux.Client(conn, nil) + if err != nil { + log.Fatal("yamux client: %w", err) + } + + // First open the client's gRPC stream to the server. + grpcStream, err := session.Open() + if err != nil { + log.Fatal("open grpc stream: %w", err) + } + + // Here we begin listening for the incoming sidechannel connections. + go func() { + if err := ch.registry.Listen(session); err != nil { + log.Fatalf("listen for sidechannels: %v", err) + } + }() + + return grpcStream, authInfo, nil +} + +// client is the client side of the sidechannel protocol. +func client() { + registry := ®istry{waiters: make(map[int64]func(io.Reader))} + clientConn, err := grpc.Dial("localhost:8080", grpc.WithTransportCredentials( + clientHandshake{ + TransportCredentials: insecure.NewCredentials(), + registry: registry, + }, + )) + if err != nil { + log.Fatal("dial server: %w", err) + } + + for { + // Each invocation of this func is a separate rpc call + func() { + // Before calling the Gitaly, we set up a callback that + // will handle the sidechannel data that the server sends. + // This also returns the stream is the server should send back to + // us so we can correlate this await call to the sidechannel + // that is being opened. + streamID, clean := registry.Await(func(r io.Reader) { + sidechannelData, err := io.ReadAll(r) + if err != nil { + log.Println("read sidechannel data: %w", err) + } + + log.Printf("received sidechannel data: %q", sidechannelData) + }) + defer clean() + + // here we just send hte server the sidechannel's id it should send back to us. + ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{ + "sidechannel-id": fmt.Sprintf("%d", streamID), + })) + + // perform the RPC call + if err := clientConn.Invoke(ctx, "/Gitaly/Mutator", &gitalypb.CreateBranchRequest{}, &gitalypb.CreateBranchResponse{}); err != nil { + log.Fatalf("call server: %v", err) + } + fmt.Println() + time.Sleep(time.Second) + }() + } + +} + +// gitaly is the server side of the sidechannel implementation. +func gitaly() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + log.Fatal("listen: %w", err) + } + + // here we use listenmux to have a multiplexed transport easily open + lnMux := listenmux.New(insecure.NewCredentials()) + lnMux.Register(muxedServer{}) + + if err := grpc.NewServer(grpc.Creds(lnMux), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { + fmt.Println("Gitaly received RPC") + + // We use the AuthInfo to pass the RPC handler a streamOpener it can + // use to open a sidechannel back to the client. + peerInfo, ok := peer.FromContext(stream.Context()) + if !ok { + return errors.New("no peer info in context") + } + + streamOpener, ok := peerInfo.AuthInfo.(interface { + OpenSidechannel(context.Context) (io.WriteCloser, error) + }) + if !ok { + return fmt.Errorf("not a stream opener: %T", peerInfo.AuthInfo) + } + + // Here we open the sidechannel to the client + fmt.Println("Gitaly opening sidechannel") + sidechannel, err := streamOpener.OpenSidechannel(stream.Context()) + if err != nil { + return fmt.Errorf("open stream: %w", err) + } + + // With the channel open, we can write data to it + fmt.Println("Gitaly writing to sidechannel") + _, err = sidechannel.Write([]byte("data sent over sidechannel")) + // we close the channel so the client knows we are done writing. + sidechannel.Close() + if err != nil { + return fmt.Errorf("write to sidechannel: %w", err) + } + + // Send a gRPC response back to the client. + fmt.Println("Gitaly responding to RPC") + return stream.SendMsg(&gitalypb.CreateBranchResponse{}) + })).Serve(ln); err != nil { + log.Fatal("serve: %w") + } +} + +// streamOpener is injected via the AuthInfo in the context to the RPC handlers +// on the server. +type streamOpener struct { + credentials.AuthInfo + session *yamux.Session +} + +// OpenSidechannel opens a new sidechannel to the client. +func (opener *streamOpener) OpenSidechannel(ctx context.Context) (io.WriteCloser, error) { + // First step is to extract the sidechannel id the client sent us so + // we can send it back. The ID is how the client correlates this particular sidechannel + // to a specific RPC invocation on its side. + md, _ := metadata.FromIncomingContext(ctx) + id, _ := strconv.ParseInt(md["sidechannel-id"][0], 10, 64) + log.Printf("opening sidechannel id %d", id) + + // open a new stream in the multiplexing session + stream, err := opener.session.Open() + if err != nil { + return nil, fmt.Errorf("open stream: %w", err) + } + + // write the sidechannel id so the client can correlate this stream to an RPC. + if err := binary.Write(stream, binary.BigEndian, id); err != nil { + stream.Close() + return nil, fmt.Errorf("write stream id: %w", err) + } + + // sidechannel is ready for use. + return stream, nil +} + +type muxedServer struct{} + +// The handshake just opens a multiplexing session on top of the network connection +// that we received from the client. +func (muxedServer) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + session, err := yamux.Server(conn, nil) + if err != nil { + log.Fatal("yamux server: %w", err) + } + + clientToServerGRPC, err := session.Accept() + if err != nil { + log.Fatal("accept grpc stream: %w", err) + } + + return clientToServerGRPC, &streamOpener{ + AuthInfo: authInfo, session: session, + }, nil +} + +func (muxedServer) Magic() string { return "mux00000000" } diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go index 63f4e5c5d..a75a5cbf8 100644 --- a/internal/backchannel/backchannel_example_test.go +++ b/internal/backchannel/backchannel_example_test.go @@ -33,7 +33,7 @@ func Example() { // the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a // backchannel connection. lm := listenmux.New(insecure.NewCredentials()) - lm.Register(backchannel.NewServerHandshaker(logger, registry, nil)) + lm.Register(backchannel.NewGRPCHandshaker(logger, registry, nil)) // Create the server srv := grpc.NewServer( diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go index 1191737b5..330ce1380 100644 --- a/internal/backchannel/backchannel_test.go +++ b/internal/backchannel/backchannel_test.go @@ -41,7 +41,7 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) { var interceptorInvoked int32 registry := NewRegistry() lm := listenmux.New(insecure.NewCredentials()) - lm.Register(NewServerHandshaker( + lm.Register(NewGRPCHandshaker( newLogger(), registry, []grpc.DialOption{ @@ -184,7 +184,7 @@ func Benchmark(b *testing.B) { var serverOpts []grpc.ServerOption if tc.multiplexed { lm := listenmux.New(insecure.NewCredentials()) - lm.Register(NewServerHandshaker(newLogger(), NewRegistry(), nil)) + lm.Register(NewGRPCHandshaker(newLogger(), NewRegistry(), nil)) serverOpts = []grpc.ServerOption{ grpc.Creds(lm), } diff --git a/internal/backchannel/grpc.go b/internal/backchannel/grpc.go new file mode 100644 index 000000000..915897aa4 --- /dev/null +++ b/internal/backchannel/grpc.go @@ -0,0 +1,74 @@ +package backchannel + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/hashicorp/yamux" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +// authInfoWrapper is used to pass the peer id through the context to the RPC handlers. +type authInfoWrapper struct { + id ID + credentials.AuthInfo +} + +func (w authInfoWrapper) peerID() ID { return w.id } + +// GetPeerID gets the ID of the current peer connection. +func GetPeerID(ctx context.Context) (ID, error) { + peerInfo, ok := peer.FromContext(ctx) + if !ok { + return 0, errors.New("no peer info in context") + } + + wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID }) + if !ok { + return 0, ErrNonMultiplexedConnection + } + + return wrapper.peerID(), nil +} + +// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler. +// This is exported to facilitate testing. +func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo { + return authInfoWrapper{id: id, AuthInfo: authInfo} +} + +// NewGRPCHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials +// are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections. +// DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or +// transport credentials as those set by the handshaker. +func NewGRPCHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker { + return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts, setup: gRPCSetup} +} + +func gRPCSetup(s *ServerHandshaker, conn net.Conn, authInfo credentials.AuthInfo, muxSession *yamux.Session) (credentials.AuthInfo, func()(), error) { + // The address does not actually matter but we set it so clientConn.Target returns a meaningful value. + // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured. + backchannelConn, err := grpc.Dial( + "multiplexed/"+conn.RemoteAddr().String(), + append( + s.dialOpts, + grpc.WithInsecure(), + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }), + )..., + ) + if err != nil { + return nil, nil, fmt.Errorf("dial backchannel: %w", err) + } + + id := s.registry.RegisterBackchannel(backchannelConn) + return WithID(authInfo, id), func() { + s.registry.RemoveBackchannel(id) + backchannelConn.Close() + muxSession.Close() + }, nil +} diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go index b2a1f58da..981cfb59e 100644 --- a/internal/backchannel/server.go +++ b/internal/backchannel/server.go @@ -1,7 +1,6 @@ package backchannel import ( - "context" "errors" "fmt" "net" @@ -10,61 +9,24 @@ import ( "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" ) // ErrNonMultiplexedConnection is returned when attempting to get the peer id of a non-multiplexed // connection. var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection") -// authInfoWrapper is used to pass the peer id through the context to the RPC handlers. -type authInfoWrapper struct { - id ID - credentials.AuthInfo -} - -func (w authInfoWrapper) peerID() ID { return w.id } - -// GetPeerID gets the ID of the current peer connection. -func GetPeerID(ctx context.Context) (ID, error) { - peerInfo, ok := peer.FromContext(ctx) - if !ok { - return 0, errors.New("no peer info in context") - } - - wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID }) - if !ok { - return 0, ErrNonMultiplexedConnection - } - - return wrapper.peerID(), nil -} - -// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler. -// This is exported to facilitate testing. -func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo { - return authInfoWrapper{id: id, AuthInfo: authInfo} -} - // ServerHandshaker implements the server side handshake of the multiplexed connection. type ServerHandshaker struct { registry *Registry logger *logrus.Entry dialOpts []grpc.DialOption + setup func (*ServerHandshaker, net.Conn, credentials.AuthInfo, *yamux.Session) (credentials.AuthInfo, func()(), error) } // Magic is used by listenmux to retrieve the magic string for // backchannel connections. func (s *ServerHandshaker) Magic() string { return string(magicBytes) } -// NewServerHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials -// are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections. -// DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or -// transport credentials as those set by the handshaker. -func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker { - return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts} -} - // Handshake establishes a gRPC ClientConn back to the backchannel client // on the other side and stores its ID in the AuthInfo where it can be // later accessed by the RPC handlers. gRPC sets an IO timeout on the @@ -91,34 +53,22 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf return nil, nil, fmt.Errorf("accept client's stream: %w", err) } - // The address does not actually matter but we set it so clientConn.Target returns a meaningful value. - // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured. - backchannelConn, err := grpc.Dial( - "multiplexed/"+conn.RemoteAddr().String(), - append( - s.dialOpts, - grpc.WithInsecure(), - grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }), - )..., - ) + authInfo, closeFunc, err := s.setup(s, conn, authInfo, muxSession) if err != nil { logger.Close() - return nil, nil, fmt.Errorf("dial backchannel: %w", err) + return nil, nil, err } - id := s.registry.RegisterBackchannel(backchannelConn) // The returned connection must close the underlying network connection, we redirect the close // to the muxSession which also closes the underlying connection. return connCloser{ Conn: clientToServerStream, close: func() error { - s.registry.RemoveBackchannel(id) - backchannelConn.Close() - muxSession.Close() + closeFunc() logger.Close() return nil }, }, - WithID(authInfo, id), + authInfo, nil } diff --git a/internal/gitaly/client/dial_test.go b/internal/gitaly/client/dial_test.go index cdd7fffa1..1146474f4 100644 --- a/internal/gitaly/client/dial_test.go +++ b/internal/gitaly/client/dial_test.go @@ -23,7 +23,7 @@ func TestDial(t *testing.T) { logger := testhelper.DiscardTestEntry(t) lm := listenmux.New(insecure.NewCredentials()) - lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + lm.Register(backchannel.NewGRPCHandshaker(logger, backchannel.NewRegistry(), nil)) srv := grpc.NewServer( grpc.Creds(lm), diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index ffdf55c10..7d5150f10 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -97,7 +97,7 @@ func New( } lm := listenmux.New(transportCredentials) - lm.Register(backchannel.NewServerHandshaker( + lm.Register(backchannel.NewGRPCHandshaker( logrusEntry, registry, []grpc.DialOption{client.UnaryInterceptor()}, diff --git a/internal/praefect/server_test.go b/internal/praefect/server_test.go index d7ba38e5f..d6327303a 100644 --- a/internal/praefect/server_test.go +++ b/internal/praefect/server_test.go @@ -60,7 +60,7 @@ func TestNewBackchannelServerFactory(t *testing.T) { registry := backchannel.NewRegistry() lm := listenmux.New(insecure.NewCredentials()) - lm.Register(backchannel.NewServerHandshaker(logger, registry, nil)) + lm.Register(backchannel.NewGRPCHandshaker(logger, registry, nil)) server := grpc.NewServer( grpc.Creds(lm), |