Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Steinhardt <psteinhardt@gitlab.com>2022-01-11 13:29:32 +0300
committerPatrick Steinhardt <psteinhardt@gitlab.com>2022-01-14 17:25:03 +0300
commit44f4a68d69d59b98bf5d94c8f9fe1c62f370529e (patch)
tree325f3068ca1e0b4f5a44ffb3581869b6aa6a7219
parent3b4390ac2ed5b9dbcc9efd3f8222a373d8db47bb (diff)
cmd/gitaly-wrapper: Refactor reading of PID file to be testable
It's hard to test the logic which reads the PID file because it depends on environment variables to be set. Refactor the function to instead accept the PID file path as parameter and add a bunch of tests to verify it works as expected. While at it, renames the function to `readPIDFile()` to more clearly show what it does.
-rw-r--r--cmd/gitaly-wrapper/main.go19
-rw-r--r--cmd/gitaly-wrapper/main_test.go37
2 files changed, 40 insertions, 16 deletions
diff --git a/cmd/gitaly-wrapper/main.go b/cmd/gitaly-wrapper/main.go
index 8cb0c3941..a70855bf5 100644
--- a/cmd/gitaly-wrapper/main.go
+++ b/cmd/gitaly-wrapper/main.go
@@ -40,12 +40,13 @@ func main() {
logger := log.Default().WithField("wrapper", os.Getpid())
logger.Info("Wrapper started")
- if pidFile() == "" {
+ pidFilePath := os.Getenv(bootstrap.EnvPidFile)
+ if pidFilePath == "" {
logger.Fatalf("missing pid file ENV variable %q", bootstrap.EnvPidFile)
}
+ logger.WithField("pid_file", pidFilePath).Info("finding gitaly")
- logger.WithField("pid_file", pidFile()).Info("finding gitaly")
- gitaly, err := findGitaly()
+ gitaly, err := findGitaly(pidFilePath)
if err != nil && !isRecoverable(err) {
logger.WithError(err).Fatal("find gitaly")
} else if err != nil {
@@ -83,8 +84,8 @@ func isRecoverable(err error) bool {
return os.IsNotExist(err) || errors.As(err, &numError)
}
-func findGitaly() (*os.Process, error) {
- pid, err := getPid()
+func findGitaly(pidFilePath string) (*os.Process, error) {
+ pid, err := readPIDFile(pidFilePath)
if err != nil {
return nil, err
}
@@ -146,8 +147,8 @@ func forwardSignals(gitaly *os.Process, log *logrus.Entry) {
signal.Notify(sigs)
}
-func getPid() (int, error) {
- data, err := os.ReadFile(pidFile())
+func readPIDFile(pidFilePath string) (int, error) {
+ data, err := os.ReadFile(pidFilePath)
if err != nil {
return 0, err
}
@@ -176,10 +177,6 @@ func isGitaly(p *os.Process, gitalyBin string) bool {
return false
}
-func pidFile() string {
- return os.Getenv(bootstrap.EnvPidFile)
-}
-
func jsonLogging() bool {
enabled, _ := env.GetBool(envJSONLogging, false)
return enabled
diff --git a/cmd/gitaly-wrapper/main_test.go b/cmd/gitaly-wrapper/main_test.go
index abca0aa09..8664c0bcb 100644
--- a/cmd/gitaly-wrapper/main_test.go
+++ b/cmd/gitaly-wrapper/main_test.go
@@ -9,7 +9,6 @@ import (
"testing"
"github.com/stretchr/testify/require"
- "gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
)
@@ -20,9 +19,6 @@ func TestStolenPid(t *testing.T) {
pidFile, err := os.Create(filepath.Join(tempDir, "pidfile"))
require.NoError(t, err)
- cleanup := testhelper.ModifyEnvironment(t, bootstrap.EnvPidFile, pidFile.Name())
- defer cleanup()
-
ctx, cancel := testhelper.Context()
defer cancel()
@@ -37,7 +33,7 @@ func TestStolenPid(t *testing.T) {
require.NoError(t, err)
require.NoError(t, pidFile.Close())
- tail, err := findGitaly()
+ tail, err := findGitaly(pidFile.Name())
require.NoError(t, err)
require.NotNil(t, tail)
require.Equal(t, cmd.Process.Pid, tail.Pid)
@@ -83,3 +79,34 @@ func TestIsRecoverable(t *testing.T) {
})
}
}
+
+func TestReadPIDFile(t *testing.T) {
+ t.Run("nonexistent", func(t *testing.T) {
+ _, err := readPIDFile("does-not-exist")
+ require.True(t, os.IsNotExist(err))
+ })
+
+ t.Run("empty", func(t *testing.T) {
+ path := filepath.Join(testhelper.TempDir(t), "pid")
+ require.NoError(t, os.WriteFile(path, nil, 0o644))
+ _, err := readPIDFile(path)
+ _, expectedErr := strconv.Atoi("")
+ require.Equal(t, expectedErr, err)
+ })
+
+ t.Run("invalid contents", func(t *testing.T) {
+ path := filepath.Join(testhelper.TempDir(t), "pid")
+ require.NoError(t, os.WriteFile(path, []byte("invalid"), 0o644))
+ _, err := readPIDFile(path)
+ _, expectedErr := strconv.Atoi("invalid")
+ require.Equal(t, expectedErr, err)
+ })
+
+ t.Run("valid", func(t *testing.T) {
+ path := filepath.Join(testhelper.TempDir(t), "pid")
+ require.NoError(t, os.WriteFile(path, []byte("12345"), 0o644))
+ pid, err := readPIDFile(path)
+ require.NoError(t, err)
+ require.Equal(t, 12345, pid)
+ })
+}