diff options
-rw-r--r-- | cmd/gitaly/main.go | 6 | ||||
-rw-r--r-- | cmd/praefect/main.go | 5 | ||||
-rw-r--r-- | internal/bootstrap/bootstrap.go | 18 | ||||
-rw-r--r-- | internal/bootstrap/bootstrap_test.go | 450 |
4 files changed, 230 insertions, 249 deletions
diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go index db7c73a06..1571c1d73 100644 --- a/cmd/gitaly/main.go +++ b/cmd/gitaly/main.go @@ -29,6 +29,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/storage" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/transaction" "gitlab.com/gitlab-org/gitaly/v14/internal/gitlab" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" glog "gitlab.com/gitlab-org/gitaly/v14/internal/log" "gitlab.com/gitlab-org/gitaly/v14/internal/streamcache" "gitlab.com/gitlab-org/gitaly/v14/internal/tempdir" @@ -301,5 +302,8 @@ func run(cfg config.Cfg) error { } }() - return b.Wait(cfg.GracefulRestartTimeout.Duration(), gitalyServerFactory.GracefulStop) + gracefulStopTicker := helper.NewTimerTicker(cfg.GracefulRestartTimeout.Duration()) + defer gracefulStopTicker.Stop() + + return b.Wait(gracefulStopTicker, gitalyServerFactory.GracefulStop) } diff --git a/cmd/praefect/main.go b/cmd/praefect/main.go index 56f56ab9a..d444aff88 100644 --- a/cmd/praefect/main.go +++ b/cmd/praefect/main.go @@ -514,7 +514,10 @@ func run( logger.Warn(`Repository cleanup background task disabled as "repositories_cleanup.run_interval" is not set or 0.`) } - return b.Wait(conf.GracefulStopTimeout.Duration(), srvFactory.GracefulStop) + gracefulStopTicker := helper.NewTimerTicker(conf.GracefulStopTimeout.Duration()) + defer gracefulStopTicker.Stop() + + return b.Wait(gracefulStopTicker, srvFactory.GracefulStop) } func getStarterConfigs(conf config.Config) ([]starter.Config, error) { diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index f1d355926..3cb08d6fa 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -6,10 +6,10 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/cloudflare/tableflip" log "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" "gitlab.com/gitlab-org/gitaly/v14/internal/helper/env" "golang.org/x/sys/unix" ) @@ -29,7 +29,7 @@ type Listener interface { // Start starts all registered starters to accept connections. Start() error // Wait terminates all registered starters. - Wait(gracefulTimeout time.Duration, stopAction func()) error + Wait(gracePeriodTicker helper.Ticker, stopAction func()) error } // Bootstrap handles graceful upgrades @@ -160,7 +160,7 @@ func (b *Bootstrap) Start() error { // SIGTERM, SIGINT and a runtime error will trigger an immediate shutdown // in case of an upgrade there will be a grace period to complete the ongoing requests // stopAction will be invoked during a graceful stop. It must wait until the shutdown is completed. -func (b *Bootstrap) Wait(gracefulTimeout time.Duration, stopAction func()) error { +func (b *Bootstrap) Wait(gracePeriodTicker helper.Ticker, stopAction func()) error { signals := []os.Signal{syscall.SIGTERM, syscall.SIGINT} immediateShutdown := make(chan os.Signal, len(signals)) signal.Notify(immediateShutdown, signals...) @@ -176,7 +176,7 @@ func (b *Bootstrap) Wait(gracefulTimeout time.Duration, stopAction func()) error // the new process signaled its readiness and we started a graceful stop // however no further upgrades can be started until this process is running // we set a grace period and then we force a termination. - waitError := b.waitGracePeriod(gracefulTimeout, immediateShutdown, stopAction) + waitError := b.waitGracePeriod(gracePeriodTicker, immediateShutdown, stopAction) err = fmt.Errorf("graceful upgrade: %v", waitError) case s := <-immediateShutdown: @@ -188,8 +188,8 @@ func (b *Bootstrap) Wait(gracefulTimeout time.Duration, stopAction func()) error return err } -func (b *Bootstrap) waitGracePeriod(gracefulTimeout time.Duration, kill <-chan os.Signal, stopAction func()) error { - log.WithField("graceful_timeout", gracefulTimeout).Warn("starting grace period") +func (b *Bootstrap) waitGracePeriod(gracePeriodTicker helper.Ticker, kill <-chan os.Signal, stopAction func()) error { + log.Warn("starting grace period") allServersDone := make(chan struct{}) go func() { @@ -199,8 +199,10 @@ func (b *Bootstrap) waitGracePeriod(gracefulTimeout time.Duration, kill <-chan o close(allServersDone) }() + gracePeriodTicker.Reset() + select { - case <-time.After(gracefulTimeout): + case <-gracePeriodTicker.C(): return fmt.Errorf("grace period expired") case <-kill: return fmt.Errorf("force shutdown") @@ -249,7 +251,7 @@ func (n *Noop) Start() error { } // Wait terminates all registered starters. -func (n *Noop) Wait(_ time.Duration, stopAction func()) error { +func (n *Noop) Wait(_ helper.Ticker, stopAction func()) error { select { case <-n.shutdown: if stopAction != nil { diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index 8890f25c8..872a62ec0 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -2,231 +2,227 @@ package bootstrap import ( "context" - "errors" "fmt" - "io" "net" - "net/http" "os" "path/filepath" - "strconv" "syscall" "testing" - "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) -type mockUpgrader struct { - exit chan struct{} - hasParent bool -} - -func (m *mockUpgrader) Exit() <-chan struct{} { - return m.exit -} - -func (m *mockUpgrader) Stop() {} - -func (m *mockUpgrader) HasParent() bool { - return m.hasParent -} - -func (m *mockUpgrader) Ready() error { return nil } - -func (m *mockUpgrader) Upgrade() error { - // to upgrade we close the exit channel - close(m.exit) - return nil -} - -type testServer struct { - t *testing.T - ctx context.Context - server *http.Server - listeners map[string]net.Listener - url string -} - -func (s *testServer) slowRequest(duration time.Duration) <-chan error { - done := make(chan error) - - go func() { - request, err := http.NewRequestWithContext(s.ctx, http.MethodGet, fmt.Sprintf("%sslow?seconds=%d", s.url, int(duration.Seconds())), nil) - require.NoError(s.t, err) - - response, err := http.DefaultClient.Do(request) - if response != nil { - _, err := io.Copy(io.Discard, response.Body) - require.NoError(s.t, err) - require.NoError(s.t, response.Body.Close()) - } - - done <- err - }() - - return done -} - -func TestCreateUnixListener(t *testing.T) { - tempDir := testhelper.TempDir(t) - - socketPath := filepath.Join(tempDir, "gitaly-test-unix-socket") - if err := os.Remove(socketPath); err != nil { - require.True(t, os.IsNotExist(err), "cannot delete dangling socket: %v", err) - } - - // simulate a dangling socket - require.NoError(t, os.WriteFile(socketPath, nil, 0o755)) - - listen := func(network, addr string) (net.Listener, error) { - require.Equal(t, "unix", network) - require.Equal(t, socketPath, addr) +func TestBootstrap_unixListener(t *testing.T) { + for _, tc := range []struct { + desc string + hasParent bool + preexistingSocket bool + expectSocketExists bool + }{ + { + desc: "no parent, no preexisting socket", + hasParent: false, + preexistingSocket: false, + expectSocketExists: false, + }, + { + desc: "no parent, preexisting socket", + hasParent: false, + preexistingSocket: true, + // On first boot, the bootstrapper is expected to remove any preexisting + // sockets. + expectSocketExists: false, + }, + { + desc: "parent, no preexisting socket", + hasParent: true, + preexistingSocket: false, + expectSocketExists: false, + }, + { + desc: "parent, preexisting socket", + hasParent: true, + preexistingSocket: true, + // When we do have a parent, then we cannot remove the socket or otherwise + // we might impact the parent process's ability to serve requests. + expectSocketExists: true, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + tempDir := testhelper.TempDir(t) + socketPath := filepath.Join(tempDir, "gitaly-test-unix-socket") + + sentinel := &mockListener{} + listen := func(network, addr string) (net.Listener, error) { + require.Equal(t, "unix", network) + require.Equal(t, socketPath, addr) + if tc.expectSocketExists { + require.FileExists(t, socketPath) + } else { + require.NoFileExists(t, socketPath) + } + + return sentinel, nil + } - return net.Listen(network, addr) - } - u := &mockUpgrader{} - b, err := _new(u, listen, false) - require.NoError(t, err) + upgrader := &mockUpgrader{ + hasParent: tc.hasParent, + } - // first boot - l, err := b.listen("unix", socketPath) - require.NoError(t, err, "failed to bind on first boot") - require.NoError(t, l.Close()) + b, err := _new(upgrader, listen, false) + require.NoError(t, err) - // simulate binding during an upgrade - u.hasParent = true - l, err = b.listen("unix", socketPath) - require.NoError(t, err, "failed to bind on upgrade") - require.NoError(t, l.Close()) -} + if tc.preexistingSocket { + require.NoError(t, os.WriteFile(socketPath, nil, 0o755)) + } -func waitWithTimeout(t *testing.T, waitCh <-chan error, timeout time.Duration) error { - select { - case <-time.After(timeout): - t.Fatal("time out waiting for waitCh") - case waitErr := <-waitCh: - return waitErr + listener, err := b.listen("unix", socketPath) + require.NoError(t, err) + require.Equal(t, sentinel, listener) + }) } - - return nil } -func TestImmediateTerminationOnSocketError(t *testing.T) { +func TestBootstrap_listenerError(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, stopAction := makeBootstrap(t, ctx) + b, upgrader, listeners := setup(t, ctx) waitCh := make(chan error) - go func() { waitCh <- b.Wait(2*time.Second, stopAction) }() + go func() { waitCh <- b.Wait(helper.NewManualTicker(), nil) }() + + // Signal readiness, but don't start the upgrade. Like this, we can close the listener in a + // raceless manner and wait for the error to propagate. + upgrader.readyCh <- nil - require.NoError(t, server.listeners["tcp"].Close(), "Closing first listener") + // Inject a listener error. + listeners["tcp"].errorCh <- assert.AnError - err := waitWithTimeout(t, waitCh, 1*time.Second) - require.Error(t, err) - require.True(t, errors.Is(err, net.ErrClosed), "expected closed connection error, got %T: %q", err, err) + require.Equal(t, assert.AnError, <-waitCh) } -func TestImmediateTerminationOnSignal(t *testing.T) { +func TestBootstrap_signal(t *testing.T) { for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { t.Run(sig.String(), func(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, stopAction := makeBootstrap(t, ctx) - - done := server.slowRequest(3 * time.Minute) + b, upgrader, _ := setup(t, ctx) waitCh := make(chan error) - go func() { waitCh <- b.Wait(2*time.Second, stopAction) }() + go func() { waitCh <- b.Wait(helper.NewManualTicker(), nil) }() - // make sure we are inside b.Wait() or we'll kill the test suite - time.Sleep(100 * time.Millisecond) + // Start the upgrade, but don't unblock `Exit()` such that we'll be blocked + // waiting on the parent. + upgrader.readyCh <- nil + // We can now kill ourselves. This signal should be retrieved by `Wait()`, + // which would then return an error. self, err := os.FindProcess(os.Getpid()) require.NoError(t, err) require.NoError(t, self.Signal(sig)) - waitErr := waitWithTimeout(t, waitCh, 1*time.Second) - require.Error(t, waitErr) - require.Contains(t, waitErr.Error(), "received signal") - require.Contains(t, waitErr.Error(), sig.String()) - - server.server.Close() - - require.Error(t, <-done) + require.Equal(t, fmt.Errorf("received signal %q", sig), <-waitCh) }) } } -func TestGracefulTerminationStuck(t *testing.T) { +func TestBootstrap_gracefulTerminationStuck(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, stopAction := makeBootstrap(t, ctx) + b, upgrader, _ := setup(t, ctx) - err := testGracefulUpdate(t, server, b, 3*time.Second, 2*time.Second, nil, stopAction) - require.Contains(t, err.Error(), "grace period expired") -} + gracePeriodTicker := helper.NewManualTicker() -func TestGracefulTerminationWithSignals(t *testing.T) { - self, err := os.FindProcess(os.Getpid()) - require.NoError(t, err) + doneCh := make(chan struct{}) + err := performUpgrade(t, b, upgrader, gracePeriodTicker, nil, func() { + defer close(doneCh) + + gracePeriodTicker.Tick() + + // We block on context cancellation here, which essentially means that this won't + // terminate and thus the graceful termination will be stuck. + <-ctx.Done() + }) + require.Equal(t, fmt.Errorf("graceful upgrade: grace period expired"), err) + cancel() + <-doneCh +} + +func TestBootstrap_gracefulTerminationWithSignals(t *testing.T) { for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { t.Run(sig.String(), func(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, stopAction := makeBootstrap(t, ctx) - err := testGracefulUpdate(t, server, b, 1*time.Second, 2*time.Second, func() { + b, upgrader, _ := setup(t, ctx) + + doneCh := make(chan struct{}) + err := performUpgrade(t, b, upgrader, helper.NewManualTicker(), func() { + self, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) require.NoError(t, self.Signal(sig)) - }, stopAction) - require.Contains(t, err.Error(), "force shutdown") + }, func() { + defer close(doneCh) + // Block the upgrade indefinitely such that we can be sure that the + // signal was processed. + <-ctx.Done() + }) + require.Equal(t, fmt.Errorf("graceful upgrade: force shutdown"), err) + + cancel() + <-doneCh }) } } -func TestGracefulTerminationServerErrors(t *testing.T) { +func TestBootstrap_gracefulTerminationTimeoutWithListenerError(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, _ := makeBootstrap(t, ctx) - done := make(chan error, 1) - // This is a simulation of receiving a listener error during waitGracePeriod - stopAction := func() { - // we close the unix listener in order to test that the shutdown will not fail, but it keep waiting for the TCP request - require.NoError(t, server.listeners["unix"].Close()) + b, upgrader, listeners := setup(t, ctx) - // we start a new TCP request that if faster than the grace period - req := server.slowRequest(time.Second) - done <- <-req - close(done) + gracePeriodTicker := helper.NewManualTicker() - require.NoError(t, server.server.Shutdown(context.Background())) - } + doneCh := make(chan struct{}) + err := performUpgrade(t, b, upgrader, gracePeriodTicker, nil, func() { + defer close(doneCh) + + // We inject an error into the Unix socket to assert that this won't kill the server + // immediately, but waits for the TCP connection to terminate as expected. + listeners["unix"].errorCh <- assert.AnError + + gracePeriodTicker.Tick() - err := testGracefulUpdate(t, server, b, 3*time.Second, 2*time.Second, nil, stopAction) - require.Contains(t, err.Error(), "grace period expired") + // We block on context cancellation here, which essentially means that this won't + // terminate. + <-ctx.Done() + }) + require.Equal(t, fmt.Errorf("graceful upgrade: grace period expired"), err) - require.NoError(t, <-done) + cancel() + <-doneCh } -func TestGracefulTermination(t *testing.T) { +func TestBootstrap_gracefulTermination(t *testing.T) { ctx, cancel := testhelper.Context() defer cancel() - b, server, _ := makeBootstrap(t, ctx) - // Using server.Close we bypass the graceful shutdown faking a completed shutdown - stopAction := func() { server.server.Close() } + b, upgrader, _ := setup(t, ctx) - err := testGracefulUpdate(t, server, b, 1*time.Second, 2*time.Second, nil, stopAction) - require.Contains(t, err.Error(), "completed") + require.Equal(t, + fmt.Errorf("graceful upgrade: completed"), + performUpgrade(t, b, upgrader, helper.NewManualTicker(), nil, nil), + ) } -func TestPortReuse(t *testing.T) { +func TestBootstrap_portReuse(t *testing.T) { b, err := New() require.NoError(t, err) @@ -244,84 +240,54 @@ func TestPortReuse(t *testing.T) { b.upgrader.Stop() } -func testGracefulUpdate(t *testing.T, server *testServer, b *Bootstrap, waitTimeout, gracefulWait time.Duration, duringGracePeriodCallback func(), stopAction func()) error { +func performUpgrade( + t *testing.T, + b *Bootstrap, + upgrader *mockUpgrader, + gracePeriodTicker helper.Ticker, + duringGracePeriodCallback func(), + stopAction func(), +) error { waitCh := make(chan error) - go func() { waitCh <- b.Wait(gracefulWait, stopAction) }() - - // Start a slow request to keep the old server from shutting down immediately. - req := server.slowRequest(2 * gracefulWait) - - // make sure slow request is being handled - time.Sleep(100 * time.Millisecond) + go func() { waitCh <- b.Wait(gracePeriodTicker, stopAction) }() - // Simulate an upgrade request after entering into the blocking b.Wait() and during the slowRequest execution - require.NoError(t, b.upgrader.Upgrade()) + // Simulate an upgrade request after entering into the blocking b.Wait() and during the + // slowRequest execution + upgrader.readyCh <- nil + upgrader.exitCh <- struct{}{} + // We know that `exitCh` has been consumed, so we're now in the grace period where we wait + // for the old server to exit. if duringGracePeriodCallback != nil { - // make sure we are on the grace period - time.Sleep(100 * time.Millisecond) - duringGracePeriodCallback() } - waitErr := waitWithTimeout(t, waitCh, waitTimeout) - require.Error(t, waitErr) - require.Contains(t, waitErr.Error(), "graceful upgrade") - - server.server.Close() - - clientErr := waitWithTimeout(t, req, 1*time.Second) - require.Error(t, clientErr, "slow request not terminated after the grace period") - - return waitErr + return <-waitCh } -func makeBootstrap(t *testing.T, ctx context.Context) (*Bootstrap, *testServer, func()) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(200) - }) - mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) { - sec, err := strconv.Atoi(r.URL.Query().Get("seconds")) - require.NoError(t, err) - - select { - case <-ctx.Done(): - case <-time.After(time.Duration(sec) * time.Second): - } - - w.WriteHeader(200) - }) - - s := http.Server{Handler: mux} - t.Cleanup(func() { testhelper.MustClose(t, &s) }) - u := &mockUpgrader{exit: make(chan struct{})} +func setup(t *testing.T, ctx context.Context) (*Bootstrap, *mockUpgrader, mockListeners) { + u := &mockUpgrader{ + exitCh: make(chan struct{}), + readyCh: make(chan error), + } b, err := _new(u, net.Listen, false) require.NoError(t, err) - listeners := make(map[string]net.Listener) + listeners := mockListeners{} start := func(network, address string) Starter { - return func(listen ListenFunc, errors chan<- error) error { - l, err := listen(network, address) - if err != nil { - return err - } - listeners[network] = l - - go func() { - errors <- s.Serve(l) - }() + listeners[network] = &mockListener{} + return func(listen ListenFunc, errors chan<- error) error { + listeners[network].errorCh = errors + listeners[network].listening = true return nil } } - tempDir := testhelper.TempDir(t) - for network, address := range map[string]string{ "tcp": "127.0.0.1:0", - "unix": filepath.Join(tempDir, "gitaly-test-unix-socket"), + "unix": "some-socket", } { b.RegisterStarter(start(network, address)) } @@ -329,44 +295,50 @@ func makeBootstrap(t *testing.T, ctx context.Context) (*Bootstrap, *testServer, require.NoError(t, b.Start()) require.Equal(t, 2, len(listeners)) - // test connection - testAllListeners(t, ctx, listeners) + for _, listener := range listeners { + require.True(t, listener.listening) + } - addr := listeners["tcp"].Addr() - url := fmt.Sprintf("http://%s/", addr.String()) + return b, u, listeners +} - return b, &testServer{ - t: t, - ctx: ctx, - server: &s, - listeners: listeners, - url: url, - }, func() { require.NoError(t, s.Shutdown(context.Background())) } +type mockUpgrader struct { + exitCh chan struct{} + readyCh chan error + hasParent bool } -func testAllListeners(t *testing.T, ctx context.Context, listeners map[string]net.Listener) { - for network, listener := range listeners { - addr := listener.Addr().String() - - // overriding Client.Transport.Dial we can connect to TCP and UNIX sockets - client := &http.Client{ - Transport: &http.Transport{ - Dial: func(_, _ string) (net.Conn, error) { - return net.Dial(network, addr) - }, - }, - } +func (m *mockUpgrader) Exit() <-chan struct{} { + return m.exitCh +} - request, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://fakeHost/", nil) - require.NoError(t, err) +func (m *mockUpgrader) Stop() {} - r, err := client.Do(request) - require.NoError(t, err) +func (m *mockUpgrader) HasParent() bool { + return m.hasParent +} - _, err = io.Copy(io.Discard, r.Body) - require.NoError(t, err) - require.NoError(t, r.Body.Close()) +func (m *mockUpgrader) Ready() error { + return <-m.readyCh +} - require.Equal(t, 200, r.StatusCode) - } +func (m *mockUpgrader) Upgrade() error { + // To upgrade, we send a message on the exit channel. Like this, we can assert that the exit + // signal has been consumed given that we'd otherwise block forever. + m.exitCh <- struct{}{} + return nil } + +type mockListener struct { + net.Listener + errorCh chan<- error + closed bool + listening bool +} + +func (m *mockListener) Close() error { + m.closed = true + return nil +} + +type mockListeners map[string]*mockListener |