Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2021-06-15 15:33:18 +0300
committerJacob Vosmaer <jacob@gitlab.com>2021-06-30 18:19:30 +0300
commitc95298c125a680a006153d5aca5d3dbb575ce352 (patch)
tree49a270994fad3cad104ce8e47fcb06884070a4cd
parent82a7a8e90f5bf3f0cae18d158a28eb8a7a1693c6 (diff)
Separate listenmux from backchannel
Changelog: other
-rw-r--r--internal/backchannel/backchannel.go12
-rw-r--r--internal/backchannel/backchannel_example_test.go6
-rw-r--r--internal/backchannel/backchannel_test.go13
-rw-r--r--internal/backchannel/server.go56
-rw-r--r--internal/gitaly/client/dial_test.go6
-rw-r--r--internal/gitaly/server/server.go10
-rw-r--r--internal/listenmux/mux.go89
-rw-r--r--internal/listenmux/mux_test.go291
-rw-r--r--internal/praefect/nodes/sql_elector_test.go6
-rw-r--r--internal/praefect/server_test.go7
10 files changed, 432 insertions, 64 deletions
diff --git a/internal/backchannel/backchannel.go b/internal/backchannel/backchannel.go
index 2e0e919cf..a429ea877 100644
--- a/internal/backchannel/backchannel.go
+++ b/internal/backchannel/backchannel.go
@@ -14,19 +14,15 @@
// independent gRPC sessions on a single connection. This allows for dialing back to the client from
// the server to establish another gRPC session where the server and client roles are switched.
//
-// The server side supports clients that are unaware of the multiplexing. The server peeks the incoming
-// network stream to see if it starts with the magic bytes that indicate a multiplexing aware client.
-// If the magic bytes are present, the server initiates the multiplexing session and dials back to the client
-// over the already established network connection. If the magic bytes are not present, the server restores the
-// the bytes back into the original network stream and handles it without a multiplexing session.
+// The server side uses listenmux to support clients that are unaware of the multiplexing.
//
// Usage:
// 1. Implement a ServerFactory, which is simply a function that returns a Server that can serve on the backchannel
// connection. Plug in the ClientHandshake to the Clientconn via grpc.WithTransportCredentials when dialing.
// This ensures all connections established by gRPC work with a multiplexing session and have a backchannel Server serving.
-// 2. Configure the ServerHandshake on the server side by passing it into the gRPC server via the grpc.Creds option.
-// The ServerHandshake method is called on each newly established connection. It peeks the network stream to see if a
-// multiplexing session should be initiated. If so, it also dials back to the client's backchannel server. Server
+// 2. Create a *listenmux.Mux and register a *ServerHandshaker with it.
+// 3. Pass the *listenmux.Mux into the grpc Server using grpc.Creds.
+// The Handshake method is called on each newly established connection that presents the backchannel magic bytes. It dials back to the client's backchannel server. Server
// makes the backchannel connection's available later via the Registry's Backchannel method. The ID of the
// peer associated with the current RPC handler can be fetched via GetPeerID. The returned ID can be used
// to access the correct backchannel connection from the Registry.
diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go
index 0b06b4779..63f4e5c5d 100644
--- a/internal/backchannel/backchannel_example_test.go
+++ b/internal/backchannel/backchannel_example_test.go
@@ -7,6 +7,7 @@ import (
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -31,11 +32,12 @@ func Example() {
// it creates the backchannel connection and stores it into the registry. For each connection,
// the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a
// backchannel connection.
- handshaker := backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), registry, nil)
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(logger, registry, nil))
// Create the server
srv := grpc.NewServer(
- grpc.Creds(handshaker),
+ grpc.Creds(lm),
grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
fmt.Println("Gitaly received a transactional mutator")
diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go
index a599d6880..1191737b5 100644
--- a/internal/backchannel/backchannel_test.go
+++ b/internal/backchannel/backchannel_test.go
@@ -13,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testassert"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
@@ -39,9 +40,9 @@ func newLogger() *logrus.Entry {
func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
var interceptorInvoked int32
registry := NewRegistry()
- handshaker := NewServerHandshaker(
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(NewServerHandshaker(
newLogger(),
- insecure.NewCredentials(),
registry,
[]grpc.DialOption{
grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
@@ -49,13 +50,13 @@ func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
return invoker(ctx, method, req, reply, cc, opts...)
}),
},
- )
+ ))
ln, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
errNonMultiplexed := status.Error(codes.FailedPrecondition, ErrNonMultiplexedConnection.Error())
- srv := grpc.NewServer(grpc.Creds(handshaker))
+ srv := grpc.NewServer(grpc.Creds(lm))
gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{
voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
@@ -182,8 +183,10 @@ func Benchmark(b *testing.B) {
b.Run(fmt.Sprintf("message size %dkb", messageSize/1024), func(b *testing.B) {
var serverOpts []grpc.ServerOption
if tc.multiplexed {
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(NewServerHandshaker(newLogger(), NewRegistry(), nil))
serverOpts = []grpc.ServerOption{
- grpc.Creds(NewServerHandshaker(newLogger(), insecure.NewCredentials(), NewRegistry(), nil)),
+ grpc.Creds(lm),
}
}
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go
index c22b46e37..b2a1f58da 100644
--- a/internal/backchannel/server.go
+++ b/internal/backchannel/server.go
@@ -1,12 +1,9 @@
package backchannel
import (
- "bytes"
"context"
"errors"
"fmt"
- "io"
- "io/ioutil"
"net"
"github.com/hashicorp/yamux"
@@ -53,58 +50,27 @@ func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo {
type ServerHandshaker struct {
registry *Registry
logger *logrus.Entry
- credentials.TransportCredentials
dialOpts []grpc.DialOption
}
+// Magic is used by listenmux to retrieve the magic string for
+// backchannel connections.
+func (s *ServerHandshaker) Magic() string { return string(magicBytes) }
+
// NewServerHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials
// are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections.
// DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or
// transport credentials as those set by the handshaker.
-func NewServerHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials, reg *Registry, dialOpts []grpc.DialOption) credentials.TransportCredentials {
- return ServerHandshaker{
- TransportCredentials: tc,
- registry: reg,
- logger: logger,
- dialOpts: dialOpts,
- }
-}
-
-// restoredConn allows for restoring the connection's stream after peeking it. If the connection
-// was not multiplexed, the peeked bytes are restored back into the stream.
-type restoredConn struct {
- net.Conn
- reader io.Reader
+func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker {
+ return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts}
}
-func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }
-
-// ServerHandshake peeks the connection to determine whether the client supports establishing a
-// backchannel by multiplexing the network connection. If so, it establishes a gRPC ClientConn back
-// to the client and stores it's ID in the AuthInfo where it can be later accessed by the RPC handlers.
-// gRPC sets an IO timeout on the connection before calling ServerHandshake, so we don't have to handle
+// Handshake establishes a gRPC ClientConn back to the backchannel client
+// on the other side and stores its ID in the AuthInfo where it can be
+// later accessed by the RPC handlers. gRPC sets an IO timeout on the
+// connection before calling ServerHandshake, so we don't have to handle
// timeouts separately.
-func (s ServerHandshaker) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
- conn, authInfo, err := s.TransportCredentials.ServerHandshake(conn)
- if err != nil {
- return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
- }
-
- peeked, err := ioutil.ReadAll(io.LimitReader(conn, int64(len(magicBytes))))
- if err != nil {
- return nil, nil, fmt.Errorf("peek network stream: %w", err)
- }
-
- if !bytes.Equal(peeked, magicBytes) {
- // If the client connection is not multiplexed, restore the peeked bytes back into the stream.
- // We also set a 0 peer ID in the authInfo to indicate that the server handshake was attempted
- // but this was not a multiplexed connection.
- return &restoredConn{
- Conn: conn,
- reader: io.MultiReader(bytes.NewReader(peeked), conn),
- }, authInfo, nil
- }
-
+func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
// It is not necessary to clean up any of the multiplexing-related sessions on errors as the
// gRPC server closes the conn if there is an error, which closes the multiplexing
// session as well.
diff --git a/internal/gitaly/client/dial_test.go b/internal/gitaly/client/dial_test.go
index 7db99c2ff..cdd7fffa1 100644
--- a/internal/gitaly/client/dial_test.go
+++ b/internal/gitaly/client/dial_test.go
@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
@@ -21,8 +22,11 @@ func TestDial(t *testing.T) {
logger := testhelper.DiscardTestEntry(t)
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
+
srv := grpc.NewServer(
- grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), backchannel.NewRegistry(), nil)),
+ grpc.Creds(lm),
grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
_, err := backchannel.GetPeerID(stream.Context())
if err == backchannel.ErrNonMultiplexedConnection {
diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go
index a2540761b..7602cd287 100644
--- a/internal/gitaly/server/server.go
+++ b/internal/gitaly/server/server.go
@@ -16,6 +16,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/server/auth"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper/fieldextractors"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
gitalylog "gitlab.com/gitlab-org/gitaly/v14/internal/log"
"gitlab.com/gitlab-org/gitaly/v14/internal/logsanitizer"
"gitlab.com/gitlab-org/gitaly/v14/internal/middleware/cache"
@@ -94,8 +95,15 @@ func New(
})
}
+ lm := listenmux.New(transportCredentials)
+ lm.Register(backchannel.NewServerHandshaker(
+ logrusEntry,
+ registry,
+ []grpc.DialOption{client.UnaryInterceptor()},
+ ))
+
opts := []grpc.ServerOption{
- grpc.Creds(backchannel.NewServerHandshaker(logrusEntry, transportCredentials, registry, []grpc.DialOption{client.UnaryInterceptor()})),
+ grpc.Creds(lm),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
grpc_ctxtags.StreamServerInterceptor(ctxTagOpts...),
grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler
diff --git a/internal/listenmux/mux.go b/internal/listenmux/mux.go
new file mode 100644
index 000000000..ad1f468a6
--- /dev/null
+++ b/internal/listenmux/mux.go
@@ -0,0 +1,89 @@
+package listenmux
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+
+ "google.golang.org/grpc/credentials"
+)
+
+const magicLen = 11
+
+// Handshaker represents a multiplexed connection type.
+type Handshaker interface {
+ // Handshake is called with a valid Conn and AuthInfo when grpc-go
+ // accepts a connection that presents the Magic() magic bytes. From the
+ // point of view of grpc-go, it is part of the
+ // credentials.TransportCredentials.ServerHandshake callback. The return
+ // values of Handshake become the return values of
+ // credentials.TransportCredentials.ServerHandshake.
+ Handshake(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error)
+
+ // Magic returns the magic bytes that clients must send on the wire to
+ // reach this handshaker. The string must be no more than 11 bytes long.
+ // If it is longer, this handshaker will never be called.
+ Magic() string
+}
+
+var _ credentials.TransportCredentials = &Mux{}
+
+// Mux is a listener multiplexer that plugs into grpc-go as a "TransportCredentials" callback.
+type Mux struct {
+ credentials.TransportCredentials
+ handshakers map[string]Handshaker
+}
+
+// Register registers a handshaker. It is not thread-safe.
+func (m *Mux) Register(h Handshaker) {
+ if len(h.Magic()) != magicLen {
+ panic("wrong magic bytes length")
+ }
+
+ m.handshakers[h.Magic()] = h
+}
+
+// New returns a *Mux that wraps existing transport credentials. This
+// does nothing interesting unless you also call Register to add
+// handshakers to the Mux.
+func New(tc credentials.TransportCredentials) *Mux {
+ return &Mux{
+ TransportCredentials: tc,
+ handshakers: make(map[string]Handshaker),
+ }
+}
+
+// restoredConn allows for restoring the connection's stream after peeking it. If the connection
+// was not multiplexed, the peeked bytes are restored back into the stream.
+type restoredConn struct {
+ net.Conn
+ reader io.Reader
+}
+
+func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }
+
+// ServerHandshake peeks the connection to determine whether the client
+// wants to make a multiplexed connection. It is part of the
+// TransportCredentials interface.
+func (m *Mux) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := m.TransportCredentials.ServerHandshake(conn)
+ if err != nil {
+ return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
+ }
+
+ peeked, err := ioutil.ReadAll(io.LimitReader(conn, magicLen))
+ if err != nil {
+ return nil, nil, fmt.Errorf("peek network stream: %w", err)
+ }
+
+ if h, ok := m.handshakers[string(peeked)]; ok {
+ return h.Handshake(conn, authInfo)
+ }
+
+ return &restoredConn{
+ Conn: conn,
+ reader: io.MultiReader(bytes.NewReader(peeked), conn),
+ }, authInfo, nil
+}
diff --git a/internal/listenmux/mux_test.go b/internal/listenmux/mux_test.go
new file mode 100644
index 000000000..905329fbe
--- /dev/null
+++ b/internal/listenmux/mux_test.go
@@ -0,0 +1,291 @@
+package listenmux
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/health"
+ healthgrpc "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+type handshakeFunc func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error)
+
+func (hf handshakeFunc) Handshake(c net.Conn, ai credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ return hf(c, ai)
+}
+
+const testmux = "test mux "
+
+func (hf handshakeFunc) Magic() string { return testmux }
+
+func serverWithHandshaker(t *testing.T, h Handshaker) string {
+ t.Helper()
+
+ l, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+ t.Cleanup(func() { l.Close() })
+
+ tc := New(insecure.NewCredentials())
+ if h != nil {
+ tc.Register(h)
+ }
+
+ s := grpc.NewServer(
+ grpc.Creds(tc),
+ )
+ t.Cleanup(s.Stop)
+
+ healthgrpc.RegisterHealthServer(s, health.NewServer())
+
+ go func() { assert.NoError(t, s.Serve(l)) }()
+
+ return l.Addr().String()
+}
+
+func checkHealth(t *testing.T, cc *grpc.ClientConn) {
+ t.Helper()
+ _, err := healthgrpc.NewHealthClient(cc).Check(context.Background(), &healthgrpc.HealthCheckRequest{})
+ require.NoError(t, err)
+}
+
+func TestMux_normalClientNoMux(t *testing.T) {
+ addr := serverWithHandshaker(t, nil)
+
+ cc, err := grpc.Dial(addr, grpc.WithInsecure())
+ require.NoError(t, err)
+ defer cc.Close()
+
+ checkHealth(t, cc)
+}
+
+func TestMux_normalClientMuxIgnored(t *testing.T) {
+ addr := serverWithHandshaker(t,
+ handshakeFunc(func(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ t.Error("never called")
+ return nil, nil, nil
+ }),
+ )
+
+ cc, err := grpc.Dial(addr, grpc.WithInsecure())
+ require.NoError(t, err)
+ defer cc.Close()
+
+ checkHealth(t, cc)
+}
+
+func TestMux_muxClientPassesThrough(t *testing.T) {
+ handshakerCalled := false
+
+ addr := serverWithHandshaker(t,
+ handshakeFunc(func(c net.Conn, ai credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ handshakerCalled = true
+ return c, ai, nil
+ }),
+ )
+
+ cc, err := grpc.Dial(
+ "ignored",
+ grpc.WithInsecure(),
+ grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
+ c, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if _, err := io.WriteString(c, testmux); err != nil {
+ return nil, err
+ }
+
+ return c, nil
+ }),
+ )
+ require.NoError(t, err)
+ defer cc.Close()
+
+ checkHealth(t, cc)
+
+ require.True(t, handshakerCalled)
+}
+
+func readN(t *testing.T, r io.Reader, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ _, err := io.ReadFull(r, buf)
+ require.NoError(t, err)
+ return buf
+}
+
+func TestMux_handshakerStealsConnection(t *testing.T) {
+ connCh := make(chan net.Conn, 1)
+ addr := serverWithHandshaker(t,
+ handshakeFunc(func(c net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ connCh <- c
+ return nil, nil, credentials.ErrConnDispatched
+ }),
+ )
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+
+ serverConn := <-connCh
+ defer serverConn.Close()
+
+ // Give grpc-go a chance to close the connection, which it shouldn't
+ time.Sleep(100 * time.Millisecond)
+
+ ping := readN(t, serverConn, 4)
+ require.Equal(t, "ping", string(ping))
+
+ _, err := io.WriteString(serverConn, "pong")
+ require.NoError(t, err)
+ }()
+
+ c, err := net.Dial("tcp", addr)
+ require.NoError(t, err)
+ defer c.Close()
+
+ _, err = io.WriteString(c, testmux+"ping")
+ require.NoError(t, err)
+
+ pong := readN(t, c, 4)
+ require.Equal(t, "pong", string(pong))
+
+ <-done
+}
+
+func TestMux_handshakerReturnsError(t *testing.T) {
+ addr := serverWithHandshaker(t,
+ handshakeFunc(func(_ net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ return nil, nil, errors.New("something went wrong")
+ }),
+ )
+
+ c, err := net.Dial("tcp", addr)
+ require.NoError(t, err)
+ defer c.Close()
+
+ _, err = io.WriteString(c, testmux)
+ require.NoError(t, err)
+
+ require.NoError(t, c.SetDeadline(time.Now().Add(1*time.Second)))
+
+ buf := make([]byte, 1)
+ _, err = io.ReadFull(c, buf)
+ require.Equal(t, io.EOF, err, "EOF tells us that grpc-go closed the connection")
+}
+
+func TestMux_concurrency(t *testing.T) {
+ const N = 100
+
+ // We want to open a lot of network connections. Raise the limits for the
+ // process as far as we're allowed.
+ var limit syscall.Rlimit
+ require.NoError(t, syscall.Getrlimit(syscall.RLIMIT_NOFILE, &limit))
+ limit.Cur = limit.Max
+ require.NoError(t, syscall.Setrlimit(syscall.RLIMIT_NOFILE, &limit))
+
+ streamServerErrors := make(chan error, N)
+
+ addr := serverWithHandshaker(t,
+ handshakeFunc(func(c net.Conn, _ credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ go func() {
+ streamServerErrors <- func() error {
+ defer c.Close()
+ if _, err := io.Copy(c, c); err != nil {
+ return err
+ }
+ return c.Close()
+ }()
+ }()
+
+ return nil, nil, credentials.ErrConnDispatched
+ }),
+ )
+
+ start := make(chan struct{})
+
+ streamClientErrors := make(chan error, N)
+ grpcHealthErrors := make(chan error, N)
+
+ for i := 0; i < N; i++ {
+ go func() {
+ <-start
+ streamClientErrors <- func() error {
+ c, err := net.Dial("tcp", addr)
+ if err != nil {
+ return err
+ }
+ defer c.Close()
+
+ if err := c.SetDeadline(time.Now().Add(1 * time.Second)); err != nil {
+ return err
+ }
+
+ if _, err := io.WriteString(c, testmux); err != nil {
+ return err
+ }
+
+ buf := make([]byte, 128)
+ if _, err = rand.Read(buf); err != nil {
+ return err
+ }
+
+ if n, err := c.Write(buf); err != nil || n < len(buf) {
+ return fmt.Errorf("write error or short write: %w", err)
+ }
+
+ if err := c.(*net.TCPConn).CloseWrite(); err != nil {
+ return err
+ }
+
+ out, err := ioutil.ReadAll(c)
+ if err != nil {
+ return err
+ }
+ if !bytes.Equal(buf, out) {
+ return fmt.Errorf("expected %x, got %x", buf, out)
+ }
+
+ return c.Close()
+ }()
+ }()
+
+ go func() {
+ <-start
+ grpcHealthErrors <- func() error {
+ cc, err := grpc.Dial(addr, grpc.WithInsecure())
+ if err != nil {
+ return err
+ }
+ defer cc.Close()
+
+ client := healthgrpc.NewHealthClient(cc)
+ _, err = client.Check(context.Background(), &healthgrpc.HealthCheckRequest{})
+ return err
+ }()
+ }()
+ }
+
+ close(start)
+
+ for i := 0; i < N; i++ {
+ require.NoError(t, <-streamServerErrors)
+ require.NoError(t, <-streamClientErrors)
+ require.NoError(t, <-grpcHealthErrors)
+ }
+}
diff --git a/internal/praefect/nodes/sql_elector_test.go b/internal/praefect/nodes/sql_elector_test.go
index 882f744fb..31a6818f4 100644
--- a/internal/praefect/nodes/sql_elector_test.go
+++ b/internal/praefect/nodes/sql_elector_test.go
@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitaly/v14/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore/glsql"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry"
@@ -430,8 +431,11 @@ func TestConnectionMultiplexing(t *testing.T) {
logger := testhelper.DiscardTestEntry(t)
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(logger, backchannel.NewRegistry(), nil))
+
srv := grpc.NewServer(
- grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), backchannel.NewRegistry(), nil)),
+ grpc.Creds(lm),
grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
_, err := backchannel.GetPeerID(stream.Context())
if err == backchannel.ErrNonMultiplexedConnection {
diff --git a/internal/praefect/server_test.go b/internal/praefect/server_test.go
index a7fb0a1ec..5635df00e 100644
--- a/internal/praefect/server_test.go
+++ b/internal/praefect/server_test.go
@@ -26,6 +26,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/setup"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper/text"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/grpc-proxy/proxy"
@@ -57,8 +58,12 @@ func TestNewBackchannelServerFactory(t *testing.T) {
logger := testhelper.DiscardTestEntry(t)
registry := backchannel.NewRegistry()
+
+ lm := listenmux.New(insecure.NewCredentials())
+ lm.Register(backchannel.NewServerHandshaker(logger, registry, nil))
+
server := grpc.NewServer(
- grpc.Creds(backchannel.NewServerHandshaker(logger, insecure.NewCredentials(), registry, nil)),
+ grpc.Creds(lm),
grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
id, err := backchannel.GetPeerID(stream.Context())
if !assert.NoError(t, err) {