diff options
author | Toon Claes <toon@gitlab.com> | 2022-06-23 15:55:32 +0300 |
---|---|---|
committer | Toon Claes <toon@gitlab.com> | 2022-06-23 15:55:32 +0300 |
commit | c4bb98b282fe8e9e5434b5d20debb8fbf278f1c2 (patch) | |
tree | 50da8a0d46f85b10589f50a9391ea7f91e918660 | |
parent | f8c0438541950ab8a554240b38dc61b006607d74 (diff) | |
parent | d270a521ec5e39f336bd269def7aaf74d9f9ce9c (diff) |
Merge branch 'pks-ssh-tests-emulate-client-side' into 'master'
ssh: Extend tests to emulate client-side failure conditions with sidechannel
See merge request gitlab-org/gitaly!4615
36 files changed, 828 insertions, 456 deletions
diff --git a/cmd/gitaly-lfs-smudge/main_test.go b/cmd/gitaly-lfs-smudge/main_test.go index 37814a34b..6154bcd25 100644 --- a/cmd/gitaly-lfs-smudge/main_test.go +++ b/cmd/gitaly-lfs-smudge/main_test.go @@ -184,10 +184,10 @@ func TestGitalyLFSSmudge(t *testing.T) { var stdout, stderr bytes.Buffer cmd, err := command.New(ctx, exec.Command(binary), - tc.stdin, - &stdout, - &stderr, - env..., + command.WithStdin(tc.stdin), + command.WithStdout(&stdout), + command.WithStderr(&stderr), + command.WithEnvironment(env), ) require.NoError(t, err) diff --git a/internal/cgroups/cgroups.go b/internal/cgroups/cgroups.go index c3e148bba..b2209b509 100644 --- a/internal/cgroups/cgroups.go +++ b/internal/cgroups/cgroups.go @@ -13,8 +13,8 @@ type Manager interface { // It is expected to be called once at Gitaly startup from any // instance of the Manager. Setup() error - // AddCommand adds a Command to a cgroup - AddCommand(*command.Command, repository.GitRepo) error + // AddCommand adds a Command to a cgroup. + AddCommand(*command.Command, repository.GitRepo) (string, error) // Cleanup cleans up cgroups created in Setup. // It is expected to be called once at Gitaly shutdown from any // instance of the Manager. diff --git a/internal/cgroups/noop.go b/internal/cgroups/noop.go index 399bd4931..feaa0d6ef 100644 --- a/internal/cgroups/noop.go +++ b/internal/cgroups/noop.go @@ -15,8 +15,8 @@ func (cg *NoopManager) Setup() error { } //nolint: revive,stylecheck // This is unintentionally missing documentation. -func (cg *NoopManager) AddCommand(cmd *command.Command, repo repository.GitRepo) error { - return nil +func (cg *NoopManager) AddCommand(cmd *command.Command, repo repository.GitRepo) (string, error) { + return "", nil } //nolint: revive,stylecheck // This is unintentionally missing documentation. diff --git a/internal/cgroups/v1_linux.go b/internal/cgroups/v1_linux.go index e4b72f28b..3fa5da4c3 100644 --- a/internal/cgroups/v1_linux.go +++ b/internal/cgroups/v1_linux.go @@ -102,7 +102,7 @@ func (cg *CGroupV1Manager) Setup() error { func (cg *CGroupV1Manager) AddCommand( cmd *command.Command, repo repository.GitRepo, -) error { +) (string, error) { var key string if repo == nil { key = strings.Join(cmd.Args(), "/") @@ -117,9 +117,7 @@ func (cg *CGroupV1Manager) AddCommand( groupID := uint(checksum) % cg.cfg.Repositories.Count cgroupPath := cg.repoPath(int(groupID)) - cmd.SetCgroupPath(cgroupPath) - - return cg.addToCgroup(cmd.Pid(), cgroupPath) + return cgroupPath, cg.addToCgroup(cmd.Pid(), cgroupPath) } func (cg *CGroupV1Manager) addToCgroup(pid int, cgroupPath string) error { diff --git a/internal/cgroups/v1_linux_test.go b/internal/cgroups/v1_linux_test.go index 844b06225..aaa095db0 100644 --- a/internal/cgroups/v1_linux_test.go +++ b/internal/cgroups/v1_linux_test.go @@ -81,7 +81,7 @@ func TestAddCommand(t *testing.T) { ctx := testhelper.Context(t) cmd1 := exec.Command("ls", "-hal", ".") - cmd2, err := command.New(ctx, cmd1, nil, nil, nil) + cmd2, err := command.New(ctx, cmd1) require.NoError(t, err) require.NoError(t, cmd2.Wait()) @@ -91,7 +91,8 @@ func TestAddCommand(t *testing.T) { } t.Run("without a repository", func(t *testing.T) { - require.NoError(t, v1Manager2.AddCommand(cmd2, nil)) + _, err := v1Manager2.AddCommand(cmd2, nil) + require.NoError(t, err) checksum := crc32.ChecksumIEEE([]byte(strings.Join(cmd2.Args(), "/"))) groupID := uint(checksum) % config.Repositories.Count @@ -109,7 +110,8 @@ func TestAddCommand(t *testing.T) { }) t.Run("with a repository", func(t *testing.T) { - require.NoError(t, v1Manager2.AddCommand(cmd2, repo)) + _, err := v1Manager2.AddCommand(cmd2, repo) + require.NoError(t, err) checksum := crc32.ChecksumIEEE([]byte(strings.Join([]string{ "default", @@ -150,6 +152,8 @@ func TestCleanup(t *testing.T) { } func TestMetrics(t *testing.T) { + t.Parallel() + mock := newMock(t) repo := &gitalypb.Repository{ StorageName: "default", @@ -167,22 +171,23 @@ func TestMetrics(t *testing.T) { mock.setupMockCgroupFiles(t, v1Manager1, 2) require.NoError(t, v1Manager1.Setup()) - ctx := testhelper.Context(t) + ctx := testhelper.Context(t) logger, hook := test.NewNullLogger() logger.SetLevel(logrus.DebugLevel) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil) + cmd, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo)) require.NoError(t, err) - gitCmd1, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil) + gitCmd1, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo)) require.NoError(t, err) - gitCmd2, err := command.New(ctx, exec.Command("ls", "-hal", "."), nil, nil, nil) + gitCmd2, err := command.New(ctx, exec.Command("ls", "-hal", "."), command.WithCgroup(v1Manager1, repo)) require.NoError(t, err) + defer func() { + require.NoError(t, gitCmd2.Wait()) + }() - require.NoError(t, v1Manager1.AddCommand(cmd, repo)) - require.NoError(t, v1Manager1.AddCommand(gitCmd1, repo)) - require.NoError(t, v1Manager1.AddCommand(gitCmd2, repo)) + require.NoError(t, err) require.NoError(t, cmd.Wait()) require.NoError(t, gitCmd1.Wait()) diff --git a/internal/command/command.go b/internal/command/command.go index 5602ee24e..abf008f3d 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -20,6 +20,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/v15/internal/command/commandcounter" + "gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag" "gitlab.com/gitlab-org/labkit/tracing" ) @@ -73,40 +74,43 @@ var ( }, []string{"grpc_service", "grpc_method", "cmd"}, ) -) -// exportedEnvVars contains a list of environment variables -// that are always exported to child processes on spawn -var exportedEnvVars = []string{ - "HOME", - "PATH", - "LD_LIBRARY_PATH", - "TZ", - - // Export git tracing variables for easier debugging - "GIT_TRACE", - "GIT_TRACE_PACK_ACCESS", - "GIT_TRACE_PACKET", - "GIT_TRACE_PERFORMANCE", - "GIT_TRACE_SETUP", - - // GIT_EXEC_PATH tells Git where to find its binaries. This must be exported especially in - // the case where we use bundled Git executables given that we cannot rely on a complete Git - // installation in that case. - "GIT_EXEC_PATH", - - // Git HTTP proxy settings: https://git-scm.com/docs/git-config#git-config-httpproxy - "all_proxy", - "http_proxy", - "HTTP_PROXY", - "https_proxy", - "HTTPS_PROXY", - // libcurl settings: https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html - "no_proxy", - "NO_PROXY", -} + // exportedEnvVars contains a list of environment variables + // that are always exported to child processes on spawn + exportedEnvVars = []string{ + "HOME", + "PATH", + "LD_LIBRARY_PATH", + "TZ", + + // Export git tracing variables for easier debugging + "GIT_TRACE", + "GIT_TRACE_PACK_ACCESS", + "GIT_TRACE_PACKET", + "GIT_TRACE_PERFORMANCE", + "GIT_TRACE_SETUP", + + // GIT_EXEC_PATH tells Git where to find its binaries. This must be exported + // especially in the case where we use bundled Git executables given that we cannot + // rely on a complete Git installation in that case. + "GIT_EXEC_PATH", + + // Git HTTP proxy settings: + // https://git-scm.com/docs/git-config#git-config-httpproxy + "all_proxy", + "http_proxy", + "HTTP_PROXY", + "https_proxy", + "HTTPS_PROXY", + // libcurl settings: https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html + "no_proxy", + "NO_PROXY", + } -var envInjector = tracing.NewEnvInjector() + // envInjector is responsible for injecting environment variables required for tracing into + // the child process. + envInjector = tracing.NewEnvInjector() +) const ( // maxStderrBytes is at most how many bytes will be written to stderr @@ -130,6 +134,8 @@ type Command struct { waitError error waitOnce sync.Once + finalizer func(*Command) + span opentracing.Span metricsCmd string @@ -137,75 +143,9 @@ type Command struct { cgroupPath string } -type stdinSentinel struct{} - -func (stdinSentinel) Read([]byte) (int, error) { - return 0, errors.New("stdin sentinel should not be read from") -} - -// SetupStdin instructs New() to configure the stdin pipe of the command it is -// creating. This allows you call Write() on the command as if it is an ordinary -// io.Writer, sending data directly to the stdin of the process. -// -// You should not call Read() on this value - it is strictly for configuration! -var SetupStdin io.Reader = stdinSentinel{} - -// Read calls Read() on the stdout pipe of the command. -func (c *Command) Read(p []byte) (int, error) { - if c.reader == nil { - panic("command has no reader") - } - - return c.reader.Read(p) -} - -// Write calls Write() on the stdin pipe of the command. -func (c *Command) Write(p []byte) (int, error) { - if c.writer == nil { - panic("command has no writer") - } - - return c.writer.Write(p) -} - -// Wait calls Wait() on the exec.Cmd instance inside the command. This -// blocks until the command has finished and reports the command exit -// status via the error return value. Use ExitStatus to get the integer -// exit status from the error returned by Wait(). -func (c *Command) Wait() error { - c.waitOnce.Do(c.wait) - - return c.waitError -} - -// SetCgroupPath sets the cgroup path for logging -func (c *Command) SetCgroupPath(path string) { - c.cgroupPath = path -} - -// SetMetricsCmd overrides the "cmd" label used in metrics -func (c *Command) SetMetricsCmd(metricsCmd string) { - c.metricsCmd = metricsCmd -} - -// SetMetricsSubCmd sets the "subcmd" label used in metrics -func (c *Command) SetMetricsSubCmd(metricsSubCmd string) { - c.metricsSubCmd = metricsSubCmd -} - -type contextWithoutDonePanic string - -var getSpawnTokenAcquiringSeconds = func(t time.Time) float64 { - return time.Since(t).Seconds() -} - -// New creates a Command from an exec.Cmd. On success, the Command -// contains a running subprocess. When ctx is canceled the embedded -// process will be terminated and reaped automatically. -// -// If stdin is specified as SetupStdin, you will be able to write to the stdin -// of the subprocess by calling Write() on the returned Command. -func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io.Writer, env ...string) (*Command, error) { +// New creates a Command from an exec.Cmd. On success, the Command contains a running subprocess. +// When ctx is canceled the embedded process will be terminated and reaped automatically. +func New(ctx context.Context, cmd *exec.Cmd, opts ...Option) (*Command, error) { if ctx.Done() == nil { panic(contextWithoutDonePanic("command spawned with context without Done() channel")) } @@ -214,6 +154,11 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io. return nil, err } + var cfg config + for _, opt := range opts { + opt(&cfg) + } + span, ctx := opentracing.StartSpanFromContext( ctx, cmd.Path, @@ -243,16 +188,19 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io. }() command := &Command{ - cmd: cmd, - startTime: time.Now(), - context: ctx, - span: span, + cmd: cmd, + startTime: time.Now(), + context: ctx, + span: span, + finalizer: cfg.finalizer, + metricsCmd: cfg.commandName, + metricsSubCmd: cfg.subcommandName, } // Export allowed environment variables as set in the Gitaly process. cmd.Env = AllowedEnvironment(os.Environ()) // Append environment variables explicitly requested by the caller. - cmd.Env = append(cmd.Env, env...) + cmd.Env = append(cmd.Env, cfg.environment...) // And finally inject environment variables required for tracing into the command. cmd.Env = envInjector(ctx, cmd.Env) @@ -261,83 +209,112 @@ func New(ctx context.Context, cmd *exec.Cmd, stdin io.Reader, stdout, stderr io. // Three possible values for stdin: // * nil - Go implicitly uses /dev/null - // * SetupStdin - configure with cmd.StdinPipe(), allowing Write() to work + // * stdinSentinel - configure with cmd.StdinPipe(), allowing Write() to work // * Another io.Reader - becomes cmd.Stdin. Write() will not work - if stdin == SetupStdin { + if _, ok := cfg.stdin.(stdinSentinel); ok { pipe, err := cmd.StdinPipe() if err != nil { - return nil, fmt.Errorf("GitCommand: stdin: %v", err) + return nil, fmt.Errorf("creating stdin pipe: %w", err) } + command.writer = pipe - } else if stdin != nil { - cmd.Stdin = stdin + } else if cfg.stdin != nil { + cmd.Stdin = cfg.stdin } - if stdout != nil { + if cfg.stdout != nil { // We don't assign a reader if an stdout override was passed. We assume // output is going to be directly handled by the caller. - cmd.Stdout = stdout + cmd.Stdout = cfg.stdout } else { pipe, err := cmd.StdoutPipe() if err != nil { - return nil, fmt.Errorf("GitCommand: stdout: %v", err) + return nil, fmt.Errorf("creating stdout pipe: %w", err) } + command.reader = pipe } - if stderr != nil { - cmd.Stderr = stderr + if cfg.stderr != nil { + cmd.Stderr = cfg.stderr } else { command.stderrBuffer, err = newStderrBuffer(maxStderrBytes, maxStderrLineLength, []byte("\n")) if err != nil { - return nil, fmt.Errorf("GitCommand: failed to create stderr buffer: %v", err) + return nil, fmt.Errorf("creating stderr buffer: %w", err) } + cmd.Stderr = command.stderrBuffer } if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("GitCommand: start %v: %v", cmd.Args, err) + return nil, fmt.Errorf("starting process %v: %w", cmd.Args, err) } - inFlightCommandGauge.Inc() - // The goroutine below is responsible for terminating and reaping the - // process when ctx is canceled. + inFlightCommandGauge.Inc() commandcounter.Increment() - go func() { - <-ctx.Done() - if process := cmd.Process; process != nil && process.Pid > 0 { - //nolint:errcheck // TODO: do we want to report errors? - // Send SIGTERM to the process group of cmd - syscall.Kill(-process.Pid, syscall.SIGTERM) - } + // The goroutine below is responsible for terminating and reaping the process when ctx is + // canceled. While we must ensure that it does run when `cmd.Start()` was successful, it + // must not run before have fully set up the command. Otherwise, we may end up with racy + // access patterns when the context gets terminated early. + // + // We thus defer spawning the Goroutine. + defer func() { + go func() { + <-ctx.Done() + + if process := cmd.Process; process != nil && process.Pid > 0 { + //nolint:errcheck // TODO: do we want to report errors? + // Send SIGTERM to the process group of cmd + syscall.Kill(-process.Pid, syscall.SIGTERM) + } - // We do not care for any potential erorr code, but just want to make sure that the - // subprocess gets properly killed and processed. - _ = command.Wait() + // We do not care for any potential error code, but just want to make sure that the + // subprocess gets properly killed and processed. + _ = command.Wait() + }() }() + if featureflag.RunCommandsInCGroup.IsEnabled(ctx) && cfg.cgroupsManager != nil { + cgroupPath, err := cfg.cgroupsManager.AddCommand(command, cfg.cgroupsRepo) + if err != nil { + return nil, err + } + + command.cgroupPath = cgroupPath + } + logPid = cmd.Process.Pid return command, nil } -// AllowedEnvironment filters the given slice of environment variables and -// returns all variables which are allowed per the variables defined above. -// This is useful for constructing a base environment in which a command can be -// run. -func AllowedEnvironment(envs []string) []string { - var filtered []string +// Read calls Read() on the stdout pipe of the command. +func (c *Command) Read(p []byte) (int, error) { + if c.reader == nil { + panic("command has no reader") + } - for _, env := range envs { - for _, exportedEnv := range exportedEnvVars { - if strings.HasPrefix(env, exportedEnv+"=") { - filtered = append(filtered, env) - } - } + return c.reader.Read(p) +} + +// Write calls Write() on the stdin pipe of the command. +func (c *Command) Write(p []byte) (int, error) { + if c.writer == nil { + panic("command has no writer") } - return filtered + return c.writer.Write(p) +} + +// Wait calls Wait() on the exec.Cmd instance inside the command. This +// blocks until the command has finished and reports the command exit +// status via the error return value. Use ExitStatus to get the integer +// exit status from the error returned by Wait(). +func (c *Command) Wait() error { + c.waitOnce.Do(c.wait) + + return c.waitError } // This function should never be called directly, use Wait(). @@ -366,21 +343,10 @@ func (c *Command) wait() { // counter again. So we instead do it here to accelerate the process, even though it's less // idiomatic. commandcounter.Decrement() -} - -// ExitStatus will return the exit-code from an error returned by Wait(). -func ExitStatus(err error) (int, bool) { - exitError, ok := err.(*exec.ExitError) - if !ok { - return 0, false - } - waitStatus, ok := exitError.Sys().(syscall.WaitStatus) - if !ok { - return 0, false + if c.finalizer != nil { + c.finalizer(c) } - - return waitStatus.ExitStatus(), true } func (c *Command) logProcessComplete() { @@ -480,6 +446,66 @@ func (c *Command) logProcessComplete() { c.span.Finish() } +// Args is an accessor for the command arguments +func (c *Command) Args() []string { + return c.cmd.Args +} + +// Env is an accessor for the environment variables +func (c *Command) Env() []string { + return c.cmd.Env +} + +// Pid is an accessor for the pid +func (c *Command) Pid() int { + return c.cmd.Process.Pid +} + +type contextWithoutDonePanic string + +var getSpawnTokenAcquiringSeconds = func(t time.Time) float64 { + return time.Since(t).Seconds() +} + +type stdinSentinel struct{} + +func (stdinSentinel) Read([]byte) (int, error) { + return 0, errors.New("stdin sentinel should not be read from") +} + +// AllowedEnvironment filters the given slice of environment variables and +// returns all variables which are allowed per the variables defined above. +// This is useful for constructing a base environment in which a command can be +// run. +func AllowedEnvironment(envs []string) []string { + var filtered []string + + for _, env := range envs { + for _, exportedEnv := range exportedEnvVars { + if strings.HasPrefix(env, exportedEnv+"=") { + filtered = append(filtered, env) + } + } + } + + return filtered +} + +// ExitStatus will return the exit-code from an error returned by Wait(). +func ExitStatus(err error) (int, bool) { + exitError, ok := err.(*exec.ExitError) + if !ok { + return 0, false + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + return 0, false + } + + return waitStatus.ExitStatus(), true +} + func methodFromContext(ctx context.Context) (service string, method string) { tags := grpcmwtags.Extract(ctx) ctxValue := tags.Values()["grpc.request.fullMethod"] @@ -514,18 +540,3 @@ func checkNullArgv(cmd *exec.Cmd) error { return nil } - -// Args is an accessor for the command arguments -func (c *Command) Args() []string { - return c.cmd.Args -} - -// Env is an accessor for the environment variables -func (c *Command) Env() []string { - return c.cmd.Env -} - -// Pid is an accessor for the pid -func (c *Command) Pid() int { - return c.cmd.Process.Pid -} diff --git a/internal/command/command_test.go b/internal/command/command_test.go index 7611e10c6..1f6eae3b0 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -6,7 +6,7 @@ import ( "fmt" "io" "os/exec" - "regexp" + "path/filepath" "runtime" "strings" "testing" @@ -19,30 +19,30 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v15/internal/git/repository" "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper" ) -func TestMain(m *testing.M) { - testhelper.Run(m) -} +func TestNew_environment(t *testing.T) { + t.Parallel() -func TestNewCommandExtraEnv(t *testing.T) { ctx := testhelper.Context(t) extraVar := "FOOBAR=123456" - buff := &bytes.Buffer{} - cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil, extraVar) + + var buf bytes.Buffer + cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf), WithEnvironment([]string{extraVar})) require.NoError(t, err) require.NoError(t, cmd.Wait()) - require.Contains(t, strings.Split(buff.String(), "\n"), extraVar) + require.Contains(t, strings.Split(buf.String(), "\n"), extraVar) } -func TestNewCommandExportedEnv(t *testing.T) { +func TestNew_exportedEnvironment(t *testing.T) { ctx := testhelper.Context(t) - testCases := []struct { + for _, tc := range []struct { key string value string }{ @@ -110,9 +110,7 @@ func TestNewCommandExportedEnv(t *testing.T) { key: "NO_PROXY", value: "https://excluded:5000", }, - } - - for _, tc := range testCases { + } { t.Run(tc.key, func(t *testing.T) { if tc.key == "LD_LIBRARY_PATH" && runtime.GOOS == "darwin" { t.Skip("System Integrity Protection prevents using dynamic linker (dyld) environment variables on macOS. https://apple.co/2XDH4iC") @@ -120,50 +118,41 @@ func TestNewCommandExportedEnv(t *testing.T) { testhelper.ModifyEnvironment(t, tc.key, tc.value) - buff := &bytes.Buffer{} - cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil) + var buf bytes.Buffer + cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf)) require.NoError(t, err) require.NoError(t, cmd.Wait()) expectedEnv := fmt.Sprintf("%s=%s", tc.key, tc.value) - require.Contains(t, strings.Split(buff.String(), "\n"), expectedEnv) + require.Contains(t, strings.Split(buf.String(), "\n"), expectedEnv) }) } } -func TestNewCommandUnexportedEnv(t *testing.T) { +func TestNew_unexportedEnv(t *testing.T) { ctx := testhelper.Context(t) unexportedEnvKey, unexportedEnvVal := "GITALY_UNEXPORTED_ENV", "foobar" testhelper.ModifyEnvironment(t, unexportedEnvKey, unexportedEnvVal) - buff := &bytes.Buffer{} - cmd, err := New(ctx, exec.Command("/usr/bin/env"), nil, buff, nil) - + var buf bytes.Buffer + cmd, err := New(ctx, exec.Command("/usr/bin/env"), WithStdout(&buf)) require.NoError(t, err) require.NoError(t, cmd.Wait()) - require.NotContains(t, strings.Split(buff.String(), "\n"), fmt.Sprintf("%s=%s", unexportedEnvKey, unexportedEnvVal)) + require.NotContains(t, strings.Split(buf.String(), "\n"), fmt.Sprintf("%s=%s", unexportedEnvKey, unexportedEnvVal)) } -func TestRejectEmptyContextDone(t *testing.T) { - defer func() { - p := recover() - if p == nil { - t.Error("expected panic, got none") - return - } - - if _, ok := p.(contextWithoutDonePanic); !ok { - panic(p) - } - }() +func TestNew_rejectContextWithoutDone(t *testing.T) { + t.Parallel() - _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true"), nil, nil, nil) - require.NoError(t, err) + require.PanicsWithValue(t, contextWithoutDonePanic("command spawned with context without Done() channel"), func() { + _, err := New(testhelper.ContextWithoutCancel(), exec.Command("true")) + require.NoError(t, err) + }) } -func TestNewCommandTimeout(t *testing.T) { +func TestNew_spawnTimeout(t *testing.T) { ctx := testhelper.Context(t) defer func(ch chan struct{}, t time.Duration) { @@ -177,39 +166,33 @@ func TestNewCommandTimeout(t *testing.T) { spawnTimeout := 200 * time.Millisecond spawnConfig.Timeout = spawnTimeout - testDeadline := time.After(1 * time.Second) tick := time.After(spawnTimeout / 2) errCh := make(chan error) go func() { - _, err := New(ctx, exec.Command("true"), nil, nil, nil) + _, err := New(ctx, exec.Command("true")) errCh <- err }() - var err error - timePassed := false - -wait: - for { - select { - case err = <-errCh: - break wait - case <-tick: - timePassed = true - case <-testDeadline: - t.Fatal("test timed out") - } + select { + case <-errCh: + require.FailNow(t, "expected spawning to be delayed") + case <-tick: + // This is the happy case: we expect spawning of the command to be delayed by up to + // 200ms until it finally times out. } - require.True(t, timePassed, "time must have passed") - require.Error(t, err) - require.Contains(t, err.Error(), "process spawn timed out after") + // And after some time we expect that spawning of the command fails due to the configured + // timeout. + require.Equal(t, fmt.Errorf("process spawn timed out after 200ms"), <-errCh) } -func TestCommand_Wait_interrupts_after_context_cancellation(t *testing.T) { +func TestCommand_Wait_contextCancellationKillsCommand(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(testhelper.Context(t)) - cmd, err := New(ctx, exec.CommandContext(ctx, "sleep", "1h"), nil, nil, nil) + cmd, err := New(ctx, exec.CommandContext(ctx, "sleep", "1h")) require.NoError(t, err) // Cancel the command early. @@ -222,184 +205,243 @@ func TestCommand_Wait_interrupts_after_context_cancellation(t *testing.T) { require.Equal(t, -1, s) } -func TestNewCommandWithSetupStdin(t *testing.T) { - ctx := testhelper.Context(t) +func TestNew_setupStdin(t *testing.T) { + t.Parallel() - value := "Test value" - output := bytes.NewBuffer(nil) + ctx := testhelper.Context(t) - cmd, err := New(ctx, exec.Command("cat"), SetupStdin, nil, nil) - require.NoError(t, err) + stdin := "Test value" - _, err = fmt.Fprintf(cmd, "%s", value) + var buf bytes.Buffer + cmd, err := New(ctx, exec.Command("cat"), WithSetupStdin(), WithStdout(&buf)) require.NoError(t, err) - // The output of the `cat` subprocess should exactly match its input - _, err = io.CopyN(output, cmd, int64(len(value))) + _, err = fmt.Fprintf(cmd, "%s", stdin) require.NoError(t, err) - require.Equal(t, value, output.String()) require.NoError(t, cmd.Wait()) + require.Equal(t, stdin, buf.String()) } -func TestNewCommandNullInArg(t *testing.T) { +func TestCommand_read(t *testing.T) { + t.Parallel() + ctx := testhelper.Context(t) - _, err := New(ctx, exec.Command("sh", "-c", "hello\x00world"), nil, nil, nil) - require.Error(t, err) - require.EqualError(t, err, `detected null byte in command argument "hello\x00world"`) + cmd, err := New(ctx, exec.Command("echo", "test value")) + require.NoError(t, err) + + output, err := io.ReadAll(cmd) + require.NoError(t, err) + require.Equal(t, "test value\n", string(output)) + + require.NoError(t, cmd.Wait()) } -func TestNewNonExistent(t *testing.T) { +func TestNew_nulByteInArgument(t *testing.T) { + t.Parallel() + ctx := testhelper.Context(t) - cmd, err := New(ctx, exec.Command("command-non-existent"), nil, nil, nil) + cmd, err := New(ctx, exec.Command("sh", "-c", "hello\x00world")) + require.Equal(t, fmt.Errorf("detected null byte in command argument %q", "hello\x00world"), err) require.Nil(t, cmd) - require.Error(t, err) } -func TestCommandStdErr(t *testing.T) { +func TestNew_missingBinary(t *testing.T) { + t.Parallel() + ctx := testhelper.Context(t) - var stdout, stderr bytes.Buffer - expectedMessage := `hello world\nhello world\nhello world\nhello world\nhello world\n` + cmd, err := New(ctx, exec.Command("command-non-existent")) + require.EqualError(t, err, "starting process [command-non-existent]: exec: \"command-non-existent\": executable file not found in $PATH") + require.Nil(t, cmd) +} + +func TestCommand_stderrLogging(t *testing.T) { + t.Parallel() - logger := logrus.New() - logger.SetOutput(&stderr) + binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash + for i in {1..5} + do + echo 'hello world' 1>&2 + done + exit 1 + `)) + logger, hook := test.NewNullLogger() + ctx := testhelper.Context(t) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("./testdata/stderr_script.sh"), nil, &stdout, nil) + var stdout bytes.Buffer + cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout)) require.NoError(t, err) - require.Error(t, cmd.Wait()) - assert.Empty(t, stdout.Bytes()) - require.Equal(t, expectedMessage, extractLastMessage(stderr.String())) + require.EqualError(t, cmd.Wait(), "exit status 1") + require.Empty(t, stdout.Bytes()) + require.Equal(t, strings.Repeat("hello world\n", 5), hook.LastEntry().Message) } -func TestCommandStdErrLargeOutput(t *testing.T) { - ctx := testhelper.Context(t) - - var stdout, stderr bytes.Buffer +func TestCommand_stderrLoggingTruncation(t *testing.T) { + t.Parallel() - logger := logrus.New() - logger.SetOutput(&stderr) + binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash + for i in {1..1000} + do + printf '%06d zzzzzzzzzz\n' $i >&2 + done + exit 1 + `)) + logger, hook := test.NewNullLogger() + ctx := testhelper.Context(t) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("./testdata/stderr_many_lines.sh"), nil, &stdout, nil) + var stdout bytes.Buffer + cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout)) require.NoError(t, err) - require.Error(t, cmd.Wait()) - assert.Empty(t, stdout.Bytes()) - msg := strings.ReplaceAll(extractLastMessage(stderr.String()), "\\n", "\n") - require.LessOrEqual(t, len(msg), maxStderrBytes) + require.Error(t, cmd.Wait()) + require.Empty(t, stdout.Bytes()) + require.Len(t, hook.LastEntry().Message, maxStderrBytes) } -func TestCommandStdErrBinaryNullBytes(t *testing.T) { - ctx := testhelper.Context(t) - - var stdout, stderr bytes.Buffer +func TestCommand_stderrLoggingWithNulBytes(t *testing.T) { + t.Parallel() - logger := logrus.New() - logger.SetOutput(&stderr) + binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash + dd if=/dev/zero bs=1000 count=1000 status=none >&2 + exit 1 + `)) + logger, hook := test.NewNullLogger() + ctx := testhelper.Context(t) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("./testdata/stderr_binary_null.sh"), nil, &stdout, nil) + var stdout bytes.Buffer + cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout)) require.NoError(t, err) - require.Error(t, cmd.Wait()) - assert.Empty(t, stdout.Bytes()) - msg := strings.SplitN(extractLastMessage(stderr.String()), "\\n", 2)[0] - require.Equal(t, strings.Repeat("\\x00", maxStderrLineLength), msg) + require.Error(t, cmd.Wait()) + require.Empty(t, stdout.Bytes()) + require.Equal(t, strings.Repeat("\x00", maxStderrLineLength), hook.LastEntry().Message) } -func TestCommandStdErrLongLine(t *testing.T) { - ctx := testhelper.Context(t) - - var stdout, stderr bytes.Buffer +func TestCommand_stderrLoggingLongLine(t *testing.T) { + t.Parallel() - logger := logrus.New() - logger.SetOutput(&stderr) + binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash + printf 'a%.0s' {1..8192} >&2 + printf '\n' >&2 + printf 'b%.0s' {1..8192} >&2 + exit 1 + `)) + logger, hook := test.NewNullLogger() + ctx := testhelper.Context(t) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("./testdata/stderr_repeat_a.sh"), nil, &stdout, nil) + var stdout bytes.Buffer + cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout)) require.NoError(t, err) - require.Error(t, cmd.Wait()) - assert.Empty(t, stdout.Bytes()) - require.Contains(t, stderr.String(), fmt.Sprintf("%s\\n%s", strings.Repeat("a", maxStderrLineLength), strings.Repeat("b", maxStderrLineLength))) + require.Error(t, cmd.Wait()) + require.Empty(t, stdout.Bytes()) + require.Equal(t, + strings.Join([]string{ + strings.Repeat("a", maxStderrLineLength), + strings.Repeat("b", maxStderrLineLength), + }, "\n"), + hook.LastEntry().Message, + ) } -func TestCommandStdErrMaxBytes(t *testing.T) { - ctx := testhelper.Context(t) - - var stdout, stderr bytes.Buffer - - logger := logrus.New() - logger.SetOutput(&stderr) +func TestCommand_stderrLoggingMaxBytes(t *testing.T) { + t.Parallel() + + binaryPath := testhelper.WriteExecutable(t, filepath.Join(testhelper.TempDir(t), "script"), []byte(`#!/bin/bash + # This script is used to test that a command writes at most maxBytes to stderr. It + # simulates the edge case where the logwriter has already written MaxStderrBytes-1 + # (9999) bytes + + # This edge case happens when 9999 bytes are written. To simulate this, + # stderr_max_bytes_edge_case has 4 lines of the following format: + # + # line1: 3333 bytes long + # line2: 3331 bytes + # line3: 3331 bytes + # line4: 1 byte + # + # The first 3 lines sum up to 9999 bytes written, since we write a 2-byte escaped + # "\n" for each \n we see. The 4th line can be any data. + + printf 'a%.0s' {1..3333} >&2 + printf '\n' >&2 + printf 'a%.0s' {1..3331} >&2 + printf '\n' >&2 + printf 'a%.0s' {1..3331} >&2 + printf '\na\n' >&2 + exit 1 + `)) + logger, hook := test.NewNullLogger() + ctx := testhelper.Context(t) ctx = ctxlogrus.ToContext(ctx, logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("./testdata/stderr_max_bytes_edge_case.sh"), nil, &stdout, nil) + var stdout bytes.Buffer + cmd, err := New(ctx, exec.Command(binaryPath), WithStdout(&stdout)) require.NoError(t, err) require.Error(t, cmd.Wait()) - assert.Empty(t, stdout.Bytes()) - message := extractLastMessage(stderr.String()) - require.Equal(t, maxStderrBytes, len(strings.ReplaceAll(message, "\\n", "\n"))) + require.Empty(t, stdout.Bytes()) + require.Len(t, hook.LastEntry().Message, maxStderrBytes) } -var logMsgRegex = regexp.MustCompile(`msg="(.+?)"`) - -func extractLastMessage(logMessage string) string { - subMatchesAll := logMsgRegex.FindAllStringSubmatch(logMessage, -1) - if len(subMatchesAll) < 1 { - return "" - } - - subMatches := subMatchesAll[len(subMatchesAll)-1] - if len(subMatches) != 2 { - return "" - } +type mockCgroupManager string - return subMatches[1] +func (m mockCgroupManager) AddCommand(*Command, repository.GitRepo) (string, error) { + return string(m), nil } func TestCommand_logMessage(t *testing.T) { + t.Parallel() + logger, hook := test.NewNullLogger() logger.SetLevel(logrus.DebugLevel) ctx := ctxlogrus.ToContext(testhelper.Context(t), logrus.NewEntry(logger)) - cmd, err := New(ctx, exec.Command("echo", "hello world"), nil, nil, nil) + cmd, err := New(ctx, exec.Command("echo", "hello world"), + WithCgroup(mockCgroupManager("/sys/fs/cgroup/1"), nil), + ) require.NoError(t, err) - cgroupPath := "/sys/fs/cgroup/1" - cmd.SetCgroupPath(cgroupPath) require.NoError(t, cmd.Wait()) logEntry := hook.LastEntry() assert.Equal(t, cmd.Pid(), logEntry.Data["pid"]) assert.Equal(t, []string{"echo", "hello world"}, logEntry.Data["args"]) assert.Equal(t, 0, logEntry.Data["command.exitCode"]) - assert.Equal(t, cgroupPath, logEntry.Data["command.cgroup_path"]) + assert.Equal(t, "/sys/fs/cgroup/1", logEntry.Data["command.cgroup_path"]) } -func TestNewCommandSpawnTokenMetrics(t *testing.T) { - spawnTokenAcquiringSeconds.Reset() +func TestNew_commandSpawnTokenMetrics(t *testing.T) { + defer func(old func(time.Time) float64) { + getSpawnTokenAcquiringSeconds = old + }(getSpawnTokenAcquiringSeconds) - ctx := testhelper.Context(t) getSpawnTokenAcquiringSeconds = func(t time.Time) float64 { return 1 } + spawnTokenAcquiringSeconds.Reset() + + ctx := testhelper.Context(t) + tags := grpcmwtags.NewTags() tags.Set("grpc.request.fullMethod", "/test.Service/TestRPC") ctx = grpcmwtags.SetInContext(ctx, tags) - cmd, err := New(ctx, exec.Command("echo", "goodbye, cruel world."), nil, nil, nil) + cmd, err := New(ctx, exec.Command("echo", "goodbye, cruel world.")) require.NoError(t, err) require.NoError(t, cmd.Wait()) @@ -408,7 +450,7 @@ func TestNewCommandSpawnTokenMetrics(t *testing.T) { # TYPE gitaly_command_spawn_token_acquiring_seconds_total counter gitaly_command_spawn_token_acquiring_seconds_total{cmd="echo",grpc_method="TestRPC",grpc_service="test.Service"} 1 ` - assert.NoError( + require.NoError( t, testutil.CollectAndCompare( spawnTokenAcquiringSeconds, diff --git a/internal/command/option.go b/internal/command/option.go new file mode 100644 index 000000000..08acf0f33 --- /dev/null +++ b/internal/command/option.go @@ -0,0 +1,94 @@ +package command + +import ( + "io" + + "gitlab.com/gitlab-org/gitaly/v15/internal/git/repository" +) + +type config struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + environment []string + + finalizer func(*Command) + + commandName string + subcommandName string + + cgroupsManager CgroupsManager + cgroupsRepo repository.GitRepo +} + +// Option is an option that can be passed to `New()` for controlling how the command is being +// created. +type Option func(cfg *config) + +// WithStdin sets up the command to read from the given reader. +func WithStdin(stdin io.Reader) Option { + return func(cfg *config) { + cfg.stdin = stdin + } +} + +// WithSetupStdin instructs New() to configure the stdin pipe of the command it is creating. This +// allows you call Write() on the command as if it is an ordinary io.Writer, sending data directly +// to the stdin of the process. +func WithSetupStdin() Option { + return func(cfg *config) { + cfg.stdin = stdinSentinel{} + } +} + +// WithStdout sets up the command to write standard output to the given writer. +func WithStdout(stdout io.Writer) Option { + return func(cfg *config) { + cfg.stdout = stdout + } +} + +// WithStderr sets up the command to write standard error to the given writer. +func WithStderr(stderr io.Writer) Option { + return func(cfg *config) { + cfg.stderr = stderr + } +} + +// WithEnvironment sets up environment variables that shall be set for the command. +func WithEnvironment(environment []string) Option { + return func(cfg *config) { + cfg.environment = environment + } +} + +// WithCommandName overrides the "cmd" and "subcmd" label used in metrics. +func WithCommandName(commandName, subcommandName string) Option { + return func(cfg *config) { + cfg.commandName = commandName + cfg.subcommandName = subcommandName + } +} + +// CgroupsManager is a subset of the `cgroups.Manager` interface. We need to replicate it here to +// avoid a cyclic dependency between both packages. +type CgroupsManager interface { + AddCommand(*Command, repository.GitRepo) (string, error) +} + +// WithCgroup adds the spawned command to a Cgroup. The bucket used will be derived from the +// command's arguments and/or from the repository. +func WithCgroup(cgroupsManager CgroupsManager, repo repository.GitRepo) Option { + return func(cfg *config) { + cfg.cgroupsManager = cgroupsManager + cfg.cgroupsRepo = repo + } +} + +// WithFinalizer sets up the finalizer to be run when the command is being wrapped up. It will be +// called after `Wait()` has returned. +func WithFinalizer(finalizer func(*Command)) Option { + return func(cfg *config) { + cfg.finalizer = finalizer + } +} diff --git a/internal/command/testdata/stderr_binary_null.sh b/internal/command/testdata/stderr_binary_null.sh deleted file mode 100755 index e8e1d0baa..000000000 --- a/internal/command/testdata/stderr_binary_null.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -dd if=/dev/zero bs=1000 count=1000 >&2 -exit 1 diff --git a/internal/command/testdata/stderr_many_lines.sh b/internal/command/testdata/stderr_many_lines.sh deleted file mode 100755 index e8b51b646..000000000 --- a/internal/command/testdata/stderr_many_lines.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -let x=0 -while [ $x -lt 100010 ] -do - let x=x+1 - printf '%06d zzzzzzzzzz\n' $x >&2 -done -exit 1 diff --git a/internal/command/testdata/stderr_max_bytes_edge_case.sh b/internal/command/testdata/stderr_max_bytes_edge_case.sh deleted file mode 100755 index 6e1454966..000000000 --- a/internal/command/testdata/stderr_max_bytes_edge_case.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# This script is used to test that a command writes at most maxBytes to stderr. It simulates the -# edge case where the logwriter has already written MaxStderrBytes-1 (9999) bytes - -# This edge case happens when 9999 bytes are written. To simulate this, stderr_max_bytes_edge_case has 4 lines of the following format: -# line1: 3333 bytes long -# line2: 3331 bytes -# line3: 3331 bytes -# line4: 1 byte -# The first 3 lines sum up to 9999 bytes written, since we write a 2-byte escaped `\n` for each \n we see. -# The 4th line can be any data. - -printf 'a%.0s' {1..3333} >&2 -printf '\n' >&2 -printf 'a%.0s' {1..3331} >&2 -printf '\n' >&2 -printf 'a%.0s' {1..3331} >&2 -printf '\na\n' >&2 -exit 1 diff --git a/internal/command/testdata/stderr_repeat_a.sh b/internal/command/testdata/stderr_repeat_a.sh deleted file mode 100755 index 8463929ae..000000000 --- a/internal/command/testdata/stderr_repeat_a.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -printf 'a%.0s' {1..8192} >&2 -printf '\n' >&2 -printf 'b%.0s' {1..8192} >&2 -exit 1 diff --git a/internal/command/testdata/stderr_script.sh b/internal/command/testdata/stderr_script.sh deleted file mode 100755 index 57abd97d0..000000000 --- a/internal/command/testdata/stderr_script.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -for i in {1..5} -do - echo 'hello world' 1>&2 -done -exit 1 diff --git a/internal/command/testhelper_test.go b/internal/command/testhelper_test.go new file mode 100644 index 000000000..e0398c2bd --- /dev/null +++ b/internal/command/testhelper_test.go @@ -0,0 +1,11 @@ +package command + +import ( + "testing" + + "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper" +) + +func TestMain(m *testing.M) { + testhelper.Run(m) +} diff --git a/internal/git/catfile/object_info_reader.go b/internal/git/catfile/object_info_reader.go index f391bd128..84857ce70 100644 --- a/internal/git/catfile/object_info_reader.go +++ b/internal/git/catfile/object_info_reader.go @@ -146,7 +146,7 @@ func newObjectInfoReader( git.Flag{Name: "--buffer"}, }, }, - git.WithStdin(command.SetupStdin), + git.WithSetupStdin(), ) if err != nil { return nil, err diff --git a/internal/git/catfile/object_reader.go b/internal/git/catfile/object_reader.go index 4810a8207..516050301 100644 --- a/internal/git/catfile/object_reader.go +++ b/internal/git/catfile/object_reader.go @@ -67,7 +67,7 @@ func newObjectReader( git.Flag{Name: "--buffer"}, }, }, - git.WithStdin(command.SetupStdin), + git.WithSetupStdin(), ) if err != nil { return nil, err diff --git a/internal/git/command_factory.go b/internal/git/command_factory.go index 2f2ded567..86918985f 100644 --- a/internal/git/command_factory.go +++ b/internal/git/command_factory.go @@ -17,7 +17,6 @@ import ( "gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/storage" "gitlab.com/gitlab-org/gitaly/v15/internal/log" - "gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag" ) // CommandFactory is designed to create and run git commands in a protected and fully managed manner. @@ -341,7 +340,7 @@ func (cf *ExecCommandFactory) GitVersion(ctx context.Context) (Version, error) { // Furthermore, note that we're not using `newCommand()` but instead hand-craft the command. // This is required to avoid a cyclic dependency when we need to check the version in // `newCommand()` itself. - cmd, err := command.New(ctx, exec.Command(execEnv.BinaryPath, "version"), nil, nil, nil, execEnv.EnvironmentVariables...) + cmd, err := command.New(ctx, exec.Command(execEnv.BinaryPath, "version"), command.WithEnvironment(execEnv.EnvironmentVariables)) if err != nil { return Version{}, fmt.Errorf("spawning version command: %w", err) } @@ -397,19 +396,16 @@ func (cf *ExecCommandFactory) newCommand(ctx context.Context, repo repository.Gi execCommand := exec.Command(execEnv.BinaryPath, args...) execCommand.Dir = dir - command, err := command.New(ctx, execCommand, config.stdin, config.stdout, config.stderr, env...) + command, err := command.New(ctx, execCommand, append( + config.commandOpts, + command.WithEnvironment(env), + command.WithCommandName("git", sc.Subcommand()), + command.WithCgroup(cf.cgroupsManager, repo), + )...) if err != nil { return nil, err } - command.SetMetricsSubCmd(sc.Subcommand()) - - if featureflag.RunCommandsInCGroup.IsEnabled(ctx) { - if err := cf.cgroupsManager.AddCommand(command, repo); err != nil { - return nil, err - } - } - return command, nil } diff --git a/internal/git/command_factory_cgroup_test.go b/internal/git/command_factory_cgroup_test.go index 69c3126bf..e504880b6 100644 --- a/internal/git/command_factory_cgroup_test.go +++ b/internal/git/command_factory_cgroup_test.go @@ -24,9 +24,9 @@ func (m *mockCgroupsManager) Setup() error { return nil } -func (m *mockCgroupsManager) AddCommand(c *command.Command, repo repository.GitRepo) error { +func (m *mockCgroupsManager) AddCommand(c *command.Command, repo repository.GitRepo) (string, error) { m.commands = append(m.commands, c) - return nil + return "", nil } func (m *mockCgroupsManager) Cleanup() error { diff --git a/internal/git/command_options.go b/internal/git/command_options.go index 44aecd810..044ae754d 100644 --- a/internal/git/command_options.go +++ b/internal/git/command_options.go @@ -9,6 +9,7 @@ import ( "regexp" "strings" + "gitlab.com/gitlab-org/gitaly/v15/internal/command" "gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/storage" "gitlab.com/gitlab-org/gitaly/v15/internal/helper" @@ -168,9 +169,7 @@ func ConvertConfigOptions(options []string) ([]ConfigPair, error) { type cmdCfg struct { env []string globals []GlobalOption - stdin io.Reader - stdout io.Writer - stderr io.Writer + commandOpts []command.Option hooksConfigured bool } @@ -181,7 +180,15 @@ type CmdOpt func(context.Context, config.Cfg, CommandFactory, *cmdCfg) error // command suitable for `Write()`ing to. func WithStdin(r io.Reader) CmdOpt { return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error { - c.stdin = r + c.commandOpts = append(c.commandOpts, command.WithStdin(r)) + return nil + } +} + +// WithSetupStdin sets up the command so that it can be `Write()`en to. +func WithSetupStdin() CmdOpt { + return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error { + c.commandOpts = append(c.commandOpts, command.WithSetupStdin()) return nil } } @@ -189,7 +196,7 @@ func WithStdin(r io.Reader) CmdOpt { // WithStdout sets the command's stdout. func WithStdout(w io.Writer) CmdOpt { return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error { - c.stdout = w + c.commandOpts = append(c.commandOpts, command.WithStdout(w)) return nil } } @@ -197,7 +204,7 @@ func WithStdout(w io.Writer) CmdOpt { // WithStderr sets the command's stderr. func WithStderr(w io.Writer) CmdOpt { return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error { - c.stderr = w + c.commandOpts = append(c.commandOpts, command.WithStderr(w)) return nil } } @@ -302,3 +309,12 @@ func withInternalFetch(req repoScopedRequest, withSidechannel bool) func(ctx con return nil } } + +// WithFinalizer sets up the finalizer to be run when the command is being wrapped up. It will be +// called after `Wait()` has returned. +func WithFinalizer(finalizer func(*command.Command)) CmdOpt { + return func(_ context.Context, _ config.Cfg, _ CommandFactory, c *cmdCfg) error { + c.commandOpts = append(c.commandOpts, command.WithFinalizer(finalizer)) + return nil + } +} diff --git a/internal/git/objectpool/fetch.go b/internal/git/objectpool/fetch.go index 6aedcc7e5..549522809 100644 --- a/internal/git/objectpool/fetch.go +++ b/internal/git/objectpool/fetch.go @@ -226,7 +226,7 @@ func (o *ObjectPool) logStats(ctx context.Context, when string) error { func sizeDir(ctx context.Context, dir string) (int64, error) { // du -k reports size in KB - cmd, err := command.New(ctx, exec.Command("du", "-sk", dir), nil, nil, nil) + cmd, err := command.New(ctx, exec.Command("du", "-sk", dir)) if err != nil { return 0, err } diff --git a/internal/git/packfile/index.go b/internal/git/packfile/index.go index 2aa7f4be7..e8915e1d3 100644 --- a/internal/git/packfile/index.go +++ b/internal/git/packfile/index.go @@ -93,7 +93,7 @@ func ReadIndex(idxPath string) (*Index, error) { return nil, err } - showIndex, err := command.New(ctx, exec.Command("git", "show-index"), f, nil, nil) + showIndex, err := command.New(ctx, exec.Command("git", "show-index"), command.WithStdin(f)) if err != nil { return nil, err } diff --git a/internal/git/pktline/read_monitor.go b/internal/git/pktline/read_monitor.go index e288b2530..9b371a85f 100644 --- a/internal/git/pktline/read_monitor.go +++ b/internal/git/pktline/read_monitor.go @@ -40,27 +40,22 @@ type ReadMonitor struct { // to the pipe. The stream will be monitored for a pktline-formatted packet // matching pkt. If it isn't seen within the timeout, cancelFn will be called. // -// Resources will be freed when the context is done, but you should close the -// returned *os.File earlier if possible. -func NewReadMonitor(ctx context.Context, r io.Reader) (*os.File, *ReadMonitor, error) { +// The returned function will release allocated resources. You must make sure to call this +// function. +func NewReadMonitor(ctx context.Context, r io.Reader) (*os.File, *ReadMonitor, func(), error) { pr, pw, err := os.Pipe() if err != nil { - return nil, nil, err + return nil, nil, nil, err } - // Ensure all resources are closed once the context is done - go func() { - <-ctx.Done() - - pr.Close() - pw.Close() - }() - return pr, &ReadMonitor{ - pr: pr, - pw: pw, - underlying: r, - }, nil + pr: pr, + pw: pw, + underlying: r, + }, func() { + pr.Close() + pw.Close() + }, nil } // Monitor should be called at most once. It scans the stream, looking for the diff --git a/internal/git/pktline/read_monitor_test.go b/internal/git/pktline/read_monitor_test.go index e78bdd8c5..65a0ec3ab 100644 --- a/internal/git/pktline/read_monitor_test.go +++ b/internal/git/pktline/read_monitor_test.go @@ -24,7 +24,7 @@ func TestReadMonitorTimeout(t *testing.T) { waitPipeR, // this pipe reader lets us block the multi reader ) - r, monitor, err := NewReadMonitor(ctx, in) + r, monitor, cleanup, err := NewReadMonitor(ctx, in) require.NoError(t, err) timeoutTicker := helper.NewManualTicker() @@ -41,6 +41,8 @@ func TestReadMonitorTimeout(t *testing.T) { require.Equal(t, ctx.Err(), context.Canceled) require.True(t, elapsed < time.Second, "Expected context to be cancelled quickly, but it was not") + cleanup() + // Verify that pipe is closed _, err = io.ReadAll(r) require.Error(t, err) @@ -61,8 +63,9 @@ func TestReadMonitorSuccess(t *testing.T) { strings.NewReader(postTimeoutPayload), ) - r, monitor, err := NewReadMonitor(ctx, in) + r, monitor, cleanup, err := NewReadMonitor(ctx, in) require.NoError(t, err) + defer cleanup() timeoutTicker := helper.NewManualTicker() diff --git a/internal/git/updateref/updateref.go b/internal/git/updateref/updateref.go index a00f78c3e..cc97513cc 100644 --- a/internal/git/updateref/updateref.go +++ b/internal/git/updateref/updateref.go @@ -63,7 +63,7 @@ func New(ctx context.Context, repo git.RepositoryExecutor, opts ...UpdaterOpt) ( Flags: []git.Option{git.Flag{Name: "-z"}, git.Flag{Name: "--stdin"}}, }, txOption, - git.WithStdin(command.SetupStdin), + git.WithSetupStdin(), git.WithStderr(&stderr), ) if err != nil { diff --git a/internal/git2go/executor.go b/internal/git2go/executor.go index 75d451fe0..c1e220f43 100644 --- a/internal/git2go/executor.go +++ b/internal/git2go/executor.go @@ -88,13 +88,17 @@ func (b *Executor) run(ctx context.Context, repo repository.GitRepo, stdin io.Re }, args...) var stdout bytes.Buffer - cmd, err := command.New(ctx, exec.Command(b.binaryPath, args...), stdin, &stdout, log, env...) + cmd, err := command.New(ctx, exec.Command(b.binaryPath, args...), + command.WithStdin(stdin), + command.WithStdout(&stdout), + command.WithStderr(log), + command.WithEnvironment(env), + command.WithCommandName("gitaly-git2go", subcmd), + ) if err != nil { return nil, err } - cmd.SetMetricsSubCmd(subcmd) - if err := cmd.Wait(); err != nil { return nil, err } diff --git a/internal/gitaly/hook/custom.go b/internal/gitaly/hook/custom.go index 327e014de..2c819dd3e 100644 --- a/internal/gitaly/hook/custom.go +++ b/internal/gitaly/hook/custom.go @@ -68,13 +68,17 @@ func (m *GitLabHookManager) newCustomHooksExecutor(repo *gitalypb.Repository, ho for _, hookFile := range hookFiles { cmd := exec.Command(hookFile, args...) cmd.Dir = repoPath - c, err := command.New(ctx, cmd, bytes.NewReader(stdinBytes), stdout, stderr, env...) + c, err := command.New(ctx, cmd, + command.WithStdin(bytes.NewReader(stdinBytes)), + command.WithStdout(stdout), + command.WithStderr(stderr), + command.WithEnvironment(env), + command.WithCommandName("gitaly-hooks", hookName), + ) if err != nil { return err } - c.SetMetricsSubCmd(hookName) - if err = c.Wait(); err != nil { // Custom hook errors need to be handled specially when we update // refs via updateref.UpdaterWithHooks: their stdout and stderr must diff --git a/internal/gitaly/linguist/linguist.go b/internal/gitaly/linguist/linguist.go index 4ac9abfd2..da085b58a 100644 --- a/internal/gitaly/linguist/linguist.go +++ b/internal/gitaly/linguist/linguist.go @@ -95,14 +95,14 @@ func (inst *Instance) startGitLinguist(ctx context.Context, repoPath string, com cmd := exec.Command(bundle, "exec", "bin/gitaly-linguist", "--repository="+repoPath, "--commit="+commitID) cmd.Dir = inst.cfg.Ruby.Dir - internalCmd, err := command.New(ctx, cmd, nil, nil, nil, env.AllowedRubyEnvironment(os.Environ())...) + internalCmd, err := command.New(ctx, cmd, + command.WithEnvironment(env.AllowedRubyEnvironment(os.Environ())), + command.WithCommandName("git-linguist", "stats"), + ) if err != nil { return nil, fmt.Errorf("creating command: %w", err) } - internalCmd.SetMetricsCmd("git-linguist") - internalCmd.SetMetricsSubCmd("stats") - return internalCmd, nil } diff --git a/internal/gitaly/service/repository/archive.go b/internal/gitaly/service/repository/archive.go index 86250e29b..959064237 100644 --- a/internal/gitaly/service/repository/archive.go +++ b/internal/gitaly/service/repository/archive.go @@ -230,7 +230,9 @@ func (s *server) handleArchive(ctx context.Context, p archiveParams) error { } if p.compressCmd != nil { - command, err := command.New(ctx, p.compressCmd, archiveCommand, p.writer, nil) + command, err := command.New(ctx, p.compressCmd, + command.WithStdin(archiveCommand), command.WithStdout(p.writer), + ) if err != nil { return err } diff --git a/internal/gitaly/service/repository/backup_custom_hooks.go b/internal/gitaly/service/repository/backup_custom_hooks.go index da774f89b..e434e6f65 100644 --- a/internal/gitaly/service/repository/backup_custom_hooks.go +++ b/internal/gitaly/service/repository/backup_custom_hooks.go @@ -30,7 +30,7 @@ func (s *server) BackupCustomHooks(in *gitalypb.BackupCustomHooksRequest, stream ctx := stream.Context() tar := exec.Command("tar", "-c", "-f", "-", "-C", repoPath, customHooksDir) - cmd, err := command.New(ctx, tar, nil, writer, nil) + cmd, err := command.New(ctx, tar, command.WithStdout(writer)) if err != nil { return status.Errorf(codes.Internal, "%v", err) } diff --git a/internal/gitaly/service/repository/create_repository_from_snapshot.go b/internal/gitaly/service/repository/create_repository_from_snapshot.go index 021bb2cec..f665faa87 100644 --- a/internal/gitaly/service/repository/create_repository_from_snapshot.go +++ b/internal/gitaly/service/repository/create_repository_from_snapshot.go @@ -67,7 +67,7 @@ func untar(ctx context.Context, path string, in *gitalypb.CreateRepositoryFromSn return status.Errorf(codes.Internal, "HTTP server: %v", rsp.Status) } - cmd, err := command.New(ctx, exec.Command("tar", "-C", path, "-xvf", "-"), rsp.Body, nil, nil) + cmd, err := command.New(ctx, exec.Command("tar", "-C", path, "-xvf", "-"), command.WithStdin(rsp.Body)) if err != nil { return err } diff --git a/internal/gitaly/service/repository/replicate.go b/internal/gitaly/service/repository/replicate.go index 851ba8b4c..6c66bcd74 100644 --- a/internal/gitaly/service/repository/replicate.go +++ b/internal/gitaly/service/repository/replicate.go @@ -188,7 +188,9 @@ func (s *server) extractSnapshot(ctx context.Context, source, target *gitalypb.R } stderr := &bytes.Buffer{} - cmd, err := command.New(ctx, exec.Command("tar", "-C", targetPath, "-xvf", "-"), snapshotReader, nil, stderr) + cmd, err := command.New(ctx, exec.Command("tar", "-C", targetPath, "-xvf", "-"), + command.WithStdin(snapshotReader), command.WithStderr(stderr), + ) if err != nil { return fmt.Errorf("create tar command: %w", err) } diff --git a/internal/gitaly/service/repository/restore_custom_hooks.go b/internal/gitaly/service/repository/restore_custom_hooks.go index 54d73ddc0..572181ffe 100644 --- a/internal/gitaly/service/repository/restore_custom_hooks.go +++ b/internal/gitaly/service/repository/restore_custom_hooks.go @@ -62,7 +62,7 @@ func (s *server) RestoreCustomHooks(stream gitalypb.RepositoryService_RestoreCus } ctx := stream.Context() - cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), reader, nil, nil) + cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), command.WithStdin(reader)) if err != nil { return status.Errorf(codes.Internal, "RestoreCustomHooks: Could not untar custom hooks tar %v", err) } @@ -149,7 +149,7 @@ func (s *server) restoreCustomHooksWithVoting(stream gitalypb.RepositoryService_ customHooksDir, } - cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), reader, nil, nil) + cmd, err := command.New(ctx, exec.Command("tar", cmdArgs...), command.WithStdin(reader)) if err != nil { return helper.ErrInternalf("RestoreCustomHooks: Could not untar custom hooks tar %w", err) } diff --git a/internal/gitaly/service/repository/size.go b/internal/gitaly/service/repository/size.go index d8476274a..95bf22402 100644 --- a/internal/gitaly/service/repository/size.go +++ b/internal/gitaly/service/repository/size.go @@ -67,7 +67,7 @@ func (s *server) GetObjectDirectorySize(ctx context.Context, in *gitalypb.GetObj } func getPathSize(ctx context.Context, path string) int64 { - cmd, err := command.New(ctx, exec.Command("du", "-sk", path), nil, nil, nil) + cmd, err := command.New(ctx, exec.Command("du", "-sk", path)) if err != nil { ctxlogrus.Extract(ctx).WithError(err).Warn("ignoring du command error") return 0 diff --git a/internal/gitaly/service/ssh/monitor_stdin_command.go b/internal/gitaly/service/ssh/monitor_stdin_command.go index 3d78218ba..5bad1df27 100644 --- a/internal/gitaly/service/ssh/monitor_stdin_command.go +++ b/internal/gitaly/service/ssh/monitor_stdin_command.go @@ -11,16 +11,20 @@ import ( ) func monitorStdinCommand(ctx context.Context, gitCmdFactory git.CommandFactory, stdin io.Reader, stdout, stderr io.Writer, sc git.SubCmd, opts ...git.CmdOpt) (*command.Command, *pktline.ReadMonitor, error) { - stdinPipe, monitor, err := pktline.NewReadMonitor(ctx, stdin) + stdinPipe, monitor, cleanup, err := pktline.NewReadMonitor(ctx, stdin) if err != nil { return nil, nil, fmt.Errorf("create monitor: %v", err) } cmd, err := gitCmdFactory.NewWithoutRepo(ctx, sc, append([]git.CmdOpt{ - git.WithStdin(stdinPipe), git.WithStdout(stdout), git.WithStderr(stderr), + git.WithStdin(stdinPipe), + git.WithStdout(stdout), + git.WithStderr(stderr), + git.WithFinalizer(func(*command.Command) { cleanup() }), }, opts...)...) stdinPipe.Close() // this now belongs to cmd if err != nil { + cleanup() return nil, nil, fmt.Errorf("start cmd: %v", err) } diff --git a/internal/gitaly/service/ssh/upload_pack.go b/internal/gitaly/service/ssh/upload_pack.go index 1b0ef278f..89e67e9a9 100644 --- a/internal/gitaly/service/ssh/upload_pack.go +++ b/internal/gitaly/service/ssh/upload_pack.go @@ -98,10 +98,14 @@ func (s *server) sshUploadPack(ctx context.Context, req sshUploadPackRequest, st return 0, err } + var wg sync.WaitGroup pr, pw := io.Pipe() - defer pw.Close() + defer func() { + pw.Close() + wg.Wait() + }() + stdin = io.TeeReader(stdin, pw) - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -146,16 +150,10 @@ func (s *server) sshUploadPack(ctx context.Context, req sshUploadPackRequest, st go monitor.Monitor(ctx, pktline.PktDone(), timeoutTicker, cancelCtx) if err := cmd.Wait(); err != nil { - pw.Close() - wg.Wait() - status, _ := command.ExitStatus(err) return status, fmt.Errorf("cmd wait: %w, stderr: %q", err, stderrBuilder.String()) } - pw.Close() - wg.Wait() - ctxlogrus.Extract(ctx).WithField("response_bytes", stdoutCounter.N).Info("request details") return 0, nil diff --git a/internal/gitaly/service/ssh/upload_pack_test.go b/internal/gitaly/service/ssh/upload_pack_test.go index dc4d36268..113cd816a 100644 --- a/internal/gitaly/service/ssh/upload_pack_test.go +++ b/internal/gitaly/service/ssh/upload_pack_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -14,17 +15,21 @@ import ( "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" + gitalyauth "gitlab.com/gitlab-org/gitaly/v15/auth" "gitlab.com/gitlab-org/gitaly/v15/internal/git" "gitlab.com/gitlab-org/gitaly/v15/internal/git/gittest" "gitlab.com/gitlab-org/gitaly/v15/internal/gitaly/config" "gitlab.com/gitlab-org/gitaly/v15/internal/helper" "gitlab.com/gitlab-org/gitaly/v15/internal/helper/text" "gitlab.com/gitlab-org/gitaly/v15/internal/metadata/featureflag" + "gitlab.com/gitlab-org/gitaly/v15/internal/sidechannel" "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper" "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper/testcfg" "gitlab.com/gitlab-org/gitaly/v15/internal/testhelper/testserver" "gitlab.com/gitlab-org/gitaly/v15/proto/go/gitalypb" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/encoding/protojson" ) @@ -125,6 +130,234 @@ func testUploadPackTimeout(t *testing.T, opts ...testcfg.Option) { }) } +func TestUploadPackWithSidechannel_client(t *testing.T) { + t.Parallel() + + cfg := testcfg.Build(t) + cfg.SocketPath = runSSHServer(t, cfg) + + repo, repoPath := gittest.CreateRepository(testhelper.Context(t), t, cfg, gittest.CreateRepositoryConfig{ + Seed: gittest.SeedGitLabTest, + }) + commitID := gittest.Exec(t, cfg, "-C", repoPath, "rev-parse", "HEAD^{commit}") + + registry := sidechannel.NewRegistry() + clientHandshaker := sidechannel.NewClientHandshaker(testhelper.NewDiscardingLogEntry(t), registry) + conn, err := grpc.Dial(cfg.SocketPath, + grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())), + grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(cfg.Auth.Token)), + ) + require.NoError(t, err) + + client := gitalypb.NewSSHServiceClient(conn) + defer testhelper.MustClose(t, conn) + + for _, tc := range []struct { + desc string + request *gitalypb.SSHUploadPackWithSidechannelRequest + client func(clientConn *sidechannel.ClientConn, cancelContext func()) error + expectedErr error + expectedResponse *gitalypb.SSHUploadPackWithSidechannelResponse + }{ + { + desc: "successful clone", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+" multi_ack\n") + gittest.WritePktlineFlush(t, clientConn) + gittest.WritePktlineString(t, clientConn, "done\n") + + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{}, + }, + { + desc: "successful clone with protocol v2", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n") + gittest.WritePktlineString(t, clientConn, "object-format=sha1\n") + gittest.WritePktlineDelim(t, clientConn) + gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+"\n") + gittest.WritePktlineString(t, clientConn, "done\n") + gittest.WritePktlineFlush(t, clientConn) + + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{}, + }, + { + desc: "client talks protocol v0 but v2 is requested", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+" multi_ack\n") + gittest.WritePktlineFlush(t, clientConn) + gittest.WritePktlineString(t, clientConn, "done\n") + + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedErr: helper.ErrInternalf( + "cmd wait: exit status 128, stderr: %q", + "fatal: unknown capability 'want 1e292f8fedd741b75372e19097c76d327140c312 multi_ack'\n", + ), + }, + { + desc: "client talks protocol v2 but v0 is requested", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n") + gittest.WritePktlineString(t, clientConn, "object-format=sha1\n") + gittest.WritePktlineDelim(t, clientConn) + gittest.WritePktlineString(t, clientConn, "want "+text.ChompBytes(commitID)+"\n") + gittest.WritePktlineString(t, clientConn, "done\n") + gittest.WritePktlineFlush(t, clientConn) + + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedErr: helper.ErrInternalf( + "cmd wait: exit status 128, stderr: %q", + "fatal: git upload-pack: protocol error, expected to get object ID, not 'command=fetch'\n", + ), + }, + { + desc: "missing input", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + require.NoError(t, clientConn.CloseWrite()) + return nil + }, + expectedResponse: &gitalypb.SSHUploadPackWithSidechannelResponse{}, + }, + { + desc: "short write", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + + _, err := io.WriteString(clientConn, "0011agent") + require.NoError(t, err) + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedErr: helper.ErrInternalf("cmd wait: exit status 128, stderr: %q", "fatal: the remote end hung up unexpectedly\n"), + }, + { + desc: "garbage", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, _ func()) error { + gittest.WritePktlineString(t, clientConn, "foobar") + require.NoError(t, clientConn.CloseWrite()) + return nil + }, + expectedErr: helper.ErrInternalf("cmd wait: exit status 128, stderr: %q", "fatal: unknown capability 'foobar'\n"), + }, + { + desc: "close and cancellation", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n") + + require.NoError(t, clientConn.CloseWrite()) + cancelContext() + + return nil + }, + expectedErr: helper.ErrCanceled(context.Canceled), + }, + { + desc: "cancellation and close", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n") + + cancelContext() + require.NoError(t, clientConn.CloseWrite()) + + return nil + }, + expectedErr: helper.ErrCanceled(context.Canceled), + }, + { + desc: "cancellation without close", + request: &gitalypb.SSHUploadPackWithSidechannelRequest{ + Repository: repo, + GitProtocol: git.ProtocolV2, + }, + client: func(clientConn *sidechannel.ClientConn, cancelContext func()) error { + gittest.WritePktlineString(t, clientConn, "command=fetch\n") + gittest.WritePktlineString(t, clientConn, "agent=git/2.36.1\n") + + cancelContext() + + return nil + }, + expectedErr: helper.ErrCanceled(context.Canceled), + }, + } { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithCancel(testhelper.Context(t)) + + ctx, waiter := sidechannel.RegisterSidechannel(ctx, registry, func(clientConn *sidechannel.ClientConn) (returnedErr error) { + errCh := make(chan error, 1) + go func() { + _, err := io.Copy(io.Discard, clientConn) + errCh <- err + }() + defer func() { + if err := <-errCh; err != nil && returnedErr == nil { + returnedErr = err + } + }() + + return tc.client(clientConn, cancel) + }) + defer testhelper.MustClose(t, waiter) + + response, err := client.SSHUploadPackWithSidechannel(ctx, tc.request) + testhelper.RequireGrpcError(t, tc.expectedErr, err) + testhelper.ProtoEqual(t, tc.expectedResponse, response) + }) + } +} + func requireFailedSSHStream(t *testing.T, recv func() (int32, error)) { done := make(chan struct{}) var code int32 |