diff options
author | Alessio Caiazza <acaiazza@gitlab.com> | 2019-06-04 12:02:54 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2019-06-04 12:02:54 +0300 |
commit | 8065040919302aa6cdefc978078a95610a0747f5 (patch) | |
tree | ae4660e3a808b001f0c7e61b5428bc6338ba1776 /internal | |
parent | 9855c59db137a607a60366d1ad23a3834a2e9de7 (diff) |
Wait for all the socket on graceful restart
We terminate gitaly on the first failure from a listening socket,
but on a graceful restart we must wait for all the socket to properly
terminate active connections.
This is a complete refactoring that introduces the bootstrap package
with a proper test coverage
Diffstat (limited to 'internal')
-rw-r--r-- | internal/bootstrap/bootstrap.go | 180 | ||||
-rw-r--r-- | internal/bootstrap/bootstrap_test.go | 324 | ||||
-rw-r--r-- | internal/bootstrap/server_factory.go | 93 |
3 files changed, 597 insertions, 0 deletions
diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go new file mode 100644 index 000000000..bfb58066c --- /dev/null +++ b/internal/bootstrap/bootstrap.go @@ -0,0 +1,180 @@ +package bootstrap + +import ( + "fmt" + "net" + "os" + "os/signal" + "syscall" + "time" + + "github.com/cloudflare/tableflip" + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/config" +) + +// Bootstrap handles graceful upgrades +type Bootstrap struct { + // StopAction will be invoked during a graceful stop. It must wait until the shutdown is completed + StopAction func() + + upgrader upgrader + listenFunc ListenFunc + errChan chan error + starters []Starter +} + +type upgrader interface { + Exit() <-chan struct{} + HasParent() bool + Ready() error + Upgrade() error +} + +// New performs tableflip initialization +// pidFile is optional, if provided it will always contain the current process PID +// upgradesEnabled controls the upgrade process on SIGHUP signal +// +// first boot: +// * gitaly starts as usual, we will refer to it as p1 +// * New will build a tableflip.Upgrader, we will refer to it as upg +// * sockets and files must be opened with upg.Fds +// * p1 will trap SIGHUP and invoke upg.Upgrade() +// * when ready to accept incoming connections p1 will call upg.Ready() +// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate +// +// graceful upgrade: +// * user replaces gitaly binary and/or config file +// * user sends SIGHUP to p1 +// * p1 will fork and exec the new gitaly, we will refer to it as p2 +// * from now on p1 will ignore other SIGHUP +// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored +// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available +// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed +// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections +// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome +// freezes during a graceful shutdown +func New(pidFile string, upgradesEnabled bool) (*Bootstrap, error) { + // PIDFile is optional, if provided tableflip will keep it updated + upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile}) + if err != nil { + return nil, err + } + + return _new(upg, upg.Fds.Listen, upgradesEnabled) +} + +func _new(upg upgrader, listenFunc ListenFunc, upgradesEnabled bool) (*Bootstrap, error) { + if upgradesEnabled { + go func() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGHUP) + + for range sig { + err := upg.Upgrade() + if err != nil { + log.WithError(err).Error("Upgrade failed") + continue + } + + log.Info("Upgrade succeeded") + } + }() + } + + return &Bootstrap{ + upgrader: upg, + listenFunc: listenFunc, + }, nil +} + +// ListenFunc is a net.Listener factory +type ListenFunc func(net, addr string) (net.Listener, error) + +// Starter is function to initialize a net.Listener +// it receives a ListenFunc to be used for net.Listener creation and a chan<- error to signal runtime errors +// It must serve incoming connections asynchronously and signal errors on the channel +// the return value is for setup errors +type Starter func(ListenFunc, chan<- error) error + +func (b *Bootstrap) isFirstBoot() bool { return !b.upgrader.HasParent() } + +// RegisterStarter adds a new starter +func (b *Bootstrap) RegisterStarter(starter Starter) { + b.starters = append(b.starters, starter) +} + +// Start will invoke all the registered starters and wait asynchronously for runtime errors +// in case a Starter fails then the error is returned and the function is aborted +func (b *Bootstrap) Start() error { + b.errChan = make(chan error, len(b.starters)) + + for _, start := range b.starters { + if err := start(b.listen, b.errChan); err != nil { + return err + } + } + + return nil +} + +// Wait will signal process readiness to the parent and than wait for an exit condition +// SIGTERM, SIGINT and a runtime error will trigger an immediate shutdown +// in case of an upgrade there will be a grace period to complete the ongoing requests +func (b *Bootstrap) Wait() error { + signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT} + immediateShutdown := make(chan os.Signal, len(signals)) + signal.Notify(immediateShutdown, signals...) + + if err := b.upgrader.Ready(); err != nil { + return err + } + + var err error + select { + case <-b.upgrader.Exit(): + // this is the old process and a graceful upgrade is in progress + // the new process signaled its readiness and we started a graceful stop + // however no further upgrades can be started until this process is running + // we set a grace period and then we force a termination. + waitError := b.waitGracePeriod(immediateShutdown) + + err = fmt.Errorf("graceful upgrade: %v", waitError) + case s := <-immediateShutdown: + err = fmt.Errorf("received signal %q", s) + case err = <-b.errChan: + } + + return err +} + +func (b *Bootstrap) waitGracePeriod(kill <-chan os.Signal) error { + log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period") + + allServersDone := make(chan struct{}) + go func() { + if b.StopAction != nil { + b.StopAction() + } + close(allServersDone) + }() + + select { + case <-time.After(config.Config.GracefulRestartTimeout): + return fmt.Errorf("grace period expired") + case <-kill: + return fmt.Errorf("force shutdown") + case <-allServersDone: + return fmt.Errorf("completed") + } +} + +func (b *Bootstrap) listen(network, path string) (net.Listener, error) { + if network == "unix" && b.isFirstBoot() { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return nil, err + } + } + + return b.listenFunc(network, path) +} diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go new file mode 100644 index 000000000..78cba7ac9 --- /dev/null +++ b/internal/bootstrap/bootstrap_test.go @@ -0,0 +1,324 @@ +package bootstrap + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "path" + "strconv" + "syscall" + "testing" + "time" + + "gitlab.com/gitlab-org/gitaly/internal/config" + + "github.com/stretchr/testify/require" +) + +var testConfigGracefulRestartTimeout = 2 * time.Second + +type mockUpgrader struct { + exit chan struct{} + hasParent bool +} + +func (m *mockUpgrader) Exit() <-chan struct{} { + return m.exit +} + +func (m *mockUpgrader) HasParent() bool { + return m.hasParent +} + +func (m *mockUpgrader) Ready() error { return nil } + +func (m *mockUpgrader) Upgrade() error { + // to upgrade we close the exit channel + close(m.exit) + return nil +} + +type testServer struct { + server *http.Server + listeners map[string]net.Listener + url string +} + +func (s *testServer) slowRequest(duration time.Duration) <-chan error { + done := make(chan error) + + go func() { + r, err := http.Get(fmt.Sprintf("%sslow?seconds=%d", s.url, int(duration.Seconds()))) + if r != nil { + r.Body.Close() + } + + done <- err + }() + + return done +} + +func TestCreateUnixListener(t *testing.T) { + socketPath := path.Join(os.TempDir(), "gitaly-test-unix-socket") + if err := os.Remove(socketPath); err != nil { + require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err) + } + + // simulate a dangling socket + require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755)) + + listen := func(network, addr string) (net.Listener, error) { + require.Equal(t, "unix", network) + require.Equal(t, socketPath, addr) + + return net.Listen(network, addr) + } + u := &mockUpgrader{} + b, err := _new(u, listen, false) + require.NoError(t, err) + + // first boot + l, err := b.listen("unix", socketPath) + require.NoError(t, err, "failed to bind on first boot") + require.NoError(t, l.Close()) + + // simulate binding during an upgrade + u.hasParent = true + l, err = b.listen("unix", socketPath) + require.NoError(t, err, "failed to bind on upgrade") + require.NoError(t, l.Close()) +} + +func waitWithTimeout(t *testing.T, waitCh <-chan error, timeout time.Duration) error { + select { + case <-time.After(timeout): + t.Fatal("time out waiting for waitCh") + case waitErr := <-waitCh: + return waitErr + } + + return nil +} + +func TestImmediateTerminationOnSocketError(t *testing.T) { + b, server := makeBootstrap(t) + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + require.NoError(t, server.listeners["tcp"].Close(), "Closing first listener") + + err := waitWithTimeout(t, waitCh, 1*time.Second) + require.Error(t, err) + require.Contains(t, err.Error(), "use of closed network connection") +} + +func TestImmediateTerminationOnSignal(t *testing.T) { + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + t.Run(sig.String(), func(t *testing.T) { + b, server := makeBootstrap(t) + + done := server.slowRequest(3 * time.Minute) + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + // make sure we are inside b.Wait() or we'll kill the test suite + time.Sleep(100 * time.Millisecond) + + self, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + require.NoError(t, self.Signal(sig)) + + waitErr := waitWithTimeout(t, waitCh, 1*time.Second) + require.Error(t, waitErr) + require.Contains(t, waitErr.Error(), "received signal") + require.Contains(t, waitErr.Error(), sig.String()) + + server.server.Close() + + require.Error(t, <-done) + }) + } +} + +func TestGracefulTerminationStuck(t *testing.T) { + b, server := makeBootstrap(t) + + err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil) + require.Contains(t, err.Error(), "grace period expired") +} + +func TestGracefulTerminationWithSignals(t *testing.T) { + self, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + t.Run(sig.String(), func(t *testing.T) { + b, server := makeBootstrap(t) + + err := testGracefulUpdate(t, server, b, 1*time.Second, func() { + require.NoError(t, self.Signal(sig)) + }) + require.Contains(t, err.Error(), "force shutdown") + }) + } +} + +func TestGracefulTerminationServerErrors(t *testing.T) { + b, server := makeBootstrap(t) + + done := make(chan error, 1) + // This is a simulation of receiving a listener error during waitGracePeriod + b.StopAction = func() { + // we close the unix listener in order to test that the shutdown will not fail, but it keep waiting for the TCP request + require.NoError(t, server.listeners["unix"].Close()) + + // we start a new TCP request that if faster than the grace period + req := server.slowRequest(config.Config.GracefulRestartTimeout / 2) + done <- <-req + close(done) + + server.server.Shutdown(context.Background()) + } + + err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil) + require.Contains(t, err.Error(), "grace period expired") + + require.NoError(t, <-done) +} + +func TestGracefulTermination(t *testing.T) { + b, server := makeBootstrap(t) + + // Using server.Close we bypass the graceful shutdown faking a completed shutdown + b.StopAction = func() { server.server.Close() } + + err := testGracefulUpdate(t, server, b, 1*time.Second, nil) + require.Contains(t, err.Error(), "completed") +} + +func testGracefulUpdate(t *testing.T, server *testServer, b *Bootstrap, waitTimeout time.Duration, duringGracePeriodCallback func()) error { + defer func(oldVal time.Duration) { + config.Config.GracefulRestartTimeout = oldVal + }(config.Config.GracefulRestartTimeout) + config.Config.GracefulRestartTimeout = testConfigGracefulRestartTimeout + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + // Start a slow request to keep the old server from shutting down immediately. + req := server.slowRequest(2 * config.Config.GracefulRestartTimeout) + + // make sure slow request is being handled + time.Sleep(100 * time.Millisecond) + + // Simulate an upgrade request after entering into the blocking b.Wait() and during the slowRequest execution + b.upgrader.Upgrade() + + if duringGracePeriodCallback != nil { + // make sure we are on the grace period + time.Sleep(100 * time.Millisecond) + + duringGracePeriodCallback() + } + + waitErr := waitWithTimeout(t, waitCh, waitTimeout) + require.Error(t, waitErr) + require.Contains(t, waitErr.Error(), "graceful upgrade") + + server.server.Close() + + clientErr := waitWithTimeout(t, req, 1*time.Second) + require.Error(t, clientErr, "slow request not terminated after the grace period") + + return waitErr +} + +func makeBootstrap(t *testing.T) (*Bootstrap, *testServer) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + }) + mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { + sec, err := strconv.Atoi(r.URL.Query().Get("seconds")) + require.NoError(t, err) + + t.Logf("Serving a slow request for %d seconds", sec) + time.Sleep(time.Duration(sec) * time.Second) + + w.WriteHeader(200) + }) + + s := http.Server{Handler: mux} + u := &mockUpgrader{exit: make(chan struct{})} + + b, err := _new(u, net.Listen, false) + require.NoError(t, err) + + b.StopAction = func() { s.Shutdown(context.Background()) } + + listeners := make(map[string]net.Listener) + start := func(network, address string) Starter { + return func(listen ListenFunc, errors chan<- error) error { + l, err := listen(network, address) + if err != nil { + return err + } + listeners[network] = l + + go func() { + errors <- s.Serve(l) + }() + + return nil + } + } + + for network, address := range map[string]string{ + "tcp": "127.0.0.1:0", + "unix": path.Join(os.TempDir(), "gitaly-test-unix-socket"), + } { + b.RegisterStarter(start(network, address)) + } + + require.NoError(t, b.Start()) + require.Equal(t, 2, len(listeners)) + + // test connection + testAllListeners(t, listeners) + + addr := listeners["tcp"].Addr() + url := fmt.Sprintf("http://%s/", addr.String()) + + return b, &testServer{ + server: &s, + listeners: listeners, + url: url, + } +} + +func testAllListeners(t *testing.T, listeners map[string]net.Listener) { + for network, listener := range listeners { + addr := listener.Addr().String() + + // overriding Client.Transport.Dial we can connect to TCP and UNIX sockets + client := &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial(network, addr) + }, + }, + } + + // we don't need a real address because we forced it on Dial + r, err := client.Get("http://fakeHost/") + require.NoError(t, err) + r.Body.Close() + require.Equal(t, 200, r.StatusCode) + } +} diff --git a/internal/bootstrap/server_factory.go b/internal/bootstrap/server_factory.go new file mode 100644 index 000000000..f1bfa7624 --- /dev/null +++ b/internal/bootstrap/server_factory.go @@ -0,0 +1,93 @@ +package bootstrap + +import ( + "net" + "sync" + + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/rubyserver" + "gitlab.com/gitlab-org/gitaly/internal/server" + "google.golang.org/grpc" +) + +type serverFactory struct { + ruby *rubyserver.Server + secure, insecure *grpc.Server +} + +// GracefulStoppableServer allows to serve contents on a net.Listener, Stop serving and performing a GracefulStop +type GracefulStoppableServer interface { + GracefulStop() + Stop() + Serve(l net.Listener, secure bool) error +} + +// NewServerFactory initializes a rubyserver and then lazily initializes both secure and insecure grpc.Server +func NewServerFactory() (GracefulStoppableServer, error) { + ruby, err := rubyserver.Start() + if err != nil { + log.Error("start ruby server") + + return nil, err + } + + return &serverFactory{ruby: ruby}, nil +} + +func (s *serverFactory) Stop() { + for _, srv := range s.all() { + srv.Stop() + } + + s.ruby.Stop() +} + +func (s *serverFactory) GracefulStop() { + wg := sync.WaitGroup{} + + for _, srv := range s.all() { + wg.Add(1) + + go func(s *grpc.Server) { + s.GracefulStop() + wg.Done() + }(srv) + } + + wg.Wait() +} + +func (s *serverFactory) Serve(l net.Listener, secure bool) error { + srv := s.get(secure) + + return srv.Serve(l) +} + +func (s *serverFactory) get(secure bool) *grpc.Server { + if secure { + if s.secure == nil { + s.secure = server.NewSecure(s.ruby) + } + + return s.secure + } + + if s.insecure == nil { + s.insecure = server.NewInsecure(s.ruby) + } + + return s.insecure +} + +func (s *serverFactory) all() []*grpc.Server { + var servers []*grpc.Server + if s.secure != nil { + servers = append(servers, s.secure) + } + + if s.insecure != nil { + servers = append(servers, s.insecure) + } + + return servers +} |