diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2022-01-11 13:29:32 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2022-01-14 17:25:03 +0300 |
commit | 44f4a68d69d59b98bf5d94c8f9fe1c62f370529e (patch) | |
tree | 325f3068ca1e0b4f5a44ffb3581869b6aa6a7219 | |
parent | 3b4390ac2ed5b9dbcc9efd3f8222a373d8db47bb (diff) |
cmd/gitaly-wrapper: Refactor reading of PID file to be testable
It's hard to test the logic which reads the PID file because it depends
on environment variables to be set. Refactor the function to instead
accept the PID file path as parameter and add a bunch of tests to verify
it works as expected.
While at it, renames the function to `readPIDFile()` to more clearly
show what it does.
-rw-r--r-- | cmd/gitaly-wrapper/main.go | 19 | ||||
-rw-r--r-- | cmd/gitaly-wrapper/main_test.go | 37 |
2 files changed, 40 insertions, 16 deletions
diff --git a/cmd/gitaly-wrapper/main.go b/cmd/gitaly-wrapper/main.go index 8cb0c3941..a70855bf5 100644 --- a/cmd/gitaly-wrapper/main.go +++ b/cmd/gitaly-wrapper/main.go @@ -40,12 +40,13 @@ func main() { logger := log.Default().WithField("wrapper", os.Getpid()) logger.Info("Wrapper started") - if pidFile() == "" { + pidFilePath := os.Getenv(bootstrap.EnvPidFile) + if pidFilePath == "" { logger.Fatalf("missing pid file ENV variable %q", bootstrap.EnvPidFile) } + logger.WithField("pid_file", pidFilePath).Info("finding gitaly") - logger.WithField("pid_file", pidFile()).Info("finding gitaly") - gitaly, err := findGitaly() + gitaly, err := findGitaly(pidFilePath) if err != nil && !isRecoverable(err) { logger.WithError(err).Fatal("find gitaly") } else if err != nil { @@ -83,8 +84,8 @@ func isRecoverable(err error) bool { return os.IsNotExist(err) || errors.As(err, &numError) } -func findGitaly() (*os.Process, error) { - pid, err := getPid() +func findGitaly(pidFilePath string) (*os.Process, error) { + pid, err := readPIDFile(pidFilePath) if err != nil { return nil, err } @@ -146,8 +147,8 @@ func forwardSignals(gitaly *os.Process, log *logrus.Entry) { signal.Notify(sigs) } -func getPid() (int, error) { - data, err := os.ReadFile(pidFile()) +func readPIDFile(pidFilePath string) (int, error) { + data, err := os.ReadFile(pidFilePath) if err != nil { return 0, err } @@ -176,10 +177,6 @@ func isGitaly(p *os.Process, gitalyBin string) bool { return false } -func pidFile() string { - return os.Getenv(bootstrap.EnvPidFile) -} - func jsonLogging() bool { enabled, _ := env.GetBool(envJSONLogging, false) return enabled diff --git a/cmd/gitaly-wrapper/main_test.go b/cmd/gitaly-wrapper/main_test.go index abca0aa09..8664c0bcb 100644 --- a/cmd/gitaly-wrapper/main_test.go +++ b/cmd/gitaly-wrapper/main_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap" "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" ) @@ -20,9 +19,6 @@ func TestStolenPid(t *testing.T) { pidFile, err := os.Create(filepath.Join(tempDir, "pidfile")) require.NoError(t, err) - cleanup := testhelper.ModifyEnvironment(t, bootstrap.EnvPidFile, pidFile.Name()) - defer cleanup() - ctx, cancel := testhelper.Context() defer cancel() @@ -37,7 +33,7 @@ func TestStolenPid(t *testing.T) { require.NoError(t, err) require.NoError(t, pidFile.Close()) - tail, err := findGitaly() + tail, err := findGitaly(pidFile.Name()) require.NoError(t, err) require.NotNil(t, tail) require.Equal(t, cmd.Process.Pid, tail.Pid) @@ -83,3 +79,34 @@ func TestIsRecoverable(t *testing.T) { }) } } + +func TestReadPIDFile(t *testing.T) { + t.Run("nonexistent", func(t *testing.T) { + _, err := readPIDFile("does-not-exist") + require.True(t, os.IsNotExist(err)) + }) + + t.Run("empty", func(t *testing.T) { + path := filepath.Join(testhelper.TempDir(t), "pid") + require.NoError(t, os.WriteFile(path, nil, 0o644)) + _, err := readPIDFile(path) + _, expectedErr := strconv.Atoi("") + require.Equal(t, expectedErr, err) + }) + + t.Run("invalid contents", func(t *testing.T) { + path := filepath.Join(testhelper.TempDir(t), "pid") + require.NoError(t, os.WriteFile(path, []byte("invalid"), 0o644)) + _, err := readPIDFile(path) + _, expectedErr := strconv.Atoi("invalid") + require.Equal(t, expectedErr, err) + }) + + t.Run("valid", func(t *testing.T) { + path := filepath.Join(testhelper.TempDir(t), "pid") + require.NoError(t, os.WriteFile(path, []byte("12345"), 0o644)) + pid, err := readPIDFile(path) + require.NoError(t, err) + require.Equal(t, 12345, pid) + }) +} |