Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'internal/gitaly/hook/sidechannel.go')
-rw-r--r--internal/gitaly/hook/sidechannel.go109
1 files changed, 109 insertions, 0 deletions
diff --git a/internal/gitaly/hook/sidechannel.go b/internal/gitaly/hook/sidechannel.go
new file mode 100644
index 000000000..18381237c
--- /dev/null
+++ b/internal/gitaly/hook/sidechannel.go
@@ -0,0 +1,109 @@
+package hook
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "time"
+
+ gitaly_metadata "gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
+ "google.golang.org/grpc/metadata"
+)
+
+const (
+ sidechannelHeader = "gitaly-sidechannel-socket"
+ sidechannelSocket = "sidechannel"
+)
+
+type errInvalidSidechannelAddress struct{ string }
+
+func (e *errInvalidSidechannelAddress) Error() string {
+ return fmt.Sprintf("invalid side channel address: %q", e.string)
+}
+
+// GetSidechannel looks for a sidechannel address in an incoming context
+// and establishes a connection if it finds an address.
+func GetSidechannel(ctx context.Context) (net.Conn, error) {
+ address := gitaly_metadata.GetValue(ctx, sidechannelHeader)
+ if path.Base(address) != sidechannelSocket {
+ return nil, &errInvalidSidechannelAddress{address}
+ }
+
+ return net.DialTimeout("unix", address, time.Second)
+}
+
+// SetupSidechannel creates a sidechannel listener in a tempdir and
+// launches a goroutine that will run the callback if the listener
+// receives a connection. The address of the listener is stored in the
+// returned context, so that the caller can propagate it to a server. The
+// caller must Close the SidechannelWaiter to prevent resource leaks.
+func SetupSidechannel(ctx context.Context, callback func(*net.UnixConn) error) (context.Context, *SidechannelWaiter, error) {
+ socketDir, err := ioutil.TempDir("", "gitaly")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ address := path.Join(socketDir, sidechannelSocket)
+ l, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: address})
+ if err != nil {
+ return nil, nil, err
+ }
+
+ wt := &SidechannelWaiter{
+ errC: make(chan error),
+ socketDir: socketDir,
+ listener: l,
+ }
+ go wt.run(callback)
+
+ ctx = metadata.AppendToOutgoingContext(ctx, sidechannelHeader, address)
+ return ctx, wt, nil
+}
+
+// SidechannelWaiter provides cleanup and error propagation for a
+// sidechannel callback.
+type SidechannelWaiter struct {
+ errC chan error
+ socketDir string
+ listener *net.UnixListener
+}
+
+func (wt *SidechannelWaiter) run(callback func(*net.UnixConn) error) {
+ defer close(wt.errC)
+
+ wt.errC <- func() error {
+ c, err := wt.listener.AcceptUnix()
+ if err != nil {
+ return err
+ }
+ defer c.Close()
+
+ return callback(c)
+ }()
+}
+
+// Close cleans up sidechannel resources. If the callback is already
+// running, Close will block until the callback is done.
+func (wt *SidechannelWaiter) Close() error {
+ // Run all cleanup actions _before_ checking errors, so that we cannot
+ // forget one.
+ cleanupErrors := []error{
+ wt.listener.Close(),
+ os.RemoveAll(wt.socketDir),
+ wt.Wait(),
+ }
+
+ for _, err := range cleanupErrors {
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Wait waits for the callback to run and returns its error value.
+func (wt *SidechannelWaiter) Wait() error { return <-wt.errC }