diff options
-rw-r--r-- | changelogs/unreleased/more-graceful.yml | 5 | ||||
-rw-r--r-- | cmd/gitaly/bootstrap.go | 205 | ||||
-rw-r--r-- | cmd/gitaly/bootstrap_test.go | 141 | ||||
-rw-r--r-- | cmd/gitaly/main.go | 65 | ||||
-rw-r--r-- | cmd/gitaly/starter_config.go | 45 | ||||
-rw-r--r-- | cmd/gitaly/starter_config_test.go | 38 | ||||
-rw-r--r-- | internal/bootstrap/bootstrap.go | 180 | ||||
-rw-r--r-- | internal/bootstrap/bootstrap_test.go | 324 | ||||
-rw-r--r-- | internal/bootstrap/server_factory.go | 93 |
9 files changed, 732 insertions, 364 deletions
diff --git a/changelogs/unreleased/more-graceful.yml b/changelogs/unreleased/more-graceful.yml new file mode 100644 index 000000000..0b1d1e917 --- /dev/null +++ b/changelogs/unreleased/more-graceful.yml @@ -0,0 +1,5 @@ +--- +title: Wait for all the socket to terminate during a graceful restart +merge_request: 1190 +author: +type: fixed diff --git a/cmd/gitaly/bootstrap.go b/cmd/gitaly/bootstrap.go deleted file mode 100644 index d3b464052..000000000 --- a/cmd/gitaly/bootstrap.go +++ /dev/null @@ -1,205 +0,0 @@ -package main - -import ( - "fmt" - "net" - "os" - "os/signal" - "syscall" - "time" - - "github.com/cloudflare/tableflip" - log "github.com/sirupsen/logrus" - "gitlab.com/gitlab-org/gitaly/internal/config" - "gitlab.com/gitlab-org/gitaly/internal/connectioncounter" - "gitlab.com/gitlab-org/gitaly/internal/rubyserver" - "gitlab.com/gitlab-org/gitaly/internal/server" - "google.golang.org/grpc" -) - -type bootstrap struct { - *tableflip.Upgrader - - insecureListeners []net.Listener - secureListeners []net.Listener - - serversErrors chan error -} - -// newBootstrap performs tableflip initialization -// -// first boot: -// * gitaly starts as usual, we will refer to it as p1 -// * newBootstrap will build a tableflip.Upgrader, we will refer to it as upg -// * sockets and files must be opened with upg.Fds -// * p1 will trap SIGHUP and invoke upg.Upgrade() -// * when ready to accept incoming connections p1 will call upg.Ready() -// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate -// -// graceful upgrade: -// * user replaces gitaly binary and/or config file -// * user sends SIGHUP to p1 -// * p1 will fork and exec the new gitaly, we will refer to it as p2 -// * from now on p1 will ignore other SIGHUP -// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored -// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available -// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed -// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections -// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome -// freezes during a graceful shutdown -func newBootstrap(pidFile string, upgradesEnabled bool) (*bootstrap, error) { - // PIDFile is optional, if provided tableflip will keep it updated - upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile}) - if err != nil { - return nil, err - } - - if upgradesEnabled { - go func() { - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGHUP) - - for range sig { - err := upg.Upgrade() - if err != nil { - log.WithError(err).Error("Upgrade failed") - continue - } - - log.Info("Upgrade succeeded") - } - }() - } - - return &bootstrap{Upgrader: upg}, nil -} - -func (b *bootstrap) listen() error { - if socketPath := config.Config.SocketPath; socketPath != "" { - l, err := b.createUnixListener(socketPath) - if err != nil { - return err - } - - log.WithField("address", socketPath).Info("listening on unix socket") - b.insecureListeners = append(b.insecureListeners, l) - } - - if addr := config.Config.ListenAddr; addr != "" { - l, err := b.Fds.Listen("tcp", addr) - if err != nil { - return err - } - - log.WithField("address", addr).Info("listening at tcp address") - b.insecureListeners = append(b.insecureListeners, connectioncounter.New("tcp", l)) - } - - if addr := config.Config.TLSListenAddr; addr != "" { - tlsListener, err := b.Fds.Listen("tcp", addr) - if err != nil { - return err - } - - b.secureListeners = append(b.secureListeners, connectioncounter.New("tls", tlsListener)) - } - - b.serversErrors = make(chan error, len(b.insecureListeners)+len(b.secureListeners)) - - return nil -} - -func (b *bootstrap) prometheusListener() (net.Listener, error) { - log.WithField("address", config.Config.PrometheusListenAddr).Info("starting prometheus listener") - - return b.Fds.Listen("tcp", config.Config.PrometheusListenAddr) -} - -func (b *bootstrap) run() { - signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT} - done := make(chan os.Signal, len(signals)) - signal.Notify(done, signals...) - - ruby, err := rubyserver.Start() - if err != nil { - log.WithError(err).Error("start ruby server") - return - } - defer ruby.Stop() - - if len(b.insecureListeners) > 0 { - insecureServer := server.NewInsecure(ruby) - defer insecureServer.Stop() - - serve(insecureServer, b.insecureListeners, b.Exit(), b.serversErrors) - } - - if len(b.secureListeners) > 0 { - secureServer := server.NewSecure(ruby) - defer secureServer.Stop() - - serve(secureServer, b.secureListeners, b.Exit(), b.serversErrors) - } - - if err := b.Ready(); err != nil { - log.WithError(err).Error("incomplete bootstrap") - return - } - - select { - case <-b.Exit(): - // this is the old process and a graceful upgrade is in progress - // the new process signaled its readiness and we started a graceful stop - // however no further upgrades can be started until this process is running - // we set a grace period and then we force a termination. - b.waitGracePeriod(done) - - err = fmt.Errorf("graceful upgrade") - case s := <-done: - err = fmt.Errorf("received signal %q", s) - case err = <-b.serversErrors: - } - - log.WithError(err).Error("terminating") -} - -func (b *bootstrap) waitGracePeriod(kill <-chan os.Signal) { - log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period") - - select { - case <-time.After(config.Config.GracefulRestartTimeout): - log.Error("old process stuck on termination. Grace period expired.") - case <-kill: - log.Error("force shutdown") - case <-b.serversErrors: - log.Info("graceful stop completed") - } -} - -func (b *bootstrap) createUnixListener(socketPath string) (net.Listener, error) { - if !b.HasParent() { - // During an update the unix socket exists and if we delete it tableflip will not create a new one - if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { - return nil, err - } - } - - l, err := b.Fds.Listen("unix", socketPath) - return connectioncounter.New("unix", l), err -} - -func serve(server *grpc.Server, listeners []net.Listener, done <-chan struct{}, errors chan<- error) { - go func() { - <-done - - server.GracefulStop() - }() - - for _, listener := range listeners { - // Must pass the listener as a function argument because there is a race - // between 'go' and 'for'. - go func(l net.Listener) { - errors <- server.Serve(l) - }(listener) - } -} diff --git a/cmd/gitaly/bootstrap_test.go b/cmd/gitaly/bootstrap_test.go deleted file mode 100644 index f6fdf6f90..000000000 --- a/cmd/gitaly/bootstrap_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package main - -import ( - "context" - "fmt" - "io" - "io/ioutil" - "net" - "net/http" - "os" - "path" - "strconv" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// b is global because tableflip do not allow to init more than one Upgrader per process -var b *bootstrap -var socketPath = path.Join(os.TempDir(), "test-unix-socket") - -// TestMain helps testing bootstrap. -// When invoked directly it behaves like a normal go test, but if a test performs an upgrade the children will -// avoid the test suite and start a pid HTTP server on socketPath -func TestMain(m *testing.M) { - var err error - b, err = newBootstrap("", true) - if err != nil { - panic(err) - } - - if !b.HasParent() { - // Execute test suite if there is no parent. - os.Exit(m.Run()) - } - - // this is a test suite that triggered an upgrade, we are in the children here - l, err := b.createUnixListener(socketPath) - if err != nil { - panic(err) - } - - if err := b.Ready(); err != nil { - panic(err) - } - - done := make(chan struct{}) - srv := startPidServer(done, l) - - select { - case <-done: - //no op - case <-time.After(2 * time.Minute): - srv.Close() - panic("safeguard against zombie process") - } -} - -func TestCreateUnixListener(t *testing.T) { - // simulate a dangling socket - if err := os.Remove(socketPath); err != nil { - require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err) - } - - file, err := os.OpenFile(socketPath, os.O_CREATE, 0755) - require.NoError(t, err) - require.NoError(t, file.Close()) - - require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755)) - - l, err := b.createUnixListener(socketPath) - require.NoError(t, err) - - done := make(chan struct{}) - srv := startPidServer(done, l) - defer srv.Close() - - require.NoError(t, b.Ready(), "not ready") - - myPid, err := askPid() - require.NoError(t, err) - require.Equal(t, os.Getpid(), myPid) - - // we trigger an upgrade and wait for children readiness - require.NoError(t, b.Upgrade(), "upgrade failed") - <-b.Exit() - require.NoError(t, srv.Close()) - <-done - - childPid, err := askPid() - require.NoError(t, err) - require.NotEqual(t, os.Getpid(), childPid, "this request must be handled by the children") -} - -func askPid() (int, error) { - client := &http.Client{ - Transport: &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) - }, - }, - } - - response, err := client.Get("http://unix") - if err != nil { - return 0, err - } - defer response.Body.Close() - - pid, err := ioutil.ReadAll(response.Body) - if err != nil { - return 0, err - } - - return strconv.Atoi(string(pid)) -} - -// startPidServer starts an HTTP server that returns the current PID, if running on a children it will kill itself after serving -// the first client -func startPidServer(done chan<- struct{}, l net.Listener) *http.Server { - mux := http.NewServeMux() - srv := &http.Server{Handler: mux} - - mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { - io.WriteString(w, fmt.Sprint(os.Getpid())) - - if b.HasParent() { - time.AfterFunc(1*time.Second, func() { srv.Close() }) - } - }) - - go func() { - if err := srv.Serve(l); err != http.ErrServerClosed { - fmt.Printf("Serve error: %v", err) - } - close(done) - }() - - return srv -} diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go index 10149365f..55e0fcc46 100644 --- a/cmd/gitaly/main.go +++ b/cmd/gitaly/main.go @@ -9,6 +9,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/bootstrap" "gitlab.com/gitlab-org/gitaly/internal/config" "gitlab.com/gitlab-org/gitaly/internal/git" "gitlab.com/gitlab-org/gitaly/internal/linguist" @@ -78,11 +79,10 @@ func main() { // gitaly-wrapper is supposed to set config.EnvUpgradesEnabled in order to enable graceful upgrades _, isWrapped := os.LookupEnv(config.EnvUpgradesEnabled) - b, err := newBootstrap(os.Getenv(config.EnvPidFile), isWrapped) + b, err := bootstrap.New(os.Getenv(config.EnvPidFile), isWrapped) if err != nil { log.WithError(err).Fatal("init bootstrap") } - defer b.Stop() // If invoked with -version if *flagVersion { @@ -111,30 +111,59 @@ func main() { tempdir.StartCleaning() - if err = b.listen(); err != nil { - log.WithError(err).Fatal("bootstrap failed") + log.WithError(run(b)).Error("shutting down") +} + +// Inside here we can use deferred functions. This is needed because +// log.Fatal bypasses deferred functions. +func run(b *bootstrap.Bootstrap) error { + servers, err := bootstrap.NewServerFactory() + if err != nil { + return err } + defer servers.Stop() - if config.Config.PrometheusListenAddr != "" { - l, err := b.prometheusListener() - if err != nil { - log.WithError(err).Fatal("configure prometheus listener") - } + b.StopAction = servers.GracefulStop - promMux := http.NewServeMux() - promMux.Handle("/metrics", promhttp.Handler()) + for _, c := range []starterConfig{ + {unix, config.Config.SocketPath}, + {tcp, config.Config.ListenAddr}, + {tls, config.Config.TLSListenAddr}, + } { + if c.addr == "" { + continue + } - server.AddPprofHandlers(promMux) + b.RegisterStarter(gitalyStarter(c, servers)) + } - go func() { - err = http.Serve(l, promMux) + if addr := config.Config.PrometheusListenAddr; addr != "" { + b.RegisterStarter(func(listen bootstrap.ListenFunc, _ chan<- error) error { + l, err := listen("tcp", addr) if err != nil { - log.WithError(err).Fatal("Unable to serve prometheus") + return err } - }() + + log.WithField("address", addr).Info("starting prometheus listener") + + promMux := http.NewServeMux() + promMux.Handle("/metrics", promhttp.Handler()) + + server.AddPprofHandlers(promMux) + + go func() { + if err := http.Serve(l, promMux); err != nil { + log.WithError(err).Error("Unable to serve prometheus") + } + }() + + return nil + }) } - b.run() + if err := b.Start(); err != nil { + return fmt.Errorf("unable to start the bootstrap: %v", err) + } - log.Fatal("shutting down") + return b.Wait() } diff --git a/cmd/gitaly/starter_config.go b/cmd/gitaly/starter_config.go new file mode 100644 index 000000000..889855815 --- /dev/null +++ b/cmd/gitaly/starter_config.go @@ -0,0 +1,45 @@ +package main + +import ( + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/bootstrap" +) + +const ( + tcp string = "tcp" + tls string = "tls" + unix string = "unix" +) + +type starterConfig struct { + name, addr string +} + +func (s *starterConfig) isSecure() bool { + return s.name == tls +} + +func (s *starterConfig) family() string { + if s.isSecure() { + return tcp + } + + return s.name +} + +func gitalyStarter(cfg starterConfig, servers bootstrap.GracefulStoppableServer) bootstrap.Starter { + return func(listen bootstrap.ListenFunc, errCh chan<- error) error { + l, err := listen(cfg.family(), cfg.addr) + if err != nil { + return err + } + + logrus.WithField("address", cfg.addr).Infof("listening at %s address", cfg.name) + + go func() { + errCh <- servers.Serve(l, cfg.isSecure()) + }() + + return nil + } +} diff --git a/cmd/gitaly/starter_config_test.go b/cmd/gitaly/starter_config_test.go new file mode 100644 index 000000000..377880f25 --- /dev/null +++ b/cmd/gitaly/starter_config_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsSecure(t *testing.T) { + for _, test := range []struct { + name string + secure bool + }{ + {"tcp", false}, + {"unix", false}, + {"tls", true}, + } { + t.Run(test.name, func(t *testing.T) { + conf := starterConfig{name: test.name} + require.Equal(t, test.secure, conf.isSecure()) + }) + } +} + +func TestFamily(t *testing.T) { + for _, test := range []struct { + name, family string + }{ + {"tcp", "tcp"}, + {"unix", "unix"}, + {"tls", "tcp"}, + } { + t.Run(test.name, func(t *testing.T) { + conf := starterConfig{name: test.name} + require.Equal(t, test.family, conf.family()) + }) + } +} diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go new file mode 100644 index 000000000..bfb58066c --- /dev/null +++ b/internal/bootstrap/bootstrap.go @@ -0,0 +1,180 @@ +package bootstrap + +import ( + "fmt" + "net" + "os" + "os/signal" + "syscall" + "time" + + "github.com/cloudflare/tableflip" + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/config" +) + +// Bootstrap handles graceful upgrades +type Bootstrap struct { + // StopAction will be invoked during a graceful stop. It must wait until the shutdown is completed + StopAction func() + + upgrader upgrader + listenFunc ListenFunc + errChan chan error + starters []Starter +} + +type upgrader interface { + Exit() <-chan struct{} + HasParent() bool + Ready() error + Upgrade() error +} + +// New performs tableflip initialization +// pidFile is optional, if provided it will always contain the current process PID +// upgradesEnabled controls the upgrade process on SIGHUP signal +// +// first boot: +// * gitaly starts as usual, we will refer to it as p1 +// * New will build a tableflip.Upgrader, we will refer to it as upg +// * sockets and files must be opened with upg.Fds +// * p1 will trap SIGHUP and invoke upg.Upgrade() +// * when ready to accept incoming connections p1 will call upg.Ready() +// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate +// +// graceful upgrade: +// * user replaces gitaly binary and/or config file +// * user sends SIGHUP to p1 +// * p1 will fork and exec the new gitaly, we will refer to it as p2 +// * from now on p1 will ignore other SIGHUP +// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored +// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available +// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed +// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections +// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome +// freezes during a graceful shutdown +func New(pidFile string, upgradesEnabled bool) (*Bootstrap, error) { + // PIDFile is optional, if provided tableflip will keep it updated + upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile}) + if err != nil { + return nil, err + } + + return _new(upg, upg.Fds.Listen, upgradesEnabled) +} + +func _new(upg upgrader, listenFunc ListenFunc, upgradesEnabled bool) (*Bootstrap, error) { + if upgradesEnabled { + go func() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGHUP) + + for range sig { + err := upg.Upgrade() + if err != nil { + log.WithError(err).Error("Upgrade failed") + continue + } + + log.Info("Upgrade succeeded") + } + }() + } + + return &Bootstrap{ + upgrader: upg, + listenFunc: listenFunc, + }, nil +} + +// ListenFunc is a net.Listener factory +type ListenFunc func(net, addr string) (net.Listener, error) + +// Starter is function to initialize a net.Listener +// it receives a ListenFunc to be used for net.Listener creation and a chan<- error to signal runtime errors +// It must serve incoming connections asynchronously and signal errors on the channel +// the return value is for setup errors +type Starter func(ListenFunc, chan<- error) error + +func (b *Bootstrap) isFirstBoot() bool { return !b.upgrader.HasParent() } + +// RegisterStarter adds a new starter +func (b *Bootstrap) RegisterStarter(starter Starter) { + b.starters = append(b.starters, starter) +} + +// Start will invoke all the registered starters and wait asynchronously for runtime errors +// in case a Starter fails then the error is returned and the function is aborted +func (b *Bootstrap) Start() error { + b.errChan = make(chan error, len(b.starters)) + + for _, start := range b.starters { + if err := start(b.listen, b.errChan); err != nil { + return err + } + } + + return nil +} + +// Wait will signal process readiness to the parent and than wait for an exit condition +// SIGTERM, SIGINT and a runtime error will trigger an immediate shutdown +// in case of an upgrade there will be a grace period to complete the ongoing requests +func (b *Bootstrap) Wait() error { + signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT} + immediateShutdown := make(chan os.Signal, len(signals)) + signal.Notify(immediateShutdown, signals...) + + if err := b.upgrader.Ready(); err != nil { + return err + } + + var err error + select { + case <-b.upgrader.Exit(): + // this is the old process and a graceful upgrade is in progress + // the new process signaled its readiness and we started a graceful stop + // however no further upgrades can be started until this process is running + // we set a grace period and then we force a termination. + waitError := b.waitGracePeriod(immediateShutdown) + + err = fmt.Errorf("graceful upgrade: %v", waitError) + case s := <-immediateShutdown: + err = fmt.Errorf("received signal %q", s) + case err = <-b.errChan: + } + + return err +} + +func (b *Bootstrap) waitGracePeriod(kill <-chan os.Signal) error { + log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period") + + allServersDone := make(chan struct{}) + go func() { + if b.StopAction != nil { + b.StopAction() + } + close(allServersDone) + }() + + select { + case <-time.After(config.Config.GracefulRestartTimeout): + return fmt.Errorf("grace period expired") + case <-kill: + return fmt.Errorf("force shutdown") + case <-allServersDone: + return fmt.Errorf("completed") + } +} + +func (b *Bootstrap) listen(network, path string) (net.Listener, error) { + if network == "unix" && b.isFirstBoot() { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return nil, err + } + } + + return b.listenFunc(network, path) +} diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go new file mode 100644 index 000000000..78cba7ac9 --- /dev/null +++ b/internal/bootstrap/bootstrap_test.go @@ -0,0 +1,324 @@ +package bootstrap + +import ( + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "path" + "strconv" + "syscall" + "testing" + "time" + + "gitlab.com/gitlab-org/gitaly/internal/config" + + "github.com/stretchr/testify/require" +) + +var testConfigGracefulRestartTimeout = 2 * time.Second + +type mockUpgrader struct { + exit chan struct{} + hasParent bool +} + +func (m *mockUpgrader) Exit() <-chan struct{} { + return m.exit +} + +func (m *mockUpgrader) HasParent() bool { + return m.hasParent +} + +func (m *mockUpgrader) Ready() error { return nil } + +func (m *mockUpgrader) Upgrade() error { + // to upgrade we close the exit channel + close(m.exit) + return nil +} + +type testServer struct { + server *http.Server + listeners map[string]net.Listener + url string +} + +func (s *testServer) slowRequest(duration time.Duration) <-chan error { + done := make(chan error) + + go func() { + r, err := http.Get(fmt.Sprintf("%sslow?seconds=%d", s.url, int(duration.Seconds()))) + if r != nil { + r.Body.Close() + } + + done <- err + }() + + return done +} + +func TestCreateUnixListener(t *testing.T) { + socketPath := path.Join(os.TempDir(), "gitaly-test-unix-socket") + if err := os.Remove(socketPath); err != nil { + require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err) + } + + // simulate a dangling socket + require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755)) + + listen := func(network, addr string) (net.Listener, error) { + require.Equal(t, "unix", network) + require.Equal(t, socketPath, addr) + + return net.Listen(network, addr) + } + u := &mockUpgrader{} + b, err := _new(u, listen, false) + require.NoError(t, err) + + // first boot + l, err := b.listen("unix", socketPath) + require.NoError(t, err, "failed to bind on first boot") + require.NoError(t, l.Close()) + + // simulate binding during an upgrade + u.hasParent = true + l, err = b.listen("unix", socketPath) + require.NoError(t, err, "failed to bind on upgrade") + require.NoError(t, l.Close()) +} + +func waitWithTimeout(t *testing.T, waitCh <-chan error, timeout time.Duration) error { + select { + case <-time.After(timeout): + t.Fatal("time out waiting for waitCh") + case waitErr := <-waitCh: + return waitErr + } + + return nil +} + +func TestImmediateTerminationOnSocketError(t *testing.T) { + b, server := makeBootstrap(t) + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + require.NoError(t, server.listeners["tcp"].Close(), "Closing first listener") + + err := waitWithTimeout(t, waitCh, 1*time.Second) + require.Error(t, err) + require.Contains(t, err.Error(), "use of closed network connection") +} + +func TestImmediateTerminationOnSignal(t *testing.T) { + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + t.Run(sig.String(), func(t *testing.T) { + b, server := makeBootstrap(t) + + done := server.slowRequest(3 * time.Minute) + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + // make sure we are inside b.Wait() or we'll kill the test suite + time.Sleep(100 * time.Millisecond) + + self, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + require.NoError(t, self.Signal(sig)) + + waitErr := waitWithTimeout(t, waitCh, 1*time.Second) + require.Error(t, waitErr) + require.Contains(t, waitErr.Error(), "received signal") + require.Contains(t, waitErr.Error(), sig.String()) + + server.server.Close() + + require.Error(t, <-done) + }) + } +} + +func TestGracefulTerminationStuck(t *testing.T) { + b, server := makeBootstrap(t) + + err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil) + require.Contains(t, err.Error(), "grace period expired") +} + +func TestGracefulTerminationWithSignals(t *testing.T) { + self, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + t.Run(sig.String(), func(t *testing.T) { + b, server := makeBootstrap(t) + + err := testGracefulUpdate(t, server, b, 1*time.Second, func() { + require.NoError(t, self.Signal(sig)) + }) + require.Contains(t, err.Error(), "force shutdown") + }) + } +} + +func TestGracefulTerminationServerErrors(t *testing.T) { + b, server := makeBootstrap(t) + + done := make(chan error, 1) + // This is a simulation of receiving a listener error during waitGracePeriod + b.StopAction = func() { + // we close the unix listener in order to test that the shutdown will not fail, but it keep waiting for the TCP request + require.NoError(t, server.listeners["unix"].Close()) + + // we start a new TCP request that if faster than the grace period + req := server.slowRequest(config.Config.GracefulRestartTimeout / 2) + done <- <-req + close(done) + + server.server.Shutdown(context.Background()) + } + + err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil) + require.Contains(t, err.Error(), "grace period expired") + + require.NoError(t, <-done) +} + +func TestGracefulTermination(t *testing.T) { + b, server := makeBootstrap(t) + + // Using server.Close we bypass the graceful shutdown faking a completed shutdown + b.StopAction = func() { server.server.Close() } + + err := testGracefulUpdate(t, server, b, 1*time.Second, nil) + require.Contains(t, err.Error(), "completed") +} + +func testGracefulUpdate(t *testing.T, server *testServer, b *Bootstrap, waitTimeout time.Duration, duringGracePeriodCallback func()) error { + defer func(oldVal time.Duration) { + config.Config.GracefulRestartTimeout = oldVal + }(config.Config.GracefulRestartTimeout) + config.Config.GracefulRestartTimeout = testConfigGracefulRestartTimeout + + waitCh := make(chan error) + go func() { waitCh <- b.Wait() }() + + // Start a slow request to keep the old server from shutting down immediately. + req := server.slowRequest(2 * config.Config.GracefulRestartTimeout) + + // make sure slow request is being handled + time.Sleep(100 * time.Millisecond) + + // Simulate an upgrade request after entering into the blocking b.Wait() and during the slowRequest execution + b.upgrader.Upgrade() + + if duringGracePeriodCallback != nil { + // make sure we are on the grace period + time.Sleep(100 * time.Millisecond) + + duringGracePeriodCallback() + } + + waitErr := waitWithTimeout(t, waitCh, waitTimeout) + require.Error(t, waitErr) + require.Contains(t, waitErr.Error(), "graceful upgrade") + + server.server.Close() + + clientErr := waitWithTimeout(t, req, 1*time.Second) + require.Error(t, clientErr, "slow request not terminated after the grace period") + + return waitErr +} + +func makeBootstrap(t *testing.T) (*Bootstrap, *testServer) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + }) + mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { + sec, err := strconv.Atoi(r.URL.Query().Get("seconds")) + require.NoError(t, err) + + t.Logf("Serving a slow request for %d seconds", sec) + time.Sleep(time.Duration(sec) * time.Second) + + w.WriteHeader(200) + }) + + s := http.Server{Handler: mux} + u := &mockUpgrader{exit: make(chan struct{})} + + b, err := _new(u, net.Listen, false) + require.NoError(t, err) + + b.StopAction = func() { s.Shutdown(context.Background()) } + + listeners := make(map[string]net.Listener) + start := func(network, address string) Starter { + return func(listen ListenFunc, errors chan<- error) error { + l, err := listen(network, address) + if err != nil { + return err + } + listeners[network] = l + + go func() { + errors <- s.Serve(l) + }() + + return nil + } + } + + for network, address := range map[string]string{ + "tcp": "127.0.0.1:0", + "unix": path.Join(os.TempDir(), "gitaly-test-unix-socket"), + } { + b.RegisterStarter(start(network, address)) + } + + require.NoError(t, b.Start()) + require.Equal(t, 2, len(listeners)) + + // test connection + testAllListeners(t, listeners) + + addr := listeners["tcp"].Addr() + url := fmt.Sprintf("http://%s/", addr.String()) + + return b, &testServer{ + server: &s, + listeners: listeners, + url: url, + } +} + +func testAllListeners(t *testing.T, listeners map[string]net.Listener) { + for network, listener := range listeners { + addr := listener.Addr().String() + + // overriding Client.Transport.Dial we can connect to TCP and UNIX sockets + client := &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial(network, addr) + }, + }, + } + + // we don't need a real address because we forced it on Dial + r, err := client.Get("http://fakeHost/") + require.NoError(t, err) + r.Body.Close() + require.Equal(t, 200, r.StatusCode) + } +} diff --git a/internal/bootstrap/server_factory.go b/internal/bootstrap/server_factory.go new file mode 100644 index 000000000..f1bfa7624 --- /dev/null +++ b/internal/bootstrap/server_factory.go @@ -0,0 +1,93 @@ +package bootstrap + +import ( + "net" + "sync" + + log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/internal/rubyserver" + "gitlab.com/gitlab-org/gitaly/internal/server" + "google.golang.org/grpc" +) + +type serverFactory struct { + ruby *rubyserver.Server + secure, insecure *grpc.Server +} + +// GracefulStoppableServer allows to serve contents on a net.Listener, Stop serving and performing a GracefulStop +type GracefulStoppableServer interface { + GracefulStop() + Stop() + Serve(l net.Listener, secure bool) error +} + +// NewServerFactory initializes a rubyserver and then lazily initializes both secure and insecure grpc.Server +func NewServerFactory() (GracefulStoppableServer, error) { + ruby, err := rubyserver.Start() + if err != nil { + log.Error("start ruby server") + + return nil, err + } + + return &serverFactory{ruby: ruby}, nil +} + +func (s *serverFactory) Stop() { + for _, srv := range s.all() { + srv.Stop() + } + + s.ruby.Stop() +} + +func (s *serverFactory) GracefulStop() { + wg := sync.WaitGroup{} + + for _, srv := range s.all() { + wg.Add(1) + + go func(s *grpc.Server) { + s.GracefulStop() + wg.Done() + }(srv) + } + + wg.Wait() +} + +func (s *serverFactory) Serve(l net.Listener, secure bool) error { + srv := s.get(secure) + + return srv.Serve(l) +} + +func (s *serverFactory) get(secure bool) *grpc.Server { + if secure { + if s.secure == nil { + s.secure = server.NewSecure(s.ruby) + } + + return s.secure + } + + if s.insecure == nil { + s.insecure = server.NewInsecure(s.ruby) + } + + return s.insecure +} + +func (s *serverFactory) all() []*grpc.Server { + var servers []*grpc.Server + if s.secure != nil { + servers = append(servers, s.secure) + } + + if s.insecure != nil { + servers = append(servers, s.insecure) + } + + return servers +} |