Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-21 07:20:54 +0300
committerQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-22 13:36:00 +0300
commitbd0e4759b615c9024e9c3608c6cdf021c5bffa14 (patch)
tree16cf4bb0e44d289a1fedeb206a5576f58514c4e2
parent9aa765051521550dd95193cc1bdbfad331d8b6b3 (diff)
Implement StreamRPC Server Stop and GracefulStopqmnguyen0711/1181-draft-handle-streamrpc-server-graceful-shutdown
Issue: https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1181 Changelog: added
-rw-r--r--internal/gitaly/server/server_factory.go31
-rw-r--r--internal/gitaly/server/server_factory_test.go170
-rw-r--r--internal/streamrpc/rpc_test.go71
-rw-r--r--internal/streamrpc/server.go88
4 files changed, 317 insertions, 43 deletions
diff --git a/internal/gitaly/server/server_factory.go b/internal/gitaly/server/server_factory.go
index bae8fad63..38b0db51f 100644
--- a/internal/gitaly/server/server_factory.go
+++ b/internal/gitaly/server/server_factory.go
@@ -22,12 +22,19 @@ import (
// GitalyServerFactory is a factory of gitaly grpc servers
type GitalyServerFactory struct {
- registry *backchannel.Registry
- cacheInvalidator cache.Invalidator
- cfg config.Cfg
- logger *logrus.Entry
- externalServers []*grpc.Server
- internalServers []*grpc.Server
+ registry *backchannel.Registry
+ cacheInvalidator cache.Invalidator
+ cfg config.Cfg
+ logger *logrus.Entry
+ externalServers []stopper
+ externalStreamRPCServers []stopper
+ internalServers []stopper
+ internalStreamRPCServers []stopper
+}
+
+type stopper interface {
+ Stop()
+ GracefulStop()
}
// NewGitalyServerFactory allows to create and start secure/insecure 'grpc.Server'-s with gitaly-ruby
@@ -101,9 +108,11 @@ func (s *GitalyServerFactory) StartWorkers(ctx context.Context, l logrus.FieldLo
// Stop immediately stops all servers created by the GitalyServerFactory.
func (s *GitalyServerFactory) Stop() {
- for _, servers := range [][]*grpc.Server{
+ for _, servers := range [][]stopper{
s.externalServers,
+ s.externalStreamRPCServers,
s.internalServers,
+ s.internalStreamRPCServers,
} {
for _, server := range servers {
server.Stop()
@@ -116,15 +125,17 @@ func (s *GitalyServerFactory) Stop() {
// can still complete their requests to the internal servers. This is important for hooks calling
// back to Gitaly.
func (s *GitalyServerFactory) GracefulStop() {
- for _, servers := range [][]*grpc.Server{
+ for _, servers := range [][]stopper{
s.externalServers,
+ s.externalStreamRPCServers,
s.internalServers,
+ s.internalStreamRPCServers,
} {
var wg sync.WaitGroup
for _, server := range servers {
wg.Add(1)
- go func(server *grpc.Server) {
+ go func(server stopper) {
defer wg.Done()
server.GracefulStop()
}(server)
@@ -144,6 +155,7 @@ func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, *stream
}
s.externalServers = append(s.externalServers, grpcServer)
+ s.externalStreamRPCServers = append(s.externalStreamRPCServers, streamRPCServer)
return grpcServer, streamRPCServer, nil
}
@@ -158,6 +170,7 @@ func (s *GitalyServerFactory) CreateInternal() (*grpc.Server, *streamrpc.Server,
}
s.internalServers = append(s.internalServers, grpcServer)
+ s.internalStreamRPCServers = append(s.internalStreamRPCServers, streamRPCServer)
return grpcServer, streamRPCServer, nil
}
diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go
index 34c9c46c3..df7115826 100644
--- a/internal/gitaly/server/server_factory_test.go
+++ b/internal/gitaly/server/server_factory_test.go
@@ -70,20 +70,28 @@ func TestGitalyServerFactory(t *testing.T) {
sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
t.Cleanup(sf.Stop)
- tcpHealthClient := check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0")
+ tcpHealthClient, tcpStreamRPCCall := check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0")
socket := testhelper.GetTemporaryGitalySocketFileName(t)
t.Cleanup(func() { require.NoError(t, os.RemoveAll(socket)) })
- socketHealthClient := check(t, ctx, sf, cfg, repo, starter.Unix, socket)
+ socketHealthClient, socketStreamRPCCall := check(t, ctx, sf, cfg, repo, starter.Unix, socket)
sf.GracefulStop() // stops all started servers(listeners)
+ // gRPC requests should return errors
_, tcpErr := tcpHealthClient.Check(ctx, &healthpb.HealthCheckRequest{})
require.Equal(t, codes.Unavailable, status.Code(tcpErr))
_, socketErr := socketHealthClient.Check(ctx, &healthpb.HealthCheckRequest{})
require.Equal(t, codes.Unavailable, status.Code(socketErr))
+
+ // StreamRPC requests should return errors as well
+ _, _, err := tcpStreamRPCCall()
+ require.Error(t, err)
+
+ _, _, err = socketStreamRPCCall()
+ require.Error(t, err)
})
}
@@ -107,6 +115,23 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
require.Equal(t, errQuickRPC, err)
}
+ streamRPCCallQuick := func(dial streamrpc.DialFunc, shouldSucceed bool) {
+ err := streamrpc.Call(
+ ctx,
+ dial,
+ "/Service/Quick",
+ &gitalypb.TestStreamRequest{},
+ func(c net.Conn) error {
+ return nil
+ },
+ )
+ if !shouldSucceed {
+ require.Error(t, err)
+ return
+ }
+
+ require.EqualError(t, err, errQuickRPC.Error())
+ }
invokeBlocking := func(conn *grpc.ClientConn) chan struct{} {
rpcFinished := make(chan struct{})
@@ -122,6 +147,25 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
return rpcFinished
}
+ streamRPCCallBlocking := func(dial streamrpc.DialFunc) chan struct{} {
+ streamFinished := make(chan struct{})
+
+ go func() {
+ defer close(streamFinished)
+ err := streamrpc.Call(
+ ctx,
+ dial,
+ "/Service/Blocking",
+ &gitalypb.TestStreamRequest{},
+ func(c net.Conn) error {
+ return nil
+ },
+ )
+ require.EqualError(t, err, errBlockingRPC.Error())
+ }()
+ return streamFinished
+ }
+
waitUntilFailure := func(conn *grpc.ClientConn) {
for {
err := conn.Invoke(ctx, "/Service/Quick", &healthpb.HealthCheckRequest{}, &healthpb.HealthCheckRequest{})
@@ -134,44 +178,85 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
}
}
+ waitUntilStreamRPCFailure := func(dial streamrpc.DialFunc) {
+ for {
+ err := streamrpc.Call(
+ ctx,
+ dial,
+ "/Service/Quick",
+ &gitalypb.TestStreamRequest{},
+ func(c net.Conn) error {
+ return nil
+ },
+ )
+ if err != nil && err.Error() == errQuickRPC.Error() {
+ continue
+ }
+
+ require.Error(t, err)
+ break
+ }
+ }
+
var internalConn, externalConn *grpc.ClientConn
var internalIsBlocking, externalIsBlocking chan struct{}
+
+ var internalStreamRPCDial, externalStreamRPCDial streamrpc.DialFunc
+ var internalStreamRPCIsBlocking, externalStreamRPCIsBlocking chan struct{}
+
var releaseInternalBlock, releaseExternalBlock chan struct{}
+ var releaseInternalStreamRPCBlock, releaseExternalStreamRPCBlock chan struct{}
+
for _, builder := range []struct {
- createServer func() *grpc.Server
- conn **grpc.ClientConn
- isBlocking *chan struct{}
- releaseBlock *chan struct{}
+ createServer func() (*grpc.Server, *streamrpc.Server)
+ conn **grpc.ClientConn
+ isBlocking *chan struct{}
+ releaseBlock *chan struct{}
+ streamRPCIsBlocking *chan struct{}
+ streamRPCReleaseBlock *chan struct{}
+ streamRPCDial *streamrpc.DialFunc
}{
{
- createServer: func() *grpc.Server {
- server, _, err := sf.CreateInternal()
+ createServer: func() (*grpc.Server, *streamrpc.Server) {
+ server, streamRPCServer, err := sf.CreateInternal()
require.NoError(t, err)
- return server
+ return server, streamRPCServer
},
- conn: &internalConn,
- isBlocking: &internalIsBlocking,
- releaseBlock: &releaseInternalBlock,
+ conn: &internalConn,
+ isBlocking: &internalIsBlocking,
+ releaseBlock: &releaseInternalBlock,
+ streamRPCIsBlocking: &internalStreamRPCIsBlocking,
+ streamRPCReleaseBlock: &releaseInternalStreamRPCBlock,
+ streamRPCDial: &internalStreamRPCDial,
},
{
- createServer: func() *grpc.Server {
- server, _, err := sf.CreateExternal(false)
+ createServer: func() (*grpc.Server, *streamrpc.Server) {
+ server, streamRPCServer, err := sf.CreateExternal(false)
require.NoError(t, err)
- return server
+ return server, streamRPCServer
},
- conn: &externalConn,
- isBlocking: &externalIsBlocking,
- releaseBlock: &releaseExternalBlock,
+ conn: &externalConn,
+ isBlocking: &externalIsBlocking,
+ releaseBlock: &releaseExternalBlock,
+ streamRPCIsBlocking: &externalStreamRPCIsBlocking,
+ streamRPCReleaseBlock: &releaseExternalStreamRPCBlock,
+ streamRPCDial: &externalStreamRPCDial,
},
} {
- server := builder.createServer()
+ server, streamRPCServer := builder.createServer()
releaseBlock := make(chan struct{})
*builder.releaseBlock = releaseBlock
+ streamRPCReleaseBlock := make(chan struct{})
+ *builder.streamRPCReleaseBlock = streamRPCReleaseBlock
+
isBlocking := make(chan struct{})
*builder.isBlocking = isBlocking
+ streamRPCIsBlocking := make(chan struct{})
+ *builder.streamRPCIsBlocking = streamRPCIsBlocking
+
server.RegisterService(&grpc.ServiceDesc{
ServiceName: "Service",
Methods: []grpc.MethodDesc{
@@ -193,6 +278,28 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
HandlerType: (*interface{})(nil),
}, server)
+ streamRPCServer.RegisterService(&grpc.ServiceDesc{
+ ServiceName: "Service",
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Quick",
+ Handler: func(interface{}, context.Context, func(interface{}) error, grpc.UnaryServerInterceptor) (interface{}, error) {
+ return nil, errQuickRPC
+ },
+ },
+ {
+ MethodName: "Blocking",
+ Handler: func(interface{}, context.Context, func(interface{}) error, grpc.UnaryServerInterceptor) (interface{}, error) {
+ close(streamRPCIsBlocking)
+ _, _ = streamrpc.AcceptConnection(ctx)
+ <-streamRPCReleaseBlock
+ return nil, errBlockingRPC
+ },
+ },
+ },
+ HandlerType: (*interface{})(nil),
+ }, streamRPCServer)
+
ln, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer ln.Close()
@@ -201,15 +308,21 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
*builder.conn, err = grpc.DialContext(ctx, ln.Addr().String(), grpc.WithInsecure())
require.NoError(t, err)
+
+ *builder.streamRPCDial = streamrpc.DialNet("tcp://" + ln.Addr().String())
}
// both servers should be up and accepting RPCs
invokeQuick(externalConn, true)
invokeQuick(internalConn, true)
+ streamRPCCallQuick(internalStreamRPCDial, true)
+ streamRPCCallQuick(externalStreamRPCDial, true)
// invoke a blocking RPC on the external server to block the graceful shutdown
invokeBlocking(externalConn)
<-externalIsBlocking
+ streamRPCCallBlocking(externalStreamRPCDial)
+ <-externalStreamRPCIsBlocking
shutdownCompeleted := make(chan struct{})
go func() {
@@ -220,35 +333,45 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
// wait until the graceful shutdown is in progress and new RPCs are no longer accepted on the
// external servers
waitUntilFailure(externalConn)
+ waitUntilStreamRPCFailure(externalStreamRPCDial)
// internal sockets should still accept RPCs even if external sockets are gracefully closing.
invokeQuick(internalConn, true)
+ streamRPCCallQuick(internalStreamRPCDial, true)
// block on the internal server
internalBlockingRPCFinished := invokeBlocking(internalConn)
<-internalIsBlocking
+ internalBlockingStreamRPCFinished := streamRPCCallBlocking(internalStreamRPCDial)
+ <-internalStreamRPCIsBlocking
// release the external server's blocking RPC so the graceful shutdown can complete and proceed to
// shutting down the internal servers.
close(releaseExternalBlock)
+ close(releaseExternalStreamRPCBlock)
// wait until the graceful shutdown is in progress and new RPCs are no longer accepted on the internal
// servers
waitUntilFailure(internalConn)
+ waitUntilStreamRPCFailure(externalStreamRPCDial)
// neither internal nor external servers should be accepting new RPCs anymore
invokeQuick(externalConn, false)
invokeQuick(internalConn, false)
+ streamRPCCallQuick(internalStreamRPCDial, false)
+ streamRPCCallQuick(externalStreamRPCDial, false)
// wait until the blocking rpc has successfully completed
close(releaseInternalBlock)
<-internalBlockingRPCFinished
+ close(releaseInternalStreamRPCBlock)
+ <-internalBlockingStreamRPCFinished
// wait until the graceful shutdown completes
<-shutdownCompeleted
}
-func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg config.Cfg, repo *gitalypb.Repository, schema, addr string) healthpb.HealthClient {
+func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg config.Cfg, repo *gitalypb.Repository, schema, addr string) (healthpb.HealthClient, func() ([]byte, []byte, error)) {
t.Helper()
var grpcConn *grpc.ClientConn
@@ -308,11 +431,14 @@ func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg confi
require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status)
// Make a streamRPC call
- in, out, err := checkStreamRPC(t, streamRPCDial, repo)
+ streamRPCCall := func() ([]byte, []byte, error) {
+ return checkStreamRPC(t, streamRPCDial, repo)
+ }
+ in, out, err := streamRPCCall()
require.NoError(t, err)
require.Equal(t, in, out, "byte stream works")
- return healthClient
+ return healthClient, streamRPCCall
}
func registerStreamRPCServers(t *testing.T, srv *streamrpc.Server, cfg config.Cfg) {
diff --git a/internal/streamrpc/rpc_test.go b/internal/streamrpc/rpc_test.go
index c93448036..8e974ae11 100644
--- a/internal/streamrpc/rpc_test.go
+++ b/internal/streamrpc/rpc_test.go
@@ -290,11 +290,80 @@ func TestCall_credentials(t *testing.T) {
require.Equal(t, inputs, receivedValues)
}
+func TestServer_Stop(t *testing.T) {
+ setup := func() (*Server, func() error, chan struct{}, chan struct{}) {
+ isBlocking := make(chan struct{})
+ releaseBlock := make(chan struct{})
+
+ handler := func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ close(isBlocking)
+ _, err := AcceptConnection(ctx)
+ <-releaseBlock
+ return nil, err
+ }
+
+ streamRPCServer := NewServer()
+ testpb.RegisterTestServer(streamRPCServer, &server{testHandler: handler})
+ call := func() error {
+ return Call(
+ context.Background(),
+ startServer(t, streamRPCServer, nil),
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{},
+ func(c net.Conn) error {
+ return nil
+ },
+ )
+ }
+ return streamRPCServer, call, isBlocking, releaseBlock
+ }
+
+ t.Run("normal_stop", func(t *testing.T) {
+ streamRPCServer, call, isBlocking, releaseBlock := setup()
+
+ callErrors := make(chan error, 1)
+ go func() { callErrors <- call() }()
+
+ <-isBlocking
+ streamRPCServer.Stop()
+ close(releaseBlock)
+
+ require.Error(t, <-callErrors)
+ require.Error(t, call())
+ })
+
+ t.Run("graceful_stop", func(t *testing.T) {
+ streamRPCServer, call, isBlocking, releaseBlock := setup()
+
+ callErrors := make(chan error)
+ go func() { callErrors <- call() }()
+
+ <-isBlocking
+ shutdownFinished := make(chan struct{})
+ go func() {
+ streamRPCServer.GracefulStop()
+ close(shutdownFinished)
+ }()
+ close(releaseBlock)
+
+ require.NoError(t, <-callErrors)
+ <-shutdownFinished
+ require.Error(t, call())
+ })
+}
+
func startServer(t *testing.T, s *Server, th testHandler) DialFunc {
t.Helper()
- testpb.RegisterTestServer(s, &server{testHandler: th})
+ if th != nil {
+ testpb.RegisterTestServer(s, &server{testHandler: th})
+ }
client, server := socketPair(t)
go func() { _ = s.Handle(server) }()
+ t.Cleanup(func() {
+ client.Close()
+ server.Close()
+ })
+
return func(time.Duration) (net.Conn, error) { return client, nil }
}
diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go
index 2c4771af5..021cdffc4 100644
--- a/internal/streamrpc/server.go
+++ b/internal/streamrpc/server.go
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net"
+ "sync"
"time"
"github.com/golang/protobuf/proto"
@@ -19,6 +20,10 @@ var _ grpc.ServiceRegistrar = &Server{}
type Server struct {
methods map[string]*method
interceptor grpc.UnaryServerInterceptor
+ stopped bool
+ sessions map[*serverSession]bool
+ handleMu sync.Mutex
+ handleCond *sync.Cond
}
type method struct {
@@ -34,8 +39,10 @@ type ServerOption func(*Server)
// grpc-go RegisterFooServer functions.
func NewServer(opts ...ServerOption) *Server {
s := &Server{
- methods: make(map[string]*method),
+ methods: make(map[string]*method),
+ sessions: make(map[*serverSession]bool),
}
+ s.handleCond = sync.NewCond(&s.handleMu)
for _, o := range opts {
o(s)
}
@@ -61,22 +68,81 @@ func (s *Server) UseInterceptor(interceptor grpc.UnaryServerInterceptor) {
s.interceptor = interceptor
}
+func (s *Server) addSession(session *serverSession) error {
+ s.handleMu.Lock()
+ if s.stopped {
+ s.handleMu.Unlock()
+ return fmt.Errorf("streamrpc: server already stopped")
+ }
+ s.sessions[session] = true
+ s.handleMu.Unlock()
+
+ s.handleCond.Broadcast()
+
+ return nil
+}
+
+func (s *Server) removeSession(session *serverSession) {
+ s.handleMu.Lock()
+ session.C.Close()
+ delete(s.sessions, session)
+ s.handleMu.Unlock()
+
+ s.handleCond.Broadcast()
+}
+
+// Stop stops StreamRPC server. It immediately stops all in-flight sessions,
+// and prevents any further call in the future.
+func (s *Server) Stop() {
+ s.handleMu.Lock()
+ defer s.handleMu.Unlock()
+
+ if s.stopped {
+ return
+ }
+
+ for session := range s.sessions {
+ session.C.Close()
+ delete(s.sessions, session)
+ }
+ s.stopped = true
+}
+
+// GracefulStop stops StreamRPC server gracefully. It prevents the server from
+// accepting new calls, and blocks until all pending calls finish.
+func (s *Server) GracefulStop() {
+ s.handleMu.Lock()
+ defer s.handleMu.Unlock()
+
+ if s.stopped {
+ return
+ }
+
+ for len(s.sessions) > 0 {
+ s.handleCond.Wait()
+ }
+ s.stopped = true
+}
+
// Handle handles an incoming network connection with the StreamRPC
// protocol. It is intended to be called from a net.Listener.Accept loop
// (or something equivalent).
func (s *Server) Handle(c net.Conn) error {
- defer c.Close()
-
deadline := time.Now().Add(defaultHandshakeTimeout)
+ session := &serverSession{
+ C: c,
+ deadline: deadline,
+ }
+ if err := s.addSession(session); err != nil {
+ return err
+ }
+ defer s.removeSession(session)
+
req, err := recvFrame(c, deadline)
if err != nil {
return err
}
- session := &serverSession{
- c: c,
- deadline: deadline,
- }
if err := s.handleSession(session, req); err != nil {
return session.reject(err)
}
@@ -133,7 +199,7 @@ func AcceptConnection(ctx context.Context) (net.Conn, error) {
// serverSession wraps an incoming connection whose handshake has not
// been completed yet.
type serverSession struct {
- c net.Conn
+ C net.Conn
accepted bool
deadline time.Time
}
@@ -146,11 +212,11 @@ func (ss *serverSession) Accept() (net.Conn, error) {
}
ss.accepted = true
- if err := sendFrame(ss.c, nil, ss.deadline); err != nil {
+ if err := sendFrame(ss.C, nil, ss.deadline); err != nil {
return nil, fmt.Errorf("accept session: %w", err)
}
- return ss.c, nil
+ return ss.C, nil
}
func (ss *serverSession) reject(err error) error {
@@ -163,5 +229,5 @@ func (ss *serverSession) reject(err error) error {
return fmt.Errorf("mashal response: %w", err)
}
- return sendFrame(ss.c, buf, ss.deadline)
+ return sendFrame(ss.C, buf, ss.deadline)
}