Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlessio Caiazza <acaiazza@gitlab.com>2019-06-04 12:02:54 +0300
committerJacob Vosmaer <jacob@gitlab.com>2019-06-04 12:02:54 +0300
commit8065040919302aa6cdefc978078a95610a0747f5 (patch)
treeae4660e3a808b001f0c7e61b5428bc6338ba1776 /internal/bootstrap
parent9855c59db137a607a60366d1ad23a3834a2e9de7 (diff)
Wait for all the socket on graceful restart
We terminate gitaly on the first failure from a listening socket, but on a graceful restart we must wait for all the socket to properly terminate active connections. This is a complete refactoring that introduces the bootstrap package with a proper test coverage
Diffstat (limited to 'internal/bootstrap')
-rw-r--r--internal/bootstrap/bootstrap.go180
-rw-r--r--internal/bootstrap/bootstrap_test.go324
-rw-r--r--internal/bootstrap/server_factory.go93
3 files changed, 597 insertions, 0 deletions
diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go
new file mode 100644
index 000000000..bfb58066c
--- /dev/null
+++ b/internal/bootstrap/bootstrap.go
@@ -0,0 +1,180 @@
+package bootstrap
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/cloudflare/tableflip"
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+)
+
+// Bootstrap handles graceful upgrades
+type Bootstrap struct {
+ // StopAction will be invoked during a graceful stop. It must wait until the shutdown is completed
+ StopAction func()
+
+ upgrader upgrader
+ listenFunc ListenFunc
+ errChan chan error
+ starters []Starter
+}
+
+type upgrader interface {
+ Exit() <-chan struct{}
+ HasParent() bool
+ Ready() error
+ Upgrade() error
+}
+
+// New performs tableflip initialization
+// pidFile is optional, if provided it will always contain the current process PID
+// upgradesEnabled controls the upgrade process on SIGHUP signal
+//
+// first boot:
+// * gitaly starts as usual, we will refer to it as p1
+// * New will build a tableflip.Upgrader, we will refer to it as upg
+// * sockets and files must be opened with upg.Fds
+// * p1 will trap SIGHUP and invoke upg.Upgrade()
+// * when ready to accept incoming connections p1 will call upg.Ready()
+// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate
+//
+// graceful upgrade:
+// * user replaces gitaly binary and/or config file
+// * user sends SIGHUP to p1
+// * p1 will fork and exec the new gitaly, we will refer to it as p2
+// * from now on p1 will ignore other SIGHUP
+// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored
+// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available
+// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed
+// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections
+// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome
+// freezes during a graceful shutdown
+func New(pidFile string, upgradesEnabled bool) (*Bootstrap, error) {
+ // PIDFile is optional, if provided tableflip will keep it updated
+ upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile})
+ if err != nil {
+ return nil, err
+ }
+
+ return _new(upg, upg.Fds.Listen, upgradesEnabled)
+}
+
+func _new(upg upgrader, listenFunc ListenFunc, upgradesEnabled bool) (*Bootstrap, error) {
+ if upgradesEnabled {
+ go func() {
+ sig := make(chan os.Signal, 1)
+ signal.Notify(sig, syscall.SIGHUP)
+
+ for range sig {
+ err := upg.Upgrade()
+ if err != nil {
+ log.WithError(err).Error("Upgrade failed")
+ continue
+ }
+
+ log.Info("Upgrade succeeded")
+ }
+ }()
+ }
+
+ return &Bootstrap{
+ upgrader: upg,
+ listenFunc: listenFunc,
+ }, nil
+}
+
+// ListenFunc is a net.Listener factory
+type ListenFunc func(net, addr string) (net.Listener, error)
+
+// Starter is function to initialize a net.Listener
+// it receives a ListenFunc to be used for net.Listener creation and a chan<- error to signal runtime errors
+// It must serve incoming connections asynchronously and signal errors on the channel
+// the return value is for setup errors
+type Starter func(ListenFunc, chan<- error) error
+
+func (b *Bootstrap) isFirstBoot() bool { return !b.upgrader.HasParent() }
+
+// RegisterStarter adds a new starter
+func (b *Bootstrap) RegisterStarter(starter Starter) {
+ b.starters = append(b.starters, starter)
+}
+
+// Start will invoke all the registered starters and wait asynchronously for runtime errors
+// in case a Starter fails then the error is returned and the function is aborted
+func (b *Bootstrap) Start() error {
+ b.errChan = make(chan error, len(b.starters))
+
+ for _, start := range b.starters {
+ if err := start(b.listen, b.errChan); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Wait will signal process readiness to the parent and than wait for an exit condition
+// SIGTERM, SIGINT and a runtime error will trigger an immediate shutdown
+// in case of an upgrade there will be a grace period to complete the ongoing requests
+func (b *Bootstrap) Wait() error {
+ signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT}
+ immediateShutdown := make(chan os.Signal, len(signals))
+ signal.Notify(immediateShutdown, signals...)
+
+ if err := b.upgrader.Ready(); err != nil {
+ return err
+ }
+
+ var err error
+ select {
+ case <-b.upgrader.Exit():
+ // this is the old process and a graceful upgrade is in progress
+ // the new process signaled its readiness and we started a graceful stop
+ // however no further upgrades can be started until this process is running
+ // we set a grace period and then we force a termination.
+ waitError := b.waitGracePeriod(immediateShutdown)
+
+ err = fmt.Errorf("graceful upgrade: %v", waitError)
+ case s := <-immediateShutdown:
+ err = fmt.Errorf("received signal %q", s)
+ case err = <-b.errChan:
+ }
+
+ return err
+}
+
+func (b *Bootstrap) waitGracePeriod(kill <-chan os.Signal) error {
+ log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period")
+
+ allServersDone := make(chan struct{})
+ go func() {
+ if b.StopAction != nil {
+ b.StopAction()
+ }
+ close(allServersDone)
+ }()
+
+ select {
+ case <-time.After(config.Config.GracefulRestartTimeout):
+ return fmt.Errorf("grace period expired")
+ case <-kill:
+ return fmt.Errorf("force shutdown")
+ case <-allServersDone:
+ return fmt.Errorf("completed")
+ }
+}
+
+func (b *Bootstrap) listen(network, path string) (net.Listener, error) {
+ if network == "unix" && b.isFirstBoot() {
+ if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+ }
+
+ return b.listenFunc(network, path)
+}
diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go
new file mode 100644
index 000000000..78cba7ac9
--- /dev/null
+++ b/internal/bootstrap/bootstrap_test.go
@@ -0,0 +1,324 @@
+package bootstrap
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+
+ "github.com/stretchr/testify/require"
+)
+
+var testConfigGracefulRestartTimeout = 2 * time.Second
+
+type mockUpgrader struct {
+ exit chan struct{}
+ hasParent bool
+}
+
+func (m *mockUpgrader) Exit() <-chan struct{} {
+ return m.exit
+}
+
+func (m *mockUpgrader) HasParent() bool {
+ return m.hasParent
+}
+
+func (m *mockUpgrader) Ready() error { return nil }
+
+func (m *mockUpgrader) Upgrade() error {
+ // to upgrade we close the exit channel
+ close(m.exit)
+ return nil
+}
+
+type testServer struct {
+ server *http.Server
+ listeners map[string]net.Listener
+ url string
+}
+
+func (s *testServer) slowRequest(duration time.Duration) <-chan error {
+ done := make(chan error)
+
+ go func() {
+ r, err := http.Get(fmt.Sprintf("%sslow?seconds=%d", s.url, int(duration.Seconds())))
+ if r != nil {
+ r.Body.Close()
+ }
+
+ done <- err
+ }()
+
+ return done
+}
+
+func TestCreateUnixListener(t *testing.T) {
+ socketPath := path.Join(os.TempDir(), "gitaly-test-unix-socket")
+ if err := os.Remove(socketPath); err != nil {
+ require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err)
+ }
+
+ // simulate a dangling socket
+ require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755))
+
+ listen := func(network, addr string) (net.Listener, error) {
+ require.Equal(t, "unix", network)
+ require.Equal(t, socketPath, addr)
+
+ return net.Listen(network, addr)
+ }
+ u := &mockUpgrader{}
+ b, err := _new(u, listen, false)
+ require.NoError(t, err)
+
+ // first boot
+ l, err := b.listen("unix", socketPath)
+ require.NoError(t, err, "failed to bind on first boot")
+ require.NoError(t, l.Close())
+
+ // simulate binding during an upgrade
+ u.hasParent = true
+ l, err = b.listen("unix", socketPath)
+ require.NoError(t, err, "failed to bind on upgrade")
+ require.NoError(t, l.Close())
+}
+
+func waitWithTimeout(t *testing.T, waitCh <-chan error, timeout time.Duration) error {
+ select {
+ case <-time.After(timeout):
+ t.Fatal("time out waiting for waitCh")
+ case waitErr := <-waitCh:
+ return waitErr
+ }
+
+ return nil
+}
+
+func TestImmediateTerminationOnSocketError(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ require.NoError(t, server.listeners["tcp"].Close(), "Closing first listener")
+
+ err := waitWithTimeout(t, waitCh, 1*time.Second)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "use of closed network connection")
+}
+
+func TestImmediateTerminationOnSignal(t *testing.T) {
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ t.Run(sig.String(), func(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ done := server.slowRequest(3 * time.Minute)
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ // make sure we are inside b.Wait() or we'll kill the test suite
+ time.Sleep(100 * time.Millisecond)
+
+ self, err := os.FindProcess(os.Getpid())
+ require.NoError(t, err)
+ require.NoError(t, self.Signal(sig))
+
+ waitErr := waitWithTimeout(t, waitCh, 1*time.Second)
+ require.Error(t, waitErr)
+ require.Contains(t, waitErr.Error(), "received signal")
+ require.Contains(t, waitErr.Error(), sig.String())
+
+ server.server.Close()
+
+ require.Error(t, <-done)
+ })
+ }
+}
+
+func TestGracefulTerminationStuck(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil)
+ require.Contains(t, err.Error(), "grace period expired")
+}
+
+func TestGracefulTerminationWithSignals(t *testing.T) {
+ self, err := os.FindProcess(os.Getpid())
+ require.NoError(t, err)
+
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ t.Run(sig.String(), func(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ err := testGracefulUpdate(t, server, b, 1*time.Second, func() {
+ require.NoError(t, self.Signal(sig))
+ })
+ require.Contains(t, err.Error(), "force shutdown")
+ })
+ }
+}
+
+func TestGracefulTerminationServerErrors(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ done := make(chan error, 1)
+ // This is a simulation of receiving a listener error during waitGracePeriod
+ b.StopAction = func() {
+ // we close the unix listener in order to test that the shutdown will not fail, but it keep waiting for the TCP request
+ require.NoError(t, server.listeners["unix"].Close())
+
+ // we start a new TCP request that if faster than the grace period
+ req := server.slowRequest(config.Config.GracefulRestartTimeout / 2)
+ done <- <-req
+ close(done)
+
+ server.server.Shutdown(context.Background())
+ }
+
+ err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil)
+ require.Contains(t, err.Error(), "grace period expired")
+
+ require.NoError(t, <-done)
+}
+
+func TestGracefulTermination(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ // Using server.Close we bypass the graceful shutdown faking a completed shutdown
+ b.StopAction = func() { server.server.Close() }
+
+ err := testGracefulUpdate(t, server, b, 1*time.Second, nil)
+ require.Contains(t, err.Error(), "completed")
+}
+
+func testGracefulUpdate(t *testing.T, server *testServer, b *Bootstrap, waitTimeout time.Duration, duringGracePeriodCallback func()) error {
+ defer func(oldVal time.Duration) {
+ config.Config.GracefulRestartTimeout = oldVal
+ }(config.Config.GracefulRestartTimeout)
+ config.Config.GracefulRestartTimeout = testConfigGracefulRestartTimeout
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ // Start a slow request to keep the old server from shutting down immediately.
+ req := server.slowRequest(2 * config.Config.GracefulRestartTimeout)
+
+ // make sure slow request is being handled
+ time.Sleep(100 * time.Millisecond)
+
+ // Simulate an upgrade request after entering into the blocking b.Wait() and during the slowRequest execution
+ b.upgrader.Upgrade()
+
+ if duringGracePeriodCallback != nil {
+ // make sure we are on the grace period
+ time.Sleep(100 * time.Millisecond)
+
+ duringGracePeriodCallback()
+ }
+
+ waitErr := waitWithTimeout(t, waitCh, waitTimeout)
+ require.Error(t, waitErr)
+ require.Contains(t, waitErr.Error(), "graceful upgrade")
+
+ server.server.Close()
+
+ clientErr := waitWithTimeout(t, req, 1*time.Second)
+ require.Error(t, clientErr, "slow request not terminated after the grace period")
+
+ return waitErr
+}
+
+func makeBootstrap(t *testing.T) (*Bootstrap, *testServer) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(200)
+ })
+ mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) {
+ sec, err := strconv.Atoi(r.URL.Query().Get("seconds"))
+ require.NoError(t, err)
+
+ t.Logf("Serving a slow request for %d seconds", sec)
+ time.Sleep(time.Duration(sec) * time.Second)
+
+ w.WriteHeader(200)
+ })
+
+ s := http.Server{Handler: mux}
+ u := &mockUpgrader{exit: make(chan struct{})}
+
+ b, err := _new(u, net.Listen, false)
+ require.NoError(t, err)
+
+ b.StopAction = func() { s.Shutdown(context.Background()) }
+
+ listeners := make(map[string]net.Listener)
+ start := func(network, address string) Starter {
+ return func(listen ListenFunc, errors chan<- error) error {
+ l, err := listen(network, address)
+ if err != nil {
+ return err
+ }
+ listeners[network] = l
+
+ go func() {
+ errors <- s.Serve(l)
+ }()
+
+ return nil
+ }
+ }
+
+ for network, address := range map[string]string{
+ "tcp": "127.0.0.1:0",
+ "unix": path.Join(os.TempDir(), "gitaly-test-unix-socket"),
+ } {
+ b.RegisterStarter(start(network, address))
+ }
+
+ require.NoError(t, b.Start())
+ require.Equal(t, 2, len(listeners))
+
+ // test connection
+ testAllListeners(t, listeners)
+
+ addr := listeners["tcp"].Addr()
+ url := fmt.Sprintf("http://%s/", addr.String())
+
+ return b, &testServer{
+ server: &s,
+ listeners: listeners,
+ url: url,
+ }
+}
+
+func testAllListeners(t *testing.T, listeners map[string]net.Listener) {
+ for network, listener := range listeners {
+ addr := listener.Addr().String()
+
+ // overriding Client.Transport.Dial we can connect to TCP and UNIX sockets
+ client := &http.Client{
+ Transport: &http.Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ },
+ }
+
+ // we don't need a real address because we forced it on Dial
+ r, err := client.Get("http://fakeHost/")
+ require.NoError(t, err)
+ r.Body.Close()
+ require.Equal(t, 200, r.StatusCode)
+ }
+}
diff --git a/internal/bootstrap/server_factory.go b/internal/bootstrap/server_factory.go
new file mode 100644
index 000000000..f1bfa7624
--- /dev/null
+++ b/internal/bootstrap/server_factory.go
@@ -0,0 +1,93 @@
+package bootstrap
+
+import (
+ "net"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/rubyserver"
+ "gitlab.com/gitlab-org/gitaly/internal/server"
+ "google.golang.org/grpc"
+)
+
+type serverFactory struct {
+ ruby *rubyserver.Server
+ secure, insecure *grpc.Server
+}
+
+// GracefulStoppableServer allows to serve contents on a net.Listener, Stop serving and performing a GracefulStop
+type GracefulStoppableServer interface {
+ GracefulStop()
+ Stop()
+ Serve(l net.Listener, secure bool) error
+}
+
+// NewServerFactory initializes a rubyserver and then lazily initializes both secure and insecure grpc.Server
+func NewServerFactory() (GracefulStoppableServer, error) {
+ ruby, err := rubyserver.Start()
+ if err != nil {
+ log.Error("start ruby server")
+
+ return nil, err
+ }
+
+ return &serverFactory{ruby: ruby}, nil
+}
+
+func (s *serverFactory) Stop() {
+ for _, srv := range s.all() {
+ srv.Stop()
+ }
+
+ s.ruby.Stop()
+}
+
+func (s *serverFactory) GracefulStop() {
+ wg := sync.WaitGroup{}
+
+ for _, srv := range s.all() {
+ wg.Add(1)
+
+ go func(s *grpc.Server) {
+ s.GracefulStop()
+ wg.Done()
+ }(srv)
+ }
+
+ wg.Wait()
+}
+
+func (s *serverFactory) Serve(l net.Listener, secure bool) error {
+ srv := s.get(secure)
+
+ return srv.Serve(l)
+}
+
+func (s *serverFactory) get(secure bool) *grpc.Server {
+ if secure {
+ if s.secure == nil {
+ s.secure = server.NewSecure(s.ruby)
+ }
+
+ return s.secure
+ }
+
+ if s.insecure == nil {
+ s.insecure = server.NewInsecure(s.ruby)
+ }
+
+ return s.insecure
+}
+
+func (s *serverFactory) all() []*grpc.Server {
+ var servers []*grpc.Server
+ if s.secure != nil {
+ servers = append(servers, s.secure)
+ }
+
+ if s.insecure != nil {
+ servers = append(servers, s.insecure)
+ }
+
+ return servers
+}