Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2021-09-23 15:55:47 +0300
committerJacob Vosmaer <jacob@gitlab.com>2021-09-30 18:27:04 +0300
commit098ee38438822b534723e76bad0bc3a5ba8baadb (patch)
tree161eb60fbe26134ed1e156160e4ec072098466c6 /internal
parent1ec2f7682f9aad976a323881c4437228f4e7c201 (diff)
sidechannel: remove waiter.Wait()
Before this change, the client-side sidechannel.Waiter object had a Close() and a Wait() method. These methods were the same except Close() would _sometimes_ block, and Wait() would _always_ block waiting for the sidechannel callback to run. This does not make sense when implementing a sidechannel proxy (like we need for Praefect) because how does the proxy know if it needs to call Wait() or Close()? This change adds an extra message to the sidechannel protocol: the client sends two bytes "ok" to the server after accepting the sidechannel connection. This creates a new synchronization point. Now, RPC's that want the behavior of Wait() can implement it by first blocking in the client waiting for a gRPC message from the server, and then calling Waiter.Close(). This makes proxying easier because a proxy can now just call Close() after the gRPC call is done. This is a breaking change to the protocol but that is still OK because we have no RPC's yet that use the protocol. Changelog: other
Diffstat (limited to 'internal')
-rw-r--r--internal/gitaly/service/smarthttp/upload_pack_test.go4
-rw-r--r--internal/sidechannel/registry.go49
-rw-r--r--internal/sidechannel/registry_test.go24
-rw-r--r--internal/sidechannel/sidechannel.go9
-rw-r--r--internal/sidechannel/sidechannel_test.go2
5 files changed, 61 insertions, 27 deletions
diff --git a/internal/gitaly/service/smarthttp/upload_pack_test.go b/internal/gitaly/service/smarthttp/upload_pack_test.go
index 1d9a9012a..506b1e34f 100644
--- a/internal/gitaly/service/smarthttp/upload_pack_test.go
+++ b/internal/gitaly/service/smarthttp/upload_pack_test.go
@@ -518,7 +518,7 @@ func makePostUploadPackWithSidechannelRequest(ctx context.Context, t *testing.T,
return <-errC
})
- defer testhelper.MustClose(t, waiter)
+ defer waiter.Close()
rpcRequest := &gitalypb.PostUploadPackWithSidechannelRequest{
Repository: in.GetRepository(),
@@ -527,7 +527,7 @@ func makePostUploadPackWithSidechannelRequest(ctx context.Context, t *testing.T,
}
_, err := client.PostUploadPackWithSidechannel(ctxOut, rpcRequest)
if err == nil {
- require.NoError(t, waiter.Wait())
+ require.NoError(t, waiter.Close())
}
return responseBuffer, err
diff --git a/internal/sidechannel/registry.go b/internal/sidechannel/registry.go
index 1b96e8c88..9a5145f1c 100644
--- a/internal/sidechannel/registry.go
+++ b/internal/sidechannel/registry.go
@@ -1,9 +1,11 @@
package sidechannel
import (
+ "errors"
"fmt"
"net"
"sync"
+ "time"
)
// sidechannelID is the type of ID used to differeniate sidechannel connections
@@ -23,7 +25,8 @@ type Registry struct {
type Waiter struct {
id sidechannelID
registry *Registry
- errC chan error
+ err error
+ done chan struct{}
accept chan net.Conn
callback func(*ClientConn) error
}
@@ -47,13 +50,16 @@ func (s *Registry) Register(callback func(*ClientConn) error) *Waiter {
waiter := &Waiter{
id: s.nextID,
registry: s,
- errC: make(chan error),
+ done: make(chan struct{}),
accept: make(chan net.Conn),
callback: callback,
}
s.nextID++
- go waiter.run()
+ go func() {
+ defer close(waiter.done)
+ waiter.err = waiter.run()
+ }()
s.waiters[waiter.id] = waiter
return waiter
}
@@ -99,23 +105,34 @@ func (s *Registry) waiting() int {
return len(s.waiters)
}
-func (w *Waiter) run() {
- defer close(w.errC)
+var ErrCallbackDidNotRun = errors.New("sidechannel: callback de-registered without having run")
+
+func (w *Waiter) run() error {
+ conn := <-w.accept
+ if conn == nil {
+ return ErrCallbackDidNotRun
+ }
+ defer conn.Close()
- if conn := <-w.accept; conn != nil {
- defer conn.Close()
- w.errC <- w.callback(newClientConn(conn))
+ if err := conn.SetWriteDeadline(time.Now().Add(sidechannelTimeout)); err != nil {
+ return err
+ }
+ if _, err := conn.Write([]byte("ok")); err != nil {
+ return err
}
+ if err := conn.SetWriteDeadline(time.Time{}); err != nil {
+ return err
+ }
+
+ return w.callback(newClientConn(conn))
}
-// Close cleans the waiter, removes it from the registry. If the callback is
-// executing, this method is blocked until the callback is done.
+// Close de-registers the callback. If the callback got triggered,
+// Close() will return its error return value. If the callback has not
+// started by the time Close() is called, Close() returns
+// ErrCallbackDidNotRun.
func (w *Waiter) Close() error {
w.registry.removeWaiter(w)
- return <-w.errC
-}
-
-// Wait waits until either the callback is executed, or the waiter is closed
-func (w *Waiter) Wait() error {
- return <-w.errC
+ <-w.done
+ return w.err
}
diff --git a/internal/sidechannel/registry_test.go b/internal/sidechannel/registry_test.go
index 4db4e6ac5..032534a81 100644
--- a/internal/sidechannel/registry_test.go
+++ b/internal/sidechannel/registry_test.go
@@ -33,17 +33,18 @@ func TestRegistry(t *testing.T) {
close(triggerCallback)
- require.NoError(t, waiter.Wait())
+ require.NoError(t, waiter.Close())
requireConnClosed(t, client)
})
- t.Run("pull connections successfully", func(t *testing.T) {
+ t.Run("receive connections successfully", func(t *testing.T) {
wg := sync.WaitGroup{}
- var servers []*ServerConn
+ servers := make([]net.Conn, N)
for i := 0; i < N; i++ {
client, server := socketPair(t)
- servers = append(servers, newServerConn(server))
+ servers[i] = server
+ defer server.Close()
wg.Add(1)
go func(i int) {
@@ -57,7 +58,7 @@ func TestRegistry(t *testing.T) {
defer waiter.Close()
require.NoError(t, registry.receive(waiter.id, client))
- require.NoError(t, waiter.Wait())
+ require.NoError(t, waiter.Close())
requireConnClosed(t, client)
wg.Done()
@@ -65,7 +66,14 @@ func TestRegistry(t *testing.T) {
}
for i := 0; i < N; i++ {
- out, err := io.ReadAll(servers[i])
+ // Read registry confirmation
+ buf := make([]byte, 2)
+ _, err := io.ReadFull(servers[i], buf)
+ require.NoError(t, err)
+ require.Equal(t, "ok", string(buf))
+
+ // Read data written by callback
+ out, err := io.ReadAll(newServerConn(servers[i]))
require.NoError(t, err)
require.Equal(t, strconv.Itoa(i), string(out))
}
@@ -74,7 +82,7 @@ func TestRegistry(t *testing.T) {
require.Equal(t, 0, registry.waiting())
})
- t.Run("push connection to non-existing ID", func(t *testing.T) {
+ t.Run("receive connection for non-existing ID", func(t *testing.T) {
client, _ := socketPair(t)
err := registry.receive(registry.nextID+1, client)
require.EqualError(t, err, "sidechannel registry: ID not registered")
@@ -83,7 +91,7 @@ func TestRegistry(t *testing.T) {
t.Run("pre-maturely close the waiter", func(t *testing.T) {
waiter := registry.Register(func(conn *ClientConn) error { panic("never execute") })
- require.NoError(t, waiter.Close())
+ require.Equal(t, ErrCallbackDidNotRun, waiter.Close())
require.Equal(t, 0, registry.waiting())
})
}
diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go
index a2083fb04..4886603e7 100644
--- a/internal/sidechannel/sidechannel.go
+++ b/internal/sidechannel/sidechannel.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/binary"
"fmt"
+ "io"
"net"
"strconv"
"time"
@@ -64,6 +65,14 @@ func OpenSidechannel(ctx context.Context) (_ *ServerConn, err error) {
return nil, fmt.Errorf("sidechannel: write stream id: %w", err)
}
+ buf := make([]byte, 2)
+ if _, err := io.ReadFull(stream, buf); err != nil {
+ return nil, fmt.Errorf("sidechannel: receive confirmation: %w", err)
+ }
+ if string(buf) != "ok" {
+ return nil, fmt.Errorf("sidechannel: expected ok, got %q", buf)
+ }
+
if err := stream.SetDeadline(time.Time{}); err != nil {
return nil, err
}
diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go
index f1d625504..3ad925e61 100644
--- a/internal/sidechannel/sidechannel_test.go
+++ b/internal/sidechannel/sidechannel_test.go
@@ -197,7 +197,7 @@ func call(ctx context.Context, conn *grpc.ClientConn, registry *Registry, handle
return err
}
- if err := waiter.Wait(); err != nil {
+ if err := waiter.Close(); err != nil {
return err
}