diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2021-09-23 15:55:47 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2021-09-30 18:27:04 +0300 |
commit | 098ee38438822b534723e76bad0bc3a5ba8baadb (patch) | |
tree | 161eb60fbe26134ed1e156160e4ec072098466c6 /internal | |
parent | 1ec2f7682f9aad976a323881c4437228f4e7c201 (diff) |
sidechannel: remove waiter.Wait()
Before this change, the client-side sidechannel.Waiter object had a
Close() and a Wait() method. These methods were the same except
Close() would _sometimes_ block, and Wait() would _always_ block
waiting for the sidechannel callback to run. This does not make sense
when implementing a sidechannel proxy (like we need for Praefect)
because how does the proxy know if it needs to call Wait() or Close()?
This change adds an extra message to the sidechannel protocol: the
client sends two bytes "ok" to the server after accepting the
sidechannel connection. This creates a new synchronization point. Now,
RPC's that want the behavior of Wait() can implement it by first
blocking in the client waiting for a gRPC message from the server, and
then calling Waiter.Close(). This makes proxying easier because a
proxy can now just call Close() after the gRPC call is done.
This is a breaking change to the protocol but that is still OK because
we have no RPC's yet that use the protocol.
Changelog: other
Diffstat (limited to 'internal')
-rw-r--r-- | internal/gitaly/service/smarthttp/upload_pack_test.go | 4 | ||||
-rw-r--r-- | internal/sidechannel/registry.go | 49 | ||||
-rw-r--r-- | internal/sidechannel/registry_test.go | 24 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel.go | 9 | ||||
-rw-r--r-- | internal/sidechannel/sidechannel_test.go | 2 |
5 files changed, 61 insertions, 27 deletions
diff --git a/internal/gitaly/service/smarthttp/upload_pack_test.go b/internal/gitaly/service/smarthttp/upload_pack_test.go index 1d9a9012a..506b1e34f 100644 --- a/internal/gitaly/service/smarthttp/upload_pack_test.go +++ b/internal/gitaly/service/smarthttp/upload_pack_test.go @@ -518,7 +518,7 @@ func makePostUploadPackWithSidechannelRequest(ctx context.Context, t *testing.T, return <-errC }) - defer testhelper.MustClose(t, waiter) + defer waiter.Close() rpcRequest := &gitalypb.PostUploadPackWithSidechannelRequest{ Repository: in.GetRepository(), @@ -527,7 +527,7 @@ func makePostUploadPackWithSidechannelRequest(ctx context.Context, t *testing.T, } _, err := client.PostUploadPackWithSidechannel(ctxOut, rpcRequest) if err == nil { - require.NoError(t, waiter.Wait()) + require.NoError(t, waiter.Close()) } return responseBuffer, err diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go index 1b96e8c88..9a5145f1c 100644 --- a/internal/sidechannel/registry.go +++ b/internal/sidechannel/registry.go @@ -1,9 +1,11 @@ package sidechannel import ( + "errors" "fmt" "net" "sync" + "time" ) // sidechannelID is the type of ID used to differeniate sidechannel connections @@ -23,7 +25,8 @@ type Registry struct { type Waiter struct { id sidechannelID registry *Registry - errC chan error + err error + done chan struct{} accept chan net.Conn callback func(*ClientConn) error } @@ -47,13 +50,16 @@ func (s *Registry) Register(callback func(*ClientConn) error) *Waiter { waiter := &Waiter{ id: s.nextID, registry: s, - errC: make(chan error), + done: make(chan struct{}), accept: make(chan net.Conn), callback: callback, } s.nextID++ - go waiter.run() + go func() { + defer close(waiter.done) + waiter.err = waiter.run() + }() s.waiters[waiter.id] = waiter return waiter } @@ -99,23 +105,34 @@ func (s *Registry) waiting() int { return len(s.waiters) } -func (w *Waiter) run() { - defer close(w.errC) +var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run") + +func (w *Waiter) run() error { + conn := <-w.accept + if conn == nil { + return ErrCallbackDidNotRun + } + defer conn.Close() - if conn := <-w.accept; conn != nil { - defer conn.Close() - w.errC <- w.callback(newClientConn(conn)) + if err := conn.SetWriteDeadline(time.Now().Add(sidechannelTimeout)); err != nil { + return err + } + if _, err := conn.Write([]byte("ok")); err != nil { + return err } + if err := conn.SetWriteDeadline(time.Time{}); err != nil { + return err + } + + return w.callback(newClientConn(conn)) } -// Close cleans the waiter, removes it from the registry. If the callback is -// executing, this method is blocked until the callback is done. +// Close de-registers the callback. If the callback got triggered, +// Close() will return its error return value. If the callback has not +// started by the time Close() is called, Close() returns +// ErrCallbackDidNotRun. func (w *Waiter) Close() error { w.registry.removeWaiter(w) - return <-w.errC -} - -// Wait waits until either the callback is executed, or the waiter is closed -func (w *Waiter) Wait() error { - return <-w.errC + <-w.done + return w.err } diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go index 4db4e6ac5..032534a81 100644 --- a/internal/sidechannel/registry_test.go +++ b/internal/sidechannel/registry_test.go @@ -33,17 +33,18 @@ func TestRegistry(t *testing.T) { close(triggerCallback) - require.NoError(t, waiter.Wait()) + require.NoError(t, waiter.Close()) requireConnClosed(t, client) }) - t.Run("pull connections successfully", func(t *testing.T) { + t.Run("receive connections successfully", func(t *testing.T) { wg := sync.WaitGroup{} - var servers []*ServerConn + servers := make([]net.Conn, N) for i := 0; i < N; i++ { client, server := socketPair(t) - servers = append(servers, newServerConn(server)) + servers[i] = server + defer server.Close() wg.Add(1) go func(i int) { @@ -57,7 +58,7 @@ func TestRegistry(t *testing.T) { defer waiter.Close() require.NoError(t, registry.receive(waiter.id, client)) - require.NoError(t, waiter.Wait()) + require.NoError(t, waiter.Close()) requireConnClosed(t, client) wg.Done() @@ -65,7 +66,14 @@ func TestRegistry(t *testing.T) { } for i := 0; i < N; i++ { - out, err := io.ReadAll(servers[i]) + // Read registry confirmation + buf := make([]byte, 2) + _, err := io.ReadFull(servers[i], buf) + require.NoError(t, err) + require.Equal(t, "ok", string(buf)) + + // Read data written by callback + out, err := io.ReadAll(newServerConn(servers[i])) require.NoError(t, err) require.Equal(t, strconv.Itoa(i), string(out)) } @@ -74,7 +82,7 @@ func TestRegistry(t *testing.T) { require.Equal(t, 0, registry.waiting()) }) - t.Run("push connection to non-existing ID", func(t *testing.T) { + t.Run("receive connection for non-existing ID", func(t *testing.T) { client, _ := socketPair(t) err := registry.receive(registry.nextID+1, client) require.EqualError(t, err, "sidechannel registry: ID not registered") @@ -83,7 +91,7 @@ func TestRegistry(t *testing.T) { t.Run("pre-maturely close the waiter", func(t *testing.T) { waiter := registry.Register(func(conn *ClientConn) error { panic("never execute") }) - require.NoError(t, waiter.Close()) + require.Equal(t, ErrCallbackDidNotRun, waiter.Close()) require.Equal(t, 0, registry.waiting()) }) } diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go index a2083fb04..4886603e7 100644 --- a/internal/sidechannel/sidechannel.go +++ b/internal/sidechannel/sidechannel.go @@ -4,6 +4,7 @@ import ( "context" "encoding/binary" "fmt" + "io" "net" "strconv" "time" @@ -64,6 +65,14 @@ func OpenSidechannel(ctx context.Context) (_ *ServerConn, err error) { return nil, fmt.Errorf("sidechannel: write stream id: %w", err) } + buf := make([]byte, 2) + if _, err := io.ReadFull(stream, buf); err != nil { + return nil, fmt.Errorf("sidechannel: receive confirmation: %w", err) + } + if string(buf) != "ok" { + return nil, fmt.Errorf("sidechannel: expected ok, got %q", buf) + } + if err := stream.SetDeadline(time.Time{}); err != nil { return nil, err } diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index f1d625504..3ad925e61 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -197,7 +197,7 @@ func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handle return err } - if err := waiter.Wait(); err != nil { + if err := waiter.Close(); err != nil { return err } |