Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSami Hiltunen <shiltunen@gitlab.com>2021-07-30 17:35:48 +0300
committerQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-08-18 07:40:18 +0300
commitd6b0cc37a83023e4dff276fee49c395f5ff5ee95 (patch)
tree36bf46e70919e27f64b69f865456dc4ac1ca7720
parenta8520a1568f0c0515eef6931c01b3fa8e55e7985 (diff)
dirty multiplexing sidechannel implementationsmh-muxed-stream-proto
-rw-r--r--cmd/streamrpc/main.go270
-rw-r--r--internal/backchannel/backchannel_example_test.go2
-rw-r--r--internal/backchannel/backchannel_test.go4
-rw-r--r--internal/backchannel/grpc.go74
-rw-r--r--internal/backchannel/server.go60
-rw-r--r--internal/gitaly/client/dial_test.go2
-rw-r--r--internal/gitaly/server/server.go2
-rw-r--r--internal/praefect/server_test.go2
8 files changed, 355 insertions, 61 deletions
diff --git a/cmd/streamrpc/main.go b/cmd/streamrpc/main.go
new file mode 100644
index 000000000..b8878cfd0
--- /dev/null
+++ b/cmd/streamrpc/main.go
@@ -0,0 +1,270 @@
+package main
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "strconv"
+ "time"
+
+ "github.com/hashicorp/yamux"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/peer"
+)
+
+func main() {
+ server := flag.Bool("server", false, "take the server role")
+ flag.Parse()
+
+ if *server {
+ gitaly()
+ } else {
+ client()
+ }
+}
+
+// Registry is the glue between RPC invocations waiting to receive sidechannel
+// data and the listener accepting new sidechannel connections.
+type registry struct {
+ nextID int64
+ waiters map[int64]func(io.Reader)
+}
+
+// Await registers a new waiter that should be invoked when the sidechannel
+// data is available. It returns a sidechannel id the client sents to the server
+// which the server back when opening the sidechannel. The sidechannel id allows
+// for correlating the accepted sidechannel streams to the callbacks waiting for them.
+func (r *registry) Await(callback func(r io.Reader)) (int64, func()) {
+ id := r.nextID
+ r.nextID++
+ r.waiters[id] = callback
+ return id, func() { delete(r.waiters, id) }
+}
+
+// Listen waits accepts sidechannel connections from the server. The server sends
+// the sidechannel's id as the first thing on the stream which allows the client to
+// match the sidechannel data that the server is about to send with one of the
+// registered callbacks.
+func (r *registry) Listen(ln *yamux.Session) error {
+ for {
+ sidechannel, err := ln.Accept()
+ if err != nil {
+ return fmt.Errorf("accept: %w", err)
+ }
+
+ var id int64
+ if err := binary.Read(sidechannel, binary.BigEndian, &id); err != nil {
+ return fmt.Errorf("read sidechannel id: %w", err)
+ }
+
+ callback, ok := r.waiters[id]
+ if !ok {
+ return fmt.Errorf("invalid sidechannel id: %v", id)
+ }
+
+ log.Printf("received id: %d", id)
+ callback(sidechannel)
+ }
+}
+
+type clientHandshake struct {
+ credentials.TransportCredentials
+ *registry
+}
+
+// ClientHandshake opens one stream for the gRPC session and starts listening for
+// incoming sidechannel connections.
+func (ch clientHandshake) ClientHandshake(ctx context.Context, serverName string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := ch.TransportCredentials.ClientHandshake(ctx, serverName, conn)
+ if err != nil {
+ log.Fatal("original client handshake: %w", err)
+ }
+
+ _, err = conn.Write([]byte("mux00000000"))
+ if err != nil {
+ log.Fatal("write grpc stream magic: %w", err)
+ }
+
+ session, err := yamux.Client(conn, nil)
+ if err != nil {
+ log.Fatal("yamux client: %w", err)
+ }
+
+ // First open the client's gRPC stream to the server.
+ grpcStream, err := session.Open()
+ if err != nil {
+ log.Fatal("open grpc stream: %w", err)
+ }
+
+ // Here we begin listening for the incoming sidechannel connections.
+ go func() {
+ if err := ch.registry.Listen(session); err != nil {
+ log.Fatalf("listen for sidechannels: %v", err)
+ }
+ }()
+
+ return grpcStream, authInfo, nil
+}
+
+// client is the client side of the sidechannel protocol.
+func client() {
+ registry := &registry{waiters: make(map[int64]func(io.Reader))}
+ clientConn, err := grpc.Dial("localhost:8080", grpc.WithTransportCredentials(
+ clientHandshake{
+ TransportCredentials: insecure.NewCredentials(),
+ registry: registry,
+ },
+ ))
+ if err != nil {
+ log.Fatal("dial server: %w", err)
+ }
+
+ for {
+ // Each invocation of this func is a separate rpc call
+ func() {
+ // Before calling the Gitaly, we set up a callback that
+ // will handle the sidechannel data that the server sends.
+ // This also returns the stream is the server should send back to
+ // us so we can correlate this await call to the sidechannel
+ // that is being opened.
+ streamID, clean := registry.Await(func(r io.Reader) {
+ sidechannelData, err := io.ReadAll(r)
+ if err != nil {
+ log.Println("read sidechannel data: %w", err)
+ }
+
+ log.Printf("received sidechannel data: %q", sidechannelData)
+ })
+ defer clean()
+
+ // here we just send hte server the sidechannel's id it should send back to us.
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{
+ "sidechannel-id": fmt.Sprintf("%d", streamID),
+ }))
+
+ // perform the RPC call
+ if err := clientConn.Invoke(ctx, "/Gitaly/Mutator", &gitalypb.CreateBranchRequest{}, &gitalypb.CreateBranchResponse{}); err != nil {
+ log.Fatalf("call server: %v", err)
+ }
+ fmt.Println()
+ time.Sleep(time.Second)
+ }()
+ }
+
+}
+
+// gitaly is the server side of the sidechannel implementation.
+func gitaly() {
+ ln, err := net.Listen("tcp", "localhost:8080")
+ if err != nil {
+ log.Fatal("listen: %w", err)
+ }
+
+ // here we use listenmux to have a multiplexed transport easily open
+ lnMux := listenmux.New(insecure.NewCredentials())
+ lnMux.Register(muxedServer{})
+
+ if err := grpc.NewServer(grpc.Creds(lnMux), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
+ fmt.Println("Gitaly received RPC")
+
+ // We use the AuthInfo to pass the RPC handler a streamOpener it can
+ // use to open a sidechannel back to the client.
+ peerInfo, ok := peer.FromContext(stream.Context())
+ if !ok {
+ return errors.New("no peer info in context")
+ }
+
+ streamOpener, ok := peerInfo.AuthInfo.(interface {
+ OpenSidechannel(context.Context) (io.WriteCloser, error)
+ })
+ if !ok {
+ return fmt.Errorf("not a stream opener: %T", peerInfo.AuthInfo)
+ }
+
+ // Here we open the sidechannel to the client
+ fmt.Println("Gitaly opening sidechannel")
+ sidechannel, err := streamOpener.OpenSidechannel(stream.Context())
+ if err != nil {
+ return fmt.Errorf("open stream: %w", err)
+ }
+
+ // With the channel open, we can write data to it
+ fmt.Println("Gitaly writing to sidechannel")
+ _, err = sidechannel.Write([]byte("data sent over sidechannel"))
+ // we close the channel so the client knows we are done writing.
+ sidechannel.Close()
+ if err != nil {
+ return fmt.Errorf("write to sidechannel: %w", err)
+ }
+
+ // Send a gRPC response back to the client.
+ fmt.Println("Gitaly responding to RPC")
+ return stream.SendMsg(&gitalypb.CreateBranchResponse{})
+ })).Serve(ln); err != nil {
+ log.Fatal("serve: %w")
+ }
+}
+
+// streamOpener is injected via the AuthInfo in the context to the RPC handlers
+// on the server.
+type streamOpener struct {
+ credentials.AuthInfo
+ session *yamux.Session
+}
+
+// OpenSidechannel opens a new sidechannel to the client.
+func (opener *streamOpener) OpenSidechannel(ctx context.Context) (io.WriteCloser, error) {
+ // First step is to extract the sidechannel id the client sent us so
+ // we can send it back. The ID is how the client correlates this particular sidechannel
+ // to a specific RPC invocation on its side.
+ md, _ := metadata.FromIncomingContext(ctx)
+ id, _ := strconv.ParseInt(md["sidechannel-id"][0], 10, 64)
+ log.Printf("opening sidechannel id %d", id)
+
+ // open a new stream in the multiplexing session
+ stream, err := opener.session.Open()
+ if err != nil {
+ return nil, fmt.Errorf("open stream: %w", err)
+ }
+
+ // write the sidechannel id so the client can correlate this stream to an RPC.
+ if err := binary.Write(stream, binary.BigEndian, id); err != nil {
+ stream.Close()
+ return nil, fmt.Errorf("write stream id: %w", err)
+ }
+
+ // sidechannel is ready for use.
+ return stream, nil
+}
+
+type muxedServer struct{}
+
+// The handshake just opens a multiplexing session on top of the network connection
+// that we received from the client.
+func (muxedServer) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ session, err := yamux.Server(conn, nil)
+ if err != nil {
+ log.Fatal("yamux server: %w", err)
+ }
+
+ clientToServerGRPC, err := session.Accept()
+ if err != nil {
+ log.Fatal("accept grpc stream: %w", err)
+ }
+
+ return clientToServerGRPC, &streamOpener{
+ AuthInfo: authInfo, session: session,
+ }, nil
+}
+
+func (muxedServer) Magic() string { return "mux00000000" }
diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go
index 63f4e5c5d..a75a5cbf8 100644
--- a/internal/backchannel/backchannel_example_test.go
+++ b/internal/backchannel/backchannel_example_test.go
@@ -33,7 +33,7 @@ func Example() {
// the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a
// backchannel connection.
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(backchannel.NewServerHandshaker(logger, registry, nil))
+ lm.Register(backchannel.NewGRPCHandshaker(logger, registry, nil))
// Create the server
srv := grpc.NewServer(
diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go
index 1191737b5..330ce1380 100644
--- a/internal/backchannel/backchannel_test.go
+++ b/internal/backchannel/backchannel_test.go
@@ -41,7 +41,7 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
var interceptorInvoked int32
registry := NewRegistry()
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(NewServerHandshaker(
+ lm.Register(NewGRPCHandshaker(
newLogger(),
registry,
[]grpc.DialOption{
@@ -184,7 +184,7 @@ func Benchmark(b *testing.B) {
var serverOpts []grpc.ServerOption
if tc.multiplexed {
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(NewServerHandshaker(newLogger(), NewRegistry(), nil))
+ lm.Register(NewGRPCHandshaker(newLogger(), NewRegistry(), nil))
serverOpts = []grpc.ServerOption{
grpc.Creds(lm),
}
diff --git a/internal/backchannel/grpc.go b/internal/backchannel/grpc.go
new file mode 100644
index 000000000..915897aa4
--- /dev/null
+++ b/internal/backchannel/grpc.go
@@ -0,0 +1,74 @@
+package backchannel
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+
+ "github.com/hashicorp/yamux"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/peer"
+)
+
+// authInfoWrapper is used to pass the peer id through the context to the RPC handlers.
+type authInfoWrapper struct {
+ id ID
+ credentials.AuthInfo
+}
+
+func (w authInfoWrapper) peerID() ID { return w.id }
+
+// GetPeerID gets the ID of the current peer connection.
+func GetPeerID(ctx context.Context) (ID, error) {
+ peerInfo, ok := peer.FromContext(ctx)
+ if !ok {
+ return 0, errors.New("no peer info in context")
+ }
+
+ wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID })
+ if !ok {
+ return 0, ErrNonMultiplexedConnection
+ }
+
+ return wrapper.peerID(), nil
+}
+
+// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler.
+// This is exported to facilitate testing.
+func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo {
+ return authInfoWrapper{id: id, AuthInfo: authInfo}
+}
+
+// NewGRPCHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials
+// are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections.
+// DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or
+// transport credentials as those set by the handshaker.
+func NewGRPCHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker {
+ return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts, setup: gRPCSetup}
+}
+
+func gRPCSetup(s *ServerHandshaker, conn net.Conn, authInfo credentials.AuthInfo, muxSession *yamux.Session) (credentials.AuthInfo, func()(), error) {
+ // The address does not actually matter but we set it so clientConn.Target returns a meaningful value.
+ // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured.
+ backchannelConn, err := grpc.Dial(
+ "multiplexed/"+conn.RemoteAddr().String(),
+ append(
+ s.dialOpts,
+ grpc.WithInsecure(),
+ grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }),
+ )...,
+ )
+ if err != nil {
+ return nil, nil, fmt.Errorf("dial backchannel: %w", err)
+ }
+
+ id := s.registry.RegisterBackchannel(backchannelConn)
+ return WithID(authInfo, id), func() {
+ s.registry.RemoveBackchannel(id)
+ backchannelConn.Close()
+ muxSession.Close()
+ }, nil
+}
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go
index b2a1f58da..981cfb59e 100644
--- a/internal/backchannel/server.go
+++ b/internal/backchannel/server.go
@@ -1,7 +1,6 @@
package backchannel
import (
- "context"
"errors"
"fmt"
"net"
@@ -10,61 +9,24 @@ import (
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
- "google.golang.org/grpc/peer"
)
// ErrNonMultiplexedConnection is returned when attempting to get the peer id of a non-multiplexed
// connection.
var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection")
-// authInfoWrapper is used to pass the peer id through the context to the RPC handlers.
-type authInfoWrapper struct {
- id ID
- credentials.AuthInfo
-}
-
-func (w authInfoWrapper) peerID() ID { return w.id }
-
-// GetPeerID gets the ID of the current peer connection.
-func GetPeerID(ctx context.Context) (ID, error) {
- peerInfo, ok := peer.FromContext(ctx)
- if !ok {
- return 0, errors.New("no peer info in context")
- }
-
- wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID })
- if !ok {
- return 0, ErrNonMultiplexedConnection
- }
-
- return wrapper.peerID(), nil
-}
-
-// WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler.
-// This is exported to facilitate testing.
-func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo {
- return authInfoWrapper{id: id, AuthInfo: authInfo}
-}
-
// ServerHandshaker implements the server side handshake of the multiplexed connection.
type ServerHandshaker struct {
registry *Registry
logger *logrus.Entry
dialOpts []grpc.DialOption
+ setup func (*ServerHandshaker, net.Conn, credentials.AuthInfo, *yamux.Session) (credentials.AuthInfo, func()(), error)
}
// Magic is used by listenmux to retrieve the magic string for
// backchannel connections.
func (s *ServerHandshaker) Magic() string { return string(magicBytes) }
-// NewServerHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials
-// are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections.
-// DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or
-// transport credentials as those set by the handshaker.
-func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker {
- return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts}
-}
-
// Handshake establishes a gRPC ClientConn back to the backchannel client
// on the other side and stores its ID in the AuthInfo where it can be
// later accessed by the RPC handlers. gRPC sets an IO timeout on the
@@ -91,34 +53,22 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf
return nil, nil, fmt.Errorf("accept client's stream: %w", err)
}
- // The address does not actually matter but we set it so clientConn.Target returns a meaningful value.
- // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured.
- backchannelConn, err := grpc.Dial(
- "multiplexed/"+conn.RemoteAddr().String(),
- append(
- s.dialOpts,
- grpc.WithInsecure(),
- grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }),
- )...,
- )
+ authInfo, closeFunc, err := s.setup(s, conn, authInfo, muxSession)
if err != nil {
logger.Close()
- return nil, nil, fmt.Errorf("dial backchannel: %w", err)
+ return nil, nil, err
}
- id := s.registry.RegisterBackchannel(backchannelConn)
// The returned connection must close the underlying network connection, we redirect the close
// to the muxSession which also closes the underlying connection.
return connCloser{
Conn: clientToServerStream,
close: func() error {
- s.registry.RemoveBackchannel(id)
- backchannelConn.Close()
- muxSession.Close()
+ closeFunc()
logger.Close()
return nil
},
},
- WithID(authInfo, id),
+ authInfo,
nil
}
diff --git a/internal/gitaly/client/dial_test.go b/internal/gitaly/client/dial_test.go
index cdd7fffa1..1146474f4 100644
--- a/internal/gitaly/client/dial_test.go
+++ b/internal/gitaly/client/dial_test.go
@@ -23,7 +23,7 @@ func TestDial(t *testing.T) {
logger := testhelper.DiscardTestEntry(t)
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
+ lm.Register(backchannel.NewGRPCHandshaker(logger, backchannel.NewRegistry(), nil))
srv := grpc.NewServer(
grpc.Creds(lm),
diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go
index ffdf55c10..7d5150f10 100644
--- a/internal/gitaly/server/server.go
+++ b/internal/gitaly/server/server.go
@@ -97,7 +97,7 @@ func New(
}
lm := listenmux.New(transportCredentials)
- lm.Register(backchannel.NewServerHandshaker(
+ lm.Register(backchannel.NewGRPCHandshaker(
logrusEntry,
registry,
[]grpc.DialOption{client.UnaryInterceptor()},
diff --git a/internal/praefect/server_test.go b/internal/praefect/server_test.go
index d7ba38e5f..d6327303a 100644
--- a/internal/praefect/server_test.go
+++ b/internal/praefect/server_test.go
@@ -60,7 +60,7 @@ func TestNewBackchannelServerFactory(t *testing.T) {
registry := backchannel.NewRegistry()
lm := listenmux.New(insecure.NewCredentials())
- lm.Register(backchannel.NewServerHandshaker(logger, registry, nil))
+ lm.Register(backchannel.NewGRPCHandshaker(logger, registry, nil))
server := grpc.NewServer(
grpc.Creds(lm),