Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToon Claes <toon@gitlab.com>2022-06-23 15:55:32 +0300
committerToon Claes <toon@gitlab.com>2022-06-23 15:55:32 +0300
commitc4bb98b282fe8e9e5434b5d20debb8fbf278f1c2 (patch)
tree50da8a0d46f85b10589f50a9391ea7f91e918660
parentf8c0438541950ab8a554240b38dc61b006607d74 (diff)
parentd270a521ec5e39f336bd269def7aaf74d9f9ce9c (diff)
Merge branch 'pks-ssh-tests-emulate-client-side' into 'master'
ssh: Extend tests to emulate client-side failure conditions with sidechannel See merge request gitlab-org/gitaly!4615
-rw-r--r--cmd/gitaly-lfs-smudge/main_test.go8
-rw-r--r--internal/cgroups/cgroups.go4
-rw-r--r--internal/cgroups/noop.go4
-rw-r--r--internal/cgroups/v1_linux.go6
-rw-r--r--internal/cgroups/v1_linux_test.go25
-rw-r--r--internal/command/command.go355
-rw-r--r--internal/command/command_test.go336
-rw-r--r--internal/command/option.go94
-rwxr-xr-xinternal/command/testdata/stderr_binary_null.sh4
-rwxr-xr-xinternal/command/testdata/stderr_many_lines.sh9
-rwxr-xr-xinternal/command/testdata/stderr_max_bytes_edge_case.sh20
-rwxr-xr-xinternal/command/testdata/stderr_repeat_a.sh6
-rwxr-xr-xinternal/command/testdata/stderr_script.sh7
-rw-r--r--internal/command/testhelper_test.go11
-rw-r--r--internal/git/catfile/object_info_reader.go2
-rw-r--r--internal/git/catfile/object_reader.go2
-rw-r--r--internal/git/command_factory.go18
-rw-r--r--internal/git/command_factory_cgroup_test.go4
-rw-r--r--internal/git/command_options.go28
-rw-r--r--internal/git/objectpool/fetch.go2
-rw-r--r--internal/git/packfile/index.go2
-rw-r--r--internal/git/pktline/read_monitor.go27
-rw-r--r--internal/git/pktline/read_monitor_test.go7
-rw-r--r--internal/git/updateref/updateref.go2
-rw-r--r--internal/git2go/executor.go10
-rw-r--r--internal/gitaly/hook/custom.go10
-rw-r--r--internal/gitaly/linguist/linguist.go8
-rw-r--r--internal/gitaly/service/repository/archive.go4
-rw-r--r--internal/gitaly/service/repository/backup_custom_hooks.go2
-rw-r--r--internal/gitaly/service/repository/create_repository_from_snapshot.go2
-rw-r--r--internal/gitaly/service/repository/replicate.go4
-rw-r--r--internal/gitaly/service/repository/restore_custom_hooks.go4
-rw-r--r--internal/gitaly/service/repository/size.go2
-rw-r--r--internal/gitaly/service/ssh/monitor_stdin_command.go8
-rw-r--r--internal/gitaly/service/ssh/upload_pack.go14
-rw-r--r--internal/gitaly/service/ssh/upload_pack_test.go233
36 files changed, 828 insertions, 456 deletions
diff --git a/cmd/gitaly-lfs-smudge/main_test.go b/cmd/gitaly-lfs-smudge/main_test.go
index 37814a34b..6154bcd25 100644
--- a/cmd/gitaly-lfs-smudge/main_test.go
+++ b/cmd/gitaly-lfs-smudge/main_test.go
@@ -184,10 +184,10 @@ func TestGitalyLFSSmudge(t *testing.T) {
var stdout, stderr bytes.Buffer
cmd, err := command.New(ctx,
exec.Command(binary),
- tc.stdin,
- &stdout,
- &stderr,
- env...,
+ command.WithStdin(tc.stdin),
+ command.WithStdout(&stdout),
+ command.WithStderr(&stderr),
+ command.WithEnvironment(env),
)
require.NoError(t, err)
diff --git a/internal/cgroups/cgroups.go b/internal/cgroups/cgroups.go
index c3e148bba..b2209b509 100644
--- a/internal/cgroups/cgroups.go
+++ b/internal/cgroups/cgroups.go
@@ -13,8 +13,8 @@ type Manager interface {
// It is expected to be called once at Gitaly startup from any
// instance of the Manager.
Setup() error
- // AddCommand adds a Command to a cgroup
- AddCommand(*command.Command, repository.GitRepo) error
+ // AddCommand adds a Command to a cgroup.
+ AddCommand(*command.Command, repository.GitRepo) (string, error)
// Cleanup cleans up cgroups created in Setup.
// It is expected to be called once at Gitaly shutdown from any
// instance of the Manager.
diff --git a/internal/cgroups/noop.go b/internal/cgroups/noop.go
index 399bd4931..feaa0d6ef 100644
--- a/internal/cgroups/noop.go
+++ b/internal/cgroups/noop.go
@@ -15,8 +15,8 @@ func (cg *NoopManager) Setup() error {
}
//nolint: revive,stylecheck // This is unintentionally missing documentation.
-func (cg *NoopManager) AddCommand(cmd *command.Command, repo repository.GitRepo) error {
- return nil
+func (cg *NoopManager) AddCommand(cmd *command.Command, repo repository.GitRepo) (string, error) {
+ return "", nil
}
//nolint: revive,stylecheck // This is unintentionally missing documentation.
diff --git a/internal/cgroups/v1_linux.go b/internal/cgroups/v1_linux.go
index e4b72f28b..3fa5da4c3 100644
--- a/internal/cgroups/v1_linux.go
+++ b/internal/cgroups/v1_linux.go
@@ -102,7 +102,7 @@ func (cg *CGroupV1Manager) Setup() error {
func (cg *CGroupV1Manager) AddCommand(
cmd *command.Command,
repo repository.GitRepo,
-) error {
+) (string, error) {
var key string
if repo == nil {
key = strings.Join(cmd.Args(), "/")
@@ -117,9 +117,7 @@ func (cg *CGroupV1Manager) AddCommand(
groupID := uint(checksum) % cg.cfg.Repositories.Count
cgroupPath := cg.repoPath(int(groupID))
- cmd.SetCgroupPath(cgroupPath)
-
- return cg.addToCgroup(cmd.Pid(), cgroupPath)
+ return cgroupPath, cg.addToCgroup(cmd.Pid(), cgroupPath)
}
func (cg *CGroupV1Manager) addToCgroup(pid int, cgroupPath string) error {
diff --git a/internal/cgroups/v1_linux_test.go b/internal/cgroups/v1_linux_test.go
index 844b06225..aaa095db0 100644
--- a/internal/cgroups/v1_linux_test.go
+++ b/internal/cgroups/v1_linux_test.go
@@ -81,7 +81,7 @@ func TestAddCommand(t *testing.T) {
ctx := testhelper.Context(t)
cmd1 := exec.Command("ls", "-hal", ".")
- cmd2, err := command.New(ctx, cmd1, nil, nil, nil)
+ cmd2, err := command.New(ctx, cmd1)
require.NoError(t, err)
require.NoError(t, cmd2.Wait())
@@ -91,7 +91,8 @@ func TestAddCommand(t *testing.T) {
}
t.Run("without a repository", func(t *testing.T) {
- require.NoError(t, v1Manager2.AddCommand(cmd2, nil))
+ _, err := v1Manager2.AddCommand(cmd2, nil)
+ require.NoError(t, err)
checksum := crc32.ChecksumIEEE([]byte(strings.Join(cmd2.Args(), "/")))
groupID := uint(checksum) % config.Repositories.Count
@@ -109,7 +110,8 @@ func TestAddCommand(t *testing.T) {
})
t.Run("with a repository", func(t *testing.T) {
- require.NoError(t, v1Manager2.AddCommand(cmd2, repo))
+ _, err := v1Manager2.AddCommand(cmd2, repo)
+ require.NoError(t, err)
checksum := crc32.ChecksumIEEE([]byte(strings.Join([]string{
"default",
@@ -150,6 +152,8 @@ func TestCleanup(t *testing.T) {
}
func TestMetrics(t *testing.T) {
+ t.Parallel()
+
mock := newMock(t)
repo := &gitalypb.Repository{
StorageName: "default",
@@ -167,22 +171,23 @@ func TestMetrics(t *testing.T) {
mock.setupMockCgroupFiles(t, v1Manager1, 2)
require.NoError(t, v1Manager1.Setup())
- ctx := testhelper.Context(t)
+ ctx := testhelper.Context(t)
logger, hook := test.NewNullLogger()
logger.SetLevel(logrus.DebugLevel)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo))
require.NoError(t, err)
- gitCmd1, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil)
+ gitCmd1, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo))
require.NoError(t, err)
- gitCmd2, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil)
+ gitCmd2, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo))
require.NoError(t, err)
+ defer func() {
+ require.NoError(t, gitCmd2.Wait())
+ }()
- require.NoError(t, v1Manager1.AddCommand(cmd, repo))
- require.NoError(t, v1Manager1.AddCommand(gitCmd1, repo))
- require.NoError(t, v1Manager1.AddCommand(gitCmd2, repo))
+ require.NoError(t, err)
require.NoError(t, cmd.Wait())
require.NoError(t, gitCmd1.Wait())
diff --git a/internal/command/command.go b/internal/command/command.go
index 5602ee24e..abf008f3d 100644
--- a/internal/command/command.go
+++ b/internal/command/command.go
@@ -20,6 +20,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/gitaly/v15/internal/command/commandcounter"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag"
"gitlab.com/gitlab-org/labkit/tracing"
)
@@ -73,40 +74,43 @@ var (
},
[]string{"grpc_service", "grpc_method", "cmd"},
)
-)
-// exportedEnvVars contains a list of environment variables
-// that are always exported to child processes on spawn
-var exportedEnvVars = []string{
- "HOME",
- "PATH",
- "LD_LIBRARY_PATH",
- "TZ",
-
- // Export git tracing variables for easier debugging
- "GIT_TRACE",
- "GIT_TRACE_PACK_ACCESS",
- "GIT_TRACE_PACKET",
- "GIT_TRACE_PERFORMANCE",
- "GIT_TRACE_SETUP",
-
- // GIT_EXEC_PATH tells Git where to find its binaries. This must be exported especially in
- // the case where we use bundled Git executables given that we cannot rely on a complete Git
- // installation in that case.
- "GIT_EXEC_PATH",
-
- // Git HTTP proxy settings: https://git-scm.com/docs/git-config#git-config-httpproxy
- "all_proxy",
- "http_proxy",
- "HTTP_PROXY",
- "https_proxy",
- "HTTPS_PROXY",
- // libcurl settings: https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html
- "no_proxy",
- "NO_PROXY",
-}
+ // exportedEnvVars contains a list of environment variables
+ // that are always exported to child processes on spawn
+ exportedEnvVars = []string{
+ "HOME",
+ "PATH",
+ "LD_LIBRARY_PATH",
+ "TZ",
+
+ // Export git tracing variables for easier debugging
+ "GIT_TRACE",
+ "GIT_TRACE_PACK_ACCESS",
+ "GIT_TRACE_PACKET",
+ "GIT_TRACE_PERFORMANCE",
+ "GIT_TRACE_SETUP",
+
+ // GIT_EXEC_PATH tells Git where to find its binaries. This must be exported
+ // especially in the case where we use bundled Git executables given that we cannot
+ // rely on a complete Git installation in that case.
+ "GIT_EXEC_PATH",
+
+ // Git HTTP proxy settings:
+ // https://git-scm.com/docs/git-config#git-config-httpproxy
+ "all_proxy",
+ "http_proxy",
+ "HTTP_PROXY",
+ "https_proxy",
+ "HTTPS_PROXY",
+ // libcurl settings: https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html
+ "no_proxy",
+ "NO_PROXY",
+ }
-var envInjector = tracing.NewEnvInjector()
+ // envInjector is responsible for injecting environment variables required for tracing into
+ // the child process.
+ envInjector = tracing.NewEnvInjector()
+)
const (
// maxStderrBytes is at most how many bytes will be written to stderr
@@ -130,6 +134,8 @@ type Command struct {
waitError error
waitOnce sync.Once
+ finalizer func(*Command)
+
span opentracing.Span
metricsCmd string
@@ -137,75 +143,9 @@ type Command struct {
cgroupPath string
}
-type stdinSentinel struct{}
-
-func (stdinSentinel) Read([]byte) (int, error) {
- return 0, errors.New("stdin sentinel should not be read from")
-}
-
-// SetupStdin instructs New() to configure the stdin pipe of the command it is
-// creating. This allows you call Write() on the command as if it is an ordinary
-// io.Writer, sending data directly to the stdin of the process.
-//
-// You should not call Read() on this value - it is strictly for configuration!
-var SetupStdin io.Reader = stdinSentinel{}
-
-// Read calls Read() on the stdout pipe of the command.
-func (c *Command) Read(p []byte) (int, error) {
- if c.reader == nil {
- panic("command has no reader")
- }
-
- return c.reader.Read(p)
-}
-
-// Write calls Write() on the stdin pipe of the command.
-func (c *Command) Write(p []byte) (int, error) {
- if c.writer == nil {
- panic("command has no writer")
- }
-
- return c.writer.Write(p)
-}
-
-// Wait calls Wait() on the exec.Cmd instance inside the command. This
-// blocks until the command has finished and reports the command exit
-// status via the error return value. Use ExitStatus to get the integer
-// exit status from the error returned by Wait().
-func (c *Command) Wait() error {
- c.waitOnce.Do(c.wait)
-
- return c.waitError
-}
-
-// SetCgroupPath sets the cgroup path for logging
-func (c *Command) SetCgroupPath(path string) {
- c.cgroupPath = path
-}
-
-// SetMetricsCmd overrides the "cmd" label used in metrics
-func (c *Command) SetMetricsCmd(metricsCmd string) {
- c.metricsCmd = metricsCmd
-}
-
-// SetMetricsSubCmd sets the "subcmd" label used in metrics
-func (c *Command) SetMetricsSubCmd(metricsSubCmd string) {
- c.metricsSubCmd = metricsSubCmd
-}
-
-type contextWithoutDonePanic string
-
-var getSpawnTokenAcquiringSeconds = func(t time.Time) float64 {
- return time.Since(t).Seconds()
-}
-
-// New creates a Command from an exec.Cmd. On success, the Command
-// contains a running subprocess. When ctx is canceled the embedded
-// process will be terminated and reaped automatically.
-//
-// If stdin is specified as SetupStdin, you will be able to write to the stdin
-// of the subprocess by calling Write() on the returned Command.
-func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.Writer, env ...string) (*Command, error) {
+// New creates a Command from an exec.Cmd. On success, the Command contains a running subprocess.
+// When ctx is canceled the embedded process will be terminated and reaped automatically.
+func New(ctx context.Context, cmd *exec.Cmd, opts ...Option) (*Command, error) {
if ctx.Done() == nil {
panic(contextWithoutDonePanic("command spawned with context without Done() channel"))
}
@@ -214,6 +154,11 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.
return nil, err
}
+ var cfg config
+ for _, opt := range opts {
+ opt(&cfg)
+ }
+
span, ctx := opentracing.StartSpanFromContext(
ctx,
cmd.Path,
@@ -243,16 +188,19 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.
}()
command := &Command{
- cmd: cmd,
- startTime: time.Now(),
- context: ctx,
- span: span,
+ cmd: cmd,
+ startTime: time.Now(),
+ context: ctx,
+ span: span,
+ finalizer: cfg.finalizer,
+ metricsCmd: cfg.commandName,
+ metricsSubCmd: cfg.subcommandName,
}
// Export allowed environment variables as set in the Gitaly process.
cmd.Env = AllowedEnvironment(os.Environ())
// Append environment variables explicitly requested by the caller.
- cmd.Env = append(cmd.Env, env...)
+ cmd.Env = append(cmd.Env, cfg.environment...)
// And finally inject environment variables required for tracing into the command.
cmd.Env = envInjector(ctx, cmd.Env)
@@ -261,83 +209,112 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.
// Three possible values for stdin:
// * nil - Go implicitly uses /dev/null
- // * SetupStdin - configure with cmd.StdinPipe(), allowing Write() to work
+ // * stdinSentinel - configure with cmd.StdinPipe(), allowing Write() to work
// * Another io.Reader - becomes cmd.Stdin. Write() will not work
- if stdin == SetupStdin {
+ if _, ok := cfg.stdin.(stdinSentinel); ok {
pipe, err := cmd.StdinPipe()
if err != nil {
- return nil, fmt.Errorf("GitCommand: stdin: %v", err)
+ return nil, fmt.Errorf("creating stdin pipe: %w", err)
}
+
command.writer = pipe
- } else if stdin != nil {
- cmd.Stdin = stdin
+ } else if cfg.stdin != nil {
+ cmd.Stdin = cfg.stdin
}
- if stdout != nil {
+ if cfg.stdout != nil {
// We don't assign a reader if an stdout override was passed. We assume
// output is going to be directly handled by the caller.
- cmd.Stdout = stdout
+ cmd.Stdout = cfg.stdout
} else {
pipe, err := cmd.StdoutPipe()
if err != nil {
- return nil, fmt.Errorf("GitCommand: stdout: %v", err)
+ return nil, fmt.Errorf("creating stdout pipe: %w", err)
}
+
command.reader = pipe
}
- if stderr != nil {
- cmd.Stderr = stderr
+ if cfg.stderr != nil {
+ cmd.Stderr = cfg.stderr
} else {
command.stderrBuffer, err = newStderrBuffer(maxStderrBytes, maxStderrLineLength, []byte("\n"))
if err != nil {
- return nil, fmt.Errorf("GitCommand: failed to create stderr buffer: %v", err)
+ return nil, fmt.Errorf("creating stderr buffer: %w", err)
}
+
cmd.Stderr = command.stderrBuffer
}
if err := cmd.Start(); err != nil {
- return nil, fmt.Errorf("GitCommand: start %v: %v", cmd.Args, err)
+ return nil, fmt.Errorf("starting process %v: %w", cmd.Args, err)
}
- inFlightCommandGauge.Inc()
- // The goroutine below is responsible for terminating and reaping the
- // process when ctx is canceled.
+ inFlightCommandGauge.Inc()
commandcounter.Increment()
- go func() {
- <-ctx.Done()
- if process := cmd.Process; process != nil && process.Pid > 0 {
- //nolint:errcheck // TODO: do we want to report errors?
- // Send SIGTERM to the process group of cmd
- syscall.Kill(-process.Pid, syscall.SIGTERM)
- }
+ // The goroutine below is responsible for terminating and reaping the process when ctx is
+ // canceled. While we must ensure that it does run when `cmd.Start()` was successful, it
+ // must not run before have fully set up the command. Otherwise, we may end up with racy
+ // access patterns when the context gets terminated early.
+ //
+ // We thus defer spawning the Goroutine.
+ defer func() {
+ go func() {
+ <-ctx.Done()
+
+ if process := cmd.Process; process != nil && process.Pid > 0 {
+ //nolint:errcheck // TODO: do we want to report errors?
+ // Send SIGTERM to the process group of cmd
+ syscall.Kill(-process.Pid, syscall.SIGTERM)
+ }
- // We do not care for any potential erorr code, but just want to make sure that the
- // subprocess gets properly killed and processed.
- _ = command.Wait()
+ // We do not care for any potential error code, but just want to make sure that the
+ // subprocess gets properly killed and processed.
+ _ = command.Wait()
+ }()
}()
+ if featureflag.RunCommandsInCGroup.IsEnabled(ctx) && cfg.cgroupsManager != nil {
+ cgroupPath, err := cfg.cgroupsManager.AddCommand(command, cfg.cgroupsRepo)
+ if err != nil {
+ return nil, err
+ }
+
+ command.cgroupPath = cgroupPath
+ }
+
logPid = cmd.Process.Pid
return command, nil
}
-// AllowedEnvironment filters the given slice of environment variables and
-// returns all variables which are allowed per the variables defined above.
-// This is useful for constructing a base environment in which a command can be
-// run.
-func AllowedEnvironment(envs []string) []string {
- var filtered []string
+// Read calls Read() on the stdout pipe of the command.
+func (c *Command) Read(p []byte) (int, error) {
+ if c.reader == nil {
+ panic("command has no reader")
+ }
- for _, env := range envs {
- for _, exportedEnv := range exportedEnvVars {
- if strings.HasPrefix(env, exportedEnv+"=") {
- filtered = append(filtered, env)
- }
- }
+ return c.reader.Read(p)
+}
+
+// Write calls Write() on the stdin pipe of the command.
+func (c *Command) Write(p []byte) (int, error) {
+ if c.writer == nil {
+ panic("command has no writer")
}
- return filtered
+ return c.writer.Write(p)
+}
+
+// Wait calls Wait() on the exec.Cmd instance inside the command. This
+// blocks until the command has finished and reports the command exit
+// status via the error return value. Use ExitStatus to get the integer
+// exit status from the error returned by Wait().
+func (c *Command) Wait() error {
+ c.waitOnce.Do(c.wait)
+
+ return c.waitError
}
// This function should never be called directly, use Wait().
@@ -366,21 +343,10 @@ func (c *Command) wait() {
// counter again. So we instead do it here to accelerate the process, even though it's less
// idiomatic.
commandcounter.Decrement()
-}
-
-// ExitStatus will return the exit-code from an error returned by Wait().
-func ExitStatus(err error) (int, bool) {
- exitError, ok := err.(*exec.ExitError)
- if !ok {
- return 0, false
- }
- waitStatus, ok := exitError.Sys().(syscall.WaitStatus)
- if !ok {
- return 0, false
+ if c.finalizer != nil {
+ c.finalizer(c)
}
-
- return waitStatus.ExitStatus(), true
}
func (c *Command) logProcessComplete() {
@@ -480,6 +446,66 @@ func (c *Command) logProcessComplete() {
c.span.Finish()
}
+// Args is an accessor for the command arguments
+func (c *Command) Args() []string {
+ return c.cmd.Args
+}
+
+// Env is an accessor for the environment variables
+func (c *Command) Env() []string {
+ return c.cmd.Env
+}
+
+// Pid is an accessor for the pid
+func (c *Command) Pid() int {
+ return c.cmd.Process.Pid
+}
+
+type contextWithoutDonePanic string
+
+var getSpawnTokenAcquiringSeconds = func(t time.Time) float64 {
+ return time.Since(t).Seconds()
+}
+
+type stdinSentinel struct{}
+
+func (stdinSentinel) Read([]byte) (int, error) {
+ return 0, errors.New("stdin sentinel should not be read from")
+}
+
+// AllowedEnvironment filters the given slice of environment variables and
+// returns all variables which are allowed per the variables defined above.
+// This is useful for constructing a base environment in which a command can be
+// run.
+func AllowedEnvironment(envs []string) []string {
+ var filtered []string
+
+ for _, env := range envs {
+ for _, exportedEnv := range exportedEnvVars {
+ if strings.HasPrefix(env, exportedEnv+"=") {
+ filtered = append(filtered, env)
+ }
+ }
+ }
+
+ return filtered
+}
+
+// ExitStatus will return the exit-code from an error returned by Wait().
+func ExitStatus(err error) (int, bool) {
+ exitError, ok := err.(*exec.ExitError)
+ if !ok {
+ return 0, false
+ }
+
+ waitStatus, ok := exitError.Sys().(syscall.WaitStatus)
+ if !ok {
+ return 0, false
+ }
+
+ return waitStatus.ExitStatus(), true
+}
+
func methodFromContext(ctx context.Context) (service string, method string) {
tags := grpcmwtags.Extract(ctx)
ctxValue := tags.Values()["grpc.request.fullMethod"]
@@ -514,18 +540,3 @@ func checkNullArgv(cmd *exec.Cmd) error {
return nil
}
-
-// Args is an accessor for the command arguments
-func (c *Command) Args() []string {
- return c.cmd.Args
-}
-
-// Env is an accessor for the environment variables
-func (c *Command) Env() []string {
- return c.cmd.Env
-}
-
-// Pid is an accessor for the pid
-func (c *Command) Pid() int {
- return c.cmd.Process.Pid
-}
diff --git a/internal/command/command_test.go b/internal/command/command_test.go
index 7611e10c6..1f6eae3b0 100644
--- a/internal/command/command_test.go
+++ b/internal/command/command_test.go
@@ -6,7 +6,7 @@ import (
"fmt"
"io"
"os/exec"
- "regexp"
+ "path/filepath"
"runtime"
"strings"
"testing"
@@ -19,30 +19,30 @@ import (
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/git/repository"
"gitlab.com/gitlab-org/gitaly/v15/internal/testhelper"
)
-func TestMain(m *testing.M) {
- testhelper.Run(m)
-}
+func TestNew_environment(t *testing.T) {
+ t.Parallel()
-func TestNewCommandExtraEnv(t *testing.T) {
ctx := testhelper.Context(t)
extraVar := "FOOBAR=123456"
- buff := &bytes.Buffer{}
- cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil, extraVar)
+
+ var buf bytes.Buffer
+ cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf), WithEnvironment([]string{extraVar}))
require.NoError(t, err)
require.NoError(t, cmd.Wait())
- require.Contains(t, strings.Split(buff.String(), "\n"), extraVar)
+ require.Contains(t, strings.Split(buf.String(), "\n"), extraVar)
}
-func TestNewCommandExportedEnv(t *testing.T) {
+func TestNew_exportedEnvironment(t *testing.T) {
ctx := testhelper.Context(t)
- testCases := []struct {
+ for _, tc := range []struct {
key string
value string
}{
@@ -110,9 +110,7 @@ func TestNewCommandExportedEnv(t *testing.T) {
key: "NO_PROXY",
value: "https://excluded:5000",
},
- }
-
- for _, tc := range testCases {
+ } {
t.Run(tc.key, func(t *testing.T) {
if tc.key == "LD_LIBRARY_PATH" && runtime.GOOS == "darwin" {
t.Skip("System Integrity Protection prevents using dynamic linker (dyld) environment variables on macOS. https://apple.co/2XDH4iC")
@@ -120,50 +118,41 @@ func TestNewCommandExportedEnv(t *testing.T) {
testhelper.ModifyEnvironment(t, tc.key, tc.value)
- buff := &bytes.Buffer{}
- cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil)
+ var buf bytes.Buffer
+ cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf))
require.NoError(t, err)
require.NoError(t, cmd.Wait())
expectedEnv := fmt.Sprintf("%s=%s", tc.key, tc.value)
- require.Contains(t, strings.Split(buff.String(), "\n"), expectedEnv)
+ require.Contains(t, strings.Split(buf.String(), "\n"), expectedEnv)
})
}
}
-func TestNewCommandUnexportedEnv(t *testing.T) {
+func TestNew_unexportedEnv(t *testing.T) {
ctx := testhelper.Context(t)
unexportedEnvKey, unexportedEnvVal := "GITALY_UNEXPORTED_ENV", "foobar"
testhelper.ModifyEnvironment(t, unexportedEnvKey, unexportedEnvVal)
- buff := &bytes.Buffer{}
- cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil)
-
+ var buf bytes.Buffer
+ cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf))
require.NoError(t, err)
require.NoError(t, cmd.Wait())
- require.NotContains(t, strings.Split(buff.String(), "\n"), fmt.Sprintf("%s=%s", unexportedEnvKey, unexportedEnvVal))
+ require.NotContains(t, strings.Split(buf.String(), "\n"), fmt.Sprintf("%s=%s", unexportedEnvKey, unexportedEnvVal))
}
-func TestRejectEmptyContextDone(t *testing.T) {
- defer func() {
- p := recover()
- if p == nil {
- t.Error("expected panic, got none")
- return
- }
-
- if _, ok := p.(contextWithoutDonePanic); !ok {
- panic(p)
- }
- }()
+func TestNew_rejectContextWithoutDone(t *testing.T) {
+ t.Parallel()
- _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true"), nil, nil, nil)
- require.NoError(t, err)
+ require.PanicsWithValue(t, contextWithoutDonePanic("command spawned with context without Done() channel"), func() {
+ _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true"))
+ require.NoError(t, err)
+ })
}
-func TestNewCommandTimeout(t *testing.T) {
+func TestNew_spawnTimeout(t *testing.T) {
ctx := testhelper.Context(t)
defer func(ch chan struct{}, t time.Duration) {
@@ -177,39 +166,33 @@ func TestNewCommandTimeout(t *testing.T) {
spawnTimeout := 200 * time.Millisecond
spawnConfig.Timeout = spawnTimeout
- testDeadline := time.After(1 * time.Second)
tick := time.After(spawnTimeout / 2)
errCh := make(chan error)
go func() {
- _, err := New(ctx, exec.Command("true"), nil, nil, nil)
+ _, err := New(ctx, exec.Command("true"))
errCh <- err
}()
- var err error
- timePassed := false
-
-wait:
- for {
- select {
- case err = <-errCh:
- break wait
- case <-tick:
- timePassed = true
- case <-testDeadline:
- t.Fatal("test timed out")
- }
+ select {
+ case <-errCh:
+ require.FailNow(t, "expected spawning to be delayed")
+ case <-tick:
+ // This is the happy case: we expect spawning of the command to be delayed by up to
+ // 200ms until it finally times out.
}
- require.True(t, timePassed, "time must have passed")
- require.Error(t, err)
- require.Contains(t, err.Error(), "process spawn timed out after")
+ // And after some time we expect that spawning of the command fails due to the configured
+ // timeout.
+ require.Equal(t, fmt.Errorf("process spawn timed out after 200ms"), <-errCh)
}
-func TestCommand_Wait_interrupts_after_context_cancellation(t *testing.T) {
+func TestCommand_Wait_contextCancellationKillsCommand(t *testing.T) {
+ t.Parallel()
+
ctx, cancel := context.WithCancel(testhelper.Context(t))
- cmd, err := New(ctx, exec.CommandContext(ctx, "sleep", "1h"), nil, nil, nil)
+ cmd, err := New(ctx, exec.CommandContext(ctx, "sleep", "1h"))
require.NoError(t, err)
// Cancel the command early.
@@ -222,184 +205,243 @@ func TestCommand_Wait_interrupts_after_context_cancellation(t *testing.T) {
require.Equal(t, -1, s)
}
-func TestNewCommandWithSetupStdin(t *testing.T) {
- ctx := testhelper.Context(t)
+func TestNew_setupStdin(t *testing.T) {
+ t.Parallel()
- value := "Test value"
- output := bytes.NewBuffer(nil)
+ ctx := testhelper.Context(t)
- cmd, err := New(ctx, exec.Command("cat"), SetupStdin, nil, nil)
- require.NoError(t, err)
+ stdin := "Test value"
- _, err = fmt.Fprintf(cmd, "%s", value)
+ var buf bytes.Buffer
+ cmd, err := New(ctx, exec.Command("cat"), WithSetupStdin(), WithStdout(&buf))
require.NoError(t, err)
- // The output of the `cat` subprocess should exactly match its input
- _, err = io.CopyN(output, cmd, int64(len(value)))
+ _, err = fmt.Fprintf(cmd, "%s", stdin)
require.NoError(t, err)
- require.Equal(t, value, output.String())
require.NoError(t, cmd.Wait())
+ require.Equal(t, stdin, buf.String())
}
-func TestNewCommandNullInArg(t *testing.T) {
+func TestCommand_read(t *testing.T) {
+ t.Parallel()
+
ctx := testhelper.Context(t)
- _, err := New(ctx, exec.Command("sh", "-c", "hello\x00world"), nil, nil, nil)
- require.Error(t, err)
- require.EqualError(t, err, `detected null byte in command argument "hello\x00world"`)
+ cmd, err := New(ctx, exec.Command("echo", "test value"))
+ require.NoError(t, err)
+
+ output, err := io.ReadAll(cmd)
+ require.NoError(t, err)
+ require.Equal(t, "test value\n", string(output))
+
+ require.NoError(t, cmd.Wait())
}
-func TestNewNonExistent(t *testing.T) {
+func TestNew_nulByteInArgument(t *testing.T) {
+ t.Parallel()
+
ctx := testhelper.Context(t)
- cmd, err := New(ctx, exec.Command("command-non-existent"), nil, nil, nil)
+ cmd, err := New(ctx, exec.Command("sh", "-c", "hello\x00world"))
+ require.Equal(t, fmt.Errorf("detected null byte in command argument %q", "hello\x00world"), err)
require.Nil(t, cmd)
- require.Error(t, err)
}
-func TestCommandStdErr(t *testing.T) {
+func TestNew_missingBinary(t *testing.T) {
+ t.Parallel()
+
ctx := testhelper.Context(t)
- var stdout, stderr bytes.Buffer
- expectedMessage := `hello world\nhello world\nhello world\nhello world\nhello world\n`
+ cmd, err := New(ctx, exec.Command("command-non-existent"))
+ require.EqualError(t, err, "starting process [command-non-existent]: exec: \"command-non-existent\": executable file not found in $PATH")
+ require.Nil(t, cmd)
+}
+
+func TestCommand_stderrLogging(t *testing.T) {
+ t.Parallel()
- logger := logrus.New()
- logger.SetOutput(&stderr)
+ binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash
+ for i in {1..5}
+ do
+ echo 'hello world' 1>&2
+ done
+ exit 1
+ `))
+ logger, hook := test.NewNullLogger()
+ ctx := testhelper.Context(t)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("./testdata/stderr_script.sh"), nil, &stdout, nil)
+ var stdout bytes.Buffer
+ cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout))
require.NoError(t, err)
- require.Error(t, cmd.Wait())
- assert.Empty(t, stdout.Bytes())
- require.Equal(t, expectedMessage, extractLastMessage(stderr.String()))
+ require.EqualError(t, cmd.Wait(), "exit status 1")
+ require.Empty(t, stdout.Bytes())
+ require.Equal(t, strings.Repeat("hello world\n", 5), hook.LastEntry().Message)
}
-func TestCommandStdErrLargeOutput(t *testing.T) {
- ctx := testhelper.Context(t)
-
- var stdout, stderr bytes.Buffer
+func TestCommand_stderrLoggingTruncation(t *testing.T) {
+ t.Parallel()
- logger := logrus.New()
- logger.SetOutput(&stderr)
+ binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash
+ for i in {1..1000}
+ do
+ printf '%06d zzzzzzzzzz\n' $i >&2
+ done
+ exit 1
+ `))
+ logger, hook := test.NewNullLogger()
+ ctx := testhelper.Context(t)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("./testdata/stderr_many_lines.sh"), nil, &stdout, nil)
+ var stdout bytes.Buffer
+ cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout))
require.NoError(t, err)
- require.Error(t, cmd.Wait())
- assert.Empty(t, stdout.Bytes())
- msg := strings.ReplaceAll(extractLastMessage(stderr.String()), "\\n", "\n")
- require.LessOrEqual(t, len(msg), maxStderrBytes)
+ require.Error(t, cmd.Wait())
+ require.Empty(t, stdout.Bytes())
+ require.Len(t, hook.LastEntry().Message, maxStderrBytes)
}
-func TestCommandStdErrBinaryNullBytes(t *testing.T) {
- ctx := testhelper.Context(t)
-
- var stdout, stderr bytes.Buffer
+func TestCommand_stderrLoggingWithNulBytes(t *testing.T) {
+ t.Parallel()
- logger := logrus.New()
- logger.SetOutput(&stderr)
+ binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash
+ dd if=/dev/zero bs=1000 count=1000 status=none >&2
+ exit 1
+ `))
+ logger, hook := test.NewNullLogger()
+ ctx := testhelper.Context(t)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("./testdata/stderr_binary_null.sh"), nil, &stdout, nil)
+ var stdout bytes.Buffer
+ cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout))
require.NoError(t, err)
- require.Error(t, cmd.Wait())
- assert.Empty(t, stdout.Bytes())
- msg := strings.SplitN(extractLastMessage(stderr.String()), "\\n", 2)[0]
- require.Equal(t, strings.Repeat("\\x00", maxStderrLineLength), msg)
+ require.Error(t, cmd.Wait())
+ require.Empty(t, stdout.Bytes())
+ require.Equal(t, strings.Repeat("\x00", maxStderrLineLength), hook.LastEntry().Message)
}
-func TestCommandStdErrLongLine(t *testing.T) {
- ctx := testhelper.Context(t)
-
- var stdout, stderr bytes.Buffer
+func TestCommand_stderrLoggingLongLine(t *testing.T) {
+ t.Parallel()
- logger := logrus.New()
- logger.SetOutput(&stderr)
+ binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash
+ printf 'a%.0s' {1..8192} >&2
+ printf '\n' >&2
+ printf 'b%.0s' {1..8192} >&2
+ exit 1
+ `))
+ logger, hook := test.NewNullLogger()
+ ctx := testhelper.Context(t)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("./testdata/stderr_repeat_a.sh"), nil, &stdout, nil)
+ var stdout bytes.Buffer
+ cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout))
require.NoError(t, err)
- require.Error(t, cmd.Wait())
- assert.Empty(t, stdout.Bytes())
- require.Contains(t, stderr.String(), fmt.Sprintf("%s\\n%s", strings.Repeat("a", maxStderrLineLength), strings.Repeat("b", maxStderrLineLength)))
+ require.Error(t, cmd.Wait())
+ require.Empty(t, stdout.Bytes())
+ require.Equal(t,
+ strings.Join([]string{
+ strings.Repeat("a", maxStderrLineLength),
+ strings.Repeat("b", maxStderrLineLength),
+ }, "\n"),
+ hook.LastEntry().Message,
+ )
}
-func TestCommandStdErrMaxBytes(t *testing.T) {
- ctx := testhelper.Context(t)
-
- var stdout, stderr bytes.Buffer
-
- logger := logrus.New()
- logger.SetOutput(&stderr)
+func TestCommand_stderrLoggingMaxBytes(t *testing.T) {
+ t.Parallel()
+
+ binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash
+ # This script is used to test that a command writes at most maxBytes to stderr. It
+ # simulates the edge case where the logwriter has already written MaxStderrBytes-1
+ # (9999) bytes
+
+ # This edge case happens when 9999 bytes are written. To simulate this,
+ # stderr_max_bytes_edge_case has 4 lines of the following format:
+ #
+ # line1: 3333 bytes long
+ # line2: 3331 bytes
+ # line3: 3331 bytes
+ # line4: 1 byte
+ #
+ # The first 3 lines sum up to 9999 bytes written, since we write a 2-byte escaped
+ # "\n" for each \n we see. The 4th line can be any data.
+
+ printf 'a%.0s' {1..3333} >&2
+ printf '\n' >&2
+ printf 'a%.0s' {1..3331} >&2
+ printf '\n' >&2
+ printf 'a%.0s' {1..3331} >&2
+ printf '\na\n' >&2
+ exit 1
+ `))
+ logger, hook := test.NewNullLogger()
+ ctx := testhelper.Context(t)
ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("./testdata/stderr_max_bytes_edge_case.sh"), nil, &stdout, nil)
+ var stdout bytes.Buffer
+ cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout))
require.NoError(t, err)
require.Error(t, cmd.Wait())
- assert.Empty(t, stdout.Bytes())
- message := extractLastMessage(stderr.String())
- require.Equal(t, maxStderrBytes, len(strings.ReplaceAll(message, "\\n", "\n")))
+ require.Empty(t, stdout.Bytes())
+ require.Len(t, hook.LastEntry().Message, maxStderrBytes)
}
-var logMsgRegex = regexp.MustCompile(`msg="(.+?)"`)
-
-func extractLastMessage(logMessage string) string {
- subMatchesAll := logMsgRegex.FindAllStringSubmatch(logMessage, -1)
- if len(subMatchesAll) < 1 {
- return ""
- }
-
- subMatches := subMatchesAll[len(subMatchesAll)-1]
- if len(subMatches) != 2 {
- return ""
- }
+type mockCgroupManager string
- return subMatches[1]
+func (m mockCgroupManager) AddCommand(*Command, repository.GitRepo) (string, error) {
+ return string(m), nil
}
func TestCommand_logMessage(t *testing.T) {
+ t.Parallel()
+
logger, hook := test.NewNullLogger()
logger.SetLevel(logrus.DebugLevel)
ctx := ctxlogrus.ToContext(testhelper.Context(t), logrus.NewEntry(logger))
- cmd, err := New(ctx, exec.Command("echo", "hello world"), nil, nil, nil)
+ cmd, err := New(ctx, exec.Command("echo", "hello world"),
+ WithCgroup(mockCgroupManager("/sys/fs/cgroup/1"), nil),
+ )
require.NoError(t, err)
- cgroupPath := "/sys/fs/cgroup/1"
- cmd.SetCgroupPath(cgroupPath)
require.NoError(t, cmd.Wait())
logEntry := hook.LastEntry()
assert.Equal(t, cmd.Pid(), logEntry.Data["pid"])
assert.Equal(t, []string{"echo", "hello world"}, logEntry.Data["args"])
assert.Equal(t, 0, logEntry.Data["command.exitCode"])
- assert.Equal(t, cgroupPath, logEntry.Data["command.cgroup_path"])
+ assert.Equal(t, "/sys/fs/cgroup/1", logEntry.Data["command.cgroup_path"])
}
-func TestNewCommandSpawnTokenMetrics(t *testing.T) {
- spawnTokenAcquiringSeconds.Reset()
+func TestNew_commandSpawnTokenMetrics(t *testing.T) {
+ defer func(old func(time.Time) float64) {
+ getSpawnTokenAcquiringSeconds = old
+ }(getSpawnTokenAcquiringSeconds)
- ctx := testhelper.Context(t)
getSpawnTokenAcquiringSeconds = func(t time.Time) float64 {
return 1
}
+ spawnTokenAcquiringSeconds.Reset()
+
+ ctx := testhelper.Context(t)
+
tags := grpcmwtags.NewTags()
tags.Set("grpc.request.fullMethod", "/test.Service/TestRPC")
ctx = grpcmwtags.SetInContext(ctx, tags)
- cmd, err := New(ctx, exec.Command("echo", "goodbye, cruel world."), nil, nil, nil)
+ cmd, err := New(ctx, exec.Command("echo", "goodbye, cruel world."))
require.NoError(t, err)
require.NoError(t, cmd.Wait())
@@ -408,7 +450,7 @@ func TestNewCommandSpawnTokenMetrics(t *testing.T) {
# TYPE gitaly_command_spawn_token_acquiring_seconds_total counter
gitaly_command_spawn_token_acquiring_seconds_total{cmd="echo",grpc_method="TestRPC",grpc_service="test.Service"} 1
`
- assert.NoError(
+ require.NoError(
t,
testutil.CollectAndCompare(
spawnTokenAcquiringSeconds,
diff --git a/internal/command/option.go b/internal/command/option.go
new file mode 100644
index 000000000..08acf0f33
--- /dev/null
+++ b/internal/command/option.go
@@ -0,0 +1,94 @@
+package command
+
+import (
+ "io"
+
+ "gitlab.com/gitlab-org/gitaly/v15/internal/git/repository"
+)
+
+type config struct {
+ stdin io.Reader
+ stdout io.Writer
+ stderr io.Writer
+ environment []string
+
+ finalizer func(*Command)
+
+ commandName string
+ subcommandName string
+
+ cgroupsManager CgroupsManager
+ cgroupsRepo repository.GitRepo
+}
+
+// Option is an option that can be passed to `New()` for controlling how the command is being
+// created.
+type Option func(cfg *config)
+
+// WithStdin sets up the command to read from the given reader.
+func WithStdin(stdin io.Reader) Option {
+ return func(cfg *config) {
+ cfg.stdin = stdin
+ }
+}
+
+// WithSetupStdin instructs New() to configure the stdin pipe of the command it is creating. This
+// allows you call Write() on the command as if it is an ordinary io.Writer, sending data directly
+// to the stdin of the process.
+func WithSetupStdin() Option {
+ return func(cfg *config) {
+ cfg.stdin = stdinSentinel{}
+ }
+}
+
+// WithStdout sets up the command to write standard output to the given writer.
+func WithStdout(stdout io.Writer) Option {
+ return func(cfg *config) {
+ cfg.stdout = stdout
+ }
+}
+
+// WithStderr sets up the command to write standard error to the given writer.
+func WithStderr(stderr io.Writer) Option {
+ return func(cfg *config) {
+ cfg.stderr = stderr
+ }
+}
+
+// WithEnvironment sets up environment variables that shall be set for the command.
+func WithEnvironment(environment []string) Option {
+ return func(cfg *config) {
+ cfg.environment = environment
+ }
+}
+
+// WithCommandName overrides the "cmd" and "subcmd" label used in metrics.
+func WithCommandName(commandName, subcommandName string) Option {
+ return func(cfg *config) {
+ cfg.commandName = commandName
+ cfg.subcommandName = subcommandName
+ }
+}
+
+// CgroupsManager is a subset of the `cgroups.Manager` interface. We need to replicate it here to
+// avoid a cyclic dependency between both packages.
+type CgroupsManager interface {
+ AddCommand(*Command, repository.GitRepo) (string, error)
+}
+
+// WithCgroup adds the spawned command to a Cgroup. The bucket used will be derived from the
+// command's arguments and/or from the repository.
+func WithCgroup(cgroupsManager CgroupsManager, repo repository.GitRepo) Option {
+ return func(cfg *config) {
+ cfg.cgroupsManager = cgroupsManager
+ cfg.cgroupsRepo = repo
+ }
+}
+
+// WithFinalizer sets up the finalizer to be run when the command is being wrapped up. It will be
+// called after `Wait()` has returned.
+func WithFinalizer(finalizer func(*Command)) Option {
+ return func(cfg *config) {
+ cfg.finalizer = finalizer
+ }
+}
diff --git a/internal/command/testdata/stderr_binary_null.sh b/internal/command/testdata/stderr_binary_null.sh
deleted file mode 100755
index e8e1d0baa..000000000
--- a/internal/command/testdata/stderr_binary_null.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/bin/bash
-
-dd if=/dev/zero bs=1000 count=1000 >&2
-exit 1
diff --git a/internal/command/testdata/stderr_many_lines.sh b/internal/command/testdata/stderr_many_lines.sh
deleted file mode 100755
index e8b51b646..000000000
--- a/internal/command/testdata/stderr_many_lines.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-
-let x=0
-while [ $x -lt 100010 ]
-do
- let x=x+1
- printf '%06d zzzzzzzzzz\n' $x >&2
-done
-exit 1
diff --git a/internal/command/testdata/stderr_max_bytes_edge_case.sh b/internal/command/testdata/stderr_max_bytes_edge_case.sh
deleted file mode 100755
index 6e1454966..000000000
--- a/internal/command/testdata/stderr_max_bytes_edge_case.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# This script is used to test that a command writes at most maxBytes to stderr. It simulates the
-# edge case where the logwriter has already written MaxStderrBytes-1 (9999) bytes
-
-# This edge case happens when 9999 bytes are written. To simulate this, stderr_max_bytes_edge_case has 4 lines of the following format:
-# line1: 3333 bytes long
-# line2: 3331 bytes
-# line3: 3331 bytes
-# line4: 1 byte
-# The first 3 lines sum up to 9999 bytes written, since we write a 2-byte escaped `\n` for each \n we see.
-# The 4th line can be any data.
-
-printf 'a%.0s' {1..3333} >&2
-printf '\n' >&2
-printf 'a%.0s' {1..3331} >&2
-printf '\n' >&2
-printf 'a%.0s' {1..3331} >&2
-printf '\na\n' >&2
-exit 1
diff --git a/internal/command/testdata/stderr_repeat_a.sh b/internal/command/testdata/stderr_repeat_a.sh
deleted file mode 100755
index 8463929ae..000000000
--- a/internal/command/testdata/stderr_repeat_a.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/bin/bash
-
-printf 'a%.0s' {1..8192} >&2
-printf '\n' >&2
-printf 'b%.0s' {1..8192} >&2
-exit 1
diff --git a/internal/command/testdata/stderr_script.sh b/internal/command/testdata/stderr_script.sh
deleted file mode 100755
index 57abd97d0..000000000
--- a/internal/command/testdata/stderr_script.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-#!/bin/bash
-
-for i in {1..5}
-do
- echo 'hello world' 1>&2
-done
-exit 1
diff --git a/internal/command/testhelper_test.go b/internal/command/testhelper_test.go
new file mode 100644
index 000000000..e0398c2bd
--- /dev/null
+++ b/internal/command/testhelper_test.go
@@ -0,0 +1,11 @@
+package command
+
+import (
+ "testing"
+
+ "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper"
+)
+
+func TestMain(m *testing.M) {
+ testhelper.Run(m)
+}
diff --git a/internal/git/catfile/object_info_reader.go b/internal/git/catfile/object_info_reader.go
index f391bd128..84857ce70 100644
--- a/internal/git/catfile/object_info_reader.go
+++ b/internal/git/catfile/object_info_reader.go
@@ -146,7 +146,7 @@ func newObjectInfoReader(
git.Flag{Name: "--buffer"},
},
},
- git.WithStdin(command.SetupStdin),
+ git.WithSetupStdin(),
)
if err != nil {
return nil, err
diff --git a/internal/git/catfile/object_reader.go b/internal/git/catfile/object_reader.go
index 4810a8207..516050301 100644
--- a/internal/git/catfile/object_reader.go
+++ b/internal/git/catfile/object_reader.go
@@ -67,7 +67,7 @@ func newObjectReader(
git.Flag{Name: "--buffer"},
},
},
- git.WithStdin(command.SetupStdin),
+ git.WithSetupStdin(),
)
if err != nil {
return nil, err
diff --git a/internal/git/command_factory.go b/internal/git/command_factory.go
index 2f2ded567..86918985f 100644
--- a/internal/git/command_factory.go
+++ b/internal/git/command_factory.go
@@ -17,7 +17,6 @@ import (
"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/storage"
"gitlab.com/gitlab-org/gitaly/v15/internal/log"
- "gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag"
)
// CommandFactory is designed to create and run git commands in a protected and fully managed manner.
@@ -341,7 +340,7 @@ func (cf *ExecCommandFactory) GitVersion(ctx context.Context) (Version, error) {
// Furthermore, note that we're not using `newCommand()` but instead hand-craft the command.
// This is required to avoid a cyclic dependency when we need to check the version in
// `newCommand()` itself.
- cmd, err := command.New(ctx, exec.Command(execEnv.BinaryPath, "version"), nil, nil, nil, execEnv.EnvironmentVariables...)
+ cmd, err := command.New(ctx, exec.Command(execEnv.BinaryPath, "version"), command.WithEnvironment(execEnv.EnvironmentVariables))
if err != nil {
return Version{}, fmt.Errorf("spawning version command: %w", err)
}
@@ -397,19 +396,16 @@ func (cf *ExecCommandFactory) newCommand(ctx context.Context, repo repository.Gi
execCommand := exec.Command(execEnv.BinaryPath, args...)
execCommand.Dir = dir
- command, err := command.New(ctx, execCommand, config.stdin, config.stdout, config.stderr, env...)
+ command, err := command.New(ctx, execCommand, append(
+ config.commandOpts,
+ command.WithEnvironment(env),
+ command.WithCommandName("git", sc.Subcommand()),
+ command.WithCgroup(cf.cgroupsManager, repo),
+ )...)
if err != nil {
return nil, err
}
- command.SetMetricsSubCmd(sc.Subcommand())
-
- if featureflag.RunCommandsInCGroup.IsEnabled(ctx) {
- if err := cf.cgroupsManager.AddCommand(command, repo); err != nil {
- return nil, err
- }
- }
-
return command, nil
}
diff --git a/internal/git/command_factory_cgroup_test.go b/internal/git/command_factory_cgroup_test.go
index 69c3126bf..e504880b6 100644
--- a/internal/git/command_factory_cgroup_test.go
+++ b/internal/git/command_factory_cgroup_test.go
@@ -24,9 +24,9 @@ func (m *mockCgroupsManager) Setup() error {
return nil
}
-func (m *mockCgroupsManager) AddCommand(c *command.Command, repo repository.GitRepo) error {
+func (m *mockCgroupsManager) AddCommand(c *command.Command, repo repository.GitRepo) (string, error) {
m.commands = append(m.commands, c)
- return nil
+ return "", nil
}
func (m *mockCgroupsManager) Cleanup() error {
diff --git a/internal/git/command_options.go b/internal/git/command_options.go
index 44aecd810..044ae754d 100644
--- a/internal/git/command_options.go
+++ b/internal/git/command_options.go
@@ -9,6 +9,7 @@ import (
"regexp"
"strings"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/command"
"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/storage"
"gitlab.com/gitlab-org/gitaly/v15/internal/helper"
@@ -168,9 +169,7 @@ func ConvertConfigOptions(options []string) ([]ConfigPair, error) {
type cmdCfg struct {
env []string
globals []GlobalOption
- stdin io.Reader
- stdout io.Writer
- stderr io.Writer
+ commandOpts []command.Option
hooksConfigured bool
}
@@ -181,7 +180,15 @@ type CmdOpt func(context.Context, config.Cfg, CommandFactory, *cmdCfg) error
// command suitable for `Write()`ing to.
func WithStdin(r io.Reader) CmdOpt {
return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error {
- c.stdin = r
+ c.commandOpts = append(c.commandOpts, command.WithStdin(r))
+ return nil
+ }
+}
+
+// WithSetupStdin sets up the command so that it can be `Write()`en to.
+func WithSetupStdin() CmdOpt {
+ return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error {
+ c.commandOpts = append(c.commandOpts, command.WithSetupStdin())
return nil
}
}
@@ -189,7 +196,7 @@ func WithStdin(r io.Reader) CmdOpt {
// WithStdout sets the command's stdout.
func WithStdout(w io.Writer) CmdOpt {
return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error {
- c.stdout = w
+ c.commandOpts = append(c.commandOpts, command.WithStdout(w))
return nil
}
}
@@ -197,7 +204,7 @@ func WithStdout(w io.Writer) CmdOpt {
// WithStderr sets the command's stderr.
func WithStderr(w io.Writer) CmdOpt {
return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error {
- c.stderr = w
+ c.commandOpts = append(c.commandOpts, command.WithStderr(w))
return nil
}
}
@@ -302,3 +309,12 @@ func withInternalFetch(req repoScopedRequest, withSidechannel bool) func(ctx con
return nil
}
}
+
+// WithFinalizer sets up the finalizer to be run when the command is being wrapped up. It will be
+// called after `Wait()` has returned.
+func WithFinalizer(finalizer func(*command.Command)) CmdOpt {
+ return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error {
+ c.commandOpts = append(c.commandOpts, command.WithFinalizer(finalizer))
+ return nil
+ }
+}
diff --git a/internal/git/objectpool/fetch.go b/internal/git/objectpool/fetch.go
index 6aedcc7e5..549522809 100644
--- a/internal/git/objectpool/fetch.go
+++ b/internal/git/objectpool/fetch.go
@@ -226,7 +226,7 @@ func (o *ObjectPool) logStats(ctx context.Context, when string) error {
func sizeDir(ctx context.Context, dir string) (int64, error) {
// du -k reports size in KB
- cmd, err := command.New(ctx, exec.Command("du", "-sk", dir), nil, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("du", "-sk", dir))
if err != nil {
return 0, err
}
diff --git a/internal/git/packfile/index.go b/internal/git/packfile/index.go
index 2aa7f4be7..e8915e1d3 100644
--- a/internal/git/packfile/index.go
+++ b/internal/git/packfile/index.go
@@ -93,7 +93,7 @@ func ReadIndex(idxPath string) (*Index, error) {
return nil, err
}
- showIndex, err := command.New(ctx, exec.Command("git", "show-index"), f, nil, nil)
+ showIndex, err := command.New(ctx, exec.Command("git", "show-index"), command.WithStdin(f))
if err != nil {
return nil, err
}
diff --git a/internal/git/pktline/read_monitor.go b/internal/git/pktline/read_monitor.go
index e288b2530..9b371a85f 100644
--- a/internal/git/pktline/read_monitor.go
+++ b/internal/git/pktline/read_monitor.go
@@ -40,27 +40,22 @@ type ReadMonitor struct {
// to the pipe. The stream will be monitored for a pktline-formatted packet
// matching pkt. If it isn't seen within the timeout, cancelFn will be called.
//
-// Resources will be freed when the context is done, but you should close the
-// returned *os.File earlier if possible.
-func NewReadMonitor(ctx context.Context, r io.Reader) (*os.File, *ReadMonitor, error) {
+// The returned function will release allocated resources. You must make sure to call this
+// function.
+func NewReadMonitor(ctx context.Context, r io.Reader) (*os.File, *ReadMonitor, func(), error) {
pr, pw, err := os.Pipe()
if err != nil {
- return nil, nil, err
+ return nil, nil, nil, err
}
- // Ensure all resources are closed once the context is done
- go func() {
- <-ctx.Done()
-
- pr.Close()
- pw.Close()
- }()
-
return pr, &ReadMonitor{
- pr: pr,
- pw: pw,
- underlying: r,
- }, nil
+ pr: pr,
+ pw: pw,
+ underlying: r,
+ }, func() {
+ pr.Close()
+ pw.Close()
+ }, nil
}
// Monitor should be called at most once. It scans the stream, looking for the
diff --git a/internal/git/pktline/read_monitor_test.go b/internal/git/pktline/read_monitor_test.go
index e78bdd8c5..65a0ec3ab 100644
--- a/internal/git/pktline/read_monitor_test.go
+++ b/internal/git/pktline/read_monitor_test.go
@@ -24,7 +24,7 @@ func TestReadMonitorTimeout(t *testing.T) {
waitPipeR, // this pipe reader lets us block the multi reader
)
- r, monitor, err := NewReadMonitor(ctx, in)
+ r, monitor, cleanup, err := NewReadMonitor(ctx, in)
require.NoError(t, err)
timeoutTicker := helper.NewManualTicker()
@@ -41,6 +41,8 @@ func TestReadMonitorTimeout(t *testing.T) {
require.Equal(t, ctx.Err(), context.Canceled)
require.True(t, elapsed < time.Second, "Expected context to be cancelled quickly, but it was not")
+ cleanup()
+
// Verify that pipe is closed
_, err = io.ReadAll(r)
require.Error(t, err)
@@ -61,8 +63,9 @@ func TestReadMonitorSuccess(t *testing.T) {
strings.NewReader(postTimeoutPayload),
)
- r, monitor, err := NewReadMonitor(ctx, in)
+ r, monitor, cleanup, err := NewReadMonitor(ctx, in)
require.NoError(t, err)
+ defer cleanup()
timeoutTicker := helper.NewManualTicker()
diff --git a/internal/git/updateref/updateref.go b/internal/git/updateref/updateref.go
index a00f78c3e..cc97513cc 100644
--- a/internal/git/updateref/updateref.go
+++ b/internal/git/updateref/updateref.go
@@ -63,7 +63,7 @@ func New(ctx context.Context, repo git.RepositoryExecutor, opts ...UpdaterOpt) (
Flags: []git.Option{git.Flag{Name: "-z"}, git.Flag{Name: "--stdin"}},
},
txOption,
- git.WithStdin(command.SetupStdin),
+ git.WithSetupStdin(),
git.WithStderr(&stderr),
)
if err != nil {
diff --git a/internal/git2go/executor.go b/internal/git2go/executor.go
index 75d451fe0..c1e220f43 100644
--- a/internal/git2go/executor.go
+++ b/internal/git2go/executor.go
@@ -88,13 +88,17 @@ func (b *Executor) run(ctx context.Context, repo repository.GitRepo, stdin io.Re
}, args...)
var stdout bytes.Buffer
- cmd, err := command.New(ctx, exec.Command(b.binaryPath, args...), stdin, &stdout, log, env...)
+ cmd, err := command.New(ctx, exec.Command(b.binaryPath, args...),
+ command.WithStdin(stdin),
+ command.WithStdout(&stdout),
+ command.WithStderr(log),
+ command.WithEnvironment(env),
+ command.WithCommandName("gitaly-git2go", subcmd),
+ )
if err != nil {
return nil, err
}
- cmd.SetMetricsSubCmd(subcmd)
-
if err := cmd.Wait(); err != nil {
return nil, err
}
diff --git a/internal/gitaly/hook/custom.go b/internal/gitaly/hook/custom.go
index 327e014de..2c819dd3e 100644
--- a/internal/gitaly/hook/custom.go
+++ b/internal/gitaly/hook/custom.go
@@ -68,13 +68,17 @@ func (m *GitLabHookManager) newCustomHooksExecutor(repo *gitalypb.Repository, ho
for _, hookFile := range hookFiles {
cmd := exec.Command(hookFile, args...)
cmd.Dir = repoPath
- c, err := command.New(ctx, cmd, bytes.NewReader(stdinBytes), stdout, stderr, env...)
+ c, err := command.New(ctx, cmd,
+ command.WithStdin(bytes.NewReader(stdinBytes)),
+ command.WithStdout(stdout),
+ command.WithStderr(stderr),
+ command.WithEnvironment(env),
+ command.WithCommandName("gitaly-hooks", hookName),
+ )
if err != nil {
return err
}
- c.SetMetricsSubCmd(hookName)
-
if err = c.Wait(); err != nil {
// Custom hook errors need to be handled specially when we update
// refs via updateref.UpdaterWithHooks: their stdout and stderr must
diff --git a/internal/gitaly/linguist/linguist.go b/internal/gitaly/linguist/linguist.go
index 4ac9abfd2..da085b58a 100644
--- a/internal/gitaly/linguist/linguist.go
+++ b/internal/gitaly/linguist/linguist.go
@@ -95,14 +95,14 @@ func (inst *Instance) startGitLinguist(ctx context.Context, repoPath string, com
cmd := exec.Command(bundle, "exec", "bin/gitaly-linguist", "--repository="+repoPath, "--commit="+commitID)
cmd.Dir = inst.cfg.Ruby.Dir
- internalCmd, err := command.New(ctx, cmd, nil, nil, nil, env.AllowedRubyEnvironment(os.Environ())...)
+ internalCmd, err := command.New(ctx, cmd,
+ command.WithEnvironment(env.AllowedRubyEnvironment(os.Environ())),
+ command.WithCommandName("git-linguist", "stats"),
+ )
if err != nil {
return nil, fmt.Errorf("creating command: %w", err)
}
- internalCmd.SetMetricsCmd("git-linguist")
- internalCmd.SetMetricsSubCmd("stats")
-
return internalCmd, nil
}
diff --git a/internal/gitaly/service/repository/archive.go b/internal/gitaly/service/repository/archive.go
index 86250e29b..959064237 100644
--- a/internal/gitaly/service/repository/archive.go
+++ b/internal/gitaly/service/repository/archive.go
@@ -230,7 +230,9 @@ func (s *server) handleArchive(ctx context.Context, p archiveParams) error {
}
if p.compressCmd != nil {
- command, err := command.New(ctx, p.compressCmd, archiveCommand, p.writer, nil)
+ command, err := command.New(ctx, p.compressCmd,
+ command.WithStdin(archiveCommand), command.WithStdout(p.writer),
+ )
if err != nil {
return err
}
diff --git a/internal/gitaly/service/repository/backup_custom_hooks.go b/internal/gitaly/service/repository/backup_custom_hooks.go
index da774f89b..e434e6f65 100644
--- a/internal/gitaly/service/repository/backup_custom_hooks.go
+++ b/internal/gitaly/service/repository/backup_custom_hooks.go
@@ -30,7 +30,7 @@ func (s *server) BackupCustomHooks(in *gitalypb.BackupCustomHooksRequest, stream
ctx := stream.Context()
tar := exec.Command("tar", "-c", "-f", "-", "-C", repoPath, customHooksDir)
- cmd, err := command.New(ctx, tar, nil, writer, nil)
+ cmd, err := command.New(ctx, tar, command.WithStdout(writer))
if err != nil {
return status.Errorf(codes.Internal, "%v", err)
}
diff --git a/internal/gitaly/service/repository/create_repository_from_snapshot.go b/internal/gitaly/service/repository/create_repository_from_snapshot.go
index 021bb2cec..f665faa87 100644
--- a/internal/gitaly/service/repository/create_repository_from_snapshot.go
+++ b/internal/gitaly/service/repository/create_repository_from_snapshot.go
@@ -67,7 +67,7 @@ func untar(ctx context.Context, path string, in *gitalypb.CreateRepositoryFromSn
return status.Errorf(codes.Internal, "HTTP server: %v", rsp.Status)
}
- cmd, err := command.New(ctx, exec.Command("tar", "-C", path, "-xvf", "-"), rsp.Body, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("tar", "-C", path, "-xvf", "-"), command.WithStdin(rsp.Body))
if err != nil {
return err
}
diff --git a/internal/gitaly/service/repository/replicate.go b/internal/gitaly/service/repository/replicate.go
index 851ba8b4c..6c66bcd74 100644
--- a/internal/gitaly/service/repository/replicate.go
+++ b/internal/gitaly/service/repository/replicate.go
@@ -188,7 +188,9 @@ func (s *server) extractSnapshot(ctx context.Context, source, target *gitalypb.R
}
stderr := &bytes.Buffer{}
- cmd, err := command.New(ctx, exec.Command("tar", "-C", targetPath, "-xvf", "-"), snapshotReader, nil, stderr)
+ cmd, err := command.New(ctx, exec.Command("tar", "-C", targetPath, "-xvf", "-"),
+ command.WithStdin(snapshotReader), command.WithStderr(stderr),
+ )
if err != nil {
return fmt.Errorf("create tar command: %w", err)
}
diff --git a/internal/gitaly/service/repository/restore_custom_hooks.go b/internal/gitaly/service/repository/restore_custom_hooks.go
index 54d73ddc0..572181ffe 100644
--- a/internal/gitaly/service/repository/restore_custom_hooks.go
+++ b/internal/gitaly/service/repository/restore_custom_hooks.go
@@ -62,7 +62,7 @@ func (s *server) RestoreCustomHooks(stream gitalypb.RepositoryService_RestoreCus
}
ctx := stream.Context()
- cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), reader, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), command.WithStdin(reader))
if err != nil {
return status.Errorf(codes.Internal, "RestoreCustomHooks: Could not untar custom hooks tar %v", err)
}
@@ -149,7 +149,7 @@ func (s *server) restoreCustomHooksWithVoting(stream gitalypb.RepositoryService_
customHooksDir,
}
- cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), reader, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), command.WithStdin(reader))
if err != nil {
return helper.ErrInternalf("RestoreCustomHooks: Could not untar custom hooks tar %w", err)
}
diff --git a/internal/gitaly/service/repository/size.go b/internal/gitaly/service/repository/size.go
index d8476274a..95bf22402 100644
--- a/internal/gitaly/service/repository/size.go
+++ b/internal/gitaly/service/repository/size.go
@@ -67,7 +67,7 @@ func (s *server) GetObjectDirectorySize(ctx context.Context, in *gitalypb.GetObj
}
func getPathSize(ctx context.Context, path string) int64 {
- cmd, err := command.New(ctx, exec.Command("du", "-sk", path), nil, nil, nil)
+ cmd, err := command.New(ctx, exec.Command("du", "-sk", path))
if err != nil {
ctxlogrus.Extract(ctx).WithError(err).Warn("ignoring du command error")
return 0
diff --git a/internal/gitaly/service/ssh/monitor_stdin_command.go b/internal/gitaly/service/ssh/monitor_stdin_command.go
index 3d78218ba..5bad1df27 100644
--- a/internal/gitaly/service/ssh/monitor_stdin_command.go
+++ b/internal/gitaly/service/ssh/monitor_stdin_command.go
@@ -11,16 +11,20 @@ import (
)
func monitorStdinCommand(ctx context.Context, gitCmdFactory git.CommandFactory, stdin io.Reader, stdout, stderr io.Writer, sc git.SubCmd, opts ...git.CmdOpt) (*command.Command, *pktline.ReadMonitor, error) {
- stdinPipe, monitor, err := pktline.NewReadMonitor(ctx, stdin)
+ stdinPipe, monitor, cleanup, err := pktline.NewReadMonitor(ctx, stdin)
if err != nil {
return nil, nil, fmt.Errorf("create monitor: %v", err)
}
cmd, err := gitCmdFactory.NewWithoutRepo(ctx, sc, append([]git.CmdOpt{
- git.WithStdin(stdinPipe), git.WithStdout(stdout), git.WithStderr(stderr),
+ git.WithStdin(stdinPipe),
+ git.WithStdout(stdout),
+ git.WithStderr(stderr),
+ git.WithFinalizer(func(*command.Command) { cleanup() }),
}, opts...)...)
stdinPipe.Close() // this now belongs to cmd
if err != nil {
+ cleanup()
return nil, nil, fmt.Errorf("start cmd: %v", err)
}
diff --git a/internal/gitaly/service/ssh/upload_pack.go b/internal/gitaly/service/ssh/upload_pack.go
index 1b0ef278f..89e67e9a9 100644
--- a/internal/gitaly/service/ssh/upload_pack.go
+++ b/internal/gitaly/service/ssh/upload_pack.go
@@ -98,10 +98,14 @@ func (s *server) sshUploadPack(ctx context.Context, req sshUploadPackRequest, st
return 0, err
}
+ var wg sync.WaitGroup
pr, pw := io.Pipe()
- defer pw.Close()
+ defer func() {
+ pw.Close()
+ wg.Wait()
+ }()
+
stdin = io.TeeReader(stdin, pw)
- wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -146,16 +150,10 @@ func (s *server) sshUploadPack(ctx context.Context, req sshUploadPackRequest, st
go monitor.Monitor(ctx, pktline.PktDone(), timeoutTicker, cancelCtx)
if err := cmd.Wait(); err != nil {
- pw.Close()
- wg.Wait()
-
status, _ := command.ExitStatus(err)
return status, fmt.Errorf("cmd wait: %w, stderr: %q", err, stderrBuilder.String())
}
- pw.Close()
- wg.Wait()
-
ctxlogrus.Extract(ctx).WithField("response_bytes", stdoutCounter.N).Info("request details")
return 0, nil
diff --git a/internal/gitaly/service/ssh/upload_pack_test.go b/internal/gitaly/service/ssh/upload_pack_test.go
index dc4d36268..113cd816a 100644
--- a/internal/gitaly/service/ssh/upload_pack_test.go
+++ b/internal/gitaly/service/ssh/upload_pack_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
+ "io"
"os"
"os/exec"
"path/filepath"
@@ -14,17 +15,21 @@ import (
"github.com/prometheus/client_golang/prometheus"
promtest "github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
+ gitalyauth "gitlab.com/gitlab-org/gitaly/v15/auth"
"gitlab.com/gitlab-org/gitaly/v15/internal/git"
"gitlab.com/gitlab-org/gitaly/v15/internal/git/gittest"
"gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v15/internal/helper"
"gitlab.com/gitlab-org/gitaly/v15/internal/helper/text"
"gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag"
+ "gitlab.com/gitlab-org/gitaly/v15/internal/sidechannel"
"gitlab.com/gitlab-org/gitaly/v15/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v15/internal/testhelper/testcfg"
"gitlab.com/gitlab-org/gitaly/v15/internal/testhelper/testserver"
"gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb"
+ "google.golang.org/grpc"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/encoding/protojson"
)
@@ -125,6 +130,234 @@ func testUploadPackTimeout(t *testing.T, opts ...testcfg.Option) {
})
}
+func TestUploadPackWithSidechannel_client(t *testing.T) {
+ t.Parallel()
+
+ cfg := testcfg.Build(t)
+ cfg.SocketPath = runSSHServer(t, cfg)
+
+ repo, repoPath := gittest.CreateRepository(testhelper.Context(t), t, cfg, gittest.CreateRepositoryConfig{
+ Seed: gittest.SeedGitLabTest,
+ })
+ commitID := gittest.Exec(t, cfg, "-C", repoPath, "rev-parse", "HEAD^{commit}")
+
+ registry := sidechannel.NewRegistry()
+ clientHandshaker := sidechannel.NewClientHandshaker(testhelper.NewDiscardingLogEntry(t), registry)
+ conn, err := grpc.Dial(cfg.SocketPath,
+ grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())),
+ grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(cfg.Auth.Token)),
+ )
+ require.NoError(t, err)
+
+ client := gitalypb.NewSSHServiceClient(conn)
+ defer testhelper.MustClose(t, conn)
+
+ for _, tc := range []struct {
+ desc string
+ request *gitalypb.SSHUploadPackWithSidechannelRequest
+ client func(clientConn *sidechannel.ClientConn, cancelContext func()) error
+ expectedErr error
+ expectedResponse *gitalypb.SSHUploadPackWithSidechannelResponse
+ }{
+ {
+ desc: "successful clone",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+" multi_ack\n")
+ gittest.WritePktlineFlush(t, clientConn)
+ gittest.WritePktlineString(t, clientConn, "done\n")
+
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{},
+ },
+ {
+ desc: "successful clone with protocol v2",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+ gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n")
+ gittest.WritePktlineString(t, clientConn, "object-format=sha1\n")
+ gittest.WritePktlineDelim(t, clientConn)
+ gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+"\n")
+ gittest.WritePktlineString(t, clientConn, "done\n")
+ gittest.WritePktlineFlush(t, clientConn)
+
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{},
+ },
+ {
+ desc: "client talks protocol v0 but v2 is requested",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+" multi_ack\n")
+ gittest.WritePktlineFlush(t, clientConn)
+ gittest.WritePktlineString(t, clientConn, "done\n")
+
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedErr: helper.ErrInternalf(
+ "cmd wait: exit status 128, stderr: %q",
+ "fatal: unknown capability 'want 1e292f8fedd741b75372e19097c76d327140c312 multi_ack'\n",
+ ),
+ },
+ {
+ desc: "client talks protocol v2 but v0 is requested",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+ gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n")
+ gittest.WritePktlineString(t, clientConn, "object-format=sha1\n")
+ gittest.WritePktlineDelim(t, clientConn)
+ gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+"\n")
+ gittest.WritePktlineString(t, clientConn, "done\n")
+ gittest.WritePktlineFlush(t, clientConn)
+
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedErr: helper.ErrInternalf(
+ "cmd wait: exit status 128, stderr: %q",
+ "fatal: git upload-pack: protocol error, expected to get object ID, not 'command=fetch'\n",
+ ),
+ },
+ {
+ desc: "missing input",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ require.NoError(t, clientConn.CloseWrite())
+ return nil
+ },
+ expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{},
+ },
+ {
+ desc: "short write",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+
+ _, err := io.WriteString(clientConn, "0011agent")
+ require.NoError(t, err)
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedErr: helper.ErrInternalf("cmd wait: exit status 128, stderr: %q", "fatal: the remote end hung up unexpectedly\n"),
+ },
+ {
+ desc: "garbage",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, _ func()) error {
+ gittest.WritePktlineString(t, clientConn, "foobar")
+ require.NoError(t, clientConn.CloseWrite())
+ return nil
+ },
+ expectedErr: helper.ErrInternalf("cmd wait: exit status 128, stderr: %q", "fatal: unknown capability 'foobar'\n"),
+ },
+ {
+ desc: "close and cancellation",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+ gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n")
+
+ require.NoError(t, clientConn.CloseWrite())
+ cancelContext()
+
+ return nil
+ },
+ expectedErr: helper.ErrCanceled(context.Canceled),
+ },
+ {
+ desc: "cancellation and close",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+ gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n")
+
+ cancelContext()
+ require.NoError(t, clientConn.CloseWrite())
+
+ return nil
+ },
+ expectedErr: helper.ErrCanceled(context.Canceled),
+ },
+ {
+ desc: "cancellation without close",
+ request: &gitalypb.SSHUploadPackWithSidechannelRequest{
+ Repository: repo,
+ GitProtocol: git.ProtocolV2,
+ },
+ client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error {
+ gittest.WritePktlineString(t, clientConn, "command=fetch\n")
+ gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n")
+
+ cancelContext()
+
+ return nil
+ },
+ expectedErr: helper.ErrCanceled(context.Canceled),
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ ctx, cancel := context.WithCancel(testhelper.Context(t))
+
+ ctx, waiter := sidechannel.RegisterSidechannel(ctx, registry, func(clientConn *sidechannel.ClientConn) (returnedErr error) {
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := io.Copy(io.Discard, clientConn)
+ errCh <- err
+ }()
+ defer func() {
+ if err := <-errCh; err != nil && returnedErr == nil {
+ returnedErr = err
+ }
+ }()
+
+ return tc.client(clientConn, cancel)
+ })
+ defer testhelper.MustClose(t, waiter)
+
+ response, err := client.SSHUploadPackWithSidechannel(ctx, tc.request)
+ testhelper.RequireGrpcError(t, tc.expectedErr, err)
+ testhelper.ProtoEqual(t, tc.expectedResponse, response)
+ })
+ }
+}
+
func requireFailedSSHStream(t *testing.T, recv func() (int32, error)) {
done := make(chan struct{})
var code int32