Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Steinhardt <psteinhardt@gitlab.com>2020-09-07 10:29:19 +0300
committerPatrick Steinhardt <psteinhardt@gitlab.com>2020-09-07 11:16:11 +0300
commit86c5d708df567f605fe44f48389c504b9f170b3d (patch)
tree33274ea1d9cc584cce18df4763294c2582ab0960 /internal/gitaly/rubyserver
parent8a0d83b496eaff459d55bc6825428f648059e8d9 (diff)
gitaly: Move Gitaly-specific code into `internal/gitaly`
Since the introduction of Praefect, our code layout started to become confusing: while Praefect code lives in `internal/praefect`, Gitaly-specific code is all over the place and not neatly singled out. This makes it hard at times to tell apart Praefect- and Gitaly-specific from generic code. To improve the situation, this commit thus moves most of the server specific code into a new `internal/gitaly` package. Currently, this is the `internal/config`, `internal/server`, `internal/service` and `internal/rubyserver` packages, which are all main components of Gitaly. The move was realized with the following script: #!/bin/sh mkdir -p internal/gitaly git mv internal/{config,server,service,rubyserver} internal/gitaly/ find . -name '*.go' -exec sed -i \ -e 's|gitlab-org/gitaly/internal/rubyserver|gitlab-org/gitaly/internal/gitaly/rubyserver|' \ -e 's|gitlab-org/gitaly/internal/server|gitlab-org/gitaly/internal/gitaly/server|' \ -e 's|gitlab-org/gitaly/internal/service|gitlab-org/gitaly/internal/gitaly/service|' \ -e 's|gitlab-org/gitaly/internal/config|gitlab-org/gitaly/internal/gitaly/config|' {} \; In addition to that, some minor adjustments were needed for tests which used relative paths.
Diffstat (limited to 'internal/gitaly/rubyserver')
-rw-r--r--internal/gitaly/rubyserver/balancer/balancer.go250
-rw-r--r--internal/gitaly/rubyserver/balancer/balancer_test.go238
-rw-r--r--internal/gitaly/rubyserver/balancer/pool.go59
-rw-r--r--internal/gitaly/rubyserver/concurrency_test.go98
-rw-r--r--internal/gitaly/rubyserver/health.go36
-rw-r--r--internal/gitaly/rubyserver/health_test.go31
-rw-r--r--internal/gitaly/rubyserver/proxy.go148
-rw-r--r--internal/gitaly/rubyserver/proxy_test.go61
-rw-r--r--internal/gitaly/rubyserver/rubyserver.go277
-rw-r--r--internal/gitaly/rubyserver/rubyserver_test.go85
-rw-r--r--internal/gitaly/rubyserver/stopwatch.go35
-rw-r--r--internal/gitaly/rubyserver/testhelper_test.go26
-rw-r--r--internal/gitaly/rubyserver/worker.go226
-rw-r--r--internal/gitaly/rubyserver/worker_test.go239
14 files changed, 1809 insertions, 0 deletions
diff --git a/internal/gitaly/rubyserver/balancer/balancer.go b/internal/gitaly/rubyserver/balancer/balancer.go
new file mode 100644
index 000000000..c2f1a8ae7
--- /dev/null
+++ b/internal/gitaly/rubyserver/balancer/balancer.go
@@ -0,0 +1,250 @@
+package balancer
+
+// In this package we manage a global pool of addresses for gitaly-ruby,
+// accessed via the gitaly-ruby:// scheme. The interface consists of the
+// AddAddress and RemoveAddress methods. RemoveAddress returns a boolean
+// indicating whether the address was removed; this is intended to give
+// back-pressure against repeated process restarts.
+//
+// The gitaly-ruby:// scheme exists because that is the way we can
+// interact with the internal client-side loadbalancer of grpc-go. A URL
+// for this scheme would be gitaly-ruby://foobar. For gitaly-ruby://
+// URL's, the host and port are ignored. So gitaly-ruby://foobar is
+// actually a working, valid address.
+//
+// Strictly speaking this package implements a gRPC 'Resolver'. This
+// resolver feeds address list updates to a gRPC 'balancer' which
+// interacts with the gRPC client connection machinery. A resolver
+// consists of a Builder which returns Resolver instances. Our Builder
+// manages the address pool and notifies its Resolver instances of
+// changes, which they then propagate into the gRPC library.
+//
+
+import (
+ "time"
+
+ "google.golang.org/grpc/resolver"
+)
+
+var (
+ lbBuilder = newBuilder()
+)
+
+func init() {
+ resolver.Register(lbBuilder)
+}
+
+const (
+ // DefaultRemoveDelay is the minimum time between successive address removals.
+ DefaultRemoveDelay = 1 * time.Minute
+)
+
+// AddAddress adds the address of a gitaly-ruby instance to the load
+// balancer.
+func AddAddress(a string) {
+ lbBuilder.addAddress <- a
+}
+
+// RemoveAddress removes the address of a gitaly-ruby instance from the
+// load balancer. Returns false if the pool is too small to remove the
+// address.
+func RemoveAddress(addr string) bool {
+ ok := make(chan bool)
+ lbBuilder.removeAddress <- addressRemoval{ok: ok, addr: addr}
+ return <-ok
+}
+
+type addressRemoval struct {
+ addr string
+ ok chan<- bool
+}
+
+type addressUpdate struct {
+ addrs []resolver.Address
+ next chan struct{}
+}
+
+type config struct {
+ numAddrs int
+ removeDelay time.Duration
+}
+
+type builder struct {
+ addAddress chan string
+ removeAddress chan addressRemoval
+ addressUpdates chan addressUpdate
+ configUpdate chan config
+
+ // testingTriggerRestart is for testing only. It causes b.monitor(...) to
+ // re-execute.
+ testingTriggerRestart chan struct{}
+}
+
+// ConfigureBuilder changes the configuration of the global balancer
+// instance. All calls that interact with the balancer will block until
+// ConfigureBuilder has been called at least once.
+func ConfigureBuilder(numAddrs int, removeDelay time.Duration) {
+ cfg := config{
+ numAddrs: numAddrs,
+ removeDelay: removeDelay,
+ }
+
+ if cfg.removeDelay <= 0 {
+ cfg.removeDelay = DefaultRemoveDelay
+ }
+ if numAddrs <= 0 {
+ panic("numAddrs must be at least 1")
+ }
+
+ lbBuilder.configUpdate <- cfg
+}
+
+func newBuilder() *builder {
+ b := &builder{
+ addAddress: make(chan string),
+ removeAddress: make(chan addressRemoval),
+ addressUpdates: make(chan addressUpdate),
+ configUpdate: make(chan config),
+ testingTriggerRestart: make(chan struct{}),
+ }
+ go b.monitor()
+
+ return b
+}
+
+// Scheme is the name of the address scheme that makes gRPC select this resolver.
+const Scheme = "gitaly-ruby"
+
+func (*builder) Scheme() string { return Scheme }
+
+// Build ignores its resolver.Target argument. That means it does not
+// care what "address" the caller wants to resolve. We always resolve to
+// the current list of address for local gitaly-ruby processes.
+func (b *builder) Build(_ resolver.Target, cc resolver.ClientConn, _ resolver.BuildOption) (resolver.Resolver, error) {
+ //nolint:staticcheck // There is no obvious way to use UpdateState() without completely replacing the current configuration
+ cc.NewServiceConfig(`{"LoadBalancingPolicy":"round_robin"}`)
+ return newGitalyResolver(cc, b.addressUpdates), nil
+}
+
+// monitor serves address list requests and handles address updates.
+func (b *builder) monitor() {
+ p := newPool()
+ notify := make(chan struct{})
+ cfg := <-b.configUpdate
+
+ // At this point, there has been no previous removal command yet, so the
+ // "last removal" is undefined. We want it to default to "long enough
+ // ago".
+ lastRemoval := time.Now().Add(-1 * time.Hour)
+
+ // This channel is intentionally nil so that our 'select' below won't
+ // send messages to it. We do this to prevent sending out invalid (empty)
+ // messages during boot.
+ var addressUpdates chan addressUpdate
+
+ for {
+ au := addressUpdate{next: notify}
+ for _, a := range p.activeAddrs() {
+ au.addrs = append(au.addrs, resolver.Address{Addr: a})
+ }
+
+ if len(au.addrs) > 0 && addressUpdates == nil {
+ // Start listening for address update requests
+ addressUpdates = b.addressUpdates
+ }
+
+ select {
+ case addressUpdates <- au:
+ // We have served an address update request
+ case addr := <-b.addAddress:
+ p.add(addr)
+
+ notify = broadcast(notify)
+ case removal := <-b.removeAddress:
+ if time.Since(lastRemoval) < cfg.removeDelay || p.activeSize() < cfg.numAddrs-1 {
+ removal.ok <- false
+ break
+ }
+
+ if !p.remove(removal.addr) {
+ removal.ok <- false
+ break
+ }
+
+ removal.ok <- true
+ lastRemoval = time.Now()
+ notify = broadcast(notify)
+ case cfg = <-b.configUpdate:
+ // We have received a config update
+ case <-b.testingTriggerRestart:
+ go b.monitor()
+ b.configUpdate <- cfg
+ return
+ }
+ }
+}
+
+// broadcast returns a fresh channel because we can only close them once
+func broadcast(ch chan struct{}) chan struct{} {
+ close(ch)
+ return make(chan struct{})
+}
+
+// gitalyResolver propagates address list updates to a
+// resolver.ClientConn instance
+type gitalyResolver struct {
+ clientConn resolver.ClientConn
+
+ started chan struct{}
+ done chan struct{}
+ resolveNow chan struct{}
+ addressUpdates chan addressUpdate
+}
+
+func newGitalyResolver(cc resolver.ClientConn, auCh chan addressUpdate) *gitalyResolver {
+ r := &gitalyResolver{
+ started: make(chan struct{}),
+ done: make(chan struct{}),
+ resolveNow: make(chan struct{}),
+ addressUpdates: auCh,
+ clientConn: cc,
+ }
+ go r.monitor()
+
+ // Don't return until we have sent at least one address update. This is
+ // meant to avoid panics inside the grpc-go library.
+ <-r.started
+
+ return r
+}
+
+func (r *gitalyResolver) ResolveNow(resolver.ResolveNowOption) {
+ r.resolveNow <- struct{}{}
+}
+
+func (r *gitalyResolver) Close() {
+ close(r.done)
+}
+
+func (r *gitalyResolver) monitor() {
+ notify := r.sendUpdate()
+ close(r.started)
+
+ for {
+ select {
+ case <-notify:
+ notify = r.sendUpdate()
+ case <-r.resolveNow:
+ notify = r.sendUpdate()
+ case <-r.done:
+ return
+ }
+ }
+}
+
+func (r *gitalyResolver) sendUpdate() chan struct{} {
+ au := <-r.addressUpdates
+ //nolint:staticcheck // There is no obvious way to use UpdateState() without completely replacing the current configuration
+ r.clientConn.NewAddress(au.addrs)
+ return au.next
+}
diff --git a/internal/gitaly/rubyserver/balancer/balancer_test.go b/internal/gitaly/rubyserver/balancer/balancer_test.go
new file mode 100644
index 000000000..269a37ece
--- /dev/null
+++ b/internal/gitaly/rubyserver/balancer/balancer_test.go
@@ -0,0 +1,238 @@
+package balancer
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc/resolver"
+)
+
+func TestServiceConfig(t *testing.T) {
+ configureBuilderTest(3)
+
+ tcc := &testClientConn{}
+ lbBuilder.Build(resolver.Target{}, tcc, resolver.BuildOption{})
+
+ configUpdates := tcc.ConfigUpdates()
+ require.Len(t, configUpdates, 1, "expect exactly one config update")
+
+ svcConfig := struct{ LoadBalancingPolicy string }{}
+ require.NoError(t, json.NewDecoder(strings.NewReader(configUpdates[0])).Decode(&svcConfig))
+ require.Equal(t, "round_robin", svcConfig.LoadBalancingPolicy)
+}
+
+func TestAddressUpdatesSmallestPool(t *testing.T) {
+ // The smallest number of addresses is 2: 1 standby, and 1 active.
+ addrs := configureBuilderTest(2)
+
+ tcc := &testClientConn{}
+ lbBuilder.Build(resolver.Target{}, tcc, resolver.BuildOption{})
+
+ // Simulate some random updates
+ RemoveAddress(addrs[0])
+ RemoveAddress(addrs[0])
+ AddAddress(addrs[0])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[0])
+ AddAddress(addrs[1])
+ AddAddress(addrs[1])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[0])
+ AddAddress(addrs[0])
+
+ addrUpdates := tcc.AddrUpdates()
+ require.True(t, len(addrUpdates) > 0, "expected at least one address update")
+
+ expectedActive := len(addrs) - 1 // subtract 1 for the standby
+ for _, update := range addrUpdates {
+ require.Equal(t, expectedActive, len(update))
+ }
+}
+
+func TestAddressUpdatesRoundRobinPool(t *testing.T) {
+ // With 3 addresses in the pool, 2 will be active.
+ addrs := configureBuilderTest(3)
+
+ tcc := &testClientConn{}
+ lbBuilder.Build(resolver.Target{}, tcc, resolver.BuildOption{})
+
+ // Simulate some random updates
+ RemoveAddress(addrs[0])
+ RemoveAddress(addrs[0])
+ RemoveAddress(addrs[2])
+ AddAddress(addrs[0])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[0])
+ AddAddress(addrs[2])
+ AddAddress(addrs[1])
+ AddAddress(addrs[1])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[2])
+ RemoveAddress(addrs[1])
+ AddAddress(addrs[1])
+ RemoveAddress(addrs[2])
+ RemoveAddress(addrs[1])
+ RemoveAddress(addrs[0])
+ AddAddress(addrs[0])
+
+ addrUpdates := tcc.AddrUpdates()
+ require.True(t, len(addrUpdates) > 0, "expected at least one address update")
+
+ expectedActive := len(addrs) - 1 // subtract 1 for the standby
+ for _, update := range addrUpdates {
+ require.Equal(t, expectedActive, len(update))
+ }
+}
+
+func TestRemovals(t *testing.T) {
+ okActions := []action{
+ {add: "foo"},
+ {add: "bar"},
+ {add: "qux"},
+ {remove: "bar"},
+ {add: "baz"},
+ {remove: "foo"},
+ }
+ numAddr := 3
+ removeDelay := 1 * time.Millisecond
+ ConfigureBuilder(numAddr, removeDelay)
+
+ testCases := []struct {
+ desc string
+ actions []action
+ lastFails bool
+ delay time.Duration
+ }{
+ {
+ desc: "add then remove",
+ actions: okActions,
+ delay: 2 * removeDelay,
+ },
+ {
+ desc: "add then remove but too fast",
+ actions: okActions,
+ lastFails: true,
+ delay: 0,
+ },
+ {
+ desc: "remove one address too many",
+ actions: append(okActions, action{remove: "qux"}),
+ lastFails: true,
+ delay: 2 * removeDelay,
+ },
+ {
+ desc: "remove unknown address",
+ actions: []action{
+ {add: "foo"},
+ {add: "qux"},
+ {add: "baz"},
+ {remove: "bar"},
+ },
+ lastFails: true,
+ delay: 2 * removeDelay,
+ },
+ {
+ // This relies on the implementation detail that the first address added
+ // to the balancer is the standby. The standby cannot be removed.
+ desc: "remove standby address",
+ actions: []action{
+ {add: "foo"},
+ {add: "qux"},
+ {add: "baz"},
+ {remove: "foo"},
+ },
+ lastFails: true,
+ delay: 2 * removeDelay,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ lbBuilder.testingTriggerRestart <- struct{}{}
+
+ for i, a := range tc.actions {
+ if a.add != "" {
+ AddAddress(a.add)
+ } else {
+ if tc.delay > 0 {
+ time.Sleep(tc.delay)
+ }
+
+ expected := true
+ if i+1 == len(tc.actions) && tc.lastFails {
+ expected = false
+ }
+
+ require.Equal(t, expected, RemoveAddress(a.remove), "expected result from removing %q", a.remove)
+ }
+ }
+ })
+ }
+}
+
+type action struct {
+ add string
+ remove string
+}
+
+type testClientConn struct {
+ resolver.ClientConn
+
+ addrUpdates [][]resolver.Address
+ configUpdates []string
+ mu sync.Mutex
+}
+
+func (tcc *testClientConn) NewAddress(addresses []resolver.Address) {
+ tcc.mu.Lock()
+ defer tcc.mu.Unlock()
+
+ tcc.addrUpdates = append(tcc.addrUpdates, addresses)
+}
+
+func (tcc *testClientConn) NewServiceConfig(serviceConfig string) {
+ tcc.mu.Lock()
+ defer tcc.mu.Unlock()
+
+ tcc.configUpdates = append(tcc.configUpdates, serviceConfig)
+}
+
+func (tcc *testClientConn) AddrUpdates() [][]resolver.Address {
+ tcc.mu.Lock()
+ defer tcc.mu.Unlock()
+
+ return tcc.addrUpdates
+}
+
+func (tcc *testClientConn) ConfigUpdates() []string {
+ tcc.mu.Lock()
+ defer tcc.mu.Unlock()
+
+ return tcc.configUpdates
+}
+
+func (tcc *testClientConn) UpdateState(state resolver.State) {}
+
+// configureBuilderTest reconfigures the global builder and pre-populates
+// it with addresses. It returns the list of addresses it added.
+func configureBuilderTest(numAddrs int) []string {
+ delay := 1 * time.Millisecond
+ ConfigureBuilder(numAddrs, delay)
+ lbBuilder.testingTriggerRestart <- struct{}{}
+
+ var addrs []string
+ for i := 0; i < numAddrs; i++ {
+ a := fmt.Sprintf("test.%d", i)
+ AddAddress(a)
+ addrs = append(addrs, a)
+ }
+
+ return addrs
+}
diff --git a/internal/gitaly/rubyserver/balancer/pool.go b/internal/gitaly/rubyserver/balancer/pool.go
new file mode 100644
index 000000000..f34097990
--- /dev/null
+++ b/internal/gitaly/rubyserver/balancer/pool.go
@@ -0,0 +1,59 @@
+package balancer
+
+func newPool() *pool {
+ return &pool{active: make(map[string]struct{})}
+}
+
+// pool is a set that keeps one address (element) set aside as a standby.
+// This data structure is not thread safe.
+type pool struct {
+ standby string
+ active map[string]struct{}
+}
+
+// add is idempotent. If there is no standby address yet, addr becomes
+// the standby.
+func (p *pool) add(addr string) {
+ if _, ok := p.active[addr]; ok || p.standby == addr {
+ return
+ }
+
+ if p.standby == "" {
+ p.standby = addr
+ return
+ }
+
+ p.active[addr] = struct{}{}
+}
+
+func (p *pool) activeSize() int {
+ return len(p.active)
+}
+
+// remove tries to remove addr from the active addresses. If addr is not
+// known or not active, remove returns false.
+func (p *pool) remove(addr string) bool {
+ if _, ok := p.active[addr]; !ok || p.standby == "" {
+ return false
+ }
+
+ delete(p.active, addr)
+
+ // Promote the standby to an active address
+ p.active[p.standby] = struct{}{}
+ p.standby = ""
+
+ return true
+}
+
+// activeAddrs returns the currently active addresses as a list. The
+// order is not deterministic.
+func (p *pool) activeAddrs() []string {
+ var addrs []string
+
+ for a := range p.active {
+ addrs = append(addrs, a)
+ }
+
+ return addrs
+}
diff --git a/internal/gitaly/rubyserver/concurrency_test.go b/internal/gitaly/rubyserver/concurrency_test.go
new file mode 100644
index 000000000..1472ec389
--- /dev/null
+++ b/internal/gitaly/rubyserver/concurrency_test.go
@@ -0,0 +1,98 @@
+package rubyserver
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/config"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "google.golang.org/grpc/codes"
+ healthpb "google.golang.org/grpc/health/grpc_health_v1"
+ "google.golang.org/grpc/status"
+)
+
+func waitPing(s *Server) error {
+ var err error
+ for start := time.Now(); time.Since(start) < ConnectTimeout; time.Sleep(100 * time.Millisecond) {
+ err = makeRequest(s)
+ if err == nil {
+ return nil
+ }
+ }
+ return err
+}
+
+// This benchmark lets you see what happens when you throw a lot of
+// concurrent traffic at gitaly-ruby.
+func BenchmarkConcurrency(b *testing.B) {
+ config.Config.Ruby.NumWorkers = 2
+
+ s := &Server{}
+ require.NoError(b, s.Start())
+ defer s.Stop()
+
+ // Warm-up: wait for gitaly-ruby to boot
+ if err := waitPing(s); err != nil {
+ b.Fatal(err)
+ }
+
+ concurrency := 100
+ b.Run(fmt.Sprintf("concurrency %d", concurrency), func(b *testing.B) {
+ errCh := make(chan error)
+ errCount := make(chan int)
+ go func() {
+ count := 0
+ for err := range errCh {
+ b.Log(err)
+ count++
+ }
+ errCount <- count
+ }()
+
+ wg := &sync.WaitGroup{}
+ for i := 0; i < concurrency; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ for j := 0; j < 1000; j++ {
+ err := makeRequest(s)
+ if err != nil {
+ errCh <- err
+ }
+
+ switch status.Code(err) {
+ case codes.Unavailable:
+ return
+ case codes.DeadlineExceeded:
+ return
+ }
+ }
+ }()
+ }
+
+ wg.Wait()
+ close(errCh)
+
+ if count := <-errCount; count != 0 {
+ b.Fatalf("received %d errors", count)
+ }
+ })
+}
+
+func makeRequest(s *Server) error {
+ ctx, cancel := testhelper.Context(testhelper.ContextWithTimeout(time.Second))
+ defer cancel()
+
+ conn, err := s.getConnection(ctx)
+ if err != nil {
+ return err
+ }
+
+ client := healthpb.NewHealthClient(conn)
+ _, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
+ return err
+}
diff --git a/internal/gitaly/rubyserver/health.go b/internal/gitaly/rubyserver/health.go
new file mode 100644
index 000000000..cf12a8e63
--- /dev/null
+++ b/internal/gitaly/rubyserver/health.go
@@ -0,0 +1,36 @@
+package rubyserver
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+
+ "google.golang.org/grpc"
+ healthpb "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+func ping(address string) error {
+ conn, err := grpc.Dial(
+ address,
+ grpc.WithInsecure(),
+ // Use a custom dialer to ensure that we don't experience
+ // issues in environments that have proxy configurations
+ // https://gitlab.com/gitlab-org/gitaly/merge_requests/1072#note_140408512
+ grpc.WithContextDialer(func(ctx context.Context, addr string) (conn net.Conn, err error) {
+ d := net.Dialer{}
+ return d.DialContext(ctx, "unix", addr)
+ }),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to connect to gitaly-ruby worker: %v", err)
+ }
+ defer conn.Close()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+
+ client := healthpb.NewHealthClient(conn)
+ _, err = client.Check(ctx, &healthpb.HealthCheckRequest{})
+ return err
+}
diff --git a/internal/gitaly/rubyserver/health_test.go b/internal/gitaly/rubyserver/health_test.go
new file mode 100644
index 000000000..c074435e8
--- /dev/null
+++ b/internal/gitaly/rubyserver/health_test.go
@@ -0,0 +1,31 @@
+package rubyserver
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestPingSuccess(t *testing.T) {
+ s := &Server{}
+ require.NoError(t, s.Start())
+ defer s.Stop()
+
+ require.True(t, len(s.workers) > 0, "expected at least one worker in server")
+ w := s.workers[0]
+
+ var pingErr error
+ for start := time.Now(); time.Since(start) < ConnectTimeout; time.Sleep(100 * time.Millisecond) {
+ pingErr = ping(w.address)
+ if pingErr == nil {
+ break
+ }
+ }
+
+ require.NoError(t, pingErr, "health check should pass")
+}
+
+func TestPingFail(t *testing.T) {
+ require.Error(t, ping("fake address"), "health check should fail")
+}
diff --git a/internal/gitaly/rubyserver/proxy.go b/internal/gitaly/rubyserver/proxy.go
new file mode 100644
index 000000000..17c64c6ba
--- /dev/null
+++ b/internal/gitaly/rubyserver/proxy.go
@@ -0,0 +1,148 @@
+package rubyserver
+
+import (
+ "context"
+ "io"
+ "os"
+ "strings"
+
+ "gitlab.com/gitlab-org/gitaly/internal/helper"
+ praefect_metadata "gitlab.com/gitlab-org/gitaly/internal/praefect/metadata"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc/metadata"
+)
+
+// Headers prefixed with this string get whitelisted automatically
+const rubyFeaturePrefix = "gitaly-feature-ruby-"
+
+const (
+ storagePathHeader = "gitaly-storage-path"
+ repoPathHeader = "gitaly-repo-path"
+ glRepositoryHeader = "gitaly-gl-repository"
+ repoAltDirsHeader = "gitaly-repo-alt-dirs"
+)
+
+// SetHeadersWithoutRepoCheck adds headers that tell gitaly-ruby the full
+// path to the repository. It is not an error if the repository does not
+// yet exist. This can be used on RPC calls that will create a
+// repository.
+func SetHeadersWithoutRepoCheck(ctx context.Context, repo *gitalypb.Repository) (context.Context, error) {
+ return setHeaders(ctx, repo, false)
+}
+
+// SetHeaders adds headers that tell gitaly-ruby the full path to the repository.
+func SetHeaders(ctx context.Context, repo *gitalypb.Repository) (context.Context, error) {
+ return setHeaders(ctx, repo, true)
+}
+
+func setHeaders(ctx context.Context, repo *gitalypb.Repository, mustExist bool) (context.Context, error) {
+ storagePath, err := helper.GetStorageByName(repo.GetStorageName())
+ if err != nil {
+ return nil, err
+ }
+
+ var repoPath string
+ if mustExist {
+ repoPath, err = helper.GetRepoPath(repo)
+ } else {
+ repoPath, err = helper.GetPath(repo)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ repoAltDirs := repo.GetGitAlternateObjectDirectories()
+ repoAltDirs = append(repoAltDirs, repo.GetGitObjectDirectory())
+ repoAltDirsCombined := strings.Join(repoAltDirs, string(os.PathListSeparator))
+
+ md := metadata.Pairs(
+ storagePathHeader, storagePath,
+ repoPathHeader, repoPath,
+ glRepositoryHeader, repo.GlRepository,
+ repoAltDirsHeader, repoAltDirsCombined,
+ )
+
+ // While it looks weird that we're extracting and then re-injecting the
+ // Praefect server info into the context, `PraefectFromContext()` will
+ // also resolve connection information from the context's peer info.
+ // Thus the re-injected connection info will contain resolved addresses.
+ if praefectServer, err := praefect_metadata.PraefectFromContext(ctx); err == nil {
+ ctx, err = praefectServer.Inject(ctx)
+ if err != nil {
+ return nil, err
+ }
+ } else if err != praefect_metadata.ErrPraefectServerNotFound {
+ return nil, err
+ }
+
+ // list of http/2 headers that will be forwarded as-is to gitaly-ruby
+ proxyHeaderWhitelist := []string{
+ "gitaly-servers",
+ praefect_metadata.TransactionMetadataKey,
+ praefect_metadata.PraefectMetadataKey,
+ }
+
+ if inMD, ok := metadata.FromIncomingContext(ctx); ok {
+ // Automatically whitelist any Ruby-specific feature flag
+ for header := range inMD {
+ if strings.HasPrefix(header, rubyFeaturePrefix) {
+ proxyHeaderWhitelist = append(proxyHeaderWhitelist, header)
+ }
+ }
+
+ for _, header := range proxyHeaderWhitelist {
+ for _, v := range inMD[header] {
+ md = metadata.Join(md, metadata.Pairs(header, v))
+ }
+ }
+ }
+
+ newCtx := metadata.NewOutgoingContext(ctx, md)
+ return newCtx, nil
+}
+
+// Proxy calls recvSend until it receives an error. The error is returned
+// to the caller unless it is io.EOF.
+func Proxy(recvSend func() error) (err error) {
+ for err == nil {
+ err = recvSend()
+ }
+
+ if err == io.EOF {
+ err = nil
+ }
+ return err
+}
+
+// CloseSender captures the CloseSend method from gRPC streams.
+type CloseSender interface {
+ CloseSend() error
+}
+
+// ProxyBidi works like Proxy but runs multiple callbacks simultaneously.
+// It returns immediately if proxying one of the callbacks fails. If the
+// response stream is done, ProxyBidi returns immediately without waiting
+// for the client stream to finish proxying.
+func ProxyBidi(requestFunc func() error, requestStream CloseSender, responseFunc func() error) error {
+ requestErr := make(chan error, 1)
+ go func() {
+ requestErr <- Proxy(requestFunc)
+ }()
+
+ responseErr := make(chan error, 1)
+ go func() {
+ responseErr <- Proxy(responseFunc)
+ }()
+
+ for {
+ select {
+ case err := <-requestErr:
+ if err != nil {
+ return err
+ }
+ requestStream.CloseSend()
+ case err := <-responseErr:
+ return err
+ }
+ }
+}
diff --git a/internal/gitaly/rubyserver/proxy_test.go b/internal/gitaly/rubyserver/proxy_test.go
new file mode 100644
index 000000000..7233a57d9
--- /dev/null
+++ b/internal/gitaly/rubyserver/proxy_test.go
@@ -0,0 +1,61 @@
+package rubyserver
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "google.golang.org/grpc/metadata"
+)
+
+func TestSetHeadersBlocksUnknownMetadata(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ otherKey := "unknown-key"
+ otherValue := "test-value"
+ inCtx := metadata.NewIncomingContext(ctx, metadata.Pairs(otherKey, otherValue))
+
+ outCtx, err := SetHeaders(inCtx, testRepo)
+ require.NoError(t, err)
+
+ outMd, ok := metadata.FromOutgoingContext(outCtx)
+ require.True(t, ok, "outgoing context should have metadata")
+
+ _, ok = outMd[otherKey]
+ require.False(t, ok, "outgoing MD should not contain non-whitelisted key")
+}
+
+func TestSetHeadersPreservesWhitelistedMetadata(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ key := "gitaly-servers"
+ value := "test-value"
+ inCtx := metadata.NewIncomingContext(ctx, metadata.Pairs(key, value))
+
+ outCtx, err := SetHeaders(inCtx, testRepo)
+ require.NoError(t, err)
+
+ outMd, ok := metadata.FromOutgoingContext(outCtx)
+ require.True(t, ok, "outgoing context should have metadata")
+
+ require.Equal(t, []string{value}, outMd[key], "outgoing MD should contain whitelisted key")
+}
+
+func TestRubyFeatureHeaders(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ key := "gitaly-feature-ruby-test-feature"
+ value := "true"
+ inCtx := metadata.NewIncomingContext(ctx, metadata.Pairs(key, value))
+
+ outCtx, err := SetHeaders(inCtx, testRepo)
+ require.NoError(t, err)
+
+ outMd, ok := metadata.FromOutgoingContext(outCtx)
+ require.True(t, ok, "outgoing context should have metadata")
+
+ require.Equal(t, []string{value}, outMd[key], "outgoing MD should contain whitelisted feature key")
+}
diff --git a/internal/gitaly/rubyserver/rubyserver.go b/internal/gitaly/rubyserver/rubyserver.go
new file mode 100644
index 000000000..27a02edff
--- /dev/null
+++ b/internal/gitaly/rubyserver/rubyserver.go
@@ -0,0 +1,277 @@
+package rubyserver
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "sync"
+ "time"
+
+ grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
+ grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
+ "gitlab.com/gitlab-org/gitaly/internal/command"
+ "gitlab.com/gitlab-org/gitaly/internal/git/hooks"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/config"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/rubyserver/balancer"
+ "gitlab.com/gitlab-org/gitaly/internal/gitlabshell"
+ "gitlab.com/gitlab-org/gitaly/internal/helper"
+ "gitlab.com/gitlab-org/gitaly/internal/supervisor"
+ "gitlab.com/gitlab-org/gitaly/internal/version"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "gitlab.com/gitlab-org/gitaly/streamio"
+ grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
+ grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
+ "google.golang.org/grpc"
+)
+
+var (
+ // ConnectTimeout is the timeout for establishing a connection to the gitaly-ruby process.
+ ConnectTimeout = 40 * time.Second
+)
+
+func init() {
+ timeout64, err := strconv.ParseInt(os.Getenv("GITALY_RUBY_CONNECT_TIMEOUT"), 10, 32)
+ if err == nil && timeout64 > 0 {
+ ConnectTimeout = time.Duration(timeout64) * time.Second
+ }
+}
+
+func socketPath(id int) string {
+ socketDir := config.InternalSocketDir()
+ if socketDir == "" {
+ panic("internal socket directory is missing")
+ }
+
+ return filepath.Join(socketDir, fmt.Sprintf("ruby.%d", id))
+}
+
+// Server represents a gitaly-ruby helper process.
+type Server struct {
+ startOnce sync.Once
+ startErr error
+ workers []*worker
+ clientConnMu sync.Mutex
+ clientConn *grpc.ClientConn
+}
+
+// Stop shuts down the gitaly-ruby helper process and cleans up resources.
+func (s *Server) Stop() {
+ if s != nil {
+ s.clientConnMu.Lock()
+ defer s.clientConnMu.Unlock()
+ if s.clientConn != nil {
+ s.clientConn.Close()
+ }
+
+ for _, w := range s.workers {
+ w.stopMonitor()
+ w.Process.Stop()
+ }
+ }
+}
+
+// Start spawns the Ruby server.
+func (s *Server) Start() error {
+ s.startOnce.Do(func() { s.startErr = s.start() })
+ return s.startErr
+}
+
+func (s *Server) start() error {
+ wd, err := os.Getwd()
+ if err != nil {
+ return err
+ }
+
+ cfg := config.Config
+
+ gitlabshellEnv, err := gitlabshell.Env()
+ if err != nil {
+ return err
+ }
+
+ env := append(
+ os.Environ(),
+ "GITALY_RUBY_GIT_BIN_PATH="+command.GitPath(),
+ fmt.Sprintf("GITALY_RUBY_WRITE_BUFFER_SIZE=%d", streamio.WriteBufferSize),
+ fmt.Sprintf("GITALY_RUBY_MAX_COMMIT_OR_TAG_MESSAGE_SIZE=%d", helper.MaxCommitOrTagMessageSize),
+ "GITALY_RUBY_GITALY_BIN_DIR="+cfg.BinDir,
+ "GITALY_RUBY_DIR="+cfg.Ruby.Dir,
+ "GITALY_VERSION="+version.GetVersion(),
+ "GITALY_GIT_HOOKS_DIR="+hooks.Path(),
+ "GITALY_SOCKET="+config.GitalyInternalSocketPath(),
+ "GITALY_TOKEN="+cfg.Auth.Token,
+ "GITALY_RUGGED_GIT_CONFIG_SEARCH_PATH="+cfg.Ruby.RuggedGitConfigSearchPath)
+ env = append(env, gitlabshellEnv...)
+
+ env = append(env, command.GitEnv...)
+
+ if dsn := cfg.Logging.RubySentryDSN; dsn != "" {
+ env = append(env, "SENTRY_DSN="+dsn)
+ }
+
+ if sentryEnvironment := cfg.Logging.Sentry.Environment; sentryEnvironment != "" {
+ env = append(env, "SENTRY_ENVIRONMENT="+sentryEnvironment)
+ }
+
+ gitalyRuby := path.Join(cfg.Ruby.Dir, "bin", "gitaly-ruby")
+
+ numWorkers := cfg.Ruby.NumWorkers
+ balancer.ConfigureBuilder(numWorkers, 0)
+
+ for i := 0; i < numWorkers; i++ {
+ name := fmt.Sprintf("gitaly-ruby.%d", i)
+ socketPath := socketPath(i)
+
+ // Use 'ruby-cd' to make sure gitaly-ruby has the same working directory
+ // as the current process. This is a hack to sort-of support relative
+ // Unix socket paths.
+ args := []string{"bundle", "exec", "bin/ruby-cd", wd, gitalyRuby, strconv.Itoa(os.Getpid()), socketPath}
+
+ events := make(chan supervisor.Event)
+ check := func() error { return ping(socketPath) }
+ p, err := supervisor.New(name, env, args, cfg.Ruby.Dir, cfg.Ruby.MaxRSS, events, check)
+ if err != nil {
+ return err
+ }
+
+ s.workers = append(s.workers, newWorker(p, socketPath, events, false))
+ }
+
+ return nil
+}
+
+// CommitServiceClient returns a CommitServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) CommitServiceClient(ctx context.Context) (gitalypb.CommitServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewCommitServiceClient(conn), err
+}
+
+// DiffServiceClient returns a DiffServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) DiffServiceClient(ctx context.Context) (gitalypb.DiffServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewDiffServiceClient(conn), err
+}
+
+// RefServiceClient returns a RefServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) RefServiceClient(ctx context.Context) (gitalypb.RefServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewRefServiceClient(conn), err
+}
+
+// OperationServiceClient returns a OperationServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) OperationServiceClient(ctx context.Context) (gitalypb.OperationServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewOperationServiceClient(conn), err
+}
+
+// RepositoryServiceClient returns a RefServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) RepositoryServiceClient(ctx context.Context) (gitalypb.RepositoryServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewRepositoryServiceClient(conn), err
+}
+
+// WikiServiceClient returns a WikiServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) WikiServiceClient(ctx context.Context) (gitalypb.WikiServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewWikiServiceClient(conn), err
+}
+
+// ConflictsServiceClient returns a ConflictsServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) ConflictsServiceClient(ctx context.Context) (gitalypb.ConflictsServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewConflictsServiceClient(conn), err
+}
+
+// RemoteServiceClient returns a RemoteServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) RemoteServiceClient(ctx context.Context) (gitalypb.RemoteServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewRemoteServiceClient(conn), err
+}
+
+// BlobServiceClient returns a BlobServiceClient instance that is
+// configured to connect to the running Ruby server. This assumes Start()
+// has been called already.
+func (s *Server) BlobServiceClient(ctx context.Context) (gitalypb.BlobServiceClient, error) {
+ conn, err := s.getConnection(ctx)
+ return gitalypb.NewBlobServiceClient(conn), err
+}
+
+func (s *Server) getConnection(ctx context.Context) (*grpc.ClientConn, error) {
+ s.clientConnMu.Lock()
+ conn := s.clientConn
+ s.clientConnMu.Unlock()
+
+ if conn != nil {
+ return conn, nil
+ }
+
+ return s.createConnection(ctx)
+}
+
+func (s *Server) createConnection(ctx context.Context) (*grpc.ClientConn, error) {
+ s.clientConnMu.Lock()
+ defer s.clientConnMu.Unlock()
+
+ if conn := s.clientConn; conn != nil {
+ return conn, nil
+ }
+
+ dialCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
+ defer cancel()
+
+ conn, err := grpc.DialContext(dialCtx, balancer.Scheme+":///gitaly-ruby", dialOptions()...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to connect to gitaly-ruby worker: %v", err)
+ }
+
+ s.clientConn = conn
+ return s.clientConn, nil
+}
+
+func dialOptions() []grpc.DialOption {
+ return []grpc.DialOption{
+ grpc.WithBlock(), // With this we get retries. Without, connections fail fast.
+ grpc.WithInsecure(),
+ // Use a custom dialer to ensure that we don't experience
+ // issues in environments that have proxy configurations
+ // https://gitlab.com/gitlab-org/gitaly/merge_requests/1072#note_140408512
+ grpc.WithContextDialer(func(ctx context.Context, addr string) (conn net.Conn, err error) {
+ d := net.Dialer{}
+ return d.DialContext(ctx, "unix", addr)
+ }),
+ grpc.WithUnaryInterceptor(
+ grpc_middleware.ChainUnaryClient(
+ grpc_prometheus.UnaryClientInterceptor,
+ grpctracing.UnaryClientTracingInterceptor(),
+ grpccorrelation.UnaryClientCorrelationInterceptor(),
+ ),
+ ),
+ grpc.WithStreamInterceptor(
+ grpc_middleware.ChainStreamClient(
+ grpc_prometheus.StreamClientInterceptor,
+ grpctracing.StreamClientTracingInterceptor(),
+ grpccorrelation.StreamClientCorrelationInterceptor(),
+ ),
+ ),
+ }
+}
diff --git a/internal/gitaly/rubyserver/rubyserver_test.go b/internal/gitaly/rubyserver/rubyserver_test.go
new file mode 100644
index 000000000..be3cbd477
--- /dev/null
+++ b/internal/gitaly/rubyserver/rubyserver_test.go
@@ -0,0 +1,85 @@
+package rubyserver
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc/codes"
+)
+
+func TestStopSafe(t *testing.T) {
+ badServers := []*Server{
+ nil,
+ &Server{},
+ }
+
+ for _, bs := range badServers {
+ bs.Stop()
+ }
+}
+
+func TestSetHeaders(t *testing.T) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ testCases := []struct {
+ desc string
+ repo *gitalypb.Repository
+ errType codes.Code
+ setter func(context.Context, *gitalypb.Repository) (context.Context, error)
+ }{
+ {
+ desc: "SetHeaders invalid storage",
+ repo: &gitalypb.Repository{StorageName: "foo", RelativePath: "bar.git"},
+ errType: codes.InvalidArgument,
+ setter: SetHeaders,
+ },
+ {
+ desc: "SetHeaders invalid rel path",
+ repo: &gitalypb.Repository{StorageName: testRepo.StorageName, RelativePath: "bar.git"},
+ errType: codes.NotFound,
+ setter: SetHeaders,
+ },
+ {
+ desc: "SetHeaders OK",
+ repo: testRepo,
+ errType: codes.OK,
+ setter: SetHeaders,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck invalid storage",
+ repo: &gitalypb.Repository{StorageName: "foo", RelativePath: "bar.git"},
+ errType: codes.InvalidArgument,
+ setter: SetHeadersWithoutRepoCheck,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck invalid relative path",
+ repo: &gitalypb.Repository{StorageName: testRepo.StorageName, RelativePath: "bar.git"},
+ errType: codes.OK,
+ setter: SetHeadersWithoutRepoCheck,
+ },
+ {
+ desc: "SetHeadersWithoutRepoCheck OK",
+ repo: testRepo,
+ errType: codes.OK,
+ setter: SetHeadersWithoutRepoCheck,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ clientCtx, err := tc.setter(ctx, tc.repo)
+
+ if tc.errType != codes.OK {
+ testhelper.RequireGrpcError(t, err, tc.errType)
+ assert.Nil(t, clientCtx)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, clientCtx)
+ }
+ })
+ }
+}
diff --git a/internal/gitaly/rubyserver/stopwatch.go b/internal/gitaly/rubyserver/stopwatch.go
new file mode 100644
index 000000000..b082cb155
--- /dev/null
+++ b/internal/gitaly/rubyserver/stopwatch.go
@@ -0,0 +1,35 @@
+package rubyserver
+
+import (
+ "time"
+)
+
+type stopwatch struct {
+ t1 time.Time
+ t2 time.Time
+ running bool
+}
+
+// mark records the current time and starts the stopwatch if it is not already running
+func (st *stopwatch) mark() {
+ st.t2 = time.Now()
+
+ if !st.running {
+ st.t1 = st.t2
+ st.running = true
+ }
+}
+
+// reset stops the stopwatch and returns it to zero
+func (st *stopwatch) reset() {
+ st.running = false
+}
+
+// elapsed returns the time elapsed between the first and last 'mark'
+func (st *stopwatch) elapsed() time.Duration {
+ if !st.running {
+ return time.Duration(0)
+ }
+
+ return st.t2.Sub(st.t1)
+}
diff --git a/internal/gitaly/rubyserver/testhelper_test.go b/internal/gitaly/rubyserver/testhelper_test.go
new file mode 100644
index 000000000..66575e48e
--- /dev/null
+++ b/internal/gitaly/rubyserver/testhelper_test.go
@@ -0,0 +1,26 @@
+package rubyserver
+
+import (
+ "os"
+ "testing"
+
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+)
+
+var (
+ testRepo *gitalypb.Repository
+)
+
+func TestMain(m *testing.M) {
+ testhelper.Configure()
+ os.Exit(testMain(m))
+}
+
+func testMain(m *testing.M) int {
+ defer testhelper.MustHaveNoChildProcess()
+
+ testRepo = testhelper.TestRepository()
+
+ return m.Run()
+}
diff --git a/internal/gitaly/rubyserver/worker.go b/internal/gitaly/rubyserver/worker.go
new file mode 100644
index 000000000..9fc13f2cd
--- /dev/null
+++ b/internal/gitaly/rubyserver/worker.go
@@ -0,0 +1,226 @@
+package rubyserver
+
+import (
+ "fmt"
+ "syscall"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus"
+ log "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/config"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/rubyserver/balancer"
+ "gitlab.com/gitlab-org/gitaly/internal/supervisor"
+)
+
+var (
+ terminationCounter = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitaly_ruby_memory_terminations_total",
+ Help: "Number of times gitaly-ruby has been terminated because of excessive memory use.",
+ },
+ []string{"name"},
+ )
+)
+
+func init() {
+ prometheus.MustRegister(terminationCounter)
+}
+
+// worker observes the event stream of a supervised process and restarts
+// it if necessary, in cooperation with the balancer.
+type worker struct {
+ *supervisor.Process
+ address string
+ events <-chan supervisor.Event
+ shutdown chan struct{}
+ monitorDone chan struct{}
+
+ // This is for testing only, so that we can inject a fake balancer
+ balancerUpdate chan balancerProxy
+
+ testing bool
+}
+
+func newWorker(p *supervisor.Process, address string, events <-chan supervisor.Event, testing bool) *worker {
+ w := &worker{
+ Process: p,
+ address: address,
+ events: events,
+ shutdown: make(chan struct{}),
+ monitorDone: make(chan struct{}),
+ balancerUpdate: make(chan balancerProxy),
+ testing: testing,
+ }
+ go w.monitor()
+
+ bal := defaultBalancer{}
+ w.balancerUpdate <- bal
+
+ // When we return from this function, requests may start coming in. If
+ // there are no addresses in the balancer when the first request comes in
+ // we can get a panic from grpc-go. So before returning, we ensure the
+ // current address has been added to the balancer.
+ bal.AddAddress(w.address)
+
+ return w
+}
+
+type balancerProxy interface {
+ AddAddress(string)
+ RemoveAddress(string) bool
+}
+
+type defaultBalancer struct{}
+
+func (defaultBalancer) AddAddress(s string) { balancer.AddAddress(s) }
+func (defaultBalancer) RemoveAddress(s string) bool { return balancer.RemoveAddress(s) }
+
+var (
+ // Ignore health checks for the current process after it just restarted
+ healthRestartCoolOff = 5 * time.Minute
+ // Health considered bad after sustained failed health checks
+ healthRestartDelay = 1 * time.Minute
+)
+
+func (w *worker) monitor() {
+ swMem := &stopwatch{}
+ swHealth := &stopwatch{}
+ lastRestart := time.Now()
+ currentPid := 0
+ bal := <-w.balancerUpdate
+
+ for {
+ nextEvent:
+ select {
+ case e := <-w.events:
+ switch e.Type {
+ case supervisor.Up:
+ if badPid(e.Pid) {
+ w.logBadEvent(e)
+ break nextEvent
+ }
+
+ if e.Pid == currentPid {
+ // Ignore repeated events to avoid constantly resetting our internal
+ // state.
+ break nextEvent
+ }
+
+ bal.AddAddress(w.address)
+ currentPid = e.Pid
+
+ swMem.reset()
+ swHealth.reset()
+ lastRestart = time.Now()
+ case supervisor.MemoryHigh:
+ if badPid(e.Pid) {
+ w.logBadEvent(e)
+ break nextEvent
+ }
+
+ if e.Pid != currentPid {
+ break nextEvent
+ }
+
+ swMem.mark()
+ if swMem.elapsed() <= config.Config.Ruby.RestartDelay.Duration() {
+ break nextEvent
+ }
+
+ // It is crucial to check the return value of RemoveAddress. If we don't
+ // we may leave the system without the capacity to make gitaly-ruby
+ // requests.
+ if bal.RemoveAddress(w.address) {
+ w.logPid(currentPid).Info("removed gitaly-ruby worker from balancer due to high memory")
+ go w.waitTerminate(currentPid)
+ swMem.reset()
+ }
+ case supervisor.MemoryLow:
+ if badPid(e.Pid) {
+ w.logBadEvent(e)
+ break nextEvent
+ }
+
+ if e.Pid != currentPid {
+ break nextEvent
+ }
+
+ swMem.reset()
+ case supervisor.HealthOK:
+ swHealth.reset()
+ case supervisor.HealthBad:
+ if time.Since(lastRestart) <= healthRestartCoolOff {
+ // Ignore health checks for a while after the supervised process restarted
+ break nextEvent
+ }
+
+ w.log().WithError(e.Error).Warn("gitaly-ruby worker health check failed")
+
+ swHealth.mark()
+ if swHealth.elapsed() <= healthRestartDelay {
+ break nextEvent
+ }
+
+ if bal.RemoveAddress(w.address) {
+ w.logPid(currentPid).Info("removed gitaly-ruby worker from balancer due to sustained failing health checks")
+ go w.waitTerminate(currentPid)
+ swHealth.reset()
+ }
+ default:
+ panic(fmt.Sprintf("unknown state %v", e.Type))
+ }
+ case bal = <-w.balancerUpdate:
+ // For testing only.
+ case <-w.shutdown:
+ close(w.monitorDone)
+ return
+ }
+ }
+}
+
+func (w *worker) stopMonitor() {
+ close(w.shutdown)
+ <-w.monitorDone
+}
+
+func badPid(pid int) bool {
+ return pid <= 0
+}
+
+func (w *worker) log() *log.Entry {
+ return log.WithFields(log.Fields{
+ "worker.name": w.Name,
+ })
+}
+
+func (w *worker) logPid(pid int) *log.Entry {
+ return w.log().WithFields(log.Fields{
+ "worker.pid": pid,
+ })
+}
+
+func (w *worker) logBadEvent(e supervisor.Event) {
+ w.log().WithFields(log.Fields{
+ "worker.event": e,
+ }).Error("monitor state machine received bad event")
+}
+
+func (w *worker) waitTerminate(pid int) {
+ if w.testing {
+ return
+ }
+
+ // Wait for in-flight requests to reach the worker before we slam the
+ // door in their face.
+ time.Sleep(1 * time.Minute)
+
+ terminationCounter.WithLabelValues(w.Name).Inc()
+
+ w.logPid(pid).Info("sending SIGTERM")
+ syscall.Kill(pid, syscall.SIGTERM)
+
+ time.Sleep(config.Config.Ruby.GracefulRestartTimeout.Duration())
+
+ w.logPid(pid).Info("sending SIGKILL")
+ syscall.Kill(pid, syscall.SIGKILL)
+}
diff --git a/internal/gitaly/rubyserver/worker_test.go b/internal/gitaly/rubyserver/worker_test.go
new file mode 100644
index 000000000..c043b7709
--- /dev/null
+++ b/internal/gitaly/rubyserver/worker_test.go
@@ -0,0 +1,239 @@
+package rubyserver
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/gitaly/config"
+ "gitlab.com/gitlab-org/gitaly/internal/supervisor"
+)
+
+func TestWorker(t *testing.T) {
+ restartDelay := 10 * time.Millisecond
+
+ defer func(old time.Duration) {
+ config.Config.Ruby.RestartDelay = config.Duration(old)
+ }(config.Config.Ruby.RestartDelay.Duration())
+ config.Config.Ruby.RestartDelay = config.Duration(restartDelay)
+
+ events := make(chan supervisor.Event)
+ addr := "the address"
+ w := newWorker(&supervisor.Process{Name: "testing"}, addr, events, true)
+ defer w.stopMonitor()
+
+ t.Log("ignore health failures during startup")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+
+ firstPid := 123
+
+ t.Log("register first PID as 'up'")
+ mustAdd(t, w, addr, func() { events <- upEvent(firstPid) })
+
+ t.Log("ignore repeated up event")
+ mustIgnore(t, w, func() { events <- upEvent(firstPid) })
+
+ t.Log("send mem high events but too fast to trigger restart")
+ for i := 0; i < 5; i++ {
+ mustIgnore(t, w, func() { events <- memHighEvent(firstPid) })
+ }
+
+ t.Log("mem low resets mem high counter")
+ mustIgnore(t, w, func() { events <- memLowEvent(firstPid) })
+
+ t.Log("send mem high events but too fast to trigger restart")
+ for i := 0; i < 5; i++ {
+ mustIgnore(t, w, func() { events <- memHighEvent(firstPid) })
+ }
+
+ time.Sleep(2 * restartDelay)
+ t.Log("this mem high should push us over the threshold")
+ mustRemove(t, w, addr, func() { events <- memHighEvent(firstPid) })
+
+ t.Log("ignore health failures during startup")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+
+ secondPid := 456
+ t.Log("registering a new PID")
+ mustAdd(t, w, addr, func() { events <- upEvent(secondPid) })
+
+ t.Log("ignore mem high events for the previous pid")
+ mustIgnore(t, w, func() { events <- memHighEvent(firstPid) })
+ time.Sleep(2 * restartDelay)
+ t.Log("ignore mem high also after restart delay has expired")
+ mustIgnore(t, w, func() { events <- memHighEvent(firstPid) })
+
+ t.Log("start high memory timer")
+ mustIgnore(t, w, func() { events <- memHighEvent(secondPid) })
+
+ t.Log("ignore mem low event for wrong pid")
+ mustIgnore(t, w, func() { events <- memLowEvent(firstPid) })
+
+ t.Log("send mem high count over the threshold")
+ time.Sleep(2 * restartDelay)
+ mustRemove(t, w, addr, func() { events <- memHighEvent(secondPid) })
+}
+
+func TestWorkerHealthChecks(t *testing.T) {
+ restartDelay := 10 * time.Millisecond
+
+ defer func(old time.Duration) {
+ healthRestartDelay = old
+ }(healthRestartDelay)
+ healthRestartDelay = restartDelay
+
+ defer func(old time.Duration) {
+ healthRestartCoolOff = old
+ }(healthRestartCoolOff)
+ healthRestartCoolOff = restartDelay
+
+ events := make(chan supervisor.Event)
+ addr := "the address"
+ w := newWorker(&supervisor.Process{Name: "testing"}, addr, events, true)
+ defer w.stopMonitor()
+
+ t.Log("ignore health failures during startup")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+
+ firstPid := 123
+
+ t.Log("register first PID as 'up'")
+ mustAdd(t, w, addr, func() { events <- upEvent(firstPid) })
+
+ t.Log("still ignore health failures during startup")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+
+ time.Sleep(2 * restartDelay)
+
+ t.Log("waited long enough, this health check should start health timer")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+
+ time.Sleep(2 * restartDelay)
+
+ t.Log("this second failed health check should trigger failover")
+ mustRemove(t, w, addr, func() { events <- healthBadEvent() })
+
+ t.Log("ignore extra health failures")
+ mustIgnore(t, w, func() { events <- healthBadEvent() })
+}
+
+func mustIgnore(t *testing.T, w *worker, f func()) {
+ nothing := &nothingBalancer{t}
+ w.balancerUpdate <- nothing
+ t.Log("executing function that should be ignored by balancer")
+ f()
+ // This second balancer update is used to synchronize with the monitor
+ // goroutine. When the channel send finishes, we know the event we sent
+ // before must have been processed.
+ w.balancerUpdate <- nothing
+}
+
+func mustAdd(t *testing.T, w *worker, addr string, f func()) {
+ add := newAdd(t, addr)
+ w.balancerUpdate <- add
+ t.Log("executing function that should lead to balancer.AddAddress")
+ f()
+ add.wait()
+}
+
+func mustRemove(t *testing.T, w *worker, addr string, f func()) {
+ remove := newRemove(t, addr)
+ w.balancerUpdate <- remove
+ t.Log("executing function that should lead to balancer.RemoveAddress")
+ f()
+ remove.wait()
+}
+
+func waitFail(t *testing.T, done chan struct{}) {
+ select {
+ case <-time.After(1 * time.Second):
+ t.Fatal("timeout waiting for balancer method call")
+ case <-done:
+ }
+}
+
+func upEvent(pid int) supervisor.Event {
+ return supervisor.Event{Type: supervisor.Up, Pid: pid}
+}
+
+func memHighEvent(pid int) supervisor.Event {
+ return supervisor.Event{Type: supervisor.MemoryHigh, Pid: pid}
+}
+
+func memLowEvent(pid int) supervisor.Event {
+ return supervisor.Event{Type: supervisor.MemoryLow, Pid: pid}
+}
+
+func healthBadEvent() supervisor.Event {
+ return supervisor.Event{Type: supervisor.HealthBad, Error: errors.New("test bad health")}
+}
+
+func newAdd(t *testing.T, addr string) *addBalancer {
+ return &addBalancer{
+ t: t,
+ addr: addr,
+ done: make(chan struct{}),
+ }
+}
+
+type addBalancer struct {
+ addr string
+ t *testing.T
+ done chan struct{}
+}
+
+func (ab *addBalancer) RemoveAddress(string) bool {
+ ab.t.Fatal("unexpected RemoveAddress call")
+ return false
+}
+
+func (ab *addBalancer) AddAddress(s string) {
+ require.Equal(ab.t, ab.addr, s, "addBalancer expected AddAddress argument")
+ close(ab.done)
+}
+
+func (ab *addBalancer) wait() {
+ waitFail(ab.t, ab.done)
+}
+
+func newRemove(t *testing.T, addr string) *removeBalancer {
+ return &removeBalancer{
+ t: t,
+ addr: addr,
+ done: make(chan struct{}),
+ }
+}
+
+type removeBalancer struct {
+ addr string
+ t *testing.T
+ done chan struct{}
+}
+
+func (rb *removeBalancer) RemoveAddress(s string) bool {
+ require.Equal(rb.t, rb.addr, s, "removeBalancer expected RemoveAddress argument")
+ close(rb.done)
+ return true
+}
+
+func (rb *removeBalancer) AddAddress(s string) {
+ rb.t.Fatal("unexpected AddAddress call")
+}
+
+func (rb *removeBalancer) wait() {
+ waitFail(rb.t, rb.done)
+}
+
+type nothingBalancer struct {
+ t *testing.T
+}
+
+func (nb *nothingBalancer) RemoveAddress(s string) bool {
+ nb.t.Fatal("unexpected RemoveAddress call")
+ return true
+}
+
+func (nb *nothingBalancer) AddAddress(s string) {
+ nb.t.Fatal("unexpected AddAddress call")
+}