Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2019-06-04 12:02:54 +0300
committerJacob Vosmaer <jacob@gitlab.com>2019-06-04 12:02:54 +0300
commit95d95ecdcb7e84223c9fa9ac2a37cbd6bd246fae (patch)
treeae4660e3a808b001f0c7e61b5428bc6338ba1776
parent9855c59db137a607a60366d1ad23a3834a2e9de7 (diff)
parent8065040919302aa6cdefc978078a95610a0747f5 (diff)
Merge branch 'more-graceful' into 'master'
Wait for all the sockets on graceful restart Closes #1610 See merge request gitlab-org/gitaly!1190
-rw-r--r--changelogs/unreleased/more-graceful.yml5
-rw-r--r--cmd/gitaly/bootstrap.go205
-rw-r--r--cmd/gitaly/bootstrap_test.go141
-rw-r--r--cmd/gitaly/main.go65
-rw-r--r--cmd/gitaly/starter_config.go45
-rw-r--r--cmd/gitaly/starter_config_test.go38
-rw-r--r--internal/bootstrap/bootstrap.go180
-rw-r--r--internal/bootstrap/bootstrap_test.go324
-rw-r--r--internal/bootstrap/server_factory.go93
9 files changed, 732 insertions, 364 deletions
diff --git a/changelogs/unreleased/more-graceful.yml b/changelogs/unreleased/more-graceful.yml
new file mode 100644
index 000000000..0b1d1e917
--- /dev/null
+++ b/changelogs/unreleased/more-graceful.yml
@@ -0,0 +1,5 @@
+---
+title: Wait for all the socket to terminate during a graceful restart
+merge_request: 1190
+author:
+type: fixed
diff --git a/cmd/gitaly/bootstrap.go b/cmd/gitaly/bootstrap.go
deleted file mode 100644
index d3b464052..000000000
--- a/cmd/gitaly/bootstrap.go
+++ /dev/null
@@ -1,205 +0,0 @@
-package main
-
-import (
- "fmt"
- "net"
- "os"
- "os/signal"
- "syscall"
- "time"
-
- "github.com/cloudflare/tableflip"
- log "github.com/sirupsen/logrus"
- "gitlab.com/gitlab-org/gitaly/internal/config"
- "gitlab.com/gitlab-org/gitaly/internal/connectioncounter"
- "gitlab.com/gitlab-org/gitaly/internal/rubyserver"
- "gitlab.com/gitlab-org/gitaly/internal/server"
- "google.golang.org/grpc"
-)
-
-type bootstrap struct {
- *tableflip.Upgrader
-
- insecureListeners []net.Listener
- secureListeners []net.Listener
-
- serversErrors chan error
-}
-
-// newBootstrap performs tableflip initialization
-//
-// first boot:
-// * gitaly starts as usual, we will refer to it as p1
-// * newBootstrap will build a tableflip.Upgrader, we will refer to it as upg
-// * sockets and files must be opened with upg.Fds
-// * p1 will trap SIGHUP and invoke upg.Upgrade()
-// * when ready to accept incoming connections p1 will call upg.Ready()
-// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate
-//
-// graceful upgrade:
-// * user replaces gitaly binary and/or config file
-// * user sends SIGHUP to p1
-// * p1 will fork and exec the new gitaly, we will refer to it as p2
-// * from now on p1 will ignore other SIGHUP
-// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored
-// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available
-// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed
-// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections
-// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome
-// freezes during a graceful shutdown
-func newBootstrap(pidFile string, upgradesEnabled bool) (*bootstrap, error) {
- // PIDFile is optional, if provided tableflip will keep it updated
- upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile})
- if err != nil {
- return nil, err
- }
-
- if upgradesEnabled {
- go func() {
- sig := make(chan os.Signal, 1)
- signal.Notify(sig, syscall.SIGHUP)
-
- for range sig {
- err := upg.Upgrade()
- if err != nil {
- log.WithError(err).Error("Upgrade failed")
- continue
- }
-
- log.Info("Upgrade succeeded")
- }
- }()
- }
-
- return &bootstrap{Upgrader: upg}, nil
-}
-
-func (b *bootstrap) listen() error {
- if socketPath := config.Config.SocketPath; socketPath != "" {
- l, err := b.createUnixListener(socketPath)
- if err != nil {
- return err
- }
-
- log.WithField("address", socketPath).Info("listening on unix socket")
- b.insecureListeners = append(b.insecureListeners, l)
- }
-
- if addr := config.Config.ListenAddr; addr != "" {
- l, err := b.Fds.Listen("tcp", addr)
- if err != nil {
- return err
- }
-
- log.WithField("address", addr).Info("listening at tcp address")
- b.insecureListeners = append(b.insecureListeners, connectioncounter.New("tcp", l))
- }
-
- if addr := config.Config.TLSListenAddr; addr != "" {
- tlsListener, err := b.Fds.Listen("tcp", addr)
- if err != nil {
- return err
- }
-
- b.secureListeners = append(b.secureListeners, connectioncounter.New("tls", tlsListener))
- }
-
- b.serversErrors = make(chan error, len(b.insecureListeners)+len(b.secureListeners))
-
- return nil
-}
-
-func (b *bootstrap) prometheusListener() (net.Listener, error) {
- log.WithField("address", config.Config.PrometheusListenAddr).Info("starting prometheus listener")
-
- return b.Fds.Listen("tcp", config.Config.PrometheusListenAddr)
-}
-
-func (b *bootstrap) run() {
- signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT}
- done := make(chan os.Signal, len(signals))
- signal.Notify(done, signals...)
-
- ruby, err := rubyserver.Start()
- if err != nil {
- log.WithError(err).Error("start ruby server")
- return
- }
- defer ruby.Stop()
-
- if len(b.insecureListeners) > 0 {
- insecureServer := server.NewInsecure(ruby)
- defer insecureServer.Stop()
-
- serve(insecureServer, b.insecureListeners, b.Exit(), b.serversErrors)
- }
-
- if len(b.secureListeners) > 0 {
- secureServer := server.NewSecure(ruby)
- defer secureServer.Stop()
-
- serve(secureServer, b.secureListeners, b.Exit(), b.serversErrors)
- }
-
- if err := b.Ready(); err != nil {
- log.WithError(err).Error("incomplete bootstrap")
- return
- }
-
- select {
- case <-b.Exit():
- // this is the old process and a graceful upgrade is in progress
- // the new process signaled its readiness and we started a graceful stop
- // however no further upgrades can be started until this process is running
- // we set a grace period and then we force a termination.
- b.waitGracePeriod(done)
-
- err = fmt.Errorf("graceful upgrade")
- case s := <-done:
- err = fmt.Errorf("received signal %q", s)
- case err = <-b.serversErrors:
- }
-
- log.WithError(err).Error("terminating")
-}
-
-func (b *bootstrap) waitGracePeriod(kill <-chan os.Signal) {
- log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period")
-
- select {
- case <-time.After(config.Config.GracefulRestartTimeout):
- log.Error("old process stuck on termination. Grace period expired.")
- case <-kill:
- log.Error("force shutdown")
- case <-b.serversErrors:
- log.Info("graceful stop completed")
- }
-}
-
-func (b *bootstrap) createUnixListener(socketPath string) (net.Listener, error) {
- if !b.HasParent() {
- // During an update the unix socket exists and if we delete it tableflip will not create a new one
- if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
- return nil, err
- }
- }
-
- l, err := b.Fds.Listen("unix", socketPath)
- return connectioncounter.New("unix", l), err
-}
-
-func serve(server *grpc.Server, listeners []net.Listener, done <-chan struct{}, errors chan<- error) {
- go func() {
- <-done
-
- server.GracefulStop()
- }()
-
- for _, listener := range listeners {
- // Must pass the listener as a function argument because there is a race
- // between 'go' and 'for'.
- go func(l net.Listener) {
- errors <- server.Serve(l)
- }(listener)
- }
-}
diff --git a/cmd/gitaly/bootstrap_test.go b/cmd/gitaly/bootstrap_test.go
deleted file mode 100644
index f6fdf6f90..000000000
--- a/cmd/gitaly/bootstrap_test.go
+++ /dev/null
@@ -1,141 +0,0 @@
-package main
-
-import (
- "context"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "net/http"
- "os"
- "path"
- "strconv"
- "testing"
- "time"
-
- "github.com/stretchr/testify/require"
-)
-
-// b is global because tableflip do not allow to init more than one Upgrader per process
-var b *bootstrap
-var socketPath = path.Join(os.TempDir(), "test-unix-socket")
-
-// TestMain helps testing bootstrap.
-// When invoked directly it behaves like a normal go test, but if a test performs an upgrade the children will
-// avoid the test suite and start a pid HTTP server on socketPath
-func TestMain(m *testing.M) {
- var err error
- b, err = newBootstrap("", true)
- if err != nil {
- panic(err)
- }
-
- if !b.HasParent() {
- // Execute test suite if there is no parent.
- os.Exit(m.Run())
- }
-
- // this is a test suite that triggered an upgrade, we are in the children here
- l, err := b.createUnixListener(socketPath)
- if err != nil {
- panic(err)
- }
-
- if err := b.Ready(); err != nil {
- panic(err)
- }
-
- done := make(chan struct{})
- srv := startPidServer(done, l)
-
- select {
- case <-done:
- //no op
- case <-time.After(2 * time.Minute):
- srv.Close()
- panic("safeguard against zombie process")
- }
-}
-
-func TestCreateUnixListener(t *testing.T) {
- // simulate a dangling socket
- if err := os.Remove(socketPath); err != nil {
- require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err)
- }
-
- file, err := os.OpenFile(socketPath, os.O_CREATE, 0755)
- require.NoError(t, err)
- require.NoError(t, file.Close())
-
- require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755))
-
- l, err := b.createUnixListener(socketPath)
- require.NoError(t, err)
-
- done := make(chan struct{})
- srv := startPidServer(done, l)
- defer srv.Close()
-
- require.NoError(t, b.Ready(), "not ready")
-
- myPid, err := askPid()
- require.NoError(t, err)
- require.Equal(t, os.Getpid(), myPid)
-
- // we trigger an upgrade and wait for children readiness
- require.NoError(t, b.Upgrade(), "upgrade failed")
- <-b.Exit()
- require.NoError(t, srv.Close())
- <-done
-
- childPid, err := askPid()
- require.NoError(t, err)
- require.NotEqual(t, os.Getpid(), childPid, "this request must be handled by the children")
-}
-
-func askPid() (int, error) {
- client := &http.Client{
- Transport: &http.Transport{
- DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
- return net.Dial("unix", socketPath)
- },
- },
- }
-
- response, err := client.Get("http://unix")
- if err != nil {
- return 0, err
- }
- defer response.Body.Close()
-
- pid, err := ioutil.ReadAll(response.Body)
- if err != nil {
- return 0, err
- }
-
- return strconv.Atoi(string(pid))
-}
-
-// startPidServer starts an HTTP server that returns the current PID, if running on a children it will kill itself after serving
-// the first client
-func startPidServer(done chan<- struct{}, l net.Listener) *http.Server {
- mux := http.NewServeMux()
- srv := &http.Server{Handler: mux}
-
- mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
- io.WriteString(w, fmt.Sprint(os.Getpid()))
-
- if b.HasParent() {
- time.AfterFunc(1*time.Second, func() { srv.Close() })
- }
- })
-
- go func() {
- if err := srv.Serve(l); err != http.ErrServerClosed {
- fmt.Printf("Serve error: %v", err)
- }
- close(done)
- }()
-
- return srv
-}
diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go
index 10149365f..55e0fcc46 100644
--- a/cmd/gitaly/main.go
+++ b/cmd/gitaly/main.go
@@ -9,6 +9,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/bootstrap"
"gitlab.com/gitlab-org/gitaly/internal/config"
"gitlab.com/gitlab-org/gitaly/internal/git"
"gitlab.com/gitlab-org/gitaly/internal/linguist"
@@ -78,11 +79,10 @@ func main() {
// gitaly-wrapper is supposed to set config.EnvUpgradesEnabled in order to enable graceful upgrades
_, isWrapped := os.LookupEnv(config.EnvUpgradesEnabled)
- b, err := newBootstrap(os.Getenv(config.EnvPidFile), isWrapped)
+ b, err := bootstrap.New(os.Getenv(config.EnvPidFile), isWrapped)
if err != nil {
log.WithError(err).Fatal("init bootstrap")
}
- defer b.Stop()
// If invoked with -version
if *flagVersion {
@@ -111,30 +111,59 @@ func main() {
tempdir.StartCleaning()
- if err = b.listen(); err != nil {
- log.WithError(err).Fatal("bootstrap failed")
+ log.WithError(run(b)).Error("shutting down")
+}
+
+// Inside here we can use deferred functions. This is needed because
+// log.Fatal bypasses deferred functions.
+func run(b *bootstrap.Bootstrap) error {
+ servers, err := bootstrap.NewServerFactory()
+ if err != nil {
+ return err
}
+ defer servers.Stop()
- if config.Config.PrometheusListenAddr != "" {
- l, err := b.prometheusListener()
- if err != nil {
- log.WithError(err).Fatal("configure prometheus listener")
- }
+ b.StopAction = servers.GracefulStop
- promMux := http.NewServeMux()
- promMux.Handle("/metrics", promhttp.Handler())
+ for _, c := range []starterConfig{
+ {unix, config.Config.SocketPath},
+ {tcp, config.Config.ListenAddr},
+ {tls, config.Config.TLSListenAddr},
+ } {
+ if c.addr == "" {
+ continue
+ }
- server.AddPprofHandlers(promMux)
+ b.RegisterStarter(gitalyStarter(c, servers))
+ }
- go func() {
- err = http.Serve(l, promMux)
+ if addr := config.Config.PrometheusListenAddr; addr != "" {
+ b.RegisterStarter(func(listen bootstrap.ListenFunc, _ chan<- error) error {
+ l, err := listen("tcp", addr)
if err != nil {
- log.WithError(err).Fatal("Unable to serve prometheus")
+ return err
}
- }()
+
+ log.WithField("address", addr).Info("starting prometheus listener")
+
+ promMux := http.NewServeMux()
+ promMux.Handle("/metrics", promhttp.Handler())
+
+ server.AddPprofHandlers(promMux)
+
+ go func() {
+ if err := http.Serve(l, promMux); err != nil {
+ log.WithError(err).Error("Unable to serve prometheus")
+ }
+ }()
+
+ return nil
+ })
}
- b.run()
+ if err := b.Start(); err != nil {
+ return fmt.Errorf("unable to start the bootstrap: %v", err)
+ }
- log.Fatal("shutting down")
+ return b.Wait()
}
diff --git a/cmd/gitaly/starter_config.go b/cmd/gitaly/starter_config.go
new file mode 100644
index 000000000..889855815
--- /dev/null
+++ b/cmd/gitaly/starter_config.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/bootstrap"
+)
+
+const (
+ tcp string = "tcp"
+ tls string = "tls"
+ unix string = "unix"
+)
+
+type starterConfig struct {
+ name, addr string
+}
+
+func (s *starterConfig) isSecure() bool {
+ return s.name == tls
+}
+
+func (s *starterConfig) family() string {
+ if s.isSecure() {
+ return tcp
+ }
+
+ return s.name
+}
+
+func gitalyStarter(cfg starterConfig, servers bootstrap.GracefulStoppableServer) bootstrap.Starter {
+ return func(listen bootstrap.ListenFunc, errCh chan<- error) error {
+ l, err := listen(cfg.family(), cfg.addr)
+ if err != nil {
+ return err
+ }
+
+ logrus.WithField("address", cfg.addr).Infof("listening at %s address", cfg.name)
+
+ go func() {
+ errCh <- servers.Serve(l, cfg.isSecure())
+ }()
+
+ return nil
+ }
+}
diff --git a/cmd/gitaly/starter_config_test.go b/cmd/gitaly/starter_config_test.go
new file mode 100644
index 000000000..377880f25
--- /dev/null
+++ b/cmd/gitaly/starter_config_test.go
@@ -0,0 +1,38 @@
+package main
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsSecure(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ secure bool
+ }{
+ {"tcp", false},
+ {"unix", false},
+ {"tls", true},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ conf := starterConfig{name: test.name}
+ require.Equal(t, test.secure, conf.isSecure())
+ })
+ }
+}
+
+func TestFamily(t *testing.T) {
+ for _, test := range []struct {
+ name, family string
+ }{
+ {"tcp", "tcp"},
+ {"unix", "unix"},
+ {"tls", "tcp"},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ conf := starterConfig{name: test.name}
+ require.Equal(t, test.family, conf.family())
+ })
+ }
+}
diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go
new file mode 100644
index 000000000..bfb58066c
--- /dev/null
+++ b/internal/bootstrap/bootstrap.go
@@ -0,0 +1,180 @@
+package bootstrap
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/cloudflare/tableflip"
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+)
+
+// Bootstrap handles graceful upgrades
+type Bootstrap struct {
+ // StopAction will be invoked during a graceful stop. It must wait until the shutdown is completed
+ StopAction func()
+
+ upgrader upgrader
+ listenFunc ListenFunc
+ errChan chan error
+ starters []Starter
+}
+
+type upgrader interface {
+ Exit() <-chan struct{}
+ HasParent() bool
+ Ready() error
+ Upgrade() error
+}
+
+// New performs tableflip initialization
+// pidFile is optional, if provided it will always contain the current process PID
+// upgradesEnabled controls the upgrade process on SIGHUP signal
+//
+// first boot:
+// * gitaly starts as usual, we will refer to it as p1
+// * New will build a tableflip.Upgrader, we will refer to it as upg
+// * sockets and files must be opened with upg.Fds
+// * p1 will trap SIGHUP and invoke upg.Upgrade()
+// * when ready to accept incoming connections p1 will call upg.Ready()
+// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate
+//
+// graceful upgrade:
+// * user replaces gitaly binary and/or config file
+// * user sends SIGHUP to p1
+// * p1 will fork and exec the new gitaly, we will refer to it as p2
+// * from now on p1 will ignore other SIGHUP
+// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored
+// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available
+// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed
+// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections
+// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome
+// freezes during a graceful shutdown
+func New(pidFile string, upgradesEnabled bool) (*Bootstrap, error) {
+ // PIDFile is optional, if provided tableflip will keep it updated
+ upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile})
+ if err != nil {
+ return nil, err
+ }
+
+ return _new(upg, upg.Fds.Listen, upgradesEnabled)
+}
+
+func _new(upg upgrader, listenFunc ListenFunc, upgradesEnabled bool) (*Bootstrap, error) {
+ if upgradesEnabled {
+ go func() {
+ sig := make(chan os.Signal, 1)
+ signal.Notify(sig, syscall.SIGHUP)
+
+ for range sig {
+ err := upg.Upgrade()
+ if err != nil {
+ log.WithError(err).Error("Upgrade failed")
+ continue
+ }
+
+ log.Info("Upgrade succeeded")
+ }
+ }()
+ }
+
+ return &Bootstrap{
+ upgrader: upg,
+ listenFunc: listenFunc,
+ }, nil
+}
+
+// ListenFunc is a net.Listener factory
+type ListenFunc func(net, addr string) (net.Listener, error)
+
+// Starter is function to initialize a net.Listener
+// it receives a ListenFunc to be used for net.Listener creation and a chan<- error to signal runtime errors
+// It must serve incoming connections asynchronously and signal errors on the channel
+// the return value is for setup errors
+type Starter func(ListenFunc, chan<- error) error
+
+func (b *Bootstrap) isFirstBoot() bool { return !b.upgrader.HasParent() }
+
+// RegisterStarter adds a new starter
+func (b *Bootstrap) RegisterStarter(starter Starter) {
+ b.starters = append(b.starters, starter)
+}
+
+// Start will invoke all the registered starters and wait asynchronously for runtime errors
+// in case a Starter fails then the error is returned and the function is aborted
+func (b *Bootstrap) Start() error {
+ b.errChan = make(chan error, len(b.starters))
+
+ for _, start := range b.starters {
+ if err := start(b.listen, b.errChan); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Wait will signal process readiness to the parent and than wait for an exit condition
+// SIGTERM, SIGINT and a runtime error will trigger an immediate shutdown
+// in case of an upgrade there will be a grace period to complete the ongoing requests
+func (b *Bootstrap) Wait() error {
+ signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT}
+ immediateShutdown := make(chan os.Signal, len(signals))
+ signal.Notify(immediateShutdown, signals...)
+
+ if err := b.upgrader.Ready(); err != nil {
+ return err
+ }
+
+ var err error
+ select {
+ case <-b.upgrader.Exit():
+ // this is the old process and a graceful upgrade is in progress
+ // the new process signaled its readiness and we started a graceful stop
+ // however no further upgrades can be started until this process is running
+ // we set a grace period and then we force a termination.
+ waitError := b.waitGracePeriod(immediateShutdown)
+
+ err = fmt.Errorf("graceful upgrade: %v", waitError)
+ case s := <-immediateShutdown:
+ err = fmt.Errorf("received signal %q", s)
+ case err = <-b.errChan:
+ }
+
+ return err
+}
+
+func (b *Bootstrap) waitGracePeriod(kill <-chan os.Signal) error {
+ log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period")
+
+ allServersDone := make(chan struct{})
+ go func() {
+ if b.StopAction != nil {
+ b.StopAction()
+ }
+ close(allServersDone)
+ }()
+
+ select {
+ case <-time.After(config.Config.GracefulRestartTimeout):
+ return fmt.Errorf("grace period expired")
+ case <-kill:
+ return fmt.Errorf("force shutdown")
+ case <-allServersDone:
+ return fmt.Errorf("completed")
+ }
+}
+
+func (b *Bootstrap) listen(network, path string) (net.Listener, error) {
+ if network == "unix" && b.isFirstBoot() {
+ if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+ }
+
+ return b.listenFunc(network, path)
+}
diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go
new file mode 100644
index 000000000..78cba7ac9
--- /dev/null
+++ b/internal/bootstrap/bootstrap_test.go
@@ -0,0 +1,324 @@
+package bootstrap
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "strconv"
+ "syscall"
+ "testing"
+ "time"
+
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+
+ "github.com/stretchr/testify/require"
+)
+
+var testConfigGracefulRestartTimeout = 2 * time.Second
+
+type mockUpgrader struct {
+ exit chan struct{}
+ hasParent bool
+}
+
+func (m *mockUpgrader) Exit() <-chan struct{} {
+ return m.exit
+}
+
+func (m *mockUpgrader) HasParent() bool {
+ return m.hasParent
+}
+
+func (m *mockUpgrader) Ready() error { return nil }
+
+func (m *mockUpgrader) Upgrade() error {
+ // to upgrade we close the exit channel
+ close(m.exit)
+ return nil
+}
+
+type testServer struct {
+ server *http.Server
+ listeners map[string]net.Listener
+ url string
+}
+
+func (s *testServer) slowRequest(duration time.Duration) <-chan error {
+ done := make(chan error)
+
+ go func() {
+ r, err := http.Get(fmt.Sprintf("%sslow?seconds=%d", s.url, int(duration.Seconds())))
+ if r != nil {
+ r.Body.Close()
+ }
+
+ done <- err
+ }()
+
+ return done
+}
+
+func TestCreateUnixListener(t *testing.T) {
+ socketPath := path.Join(os.TempDir(), "gitaly-test-unix-socket")
+ if err := os.Remove(socketPath); err != nil {
+ require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err)
+ }
+
+ // simulate a dangling socket
+ require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755))
+
+ listen := func(network, addr string) (net.Listener, error) {
+ require.Equal(t, "unix", network)
+ require.Equal(t, socketPath, addr)
+
+ return net.Listen(network, addr)
+ }
+ u := &mockUpgrader{}
+ b, err := _new(u, listen, false)
+ require.NoError(t, err)
+
+ // first boot
+ l, err := b.listen("unix", socketPath)
+ require.NoError(t, err, "failed to bind on first boot")
+ require.NoError(t, l.Close())
+
+ // simulate binding during an upgrade
+ u.hasParent = true
+ l, err = b.listen("unix", socketPath)
+ require.NoError(t, err, "failed to bind on upgrade")
+ require.NoError(t, l.Close())
+}
+
+func waitWithTimeout(t *testing.T, waitCh <-chan error, timeout time.Duration) error {
+ select {
+ case <-time.After(timeout):
+ t.Fatal("time out waiting for waitCh")
+ case waitErr := <-waitCh:
+ return waitErr
+ }
+
+ return nil
+}
+
+func TestImmediateTerminationOnSocketError(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ require.NoError(t, server.listeners["tcp"].Close(), "Closing first listener")
+
+ err := waitWithTimeout(t, waitCh, 1*time.Second)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "use of closed network connection")
+}
+
+func TestImmediateTerminationOnSignal(t *testing.T) {
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ t.Run(sig.String(), func(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ done := server.slowRequest(3 * time.Minute)
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ // make sure we are inside b.Wait() or we'll kill the test suite
+ time.Sleep(100 * time.Millisecond)
+
+ self, err := os.FindProcess(os.Getpid())
+ require.NoError(t, err)
+ require.NoError(t, self.Signal(sig))
+
+ waitErr := waitWithTimeout(t, waitCh, 1*time.Second)
+ require.Error(t, waitErr)
+ require.Contains(t, waitErr.Error(), "received signal")
+ require.Contains(t, waitErr.Error(), sig.String())
+
+ server.server.Close()
+
+ require.Error(t, <-done)
+ })
+ }
+}
+
+func TestGracefulTerminationStuck(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil)
+ require.Contains(t, err.Error(), "grace period expired")
+}
+
+func TestGracefulTerminationWithSignals(t *testing.T) {
+ self, err := os.FindProcess(os.Getpid())
+ require.NoError(t, err)
+
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ t.Run(sig.String(), func(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ err := testGracefulUpdate(t, server, b, 1*time.Second, func() {
+ require.NoError(t, self.Signal(sig))
+ })
+ require.Contains(t, err.Error(), "force shutdown")
+ })
+ }
+}
+
+func TestGracefulTerminationServerErrors(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ done := make(chan error, 1)
+ // This is a simulation of receiving a listener error during waitGracePeriod
+ b.StopAction = func() {
+ // we close the unix listener in order to test that the shutdown will not fail, but it keep waiting for the TCP request
+ require.NoError(t, server.listeners["unix"].Close())
+
+ // we start a new TCP request that if faster than the grace period
+ req := server.slowRequest(config.Config.GracefulRestartTimeout / 2)
+ done <- <-req
+ close(done)
+
+ server.server.Shutdown(context.Background())
+ }
+
+ err := testGracefulUpdate(t, server, b, testConfigGracefulRestartTimeout+(1*time.Second), nil)
+ require.Contains(t, err.Error(), "grace period expired")
+
+ require.NoError(t, <-done)
+}
+
+func TestGracefulTermination(t *testing.T) {
+ b, server := makeBootstrap(t)
+
+ // Using server.Close we bypass the graceful shutdown faking a completed shutdown
+ b.StopAction = func() { server.server.Close() }
+
+ err := testGracefulUpdate(t, server, b, 1*time.Second, nil)
+ require.Contains(t, err.Error(), "completed")
+}
+
+func testGracefulUpdate(t *testing.T, server *testServer, b *Bootstrap, waitTimeout time.Duration, duringGracePeriodCallback func()) error {
+ defer func(oldVal time.Duration) {
+ config.Config.GracefulRestartTimeout = oldVal
+ }(config.Config.GracefulRestartTimeout)
+ config.Config.GracefulRestartTimeout = testConfigGracefulRestartTimeout
+
+ waitCh := make(chan error)
+ go func() { waitCh <- b.Wait() }()
+
+ // Start a slow request to keep the old server from shutting down immediately.
+ req := server.slowRequest(2 * config.Config.GracefulRestartTimeout)
+
+ // make sure slow request is being handled
+ time.Sleep(100 * time.Millisecond)
+
+ // Simulate an upgrade request after entering into the blocking b.Wait() and during the slowRequest execution
+ b.upgrader.Upgrade()
+
+ if duringGracePeriodCallback != nil {
+ // make sure we are on the grace period
+ time.Sleep(100 * time.Millisecond)
+
+ duringGracePeriodCallback()
+ }
+
+ waitErr := waitWithTimeout(t, waitCh, waitTimeout)
+ require.Error(t, waitErr)
+ require.Contains(t, waitErr.Error(), "graceful upgrade")
+
+ server.server.Close()
+
+ clientErr := waitWithTimeout(t, req, 1*time.Second)
+ require.Error(t, clientErr, "slow request not terminated after the grace period")
+
+ return waitErr
+}
+
+func makeBootstrap(t *testing.T) (*Bootstrap, *testServer) {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(200)
+ })
+ mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) {
+ sec, err := strconv.Atoi(r.URL.Query().Get("seconds"))
+ require.NoError(t, err)
+
+ t.Logf("Serving a slow request for %d seconds", sec)
+ time.Sleep(time.Duration(sec) * time.Second)
+
+ w.WriteHeader(200)
+ })
+
+ s := http.Server{Handler: mux}
+ u := &mockUpgrader{exit: make(chan struct{})}
+
+ b, err := _new(u, net.Listen, false)
+ require.NoError(t, err)
+
+ b.StopAction = func() { s.Shutdown(context.Background()) }
+
+ listeners := make(map[string]net.Listener)
+ start := func(network, address string) Starter {
+ return func(listen ListenFunc, errors chan<- error) error {
+ l, err := listen(network, address)
+ if err != nil {
+ return err
+ }
+ listeners[network] = l
+
+ go func() {
+ errors <- s.Serve(l)
+ }()
+
+ return nil
+ }
+ }
+
+ for network, address := range map[string]string{
+ "tcp": "127.0.0.1:0",
+ "unix": path.Join(os.TempDir(), "gitaly-test-unix-socket"),
+ } {
+ b.RegisterStarter(start(network, address))
+ }
+
+ require.NoError(t, b.Start())
+ require.Equal(t, 2, len(listeners))
+
+ // test connection
+ testAllListeners(t, listeners)
+
+ addr := listeners["tcp"].Addr()
+ url := fmt.Sprintf("http://%s/", addr.String())
+
+ return b, &testServer{
+ server: &s,
+ listeners: listeners,
+ url: url,
+ }
+}
+
+func testAllListeners(t *testing.T, listeners map[string]net.Listener) {
+ for network, listener := range listeners {
+ addr := listener.Addr().String()
+
+ // overriding Client.Transport.Dial we can connect to TCP and UNIX sockets
+ client := &http.Client{
+ Transport: &http.Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ },
+ }
+
+ // we don't need a real address because we forced it on Dial
+ r, err := client.Get("http://fakeHost/")
+ require.NoError(t, err)
+ r.Body.Close()
+ require.Equal(t, 200, r.StatusCode)
+ }
+}
diff --git a/internal/bootstrap/server_factory.go b/internal/bootstrap/server_factory.go
new file mode 100644
index 000000000..f1bfa7624
--- /dev/null
+++ b/internal/bootstrap/server_factory.go
@@ -0,0 +1,93 @@
+package bootstrap
+
+import (
+ "net"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/rubyserver"
+ "gitlab.com/gitlab-org/gitaly/internal/server"
+ "google.golang.org/grpc"
+)
+
+type serverFactory struct {
+ ruby *rubyserver.Server
+ secure, insecure *grpc.Server
+}
+
+// GracefulStoppableServer allows to serve contents on a net.Listener, Stop serving and performing a GracefulStop
+type GracefulStoppableServer interface {
+ GracefulStop()
+ Stop()
+ Serve(l net.Listener, secure bool) error
+}
+
+// NewServerFactory initializes a rubyserver and then lazily initializes both secure and insecure grpc.Server
+func NewServerFactory() (GracefulStoppableServer, error) {
+ ruby, err := rubyserver.Start()
+ if err != nil {
+ log.Error("start ruby server")
+
+ return nil, err
+ }
+
+ return &serverFactory{ruby: ruby}, nil
+}
+
+func (s *serverFactory) Stop() {
+ for _, srv := range s.all() {
+ srv.Stop()
+ }
+
+ s.ruby.Stop()
+}
+
+func (s *serverFactory) GracefulStop() {
+ wg := sync.WaitGroup{}
+
+ for _, srv := range s.all() {
+ wg.Add(1)
+
+ go func(s *grpc.Server) {
+ s.GracefulStop()
+ wg.Done()
+ }(srv)
+ }
+
+ wg.Wait()
+}
+
+func (s *serverFactory) Serve(l net.Listener, secure bool) error {
+ srv := s.get(secure)
+
+ return srv.Serve(l)
+}
+
+func (s *serverFactory) get(secure bool) *grpc.Server {
+ if secure {
+ if s.secure == nil {
+ s.secure = server.NewSecure(s.ruby)
+ }
+
+ return s.secure
+ }
+
+ if s.insecure == nil {
+ s.insecure = server.NewInsecure(s.ruby)
+ }
+
+ return s.insecure
+}
+
+func (s *serverFactory) all() []*grpc.Server {
+ var servers []*grpc.Server
+ if s.secure != nil {
+ servers = append(servers, s.secure)
+ }
+
+ if s.insecure != nil {
+ servers = append(servers, s.insecure)
+ }
+
+ return servers
+}