Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorAlessio Caiazza <acaiazza@gitlab.com>2019-04-05 19:07:22 +0300
committerZeger-Jan van de Weg <git@zjvandeweg.nl>2019-04-05 19:07:22 +0300
commit21f9326edb73c8f9a3a3a51e7d5f07b122712350 (patch)
tree84480c437ff78adbe8814ea475d19441cfe91a8f /cmd
parent34c93abeaad6e2900bc05e0a76bb02e7d6b9e383 (diff)
Zero downtime deployment
Diffstat (limited to 'cmd')
-rw-r--r--cmd/gitaly-wrapper/main.go147
-rw-r--r--cmd/gitaly/bootstrap.go205
-rw-r--r--cmd/gitaly/bootstrap_test.go141
-rw-r--r--cmd/gitaly/main.go108
4 files changed, 512 insertions, 89 deletions
diff --git a/cmd/gitaly-wrapper/main.go b/cmd/gitaly-wrapper/main.go
new file mode 100644
index 000000000..646661b47
--- /dev/null
+++ b/cmd/gitaly-wrapper/main.go
@@ -0,0 +1,147 @@
+package main
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "os/signal"
+ "strconv"
+ "syscall"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+)
+
+const (
+ envJSONLogging = "WRAPPER_JSON_LOGGING"
+)
+
+func main() {
+ if jsonLogging() {
+ logrus.SetFormatter(&logrus.JSONFormatter{})
+ }
+
+ if len(os.Args) < 2 {
+ logrus.Fatalf("usage: %s forking_binary [args]", os.Args[0])
+ }
+
+ gitalyBin, gitalyArgs := os.Args[1], os.Args[2:]
+
+ log := logrus.WithField("wrapper", os.Getpid())
+ log.Info("Wrapper started")
+
+ if pidFile() == "" {
+ log.Fatalf("missing pid file ENV variable %q", config.EnvPidFile)
+ }
+
+ log.WithField("pid_file", pidFile()).Info("finding gitaly")
+ gitaly, err := findGitaly()
+ if err != nil {
+ log.WithError(err).Fatal("find gitaly")
+ }
+
+ if gitaly != nil {
+ log.Info("adopting a process")
+ } else {
+ log.Info("spawning a process")
+
+ proc, err := spawnGitaly(gitalyBin, gitalyArgs)
+ if err != nil {
+ log.WithError(err).Fatal("spawn gitaly")
+ }
+
+ gitaly = proc
+ }
+
+ log = log.WithField("gitaly", gitaly.Pid)
+ log.Info("monitoring gitaly")
+
+ forwardSignals(gitaly, log)
+
+ // wait
+ for isAlive(gitaly) {
+ time.Sleep(1 * time.Second)
+ }
+
+ log.Error("wrapper for gitaly shutting down")
+}
+
+func findGitaly() (*os.Process, error) {
+ pid, err := getPid()
+ if err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ // os.FindProcess on unix do not return an error if the process does not exist
+ gitaly, err := os.FindProcess(pid)
+ if err != nil {
+ return nil, err
+ }
+
+ if isAlive(gitaly) {
+ return gitaly, nil
+ }
+
+ return nil, nil
+}
+
+func spawnGitaly(bin string, args []string) (*os.Process, error) {
+ cmd := exec.Command(bin, args...)
+ cmd.Env = append(os.Environ(), fmt.Sprintf("%s=true", config.EnvUpgradesEnabled))
+
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ // This cmd.Wait() is crucial. Without it we cannot detect if the command we just spawned has crashed.
+ go cmd.Wait()
+
+ return cmd.Process, nil
+}
+
+func forwardSignals(gitaly *os.Process, log *logrus.Entry) {
+ sigs := make(chan os.Signal, 1)
+ go func() {
+ for sig := range sigs {
+ log.WithField("signal", sig).Warning("forwarding signal")
+
+ if err := gitaly.Signal(sig); err != nil {
+ log.WithField("signal", sig).WithError(err).Error("can't forward the signal")
+ }
+
+ }
+ }()
+
+ signal.Notify(sigs)
+}
+
+func getPid() (int, error) {
+ data, err := ioutil.ReadFile(pidFile())
+ if err != nil {
+ return 0, err
+ }
+
+ return strconv.Atoi(string(data))
+}
+
+func isAlive(p *os.Process) bool {
+ // After p exits, and after it gets reaped, this p.Signal will fail. It is crucial that p gets reaped.
+ // If p was spawned by the current process, it will get reaped from a goroutine that does cmd.Wait().
+ // If p was spawned by someone else we rely on them to reap it, or on p to become an orphan.
+ // In the orphan case p should get reaped by the OS (PID 1).
+ return p.Signal(syscall.Signal(0)) == nil
+}
+
+func pidFile() string {
+ return os.Getenv(config.EnvPidFile)
+}
+
+func jsonLogging() bool {
+ return os.Getenv(envJSONLogging) == "true"
+}
diff --git a/cmd/gitaly/bootstrap.go b/cmd/gitaly/bootstrap.go
new file mode 100644
index 000000000..d3b464052
--- /dev/null
+++ b/cmd/gitaly/bootstrap.go
@@ -0,0 +1,205 @@
+package main
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/cloudflare/tableflip"
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/config"
+ "gitlab.com/gitlab-org/gitaly/internal/connectioncounter"
+ "gitlab.com/gitlab-org/gitaly/internal/rubyserver"
+ "gitlab.com/gitlab-org/gitaly/internal/server"
+ "google.golang.org/grpc"
+)
+
+type bootstrap struct {
+ *tableflip.Upgrader
+
+ insecureListeners []net.Listener
+ secureListeners []net.Listener
+
+ serversErrors chan error
+}
+
+// newBootstrap performs tableflip initialization
+//
+// first boot:
+// * gitaly starts as usual, we will refer to it as p1
+// * newBootstrap will build a tableflip.Upgrader, we will refer to it as upg
+// * sockets and files must be opened with upg.Fds
+// * p1 will trap SIGHUP and invoke upg.Upgrade()
+// * when ready to accept incoming connections p1 will call upg.Ready()
+// * upg.Exit() channel will be closed when an upgrades completed successfully and the process must terminate
+//
+// graceful upgrade:
+// * user replaces gitaly binary and/or config file
+// * user sends SIGHUP to p1
+// * p1 will fork and exec the new gitaly, we will refer to it as p2
+// * from now on p1 will ignore other SIGHUP
+// * if p2 terminates with a non-zero exit code, SIGHUP handling will be restored
+// * p2 will follow the "first boot" sequence but upg.Fds will provide sockets and files from p1, when available
+// * when p2 invokes upg.Ready() all the shared file descriptors not claimed by p2 will be closed
+// * upg.Exit() channel in p1 will be closed now and p1 can gracefully terminate already accepted connections
+// * upgrades cannot starts again if p1 and p2 are both running, an hard termination should be scheduled to overcome
+// freezes during a graceful shutdown
+func newBootstrap(pidFile string, upgradesEnabled bool) (*bootstrap, error) {
+ // PIDFile is optional, if provided tableflip will keep it updated
+ upg, err := tableflip.New(tableflip.Options{PIDFile: pidFile})
+ if err != nil {
+ return nil, err
+ }
+
+ if upgradesEnabled {
+ go func() {
+ sig := make(chan os.Signal, 1)
+ signal.Notify(sig, syscall.SIGHUP)
+
+ for range sig {
+ err := upg.Upgrade()
+ if err != nil {
+ log.WithError(err).Error("Upgrade failed")
+ continue
+ }
+
+ log.Info("Upgrade succeeded")
+ }
+ }()
+ }
+
+ return &bootstrap{Upgrader: upg}, nil
+}
+
+func (b *bootstrap) listen() error {
+ if socketPath := config.Config.SocketPath; socketPath != "" {
+ l, err := b.createUnixListener(socketPath)
+ if err != nil {
+ return err
+ }
+
+ log.WithField("address", socketPath).Info("listening on unix socket")
+ b.insecureListeners = append(b.insecureListeners, l)
+ }
+
+ if addr := config.Config.ListenAddr; addr != "" {
+ l, err := b.Fds.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ log.WithField("address", addr).Info("listening at tcp address")
+ b.insecureListeners = append(b.insecureListeners, connectioncounter.New("tcp", l))
+ }
+
+ if addr := config.Config.TLSListenAddr; addr != "" {
+ tlsListener, err := b.Fds.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ b.secureListeners = append(b.secureListeners, connectioncounter.New("tls", tlsListener))
+ }
+
+ b.serversErrors = make(chan error, len(b.insecureListeners)+len(b.secureListeners))
+
+ return nil
+}
+
+func (b *bootstrap) prometheusListener() (net.Listener, error) {
+ log.WithField("address", config.Config.PrometheusListenAddr).Info("starting prometheus listener")
+
+ return b.Fds.Listen("tcp", config.Config.PrometheusListenAddr)
+}
+
+func (b *bootstrap) run() {
+ signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT}
+ done := make(chan os.Signal, len(signals))
+ signal.Notify(done, signals...)
+
+ ruby, err := rubyserver.Start()
+ if err != nil {
+ log.WithError(err).Error("start ruby server")
+ return
+ }
+ defer ruby.Stop()
+
+ if len(b.insecureListeners) > 0 {
+ insecureServer := server.NewInsecure(ruby)
+ defer insecureServer.Stop()
+
+ serve(insecureServer, b.insecureListeners, b.Exit(), b.serversErrors)
+ }
+
+ if len(b.secureListeners) > 0 {
+ secureServer := server.NewSecure(ruby)
+ defer secureServer.Stop()
+
+ serve(secureServer, b.secureListeners, b.Exit(), b.serversErrors)
+ }
+
+ if err := b.Ready(); err != nil {
+ log.WithError(err).Error("incomplete bootstrap")
+ return
+ }
+
+ select {
+ case <-b.Exit():
+ // this is the old process and a graceful upgrade is in progress
+ // the new process signaled its readiness and we started a graceful stop
+ // however no further upgrades can be started until this process is running
+ // we set a grace period and then we force a termination.
+ b.waitGracePeriod(done)
+
+ err = fmt.Errorf("graceful upgrade")
+ case s := <-done:
+ err = fmt.Errorf("received signal %q", s)
+ case err = <-b.serversErrors:
+ }
+
+ log.WithError(err).Error("terminating")
+}
+
+func (b *bootstrap) waitGracePeriod(kill <-chan os.Signal) {
+ log.WithField("graceful_restart_timeout", config.Config.GracefulRestartTimeout).Warn("starting grace period")
+
+ select {
+ case <-time.After(config.Config.GracefulRestartTimeout):
+ log.Error("old process stuck on termination. Grace period expired.")
+ case <-kill:
+ log.Error("force shutdown")
+ case <-b.serversErrors:
+ log.Info("graceful stop completed")
+ }
+}
+
+func (b *bootstrap) createUnixListener(socketPath string) (net.Listener, error) {
+ if !b.HasParent() {
+ // During an update the unix socket exists and if we delete it tableflip will not create a new one
+ if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+ }
+
+ l, err := b.Fds.Listen("unix", socketPath)
+ return connectioncounter.New("unix", l), err
+}
+
+func serve(server *grpc.Server, listeners []net.Listener, done <-chan struct{}, errors chan<- error) {
+ go func() {
+ <-done
+
+ server.GracefulStop()
+ }()
+
+ for _, listener := range listeners {
+ // Must pass the listener as a function argument because there is a race
+ // between 'go' and 'for'.
+ go func(l net.Listener) {
+ errors <- server.Serve(l)
+ }(listener)
+ }
+}
diff --git a/cmd/gitaly/bootstrap_test.go b/cmd/gitaly/bootstrap_test.go
new file mode 100644
index 000000000..f6fdf6f90
--- /dev/null
+++ b/cmd/gitaly/bootstrap_test.go
@@ -0,0 +1,141 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "path"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// b is global because tableflip do not allow to init more than one Upgrader per process
+var b *bootstrap
+var socketPath = path.Join(os.TempDir(), "test-unix-socket")
+
+// TestMain helps testing bootstrap.
+// When invoked directly it behaves like a normal go test, but if a test performs an upgrade the children will
+// avoid the test suite and start a pid HTTP server on socketPath
+func TestMain(m *testing.M) {
+ var err error
+ b, err = newBootstrap("", true)
+ if err != nil {
+ panic(err)
+ }
+
+ if !b.HasParent() {
+ // Execute test suite if there is no parent.
+ os.Exit(m.Run())
+ }
+
+ // this is a test suite that triggered an upgrade, we are in the children here
+ l, err := b.createUnixListener(socketPath)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := b.Ready(); err != nil {
+ panic(err)
+ }
+
+ done := make(chan struct{})
+ srv := startPidServer(done, l)
+
+ select {
+ case <-done:
+ //no op
+ case <-time.After(2 * time.Minute):
+ srv.Close()
+ panic("safeguard against zombie process")
+ }
+}
+
+func TestCreateUnixListener(t *testing.T) {
+ // simulate a dangling socket
+ if err := os.Remove(socketPath); err != nil {
+ require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err)
+ }
+
+ file, err := os.OpenFile(socketPath, os.O_CREATE, 0755)
+ require.NoError(t, err)
+ require.NoError(t, file.Close())
+
+ require.NoError(t, ioutil.WriteFile(socketPath, nil, 0755))
+
+ l, err := b.createUnixListener(socketPath)
+ require.NoError(t, err)
+
+ done := make(chan struct{})
+ srv := startPidServer(done, l)
+ defer srv.Close()
+
+ require.NoError(t, b.Ready(), "not ready")
+
+ myPid, err := askPid()
+ require.NoError(t, err)
+ require.Equal(t, os.Getpid(), myPid)
+
+ // we trigger an upgrade and wait for children readiness
+ require.NoError(t, b.Upgrade(), "upgrade failed")
+ <-b.Exit()
+ require.NoError(t, srv.Close())
+ <-done
+
+ childPid, err := askPid()
+ require.NoError(t, err)
+ require.NotEqual(t, os.Getpid(), childPid, "this request must be handled by the children")
+}
+
+func askPid() (int, error) {
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", socketPath)
+ },
+ },
+ }
+
+ response, err := client.Get("http://unix")
+ if err != nil {
+ return 0, err
+ }
+ defer response.Body.Close()
+
+ pid, err := ioutil.ReadAll(response.Body)
+ if err != nil {
+ return 0, err
+ }
+
+ return strconv.Atoi(string(pid))
+}
+
+// startPidServer starts an HTTP server that returns the current PID, if running on a children it will kill itself after serving
+// the first client
+func startPidServer(done chan<- struct{}, l net.Listener) *http.Server {
+ mux := http.NewServeMux()
+ srv := &http.Server{Handler: mux}
+
+ mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
+ io.WriteString(w, fmt.Sprint(os.Getpid()))
+
+ if b.HasParent() {
+ time.AfterFunc(1*time.Second, func() { srv.Close() })
+ }
+ })
+
+ go func() {
+ if err := srv.Serve(l); err != http.ErrServerClosed {
+ fmt.Printf("Serve error: %v", err)
+ }
+ close(done)
+ }()
+
+ return srv
+}
diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go
index 3be1e38e8..10149365f 100644
--- a/cmd/gitaly/main.go
+++ b/cmd/gitaly/main.go
@@ -3,20 +3,15 @@ package main
import (
"flag"
"fmt"
- "net"
"net/http"
"os"
- "os/signal"
- "syscall"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/internal/config"
- "gitlab.com/gitlab-org/gitaly/internal/connectioncounter"
"gitlab.com/gitlab-org/gitaly/internal/git"
"gitlab.com/gitlab-org/gitaly/internal/linguist"
- "gitlab.com/gitlab-org/gitaly/internal/rubyserver"
"gitlab.com/gitlab-org/gitaly/internal/server"
"gitlab.com/gitlab-org/gitaly/internal/tempdir"
"gitlab.com/gitlab-org/gitaly/internal/version"
@@ -81,6 +76,14 @@ func main() {
flag.Usage = flagUsage
flag.Parse()
+ // gitaly-wrapper is supposed to set config.EnvUpgradesEnabled in order to enable graceful upgrades
+ _, isWrapped := os.LookupEnv(config.EnvUpgradesEnabled)
+ b, err := newBootstrap(os.Getenv(config.EnvPidFile), isWrapped)
+ if err != nil {
+ log.WithError(err).Fatal("init bootstrap")
+ }
+ defer b.Stop()
+
// If invoked with -version
if *flagVersion {
fmt.Println(version.GetVersionString())
@@ -108,103 +111,30 @@ func main() {
tempdir.StartCleaning()
- var insecureListeners []net.Listener
- var secureListeners []net.Listener
-
- if socketPath := config.Config.SocketPath; socketPath != "" {
- l, err := createUnixListener(socketPath)
- if err != nil {
- log.WithError(err).Fatal("configure unix listener")
- }
- log.WithField("address", socketPath).Info("listening on unix socket")
- insecureListeners = append(insecureListeners, l)
- }
-
- if addr := config.Config.ListenAddr; addr != "" {
- l, err := net.Listen("tcp", addr)
- if err != nil {
- log.WithError(err).Fatal("configure tcp listener")
- }
-
- log.WithField("address", addr).Info("listening at tcp address")
- insecureListeners = append(insecureListeners, connectioncounter.New("tcp", l))
+ if err = b.listen(); err != nil {
+ log.WithError(err).Fatal("bootstrap failed")
}
- if addr := config.Config.TLSListenAddr; addr != "" {
- tlsListener, err := net.Listen("tcp", addr)
+ if config.Config.PrometheusListenAddr != "" {
+ l, err := b.prometheusListener()
if err != nil {
- log.WithError(err).Fatal("configure tls listener")
+ log.WithError(err).Fatal("configure prometheus listener")
}
- secureListeners = append(secureListeners, connectioncounter.New("tls", tlsListener))
- }
-
- if config.Config.PrometheusListenAddr != "" {
- log.WithField("address", config.Config.PrometheusListenAddr).Info("Starting prometheus listener")
promMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler())
server.AddPprofHandlers(promMux)
go func() {
- http.ListenAndServe(config.Config.PrometheusListenAddr, promMux)
+ err = http.Serve(l, promMux)
+ if err != nil {
+ log.WithError(err).Fatal("Unable to serve prometheus")
+ }
}()
}
- log.WithError(run(insecureListeners, secureListeners)).Fatal("shutting down")
-}
-
-func createUnixListener(socketPath string) (net.Listener, error) {
- if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
- return nil, err
- }
- l, err := net.Listen("unix", socketPath)
- return connectioncounter.New("unix", l), err
-}
-
-// Inside here we can use deferred functions. This is needed because
-// log.Fatal bypasses deferred functions.
-func run(insecureListeners, secureListeners []net.Listener) error {
- signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT}
- termCh := make(chan os.Signal, len(signals))
- signal.Notify(termCh, signals...)
-
- ruby, err := rubyserver.Start()
- if err != nil {
- return err
- }
- defer ruby.Stop()
-
- serverErrors := make(chan error, len(insecureListeners)+len(secureListeners))
- if len(insecureListeners) > 0 {
- insecureServer := server.NewInsecure(ruby)
- defer insecureServer.Stop()
-
- for _, listener := range insecureListeners {
- // Must pass the listener as a function argument because there is a race
- // between 'go' and 'for'.
- go func(l net.Listener) {
- serverErrors <- insecureServer.Serve(l)
- }(listener)
- }
- }
-
- if len(secureListeners) > 0 {
- secureServer := server.NewSecure(ruby)
- defer secureServer.Stop()
-
- for _, listener := range secureListeners {
- go func(l net.Listener) {
- serverErrors <- secureServer.Serve(l)
- }(listener)
- }
- }
-
- select {
- case s := <-termCh:
- err = fmt.Errorf("received signal %q", s)
- case err = <-serverErrors:
- }
+ b.run()
- return err
+ log.Fatal("shutting down")
}