diff options
Diffstat (limited to 'internal/sidechannel/registry.go')
-rw-r--r-- | internal/sidechannel/registry.go | 56 |
1 files changed, 18 insertions, 38 deletions
diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go index 3148cb05d..32985a9a1 100644 --- a/internal/sidechannel/registry.go +++ b/internal/sidechannel/registry.go @@ -1,11 +1,9 @@ package sidechannel import ( - "errors" "fmt" "net" "sync" - "time" ) // sidechannelID is the type of ID used to differeniate sidechannel connections @@ -25,10 +23,9 @@ type Registry struct { type Waiter struct { id sidechannelID registry *Registry - err error - done chan struct{} + errC chan error accept chan net.Conn - callback func(*ClientConn) error + callback func(net.Conn) error } // NewRegistry returns a new Registry instance @@ -43,23 +40,20 @@ func NewRegistry() *Registry { // connection arrives, the callback function is executed with arrived // connection in a new goroutine. The caller receives execution result via // waiter.Wait(). -func (s *Registry) Register(callback func(*ClientConn) error) *Waiter { +func (s *Registry) Register(callback func(net.Conn) error) *Waiter { s.mu.Lock() defer s.mu.Unlock() waiter := &Waiter{ id: s.nextID, registry: s, - done: make(chan struct{}), + errC: make(chan error), accept: make(chan net.Conn), callback: callback, } s.nextID++ - go func() { - defer close(waiter.done) - waiter.err = waiter.run() - }() + go waiter.run() s.waiters[waiter.id] = waiter return waiter } @@ -105,37 +99,23 @@ func (s *Registry) waiting() int { return len(s.waiters) } -// ErrCallbackDidNotRun indicates that a sidechannel callback was -// de-registered without having run. This can happen if the server chose -// not to use the sidechannel. -var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run") - -func (w *Waiter) run() error { - conn := <-w.accept - if conn == nil { - return ErrCallbackDidNotRun - } - defer conn.Close() +func (w *Waiter) run() { + defer close(w.errC) - if err := conn.SetWriteDeadline(time.Now().Add(sidechannelTimeout)); err != nil { - return err - } - if _, err := conn.Write([]byte("ok")); err != nil { - return err + if conn := <-w.accept; conn != nil { + defer conn.Close() + w.errC <- w.callback(conn) } - if err := conn.SetWriteDeadline(time.Time{}); err != nil { - return err - } - - return w.callback(newClientConn(conn)) } -// Close de-registers the callback. If the callback got triggered, -// Close() will return its error return value. If the callback has not -// started by the time Close() is called, Close() returns -// ErrCallbackDidNotRun. +// Close cleans the waiter, removes it from the registry. If the callback is +// executing, this method is blocked until the callback is done. func (w *Waiter) Close() error { w.registry.removeWaiter(w) - <-w.done - return w.err + return <-w.errC +} + +// Wait waits until either the callback is executed, or the waiter is closed +func (w *Waiter) Wait() error { + return <-w.errC } |