Welcome to mirror list, hosted at ThFree Co, Russian Federation.

sidechannel.go « hook « gitaly « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 26c86628fc05af8f6cccd9ea998d65af522ee250 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package hook

import (
	"context"
	"errors"
	"fmt"
	"io/fs"
	"net"
	"os"
	"path"
	"path/filepath"
	"time"

	"gitlab.com/gitlab-org/gitaly/internal/git"
	gitaly_metadata "gitlab.com/gitlab-org/gitaly/internal/metadata"
	"google.golang.org/grpc/metadata"
)

const (
	sidechannelHeader = "gitaly-sidechannel-socket"
	sidechannelSocket = "sidechannel"
)

type errInvalidSidechannelAddress struct{ string }

func (e *errInvalidSidechannelAddress) Error() string {
	return fmt.Sprintf("invalid side channel address: %q", e.string)
}

// GetSidechannel looks for a sidechannel address in an incoming context
// and establishes a connection if it finds an address.
func GetSidechannel(ctx context.Context) (net.Conn, error) {
	address := gitaly_metadata.GetValue(ctx, sidechannelHeader)
	if path.Base(address) != sidechannelSocket {
		return nil, &errInvalidSidechannelAddress{address}
	}

	return net.DialTimeout("unix", address, time.Second)
}

// SetupSidechannel creates a sidechannel listener in a tempdir and
// launches a goroutine that will run the callback if the listener
// receives a connection. The address of the listener is stored in the
// returned context, so that the caller can propagate it to a server. The
// caller must Close the SidechannelWaiter to prevent resource leaks.
func SetupSidechannel(ctx context.Context, payload git.HooksPayload, callback func(*net.UnixConn) error) (_ context.Context, _ *SidechannelWaiter, err error) {
	var sidechannelDir, sidechannelName string

	// If there is a runtime directory we try to create a sidechannel directory in there that
	// will hold all the temporary sidechannel subdirectories. Otherwise, we fall back to create
	// the sidechannel directory in the system's temporary directory.
	if payload.RuntimeDir != "" {
		sidechannelDir := filepath.Join(payload.RuntimeDir, "chan.d")

		// Note that we don't use `os.MkdirAll()` here: we don't want to accidentally create
		// the full directory hierarchy, and the assumption is that the runtime directory
		// must exist already.
		if err := os.Mkdir(sidechannelDir, 0o700); err != nil && !errors.Is(err, fs.ErrExist) {
			return nil, nil, err
		}

		sidechannelName = "*"
	} else {
		sidechannelName = "gitaly*"
	}

	socketDir, err := os.MkdirTemp(sidechannelDir, sidechannelName)
	if err != nil {
		return nil, nil, err
	}
	defer func() {
		if err != nil {
			_ = os.RemoveAll(socketDir)
		}
	}()

	address := path.Join(socketDir, sidechannelSocket)
	l, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: address})
	if err != nil {
		return nil, nil, err
	}

	wt := &SidechannelWaiter{
		errC:      make(chan error),
		socketDir: socketDir,
		listener:  l,
	}
	go wt.run(callback)

	ctx = metadata.AppendToOutgoingContext(ctx, sidechannelHeader, address)
	return ctx, wt, nil
}

// SidechannelWaiter provides cleanup and error propagation for a
// sidechannel callback.
type SidechannelWaiter struct {
	errC      chan error
	socketDir string
	listener  *net.UnixListener
}

func (wt *SidechannelWaiter) run(callback func(*net.UnixConn) error) {
	defer close(wt.errC)

	wt.errC <- func() error {
		c, err := wt.listener.AcceptUnix()
		if err != nil {
			return err
		}
		defer c.Close()

		// Eagerly remove the socket directory, in case the process exits before
		// wt.Close() can run.
		if err := os.RemoveAll(wt.socketDir); err != nil {
			return err
		}

		return callback(c)
	}()
}

// Close cleans up sidechannel resources. If the callback is already
// running, Close will block until the callback is done.
func (wt *SidechannelWaiter) Close() error {
	// Run all cleanup actions _before_ checking errors, so that we cannot
	// forget one.
	cleanupErrors := []error{
		// If wt.run() is blocked on AcceptUnix(), this will unblock it.
		wt.listener.Close(),
		// Remove the socket directory to prevent garbage in case wt.run() did
		// not run.
		os.RemoveAll(wt.socketDir),
		// Block until wt.run() is done.
		wt.Wait(),
	}

	for _, err := range cleanupErrors {
		if err != nil {
			return err
		}
	}

	return nil
}

// Wait waits for the callback to run and returns its error value.
func (wt *SidechannelWaiter) Wait() error { return <-wt.errC }