diff options
author | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-07-21 07:20:54 +0300 |
---|---|---|
committer | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-07-22 13:36:00 +0300 |
commit | bd0e4759b615c9024e9c3608c6cdf021c5bffa14 (patch) | |
tree | 16cf4bb0e44d289a1fedeb206a5576f58514c4e2 | |
parent | 9aa765051521550dd95193cc1bdbfad331d8b6b3 (diff) |
Implement StreamRPC Server Stop and GracefulStopqmnguyen0711/1181-draft-handle-streamrpc-server-graceful-shutdown
Issue: https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1181
Changelog: added
-rw-r--r-- | internal/gitaly/server/server_factory.go | 31 | ||||
-rw-r--r-- | internal/gitaly/server/server_factory_test.go | 170 | ||||
-rw-r--r-- | internal/streamrpc/rpc_test.go | 71 | ||||
-rw-r--r-- | internal/streamrpc/server.go | 88 |
4 files changed, 317 insertions, 43 deletions
diff --git a/internal/gitaly/server/server_factory.go b/internal/gitaly/server/server_factory.go index bae8fad63..38b0db51f 100644 --- a/internal/gitaly/server/server_factory.go +++ b/internal/gitaly/server/server_factory.go @@ -22,12 +22,19 @@ import ( // GitalyServerFactory is a factory of gitaly grpc servers type GitalyServerFactory struct { - registry *backchannel.Registry - cacheInvalidator cache.Invalidator - cfg config.Cfg - logger *logrus.Entry - externalServers []*grpc.Server - internalServers []*grpc.Server + registry *backchannel.Registry + cacheInvalidator cache.Invalidator + cfg config.Cfg + logger *logrus.Entry + externalServers []stopper + externalStreamRPCServers []stopper + internalServers []stopper + internalStreamRPCServers []stopper +} + +type stopper interface { + Stop() + GracefulStop() } // NewGitalyServerFactory allows to create and start secure/insecure 'grpc.Server'-s with gitaly-ruby @@ -101,9 +108,11 @@ func (s *GitalyServerFactory) StartWorkers(ctx context.Context, l logrus.FieldLo // Stop immediately stops all servers created by the GitalyServerFactory. func (s *GitalyServerFactory) Stop() { - for _, servers := range [][]*grpc.Server{ + for _, servers := range [][]stopper{ s.externalServers, + s.externalStreamRPCServers, s.internalServers, + s.internalStreamRPCServers, } { for _, server := range servers { server.Stop() @@ -116,15 +125,17 @@ func (s *GitalyServerFactory) Stop() { // can still complete their requests to the internal servers. This is important for hooks calling // back to Gitaly. func (s *GitalyServerFactory) GracefulStop() { - for _, servers := range [][]*grpc.Server{ + for _, servers := range [][]stopper{ s.externalServers, + s.externalStreamRPCServers, s.internalServers, + s.internalStreamRPCServers, } { var wg sync.WaitGroup for _, server := range servers { wg.Add(1) - go func(server *grpc.Server) { + go func(server stopper) { defer wg.Done() server.GracefulStop() }(server) @@ -144,6 +155,7 @@ func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, *stream } s.externalServers = append(s.externalServers, grpcServer) + s.externalStreamRPCServers = append(s.externalStreamRPCServers, streamRPCServer) return grpcServer, streamRPCServer, nil } @@ -158,6 +170,7 @@ func (s *GitalyServerFactory) CreateInternal() (*grpc.Server, *streamrpc.Server, } s.internalServers = append(s.internalServers, grpcServer) + s.internalStreamRPCServers = append(s.internalStreamRPCServers, streamRPCServer) return grpcServer, streamRPCServer, nil } diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go index 34c9c46c3..df7115826 100644 --- a/internal/gitaly/server/server_factory_test.go +++ b/internal/gitaly/server/server_factory_test.go @@ -70,20 +70,28 @@ func TestGitalyServerFactory(t *testing.T) { sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg))) t.Cleanup(sf.Stop) - tcpHealthClient := check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0") + tcpHealthClient, tcpStreamRPCCall := check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0") socket := testhelper.GetTemporaryGitalySocketFileName(t) t.Cleanup(func() { require.NoError(t, os.RemoveAll(socket)) }) - socketHealthClient := check(t, ctx, sf, cfg, repo, starter.Unix, socket) + socketHealthClient, socketStreamRPCCall := check(t, ctx, sf, cfg, repo, starter.Unix, socket) sf.GracefulStop() // stops all started servers(listeners) + // gRPC requests should return errors _, tcpErr := tcpHealthClient.Check(ctx, &healthpb.HealthCheckRequest{}) require.Equal(t, codes.Unavailable, status.Code(tcpErr)) _, socketErr := socketHealthClient.Check(ctx, &healthpb.HealthCheckRequest{}) require.Equal(t, codes.Unavailable, status.Code(socketErr)) + + // StreamRPC requests should return errors as well + _, _, err := tcpStreamRPCCall() + require.Error(t, err) + + _, _, err = socketStreamRPCCall() + require.Error(t, err) }) } @@ -107,6 +115,23 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { require.Equal(t, errQuickRPC, err) } + streamRPCCallQuick := func(dial streamrpc.DialFunc, shouldSucceed bool) { + err := streamrpc.Call( + ctx, + dial, + "/Service/Quick", + &gitalypb.TestStreamRequest{}, + func(c net.Conn) error { + return nil + }, + ) + if !shouldSucceed { + require.Error(t, err) + return + } + + require.EqualError(t, err, errQuickRPC.Error()) + } invokeBlocking := func(conn *grpc.ClientConn) chan struct{} { rpcFinished := make(chan struct{}) @@ -122,6 +147,25 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { return rpcFinished } + streamRPCCallBlocking := func(dial streamrpc.DialFunc) chan struct{} { + streamFinished := make(chan struct{}) + + go func() { + defer close(streamFinished) + err := streamrpc.Call( + ctx, + dial, + "/Service/Blocking", + &gitalypb.TestStreamRequest{}, + func(c net.Conn) error { + return nil + }, + ) + require.EqualError(t, err, errBlockingRPC.Error()) + }() + return streamFinished + } + waitUntilFailure := func(conn *grpc.ClientConn) { for { err := conn.Invoke(ctx, "/Service/Quick", &healthpb.HealthCheckRequest{}, &healthpb.HealthCheckRequest{}) @@ -134,44 +178,85 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { } } + waitUntilStreamRPCFailure := func(dial streamrpc.DialFunc) { + for { + err := streamrpc.Call( + ctx, + dial, + "/Service/Quick", + &gitalypb.TestStreamRequest{}, + func(c net.Conn) error { + return nil + }, + ) + if err != nil && err.Error() == errQuickRPC.Error() { + continue + } + + require.Error(t, err) + break + } + } + var internalConn, externalConn *grpc.ClientConn var internalIsBlocking, externalIsBlocking chan struct{} + + var internalStreamRPCDial, externalStreamRPCDial streamrpc.DialFunc + var internalStreamRPCIsBlocking, externalStreamRPCIsBlocking chan struct{} + var releaseInternalBlock, releaseExternalBlock chan struct{} + var releaseInternalStreamRPCBlock, releaseExternalStreamRPCBlock chan struct{} + for _, builder := range []struct { - createServer func() *grpc.Server - conn **grpc.ClientConn - isBlocking *chan struct{} - releaseBlock *chan struct{} + createServer func() (*grpc.Server, *streamrpc.Server) + conn **grpc.ClientConn + isBlocking *chan struct{} + releaseBlock *chan struct{} + streamRPCIsBlocking *chan struct{} + streamRPCReleaseBlock *chan struct{} + streamRPCDial *streamrpc.DialFunc }{ { - createServer: func() *grpc.Server { - server, _, err := sf.CreateInternal() + createServer: func() (*grpc.Server, *streamrpc.Server) { + server, streamRPCServer, err := sf.CreateInternal() require.NoError(t, err) - return server + return server, streamRPCServer }, - conn: &internalConn, - isBlocking: &internalIsBlocking, - releaseBlock: &releaseInternalBlock, + conn: &internalConn, + isBlocking: &internalIsBlocking, + releaseBlock: &releaseInternalBlock, + streamRPCIsBlocking: &internalStreamRPCIsBlocking, + streamRPCReleaseBlock: &releaseInternalStreamRPCBlock, + streamRPCDial: &internalStreamRPCDial, }, { - createServer: func() *grpc.Server { - server, _, err := sf.CreateExternal(false) + createServer: func() (*grpc.Server, *streamrpc.Server) { + server, streamRPCServer, err := sf.CreateExternal(false) require.NoError(t, err) - return server + return server, streamRPCServer }, - conn: &externalConn, - isBlocking: &externalIsBlocking, - releaseBlock: &releaseExternalBlock, + conn: &externalConn, + isBlocking: &externalIsBlocking, + releaseBlock: &releaseExternalBlock, + streamRPCIsBlocking: &externalStreamRPCIsBlocking, + streamRPCReleaseBlock: &releaseExternalStreamRPCBlock, + streamRPCDial: &externalStreamRPCDial, }, } { - server := builder.createServer() + server, streamRPCServer := builder.createServer() releaseBlock := make(chan struct{}) *builder.releaseBlock = releaseBlock + streamRPCReleaseBlock := make(chan struct{}) + *builder.streamRPCReleaseBlock = streamRPCReleaseBlock + isBlocking := make(chan struct{}) *builder.isBlocking = isBlocking + streamRPCIsBlocking := make(chan struct{}) + *builder.streamRPCIsBlocking = streamRPCIsBlocking + server.RegisterService(&grpc.ServiceDesc{ ServiceName: "Service", Methods: []grpc.MethodDesc{ @@ -193,6 +278,28 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { HandlerType: (*interface{})(nil), }, server) + streamRPCServer.RegisterService(&grpc.ServiceDesc{ + ServiceName: "Service", + Methods: []grpc.MethodDesc{ + { + MethodName: "Quick", + Handler: func(interface{}, context.Context, func(interface{}) error, grpc.UnaryServerInterceptor) (interface{}, error) { + return nil, errQuickRPC + }, + }, + { + MethodName: "Blocking", + Handler: func(interface{}, context.Context, func(interface{}) error, grpc.UnaryServerInterceptor) (interface{}, error) { + close(streamRPCIsBlocking) + _, _ = streamrpc.AcceptConnection(ctx) + <-streamRPCReleaseBlock + return nil, errBlockingRPC + }, + }, + }, + HandlerType: (*interface{})(nil), + }, streamRPCServer) + ln, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) defer ln.Close() @@ -201,15 +308,21 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { *builder.conn, err = grpc.DialContext(ctx, ln.Addr().String(), grpc.WithInsecure()) require.NoError(t, err) + + *builder.streamRPCDial = streamrpc.DialNet("tcp://" + ln.Addr().String()) } // both servers should be up and accepting RPCs invokeQuick(externalConn, true) invokeQuick(internalConn, true) + streamRPCCallQuick(internalStreamRPCDial, true) + streamRPCCallQuick(externalStreamRPCDial, true) // invoke a blocking RPC on the external server to block the graceful shutdown invokeBlocking(externalConn) <-externalIsBlocking + streamRPCCallBlocking(externalStreamRPCDial) + <-externalStreamRPCIsBlocking shutdownCompeleted := make(chan struct{}) go func() { @@ -220,35 +333,45 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) { // wait until the graceful shutdown is in progress and new RPCs are no longer accepted on the // external servers waitUntilFailure(externalConn) + waitUntilStreamRPCFailure(externalStreamRPCDial) // internal sockets should still accept RPCs even if external sockets are gracefully closing. invokeQuick(internalConn, true) + streamRPCCallQuick(internalStreamRPCDial, true) // block on the internal server internalBlockingRPCFinished := invokeBlocking(internalConn) <-internalIsBlocking + internalBlockingStreamRPCFinished := streamRPCCallBlocking(internalStreamRPCDial) + <-internalStreamRPCIsBlocking // release the external server's blocking RPC so the graceful shutdown can complete and proceed to // shutting down the internal servers. close(releaseExternalBlock) + close(releaseExternalStreamRPCBlock) // wait until the graceful shutdown is in progress and new RPCs are no longer accepted on the internal // servers waitUntilFailure(internalConn) + waitUntilStreamRPCFailure(externalStreamRPCDial) // neither internal nor external servers should be accepting new RPCs anymore invokeQuick(externalConn, false) invokeQuick(internalConn, false) + streamRPCCallQuick(internalStreamRPCDial, false) + streamRPCCallQuick(externalStreamRPCDial, false) // wait until the blocking rpc has successfully completed close(releaseInternalBlock) <-internalBlockingRPCFinished + close(releaseInternalStreamRPCBlock) + <-internalBlockingStreamRPCFinished // wait until the graceful shutdown completes <-shutdownCompeleted } -func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg config.Cfg, repo *gitalypb.Repository, schema, addr string) healthpb.HealthClient { +func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg config.Cfg, repo *gitalypb.Repository, schema, addr string) (healthpb.HealthClient, func() ([]byte, []byte, error)) { t.Helper() var grpcConn *grpc.ClientConn @@ -308,11 +431,14 @@ func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg confi require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status) // Make a streamRPC call - in, out, err := checkStreamRPC(t, streamRPCDial, repo) + streamRPCCall := func() ([]byte, []byte, error) { + return checkStreamRPC(t, streamRPCDial, repo) + } + in, out, err := streamRPCCall() require.NoError(t, err) require.Equal(t, in, out, "byte stream works") - return healthClient + return healthClient, streamRPCCall } func registerStreamRPCServers(t *testing.T, srv *streamrpc.Server, cfg config.Cfg) { diff --git a/internal/streamrpc/rpc_test.go b/internal/streamrpc/rpc_test.go index c93448036..8e974ae11 100644 --- a/internal/streamrpc/rpc_test.go +++ b/internal/streamrpc/rpc_test.go @@ -290,11 +290,80 @@ func TestCall_credentials(t *testing.T) { require.Equal(t, inputs, receivedValues) } +func TestServer_Stop(t *testing.T) { + setup := func() (*Server, func() error, chan struct{}, chan struct{}) { + isBlocking := make(chan struct{}) + releaseBlock := make(chan struct{}) + + handler := func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + close(isBlocking) + _, err := AcceptConnection(ctx) + <-releaseBlock + return nil, err + } + + streamRPCServer := NewServer() + testpb.RegisterTestServer(streamRPCServer, &server{testHandler: handler}) + call := func() error { + return Call( + context.Background(), + startServer(t, streamRPCServer, nil), + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{}, + func(c net.Conn) error { + return nil + }, + ) + } + return streamRPCServer, call, isBlocking, releaseBlock + } + + t.Run("normal_stop", func(t *testing.T) { + streamRPCServer, call, isBlocking, releaseBlock := setup() + + callErrors := make(chan error, 1) + go func() { callErrors <- call() }() + + <-isBlocking + streamRPCServer.Stop() + close(releaseBlock) + + require.Error(t, <-callErrors) + require.Error(t, call()) + }) + + t.Run("graceful_stop", func(t *testing.T) { + streamRPCServer, call, isBlocking, releaseBlock := setup() + + callErrors := make(chan error) + go func() { callErrors <- call() }() + + <-isBlocking + shutdownFinished := make(chan struct{}) + go func() { + streamRPCServer.GracefulStop() + close(shutdownFinished) + }() + close(releaseBlock) + + require.NoError(t, <-callErrors) + <-shutdownFinished + require.Error(t, call()) + }) +} + func startServer(t *testing.T, s *Server, th testHandler) DialFunc { t.Helper() - testpb.RegisterTestServer(s, &server{testHandler: th}) + if th != nil { + testpb.RegisterTestServer(s, &server{testHandler: th}) + } client, server := socketPair(t) go func() { _ = s.Handle(server) }() + t.Cleanup(func() { + client.Close() + server.Close() + }) + return func(time.Duration) (net.Conn, error) { return client, nil } } diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go index 2c4771af5..021cdffc4 100644 --- a/internal/streamrpc/server.go +++ b/internal/streamrpc/server.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/golang/protobuf/proto" @@ -19,6 +20,10 @@ var _ grpc.ServiceRegistrar = &Server{} type Server struct { methods map[string]*method interceptor grpc.UnaryServerInterceptor + stopped bool + sessions map[*serverSession]bool + handleMu sync.Mutex + handleCond *sync.Cond } type method struct { @@ -34,8 +39,10 @@ type ServerOption func(*Server) // grpc-go RegisterFooServer functions. func NewServer(opts ...ServerOption) *Server { s := &Server{ - methods: make(map[string]*method), + methods: make(map[string]*method), + sessions: make(map[*serverSession]bool), } + s.handleCond = sync.NewCond(&s.handleMu) for _, o := range opts { o(s) } @@ -61,22 +68,81 @@ func (s *Server) UseInterceptor(interceptor grpc.UnaryServerInterceptor) { s.interceptor = interceptor } +func (s *Server) addSession(session *serverSession) error { + s.handleMu.Lock() + if s.stopped { + s.handleMu.Unlock() + return fmt.Errorf("streamrpc: server already stopped") + } + s.sessions[session] = true + s.handleMu.Unlock() + + s.handleCond.Broadcast() + + return nil +} + +func (s *Server) removeSession(session *serverSession) { + s.handleMu.Lock() + session.C.Close() + delete(s.sessions, session) + s.handleMu.Unlock() + + s.handleCond.Broadcast() +} + +// Stop stops StreamRPC server. It immediately stops all in-flight sessions, +// and prevents any further call in the future. +func (s *Server) Stop() { + s.handleMu.Lock() + defer s.handleMu.Unlock() + + if s.stopped { + return + } + + for session := range s.sessions { + session.C.Close() + delete(s.sessions, session) + } + s.stopped = true +} + +// GracefulStop stops StreamRPC server gracefully. It prevents the server from +// accepting new calls, and blocks until all pending calls finish. +func (s *Server) GracefulStop() { + s.handleMu.Lock() + defer s.handleMu.Unlock() + + if s.stopped { + return + } + + for len(s.sessions) > 0 { + s.handleCond.Wait() + } + s.stopped = true +} + // Handle handles an incoming network connection with the StreamRPC // protocol. It is intended to be called from a net.Listener.Accept loop // (or something equivalent). func (s *Server) Handle(c net.Conn) error { - defer c.Close() - deadline := time.Now().Add(defaultHandshakeTimeout) + session := &serverSession{ + C: c, + deadline: deadline, + } + if err := s.addSession(session); err != nil { + return err + } + defer s.removeSession(session) + req, err := recvFrame(c, deadline) if err != nil { return err } - session := &serverSession{ - c: c, - deadline: deadline, - } if err := s.handleSession(session, req); err != nil { return session.reject(err) } @@ -133,7 +199,7 @@ func AcceptConnection(ctx context.Context) (net.Conn, error) { // serverSession wraps an incoming connection whose handshake has not // been completed yet. type serverSession struct { - c net.Conn + C net.Conn accepted bool deadline time.Time } @@ -146,11 +212,11 @@ func (ss *serverSession) Accept() (net.Conn, error) { } ss.accepted = true - if err := sendFrame(ss.c, nil, ss.deadline); err != nil { + if err := sendFrame(ss.C, nil, ss.deadline); err != nil { return nil, fmt.Errorf("accept session: %w", err) } - return ss.c, nil + return ss.C, nil } func (ss *serverSession) reject(err error) error { @@ -163,5 +229,5 @@ func (ss *serverSession) reject(err error) error { return fmt.Errorf("mashal response: %w", err) } - return sendFrame(ss.c, buf, ss.deadline) + return sendFrame(ss.C, buf, ss.deadline) } |