diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2021-06-15 15:33:18 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2021-06-30 18:19:30 +0300 |
commit | c95298c125a680a006153d5aca5d3dbb575ce352 (patch) | |
tree | 49a270994fad3cad104ce8e47fcb06884070a4cd | |
parent | 82a7a8e90f5bf3f0cae18d158a28eb8a7a1693c6 (diff) |
Separate listenmux from backchannel
Changelog: other
-rw-r--r-- | internal/backchannel/backchannel.go | 12 | ||||
-rw-r--r-- | internal/backchannel/backchannel_example_test.go | 6 | ||||
-rw-r--r-- | internal/backchannel/backchannel_test.go | 13 | ||||
-rw-r--r-- | internal/backchannel/server.go | 56 | ||||
-rw-r--r-- | internal/gitaly/client/dial_test.go | 6 | ||||
-rw-r--r-- | internal/gitaly/server/server.go | 10 | ||||
-rw-r--r-- | internal/listenmux/mux.go | 89 | ||||
-rw-r--r-- | internal/listenmux/mux_test.go | 291 | ||||
-rw-r--r-- | internal/praefect/nodes/sql_elector_test.go | 6 | ||||
-rw-r--r-- | internal/praefect/server_test.go | 7 |
10 files changed, 432 insertions, 64 deletions
diff --git a/internal/backchannel/backchannel.go b/internal/backchannel/backchannel.go index 2e0e919cf..a429ea877 100644 --- a/internal/backchannel/backchannel.go +++ b/internal/backchannel/backchannel.go @@ -14,19 +14,15 @@ // independent gRPC sessions on a single connection. This allows for dialing back to the client from // the server to establish another gRPC session where the server and client roles are switched. // -// The server side supports clients that are unaware of the multiplexing. The server peeks the incoming -// network stream to see if it starts with the magic bytes that indicate a multiplexing aware client. -// If the magic bytes are present, the server initiates the multiplexing session and dials back to the client -// over the already established network connection. If the magic bytes are not present, the server restores the -// the bytes back into the original network stream and handles it without a multiplexing session. +// The server side uses listenmux to support clients that are unaware of the multiplexing. // // Usage: // 1. Implement a ServerFactory, which is simply a function that returns a Server that can serve on the backchannel // connection. Plug in the ClientHandshake to the Clientconn via grpc.WithTransportCredentials when dialing. // This ensures all connections established by gRPC work with a multiplexing session and have a backchannel Server serving. -// 2. Configure the ServerHandshake on the server side by passing it into the gRPC server via the grpc.Creds option. -// The ServerHandshake method is called on each newly established connection. It peeks the network stream to see if a -// multiplexing session should be initiated. If so, it also dials back to the client's backchannel server. Server +// 2. Create a *listenmux.Mux and register a *ServerHandshaker with it. +// 3. Pass the *listenmux.Mux into the grpc Server using grpc.Creds. +// The Handshake method is called on each newly established connection that presents the backchannel magic bytes. It dials back to the client's backchannel server. Server // makes the backchannel connection's available later via the Registry's Backchannel method. The ID of the // peer associated with the current RPC handler can be fetched via GetPeerID. The returned ID can be used // to access the correct backchannel connection from the Registry. diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go index 0b06b4779..63f4e5c5d 100644 --- a/internal/backchannel/backchannel_example_test.go +++ b/internal/backchannel/backchannel_example_test.go @@ -7,6 +7,7 @@ import ( "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -31,11 +32,12 @@ func Example() { // it creates the backchannel connection and stores it into the registry. For each connection, // the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a // backchannel connection. - handshaker := backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), registry, nil) + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker(logger, registry, nil)) // Create the server srv := grpc.NewServer( - grpc.Creds(handshaker), + grpc.Creds(lm), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { fmt.Println("Gitaly received a transactional mutator") diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go index a599d6880..1191737b5 100644 --- a/internal/backchannel/backchannel_test.go +++ b/internal/backchannel/backchannel_test.go @@ -13,6 +13,7 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert" "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" "google.golang.org/grpc" @@ -39,9 +40,9 @@ func newLogger() *logrus.Entry { func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) { var interceptorInvoked int32 registry := NewRegistry() - handshaker := NewServerHandshaker( + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(NewServerHandshaker( newLogger(), - insecure.NewCredentials(), registry, []grpc.DialOption{ grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { @@ -49,13 +50,13 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) { return invoker(ctx, method, req, reply, cc, opts...) }), }, - ) + )) ln, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) errNonMultiplexed := status.Error(codes.FailedPrecondition, ErrNonMultiplexedConnection.Error()) - srv := grpc.NewServer(grpc.Creds(handshaker)) + srv := grpc.NewServer(grpc.Creds(lm)) gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{ voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) { @@ -182,8 +183,10 @@ func Benchmark(b *testing.B) { b.Run(fmt.Sprintf("message size %dkb", messageSize/1024), func(b *testing.B) { var serverOpts []grpc.ServerOption if tc.multiplexed { + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(NewServerHandshaker(newLogger(), NewRegistry(), nil)) serverOpts = []grpc.ServerOption{ - grpc.Creds(NewServerHandshaker(newLogger(), insecure.NewCredentials(), NewRegistry(), nil)), + grpc.Creds(lm), } } diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go index c22b46e37..b2a1f58da 100644 --- a/internal/backchannel/server.go +++ b/internal/backchannel/server.go @@ -1,12 +1,9 @@ package backchannel import ( - "bytes" "context" "errors" "fmt" - "io" - "io/ioutil" "net" "github.com/hashicorp/yamux" @@ -53,58 +50,27 @@ func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo { type ServerHandshaker struct { registry *Registry logger *logrus.Entry - credentials.TransportCredentials dialOpts []grpc.DialOption } +// Magic is used by listenmux to retrieve the magic string for +// backchannel connections. +func (s *ServerHandshaker) Magic() string { return string(magicBytes) } + // NewServerHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials // are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections. // DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or // transport credentials as those set by the handshaker. -func NewServerHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials, reg *Registry, dialOpts []grpc.DialOption) credentials.TransportCredentials { - return ServerHandshaker{ - TransportCredentials: tc, - registry: reg, - logger: logger, - dialOpts: dialOpts, - } -} - -// restoredConn allows for restoring the connection's stream after peeking it. If the connection -// was not multiplexed, the peeked bytes are restored back into the stream. -type restoredConn struct { - net.Conn - reader io.Reader +func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker { + return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts} } -func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) } - -// ServerHandshake peeks the connection to determine whether the client supports establishing a -// backchannel by multiplexing the network connection. If so, it establishes a gRPC ClientConn back -// to the client and stores it's ID in the AuthInfo where it can be later accessed by the RPC handlers. -// gRPC sets an IO timeout on the connection before calling ServerHandshake, so we don't have to handle +// Handshake establishes a gRPC ClientConn back to the backchannel client +// on the other side and stores its ID in the AuthInfo where it can be +// later accessed by the RPC handlers. gRPC sets an IO timeout on the +// connection before calling ServerHandshake, so we don't have to handle // timeouts separately. -func (s ServerHandshaker) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - conn, authInfo, err := s.TransportCredentials.ServerHandshake(conn) - if err != nil { - return nil, nil, fmt.Errorf("wrapped server handshake: %w", err) - } - - peeked, err := ioutil.ReadAll(io.LimitReader(conn, int64(len(magicBytes)))) - if err != nil { - return nil, nil, fmt.Errorf("peek network stream: %w", err) - } - - if !bytes.Equal(peeked, magicBytes) { - // If the client connection is not multiplexed, restore the peeked bytes back into the stream. - // We also set a 0 peer ID in the authInfo to indicate that the server handshake was attempted - // but this was not a multiplexed connection. - return &restoredConn{ - Conn: conn, - reader: io.MultiReader(bytes.NewReader(peeked), conn), - }, authInfo, nil - } - +func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { // It is not necessary to clean up any of the multiplexing-related sessions on errors as the // gRPC server closes the conn if there is an error, which closes the multiplexing // session as well. diff --git a/internal/gitaly/client/dial_test.go b/internal/gitaly/client/dial_test.go index 7db99c2ff..cdd7fffa1 100644 --- a/internal/gitaly/client/dial_test.go +++ b/internal/gitaly/client/dial_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" "google.golang.org/grpc" @@ -21,8 +22,11 @@ func TestDial(t *testing.T) { logger := testhelper.DiscardTestEntry(t) + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + srv := grpc.NewServer( - grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), backchannel.NewRegistry(), nil)), + grpc.Creds(lm), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { _, err := backchannel.GetPeerID(stream.Context()) if err == backchannel.ErrNonMultiplexedConnection { diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index a2540761b..7602cd287 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -16,6 +16,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v14/internal/helper/fieldextractors" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" gitalylog "gitlab.com/gitlab-org/gitaly/v14/internal/log" "gitlab.com/gitlab-org/gitaly/v14/internal/logsanitizer" "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/cache" @@ -94,8 +95,15 @@ func New( }) } + lm := listenmux.New(transportCredentials) + lm.Register(backchannel.NewServerHandshaker( + logrusEntry, + registry, + []grpc.DialOption{client.UnaryInterceptor()}, + )) + opts := []grpc.ServerOption{ - grpc.Creds(backchannel.NewServerHandshaker(logrusEntry, transportCredentials, registry, []grpc.DialOption{client.UnaryInterceptor()})), + grpc.Creds(lm), grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_ctxtags.StreamServerInterceptor(ctxTagOpts...), grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler diff --git a/internal/listenmux/mux.go b/internal/listenmux/mux.go new file mode 100644 index 000000000..ad1f468a6 --- /dev/null +++ b/internal/listenmux/mux.go @@ -0,0 +1,89 @@ +package listenmux + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net" + + "google.golang.org/grpc/credentials" +) + +const magicLen = 11 + +// Handshaker represents a multiplexed connection type. +type Handshaker interface { + // Handshake is called with a valid Conn and AuthInfo when grpc-go + // accepts a connection that presents the Magic() magic bytes. From the + // point of view of grpc-go, it is part of the + // credentials.TransportCredentials.ServerHandshake callback. The return + // values of Handshake become the return values of + // credentials.TransportCredentials.ServerHandshake. + Handshake(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) + + // Magic returns the magic bytes that clients must send on the wire to + // reach this handshaker. The string must be no more than 11 bytes long. + // If it is longer, this handshaker will never be called. + Magic() string +} + +var _ credentials.TransportCredentials = &Mux{} + +// Mux is a listener multiplexer that plugs into grpc-go as a "TransportCredentials" callback. +type Mux struct { + credentials.TransportCredentials + handshakers map[string]Handshaker +} + +// Register registers a handshaker. It is not thread-safe. +func (m *Mux) Register(h Handshaker) { + if len(h.Magic()) != magicLen { + panic("wrong magic bytes length") + } + + m.handshakers[h.Magic()] = h +} + +// New returns a *Mux that wraps existing transport credentials. This +// does nothing interesting unless you also call Register to add +// handshakers to the Mux. +func New(tc credentials.TransportCredentials) *Mux { + return &Mux{ + TransportCredentials: tc, + handshakers: make(map[string]Handshaker), + } +} + +// restoredConn allows for restoring the connection's stream after peeking it. If the connection +// was not multiplexed, the peeked bytes are restored back into the stream. +type restoredConn struct { + net.Conn + reader io.Reader +} + +func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) } + +// ServerHandshake peeks the connection to determine whether the client +// wants to make a multiplexed connection. It is part of the +// TransportCredentials interface. +func (m *Mux) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn, authInfo, err := m.TransportCredentials.ServerHandshake(conn) + if err != nil { + return nil, nil, fmt.Errorf("wrapped server handshake: %w", err) + } + + peeked, err := ioutil.ReadAll(io.LimitReader(conn, magicLen)) + if err != nil { + return nil, nil, fmt.Errorf("peek network stream: %w", err) + } + + if h, ok := m.handshakers[string(peeked)]; ok { + return h.Handshake(conn, authInfo) + } + + return &restoredConn{ + Conn: conn, + reader: io.MultiReader(bytes.NewReader(peeked), conn), + }, authInfo, nil +} diff --git a/internal/listenmux/mux_test.go b/internal/listenmux/mux_test.go new file mode 100644 index 000000000..905329fbe --- /dev/null +++ b/internal/listenmux/mux_test.go @@ -0,0 +1,291 @@ +package listenmux + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" +) + +type handshakeFunc func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) + +func (hf handshakeFunc) Handshake(c net.Conn, ai credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + return hf(c, ai) +} + +const testmux = "test mux " + +func (hf handshakeFunc) Magic() string { return testmux } + +func serverWithHandshaker(t *testing.T, h Handshaker) string { + t.Helper() + + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { l.Close() }) + + tc := New(insecure.NewCredentials()) + if h != nil { + tc.Register(h) + } + + s := grpc.NewServer( + grpc.Creds(tc), + ) + t.Cleanup(s.Stop) + + healthgrpc.RegisterHealthServer(s, health.NewServer()) + + go func() { assert.NoError(t, s.Serve(l)) }() + + return l.Addr().String() +} + +func checkHealth(t *testing.T, cc *grpc.ClientConn) { + t.Helper() + _, err := healthgrpc.NewHealthClient(cc).Check(context.Background(), &healthgrpc.HealthCheckRequest{}) + require.NoError(t, err) +} + +func TestMux_normalClientNoMux(t *testing.T) { + addr := serverWithHandshaker(t, nil) + + cc, err := grpc.Dial(addr, grpc.WithInsecure()) + require.NoError(t, err) + defer cc.Close() + + checkHealth(t, cc) +} + +func TestMux_normalClientMuxIgnored(t *testing.T) { + addr := serverWithHandshaker(t, + handshakeFunc(func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + t.Error("never called") + return nil, nil, nil + }), + ) + + cc, err := grpc.Dial(addr, grpc.WithInsecure()) + require.NoError(t, err) + defer cc.Close() + + checkHealth(t, cc) +} + +func TestMux_muxClientPassesThrough(t *testing.T) { + handshakerCalled := false + + addr := serverWithHandshaker(t, + handshakeFunc(func(c net.Conn, ai credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + handshakerCalled = true + return c, ai, nil + }), + ) + + cc, err := grpc.Dial( + "ignored", + grpc.WithInsecure(), + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + c, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + + if _, err := io.WriteString(c, testmux); err != nil { + return nil, err + } + + return c, nil + }), + ) + require.NoError(t, err) + defer cc.Close() + + checkHealth(t, cc) + + require.True(t, handshakerCalled) +} + +func readN(t *testing.T, r io.Reader, n int) []byte { + t.Helper() + buf := make([]byte, n) + _, err := io.ReadFull(r, buf) + require.NoError(t, err) + return buf +} + +func TestMux_handshakerStealsConnection(t *testing.T) { + connCh := make(chan net.Conn, 1) + addr := serverWithHandshaker(t, + handshakeFunc(func(c net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + connCh <- c + return nil, nil, credentials.ErrConnDispatched + }), + ) + + done := make(chan struct{}) + go func() { + defer close(done) + + serverConn := <-connCh + defer serverConn.Close() + + // Give grpc-go a chance to close the connection, which it shouldn't + time.Sleep(100 * time.Millisecond) + + ping := readN(t, serverConn, 4) + require.Equal(t, "ping", string(ping)) + + _, err := io.WriteString(serverConn, "pong") + require.NoError(t, err) + }() + + c, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer c.Close() + + _, err = io.WriteString(c, testmux+"ping") + require.NoError(t, err) + + pong := readN(t, c, 4) + require.Equal(t, "pong", string(pong)) + + <-done +} + +func TestMux_handshakerReturnsError(t *testing.T) { + addr := serverWithHandshaker(t, + handshakeFunc(func(_ net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New("something went wrong") + }), + ) + + c, err := net.Dial("tcp", addr) + require.NoError(t, err) + defer c.Close() + + _, err = io.WriteString(c, testmux) + require.NoError(t, err) + + require.NoError(t, c.SetDeadline(time.Now().Add(1*time.Second))) + + buf := make([]byte, 1) + _, err = io.ReadFull(c, buf) + require.Equal(t, io.EOF, err, "EOF tells us that grpc-go closed the connection") +} + +func TestMux_concurrency(t *testing.T) { + const N = 100 + + // We want to open a lot of network connections. Raise the limits for the + // process as far as we're allowed. + var limit syscall.Rlimit + require.NoError(t, syscall.Getrlimit(syscall.RLIMIT_NOFILE, &limit)) + limit.Cur = limit.Max + require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &limit)) + + streamServerErrors := make(chan error, N) + + addr := serverWithHandshaker(t, + handshakeFunc(func(c net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) { + go func() { + streamServerErrors <- func() error { + defer c.Close() + if _, err := io.Copy(c, c); err != nil { + return err + } + return c.Close() + }() + }() + + return nil, nil, credentials.ErrConnDispatched + }), + ) + + start := make(chan struct{}) + + streamClientErrors := make(chan error, N) + grpcHealthErrors := make(chan error, N) + + for i := 0; i < N; i++ { + go func() { + <-start + streamClientErrors <- func() error { + c, err := net.Dial("tcp", addr) + if err != nil { + return err + } + defer c.Close() + + if err := c.SetDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + + if _, err := io.WriteString(c, testmux); err != nil { + return err + } + + buf := make([]byte, 128) + if _, err = rand.Read(buf); err != nil { + return err + } + + if n, err := c.Write(buf); err != nil || n < len(buf) { + return fmt.Errorf("write error or short write: %w", err) + } + + if err := c.(*net.TCPConn).CloseWrite(); err != nil { + return err + } + + out, err := ioutil.ReadAll(c) + if err != nil { + return err + } + if !bytes.Equal(buf, out) { + return fmt.Errorf("expected %x, got %x", buf, out) + } + + return c.Close() + }() + }() + + go func() { + <-start + grpcHealthErrors <- func() error { + cc, err := grpc.Dial(addr, grpc.WithInsecure()) + if err != nil { + return err + } + defer cc.Close() + + client := healthgrpc.NewHealthClient(cc) + _, err = client.Check(context.Background(), &healthgrpc.HealthCheckRequest{}) + return err + }() + }() + } + + close(start) + + for i := 0; i < N; i++ { + require.NoError(t, <-streamServerErrors) + require.NoError(t, <-streamClientErrors) + require.NoError(t, <-grpcHealthErrors) + } +} diff --git a/internal/praefect/nodes/sql_elector_test.go b/internal/praefect/nodes/sql_elector_test.go index 882f744fb..31a6818f4 100644 --- a/internal/praefect/nodes/sql_elector_test.go +++ b/internal/praefect/nodes/sql_elector_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore/glsql" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry" @@ -430,8 +431,11 @@ func TestConnectionMultiplexing(t *testing.T) { logger := testhelper.DiscardTestEntry(t) + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil)) + srv := grpc.NewServer( - grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), backchannel.NewRegistry(), nil)), + grpc.Creds(lm), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { _, err := backchannel.GetPeerID(stream.Context()) if err == backchannel.ErrNonMultiplexedConnection { diff --git a/internal/praefect/server_test.go b/internal/praefect/server_test.go index a7fb0a1ec..5635df00e 100644 --- a/internal/praefect/server_test.go +++ b/internal/praefect/server_test.go @@ -26,6 +26,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/setup" "gitlab.com/gitlab-org/gitaly/v14/internal/helper" "gitlab.com/gitlab-org/gitaly/v14/internal/helper/text" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/grpc-proxy/proxy" @@ -57,8 +58,12 @@ func TestNewBackchannelServerFactory(t *testing.T) { logger := testhelper.DiscardTestEntry(t) registry := backchannel.NewRegistry() + + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(backchannel.NewServerHandshaker(logger, registry, nil)) + server := grpc.NewServer( - grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), registry, nil)), + grpc.Creds(lm), grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error { id, err := backchannel.GetPeerID(stream.Context()) if !assert.NoError(t, err) { |