Welcome to mirror list, hosted at ThFree Co, Russian Federation.

registry.go « sidechannel « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3148cb05dd0e5e652c56dd2b2b418457d049cd2a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package sidechannel

import (
	"errors"
	"fmt"
	"net"
	"sync"
	"time"
)

// sidechannelID is the type of ID used to differeniate sidechannel connections
// in the same registry
type sidechannelID int64

// Registry manages sidechannel connections. It allows the RPC
// handlers to wait for the secondary incoming connection made by the client.
type Registry struct {
	nextID  sidechannelID
	waiters map[sidechannelID]*Waiter
	mu      sync.Mutex
}

// Waiter lets the caller waits until a connection with matched id is pushed
// into the registry, then execute the callback
type Waiter struct {
	id       sidechannelID
	registry *Registry
	err      error
	done     chan struct{}
	accept   chan net.Conn
	callback func(*ClientConn) error
}

// NewRegistry returns a new Registry instance
func NewRegistry() *Registry {
	return &Registry{
		waiters: make(map[sidechannelID]*Waiter),
	}
}

// Register registers the caller into the waiting list. The caller must provide
// a callback function. The caller receives a waiter instance.  After the
// connection arrives, the callback function is executed with arrived
// connection in a new goroutine.  The caller receives execution result via
// waiter.Wait().
func (s *Registry) Register(callback func(*ClientConn) error) *Waiter {
	s.mu.Lock()
	defer s.mu.Unlock()

	waiter := &Waiter{
		id:       s.nextID,
		registry: s,
		done:     make(chan struct{}),
		accept:   make(chan net.Conn),
		callback: callback,
	}
	s.nextID++

	go func() {
		defer close(waiter.done)
		waiter.err = waiter.run()
	}()
	s.waiters[waiter.id] = waiter
	return waiter
}

// receive looks into the registry for a waiter with the given ID. If
// there is an associated ID, the waiter is removed from the registry, and the
// connection is pushed into the waiter's accept channel. After the callback is done, the
// connection is closed. When the ID is not found, an error is returned and the
// connection is closed immediately.
func (s *Registry) receive(id sidechannelID, conn net.Conn) (err error) {
	s.mu.Lock()
	defer func() {
		s.mu.Unlock()
		if err != nil {
			conn.Close()
		}
	}()

	waiter, exist := s.waiters[id]
	if !exist {
		return fmt.Errorf("sidechannel registry: ID not registered")
	}
	delete(s.waiters, waiter.id)
	waiter.accept <- conn

	return nil
}

func (s *Registry) removeWaiter(waiter *Waiter) {
	s.mu.Lock()
	defer s.mu.Unlock()

	if _, exist := s.waiters[waiter.id]; exist {
		delete(s.waiters, waiter.id)
		close(waiter.accept)
	}
}

func (s *Registry) waiting() int {
	s.mu.Lock()
	defer s.mu.Unlock()

	return len(s.waiters)
}

// ErrCallbackDidNotRun indicates that a sidechannel callback was
// de-registered without having run. This can happen if the server chose
// not to use the sidechannel.
var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run")

func (w *Waiter) run() error {
	conn := <-w.accept
	if conn == nil {
		return ErrCallbackDidNotRun
	}
	defer conn.Close()

	if err := conn.SetWriteDeadline(time.Now().Add(sidechannelTimeout)); err != nil {
		return err
	}
	if _, err := conn.Write([]byte("ok")); err != nil {
		return err
	}
	if err := conn.SetWriteDeadline(time.Time{}); err != nil {
		return err
	}

	return w.callback(newClientConn(conn))
}

// Close de-registers the callback. If the callback got triggered,
// Close() will return its error return value. If the callback has not
// started by the time Close() is called, Close() returns
// ErrCallbackDidNotRun.
func (w *Waiter) Close() error {
	w.registry.removeWaiter(w)
	<-w.done
	return w.err
}