Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-26 11:51:04 +0300
committerQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-29 14:15:06 +0300
commit930152aafbbb9860c9cb232f177152cc019e82ae (patch)
tree9de7e347775f96eb2fa69545bf60fec913f1334e
parentbf55b6fc1a0c9bba1e9d1af29cf4ccab168d30e0 (diff)
Re-implement Stream RPC server handling using sidechannel registryqmnguyen0711/re-implement-streamrpc
In https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1051#note_634147375, we came up with a new alternative design for Stream RPC. In general, we'd add additional RPCs (PostUploadPackStream, PackObjectsHookStream, SSHUploadPackStream) to return a streaming token to the client once the verifications are done and they are ready to launch git to begin streaming the response. The RPC handler would wait for StreamRPC connection registry. When the client receives the streaming token. It then dials to the listening port of Gitaly server. Gitaly server detects tthe connection, pushes the connection to StreamRPC connection registry. As a result, the stream RPC handler is unlocked, and continue to stream data on pulled raw TCP connection.
-rw-r--r--internal/streamrpcs/handshaker.go63
-rw-r--r--internal/streamrpcs/registry.go157
-rw-r--r--internal/streamrpcs/registry_test.go166
-rw-r--r--internal/streamrpcs/rpc.go218
-rw-r--r--internal/streamrpcs/rpc_test.go374
-rw-r--r--internal/streamrpcs/testdata/test.pb.go166
-rw-r--r--internal/streamrpcs/testdata/test.proto16
-rw-r--r--internal/streamrpcs/testdata/test_grpc.pb.go134
-rw-r--r--proto/go/gitalypb/protolist.go1
-rw-r--r--proto/go/gitalypb/streamrpc.pb.go155
-rw-r--r--proto/streamrpc.proto12
-rw-r--r--ruby/proto/gitaly/streamrpc_pb.rb18
12 files changed, 1480 insertions, 0 deletions
diff --git a/internal/streamrpcs/handshaker.go b/internal/streamrpcs/handshaker.go
new file mode 100644
index 000000000..a545f2739
--- /dev/null
+++ b/internal/streamrpcs/handshaker.go
@@ -0,0 +1,63 @@
+package streamrpcs
+
+import (
+ "io"
+ "net"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/credentials"
+)
+
+// The magic bytes used for classification by listenmux
+var magicBytes = []byte("streamrpc00")
+
+// ServerHandshaker implements the server side handshake of the multiplexed connection.
+type ServerHandshaker struct {
+ logger logrus.FieldLogger
+}
+
+// NewServerHandshaker returns an implementation of streamrpc server
+// handshaker. The provided TransportCredentials are handshaked prior to
+// initializing the multiplexing session. This handshaker Gitaly's unary server
+// interceptors into the interceptor chain of input StreamRPC server.
+func NewServerHandshaker(logger logrus.FieldLogger) *ServerHandshaker {
+ return &ServerHandshaker{
+ logger: logger,
+ }
+}
+
+// Magic is used by listenmux to retrieve the magic string for
+// streamrpc connections.
+func (s *ServerHandshaker) Magic() string { return string(magicBytes) }
+
+// Handshake "steals" the request from Gitaly's main gRPC server during
+// connection handshaking phase. Listenmux depends on the first 11-byte magic
+// bytes sent by the client, and invoke StreamRPC handshaker accordingly. The
+// request is then handled by stream RPC server, and skipped by Gitaly gRPC
+// server.
+func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ return nil, nil, err
+ }
+
+ token := make([]byte, TokenSizeBytes)
+ _, err := io.ReadFull(conn, token)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if err = globalRegistry.Push(string(token), conn); err != nil {
+ return nil, nil, err
+ }
+
+ // At this point, the connection is already closed. If the
+ // TransportCredentials continues its code path, gRPC constructs a HTTP2
+ // server transport to handle the connection. Eventually, it fails and logs
+ // several warnings and errors even though the stream RPC call is
+ // successful.
+ // Fortunately, gRPC has credentials.ErrConnDispatched, indicating that the
+ // connection is already dispatched out of gRPC. gRPC should leave it alone
+ // and exit in peace.
+ return nil, nil, credentials.ErrConnDispatched
+}
diff --git a/internal/streamrpcs/registry.go b/internal/streamrpcs/registry.go
new file mode 100644
index 000000000..c5af2942e
--- /dev/null
+++ b/internal/streamrpcs/registry.go
@@ -0,0 +1,157 @@
+package streamrpcs
+
+import (
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+// globalRegistry is a single per-process registry. The RPC handlers are
+// expected to use this registry to wait for incoming sidechannel connection
+var globalRegistry *Registry
+
+func init() {
+ globalRegistry = NewRegistry()
+}
+
+// Registry manages StreamRPC sidechannel connections. It allows the RPC
+// handlers to wait for the secondary incoming connection made by the client.
+// In details:
+// - We'd add additional RPCs (PostUploadPackStream, PackObjectsHookStream,
+// SSHUploadPackStream) to return a streaming token to the client once the
+// verifications are done and they are ready to launch git to begin streaming
+// the response.
+// - The RPC handler waits for StreamRPC connection registry
+// - Client receives the streaming token. It then dials to the listening port of Gitaly server.
+// - Listenmux handles the connection. It validates the connection, and pushes
+// to StreamRPC connection registry; then exists without error.
+// - The RPC handler receives the connection once it's accepted, and passes it to git.
+// - Git runs and streams directly to the client over the TCP.
+// - Once git returns, the gRPC handler returns normally with success/error to the client.
+type Registry struct {
+ waiters map[string]*Waiter
+ stopped bool
+ mu sync.Mutex
+}
+
+// Waiter lets the caller waits until a connection with matched token is pushed
+// into the registry.
+type Waiter struct {
+ Token string
+ err error
+ c chan net.Conn
+ timeout *time.Timer
+}
+
+// NewRegistry returns a new Registry instance
+func NewRegistry() *Registry {
+ return &Registry{
+ waiters: make(map[string]*Waiter),
+ }
+}
+
+// Register registers the caller into the waiting list. The caller must provide
+// a deadline for this operation. The caller receives a Waiter struct with
+// associated token. The caller is expected to be blocked by waiter.Wait().
+// After the connection arrives, or the deadline exceeds, the waiter struct is
+// removed from the registry automatically and the caller is unblocked.
+func (s *Registry) Register(deadline time.Time) (*Waiter, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.stopped {
+ return nil, fmt.Errorf("stream rpc registry: register already stopped")
+ }
+
+ waiter := &Waiter{
+ Token: s.generateToken(),
+ c: make(chan net.Conn, 1),
+ }
+ waiter.timeout = time.AfterFunc(time.Until(deadline), func() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.removeWaiter(waiter, fmt.Errorf("stream rpc registry: timeout exceeds"))
+ })
+
+ s.waiters[waiter.Token] = waiter
+ return waiter, nil
+}
+
+// Push pushes a connection with an pre-registered token into the registry. The
+// caller is unlocked immediately, and the waiter is removed from the registry.
+// If there isn't any waiting caller, this function still exists, the caller
+// can pulls the connection later through waiter struct.
+func (s *Registry) Push(token string, conn net.Conn) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if s.stopped {
+ return fmt.Errorf("stream rpc registry: register already stopped")
+ }
+
+ var waiter *Waiter
+ var exist bool
+ if waiter, exist = s.waiters[token]; !exist {
+ return fmt.Errorf("stream rpc registry: connection not registered")
+ }
+
+ waiter.c <- conn
+ s.removeWaiter(waiter, nil)
+
+ return nil
+}
+
+// Stop immedicately removes all waiters from the registry and prevent any
+// Register/Push operations in the future.
+func (s *Registry) Stop() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.stopped = true
+ for _, waiter := range s.waiters {
+ s.removeWaiter(waiter, fmt.Errorf("stream rpc registry: register already stopped"))
+ }
+}
+
+// Waiting returns the number of recent waiters the register is managing
+func (s *Registry) Waiting() int {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return len(s.waiters)
+}
+
+func (s *Registry) removeWaiter(waiter *Waiter, err error) {
+ if err != nil {
+ waiter.err = err
+ }
+ waiter.timeout.Stop()
+ close(waiter.c)
+ delete(s.waiters, waiter.Token)
+}
+
+// TokenSizeBytes indicates the size of stringify token. In this case, UUID
+// consists of 32 characters and 4 dashes
+const TokenSizeBytes = 36
+
+// generateToken generates a unique token to be used as the hash key for
+// waiter. UUID is a good choice to generate a random unique token. It's size
+// is deterministic, well randomlized, and fairly distributed.
+func (s *Registry) generateToken() string {
+ for {
+ token := uuid.New().String()
+ if _, exist := s.waiters[token]; !exist {
+ return token
+ }
+ }
+}
+
+// Wait blocks the caller until a matched connection arrives, or waiter
+// deadline exceeds, or the registry stops.
+func (waiter *Waiter) Wait() (net.Conn, error) {
+ return <-waiter.c, waiter.err
+}
diff --git a/internal/streamrpcs/registry_test.go b/internal/streamrpcs/registry_test.go
new file mode 100644
index 000000000..2fc2aa0cc
--- /dev/null
+++ b/internal/streamrpcs/registry_test.go
@@ -0,0 +1,166 @@
+package streamrpcs
+
+import (
+ "io/ioutil"
+ "net"
+ "os"
+ "strconv"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRegistry(t *testing.T) {
+ registry := NewRegistry()
+
+ t.Run("pull connections successfully", func(t *testing.T) {
+ wg := sync.WaitGroup{}
+ var readers []net.Conn
+
+ for i := 0; i < 10; i++ {
+ reader, writer := socketPair(t)
+ readers = append(readers, reader)
+
+ waiter, err := registry.Register(time.Now().Add(100 * time.Millisecond))
+ require.NoError(t, err)
+
+ wg.Add(1)
+ go func(id int) {
+ conn, err := waiter.Wait()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+
+ _, err = conn.Write([]byte(strconv.Itoa(id)))
+ require.NoError(t, err)
+
+ conn.Close()
+ wg.Done()
+ }(i)
+
+ go func() {
+ err := registry.Push(waiter.Token, writer)
+ require.NoError(t, err)
+ }()
+ }
+
+ wg.Wait()
+ for i := 0; i < 10; i++ {
+ out, err := ioutil.ReadAll(readers[i])
+ require.NoError(t, err)
+ require.Equal(t, string(out), strconv.Itoa(i))
+ }
+
+ require.Equal(t, registry.Waiting(), 0)
+ })
+
+ t.Run("timeout while pulling connections", func(t *testing.T) {
+ waiter, err := registry.Register(time.Now().Add(1 * time.Millisecond))
+ require.NoError(t, err)
+ require.Equal(t, registry.Waiting(), 1)
+
+ conn, err := waiter.Wait()
+ require.Nil(t, conn)
+ require.EqualError(t, err, "stream rpc registry: timeout exceeds")
+
+ require.Equal(t, registry.Waiting(), 0)
+ })
+
+ t.Run("push without having a waiting caller", func(t *testing.T) {
+ waiter, err := registry.Register(time.Now().Add(1 * time.Millisecond))
+ require.NoError(t, err)
+
+ _, writer := socketPair(t)
+ err = registry.Push(waiter.Token, writer)
+ require.NoError(t, err)
+ require.Equal(t, registry.Waiting(), 0)
+
+ conn, err := waiter.Wait()
+ require.NoError(t, err)
+ require.Equal(t, conn, writer)
+ })
+
+ t.Run("push connection to non-existing connection", func(t *testing.T) {
+ waiter, err := registry.Register(time.Now().Add(1 * time.Millisecond))
+ require.NoError(t, err)
+
+ _, writer := socketPair(t)
+ err = registry.Push("not exsting token", writer)
+ require.EqualError(t, err, "stream rpc registry: connection not registered")
+ require.Equal(t, registry.Waiting(), 1)
+
+ err = registry.Push(waiter.Token, writer)
+ require.NoError(t, err)
+ require.Equal(t, registry.Waiting(), 0)
+ })
+
+ t.Run("pull connection twice", func(t *testing.T) {
+ waiter, err := registry.Register(time.Now().Add(1 * time.Millisecond))
+ require.NoError(t, err)
+
+ _, writer := socketPair(t)
+
+ err = registry.Push(waiter.Token, writer)
+ require.NoError(t, err)
+
+ conn, err := waiter.Wait()
+ require.NotNil(t, conn)
+ require.NoError(t, err)
+
+ conn, err = waiter.Wait()
+ require.Nil(t, conn) // Not blocking. Channel already closed
+ require.NoError(t, err)
+
+ require.Equal(t, registry.Waiting(), 0)
+ })
+
+ t.Run("stop registry", func(t *testing.T) {
+ errors := make(chan error)
+
+ for i := 0; i < 10; i++ {
+ waiter, err := registry.Register(time.Now().Add(100 * time.Millisecond))
+ require.NoError(t, err)
+
+ go func() {
+ _, err := waiter.Wait()
+ errors <- err
+ }()
+ }
+ require.Equal(t, registry.Waiting(), 10)
+
+ registry.Stop()
+ require.Equal(t, registry.Waiting(), 0)
+
+ for i := 0; i < 10; i++ {
+ require.EqualError(t, <-errors, "stream rpc registry: register already stopped")
+ }
+
+ _, err := registry.Register(time.Now().Add(100 * time.Millisecond))
+ require.EqualError(t, err, "stream rpc registry: register already stopped")
+
+ _, writer := socketPair(t)
+ err = registry.Push("token", writer)
+ require.EqualError(t, err, "stream rpc registry: register already stopped")
+ })
+}
+
+func socketPair(t *testing.T) (net.Conn, net.Conn) {
+ t.Helper()
+
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ require.NoError(t, err)
+
+ conns := make([]net.Conn, 2)
+ for i, fd := range fds[:] {
+ f := os.NewFile(uintptr(fd), "socket pair")
+ c, err := net.FileConn(f)
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+ t.Cleanup(func() { c.Close() })
+ conns[i] = c
+ }
+
+ return conns[0], conns[1]
+}
diff --git a/internal/streamrpcs/rpc.go b/internal/streamrpcs/rpc.go
new file mode 100644
index 000000000..0efb08254
--- /dev/null
+++ b/internal/streamrpcs/rpc.go
@@ -0,0 +1,218 @@
+package streamrpcs
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "net"
+ "time"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap/starter"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+const sidechannelIdentityKey = "Gitaly-SideChannel-Identity"
+const defaultSidechannelTimeout = 10 * time.Second
+const maxClientRetryAttempts = 3
+
+// DialFunc is a method making a secondary sidechannel connection to a
+// token-proteted address
+type DialFunc func(time.Time, string, string) (net.Conn, error)
+
+func identity(ctx context.Context) string {
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return ""
+ }
+ values := md.Get(sidechannelIdentityKey)
+ if len(values) == 0 {
+ return ""
+ }
+ return values[0]
+}
+
+// DialNet lets Call initiate unencrypted connections. They tend to be used
+// with Gitaly's listenmux multiplexer only. After the connection is
+// established, streamrpc's 11-byte magic bytes are written into the wire.
+// Listemmux peeks into these magic bytes and redirects the request to
+// StreamRPC server.
+// Please visit internal/listenmux/mux.go for more information
+func DialNet() DialFunc {
+ return func(deadline time.Time, address string, token string) (net.Conn, error) {
+ endpoint, err := starter.ParseEndpoint(address)
+ if err != nil {
+ return nil, err
+ }
+
+ dialer := &net.Dialer{Deadline: deadline}
+ conn, err := dialer.Dial(endpoint.Name, endpoint.Addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = conn.SetDeadline(deadline); err != nil {
+ return nil, err
+ }
+ // Write the magic bytes on the connection so the server knows we're
+ // about to initiate a multiplexing session.
+ if _, err := conn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write magic bytes: %w", err)
+ }
+
+ // Write the stream token into the wire. This token lets the server
+ // matches waiting RPC handler
+ if _, err := conn.Write([]byte(token)); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write stream token: %w", err)
+ }
+
+ // Reset deadline of tls connection for later stages
+ if err = conn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return conn, nil
+ }
+}
+
+// DialTLS lets Call initiate TLS connections. Similar to DialNet, the
+// connections are used for listenmux multiplexer. There are 3 steps involving:
+// - TCP handshake
+// - TLS handshake
+// - Write streamrpc magic bytes
+func DialTLS(cfg *tls.Config) DialFunc {
+ return func(deadline time.Time, address string, token string) (net.Conn, error) {
+ dialer := &net.Dialer{Deadline: deadline}
+ tlsConn, err := tls.DialWithDialer(dialer, "tcp", address, cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ err = tlsConn.SetDeadline(deadline)
+ if err != nil {
+ return nil, err
+ }
+ // Write the magic bytes on the connection so the server knows we're
+ // about to initiate a multiplexing session.
+ if _, err := tlsConn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write backchannel magic bytes: %w", err)
+ }
+
+ // Write the stream token into the wire. This token lets the server
+ // matches waiting RPC handler
+ if _, err := tlsConn.Write([]byte(token)); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write stream token: %w", err)
+ }
+
+ // Reset deadline of tls connection for later stages
+ if err = tlsConn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return tlsConn, nil
+ }
+}
+
+// Call enables the client to make gRPC calls to the server. While handling the
+// RPC call, a sidechannel TCP connection is established between client and
+// server over the same listening gRPC port. This allows the clients and
+// servers exchange information over the raw TCP connection without the
+// overhead of gRPC. This consists of some steps:
+//
+// - Client preares the client stream via `handshake` func. It may send as many
+// requests to the server as it wants. A typical use case is to send repository
+// information for validation beforehand.
+// - After the `handshake` func exits, this method waits for StreamToken
+// response from the server.
+// - This method establishes a sidechannel TCP connection to gRPC server. This
+// can be done thanks to listenmux multiplexer.
+// - The raw connection is given back to client `handler` func for further data
+// exchange.
+// - This method waits until the stream is closed.
+//
+// As we are making two sequential calls with sub small steps, a lot of things
+// may happen. One notable case is that the secondary dial may fail during
+// deployment when a new Gitaly process is spawn. Therefore we should retry
+// multiple times if any step fails.
+func Call(ctx context.Context, addr string, handshake func(context.Context) (grpc.ClientStream, error), dial DialFunc, handler func(net.Conn) error) (finalError error) {
+ doCall := func() (err error) {
+ var stream grpc.ClientStream
+ var streamToken gitalypb.StreamToken
+
+ // Make the first call(s). Let the caller preare the request data.
+ ctx = metadata.AppendToOutgoingContext(ctx, sidechannelIdentityKey, addr)
+ if stream, err = handshake(ctx); err != nil {
+ return err
+ }
+
+ // We don't need to send any further information
+ if err = stream.CloseSend(); err != nil {
+ return err
+ }
+
+ // Wait for stream token from the server
+ if err = stream.RecvMsg(&streamToken); err != nil {
+ return err
+ }
+
+ // Make the secondary call to the same address, with the token received
+ // from the server
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ deadline = time.Now().Add(defaultSidechannelTimeout)
+ }
+
+ conn, err := dial(deadline, addr, streamToken.Token)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ // Delegate the raw connection to the caller
+ if err = handler(conn); err != nil {
+ return err
+ }
+
+ // The server should return, and close the streaming RPC.
+ err = stream.RecvMsg(&streamToken)
+ if err == nil {
+ return fmt.Errorf("streamrpc client: expected server stream closed")
+ } else if err == io.EOF {
+ return nil
+ } else {
+ return err
+ }
+ }
+ for i := 0; i < maxClientRetryAttempts; i++ {
+ finalError = doCall()
+ if finalError == nil {
+ break
+ }
+ }
+ return finalError
+}
+
+// AcceptConnection blocks the RPC handlers until the sidechannel TCP connection arrives.
+func AcceptConnection(ctx context.Context, stream grpc.ServerStream) (net.Conn, error) {
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ deadline = time.Now().Add(defaultSidechannelTimeout)
+ }
+
+ waiter, err := globalRegistry.Register(deadline)
+ if err != nil {
+ return nil, err
+ }
+
+ err = stream.SendMsg(&gitalypb.StreamToken{
+ Cookie: identity(ctx),
+ Token: waiter.Token,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return waiter.Wait()
+}
diff --git a/internal/streamrpcs/rpc_test.go b/internal/streamrpcs/rpc_test.go
new file mode 100644
index 000000000..faab72643
--- /dev/null
+++ b/internal/streamrpcs/rpc_test.go
@@ -0,0 +1,374 @@
+package streamrpcs
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ gitalyauth "gitlab.com/gitlab-org/gitaly/v14/auth"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap/starter"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux"
+ gitalylog "gitlab.com/gitlab-org/gitaly/v14/internal/log"
+ testpb "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpcs/testdata"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
+)
+
+func TestCall(t *testing.T) {
+ const blobSize = 1024 * 1024
+
+ var receivedField string
+
+ in := make([]byte, blobSize)
+ _, err := rand.Read(in)
+ require.NoError(t, err)
+
+ var out []byte
+ require.NotEqual(t, in, out)
+
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ request, err := stream.Recv()
+ require.NoError(t, err)
+
+ receivedField = request.StringField
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+
+ if _, err = io.CopyN(conn, conn, blobSize); err != nil {
+ return err
+ }
+
+ return conn.Close()
+ },
+ )
+
+ ctx := context.Background()
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ out, err = ioutil.ReadAll(conn)
+ errC <- err
+ }()
+
+ _, err = io.Copy(conn, bytes.NewReader(in))
+ require.NoError(t, err)
+ require.NoError(t, <-errC)
+
+ return nil
+ },
+ ))
+
+ require.Equal(t, "hello world", receivedField, "request propagates")
+ require.Equal(t, in, out, "byte stream works")
+}
+
+func TestCall_serverError(t *testing.T) {
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ return errors.New("this is the server error")
+ },
+ )
+ ctx := context.Background()
+ require.EqualError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ), "rpc error: code = Unknown desc = this is the server error")
+}
+
+func TestCall_serverMiddleware(t *testing.T) {
+ const (
+ testKey = "testkey"
+ testValue = "testvalue"
+ testMethod = "/test.streamrpc.Test/Stream"
+ )
+
+ var (
+ middlewareMethod string
+ receivedValues []string
+ )
+
+ interceptorDone := make(chan struct{})
+
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ return nil
+ },
+ grpc.StreamInterceptor(func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ defer close(interceptorDone)
+ middlewareMethod = info.FullMethod
+ if md, ok := metadata.FromIncomingContext(ss.Context()); ok {
+ receivedValues = md[testKey]
+ }
+ return handler(srv, ss)
+ }),
+ )
+
+ ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue)
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ))
+
+ <-interceptorDone
+ require.Equal(t, testMethod, middlewareMethod, "server middleware sees correct method")
+ require.Equal(t, []string{testValue}, receivedValues, "server middleware sees context metadata")
+}
+
+func TestCall_serverMiddlewareReject(t *testing.T) {
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ panic("never reached")
+ },
+ grpc.StreamInterceptor(func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ return status.Errorf(codes.PermissionDenied, "permission denied")
+ }),
+ )
+
+ ctx := context.Background()
+ testhelper.RequireGrpcError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ), codes.PermissionDenied)
+}
+
+func TestCall_credentials(t *testing.T) {
+ var receivedValue string
+ interceptorDone := make(chan struct{})
+
+ _, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ defer close(interceptorDone)
+ ctx := stream.Context()
+
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ receivedValue = md.Get("authorization")[0]
+ }
+
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ conn, err := AcceptConnection(ctx, stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ return nil
+ },
+ )
+
+ endpoint, _ := starter.ParseEndpoint(addr)
+ conn, err := grpc.Dial(
+ endpoint.Addr,
+ grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2("foobar")),
+ grpc.WithInsecure(),
+ )
+ require.NoError(t, err)
+
+ client := testpb.NewTestClient(conn)
+
+ ctx := context.Background()
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ), codes.PermissionDenied)
+
+ <-interceptorDone
+ require.Contains(t, receivedValue, "Bearer v2.")
+}
+
+func TestCall_clientRetries(t *testing.T) {
+ t.Run("error before receiving the first request", func(t *testing.T) {
+ failure := 2
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ if failure > 0 {
+ failure--
+ return errors.New("server rejected")
+ }
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ return nil
+ },
+ )
+ ctx := context.Background()
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ))
+ require.Zero(t, failure)
+ })
+
+ t.Run("error before waiting for the connection", func(t *testing.T) {
+ failure := 2
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ if failure > 0 {
+ failure--
+ return errors.New("server closed unexpected")
+ }
+
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ return nil
+ },
+ )
+ ctx := context.Background()
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ))
+ require.Zero(t, failure)
+ })
+
+ t.Run("error after connection establishment", func(t *testing.T) {
+ failure := 2
+ client, addr := startServer(
+ t,
+ func(stream testpb.Test_StreamServer) error {
+ _, err := stream.Recv()
+ if err != nil {
+ return err
+ }
+
+ conn, err := AcceptConnection(stream.Context(), stream)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if failure > 0 {
+ failure--
+ return errors.New("server closed unexpected")
+ }
+ return nil
+ },
+ )
+ ctx := context.Background()
+ require.NoError(t, Call(
+ ctx, addr, handshake(client), DialNet(),
+ func(conn net.Conn) error { return nil },
+ ))
+ require.Zero(t, failure)
+ })
+}
+
+func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) (testpb.TestClient, string) {
+ t.Helper()
+
+ transportCredentials := insecure.NewCredentials()
+ lm := listenmux.New(transportCredentials)
+ lm.Register(NewServerHandshaker(
+ gitalylog.Default(),
+ ))
+ opts = append(opts, grpc.Creds(lm))
+
+ s := grpc.NewServer(opts...)
+ t.Cleanup(func() { s.Stop() })
+
+ handler := &server{testHandler: th}
+ testpb.RegisterTestServer(s, handler)
+
+ lis, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+ t.Cleanup(func() { lis.Close() })
+
+ go func() { s.Serve(lis) }()
+
+ conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
+ require.NoError(t, err)
+ t.Cleanup(func() { conn.Close() })
+
+ client := testpb.NewTestClient(conn)
+
+ return client, "tcp://" + lis.Addr().String()
+}
+
+func handshake(client testpb.TestClient) func(context.Context) (grpc.ClientStream, error) {
+ return func(ctx context.Context) (grpc.ClientStream, error) {
+ stream, err := client.Stream(ctx)
+ if err != nil {
+ return stream, err
+ }
+ if err = stream.Send(&testpb.StreamRequest{StringField: "hello world"}); err != nil {
+ return stream, err
+ }
+ return stream, nil
+ }
+}
+
+type testHandler func(stream testpb.Test_StreamServer) error
+
+type server struct {
+ testpb.UnimplementedTestServer
+ testHandler
+}
+
+func (s *server) Stream(stream testpb.Test_StreamServer) error {
+ return s.testHandler(stream)
+}
diff --git a/internal/streamrpcs/testdata/test.pb.go b/internal/streamrpcs/testdata/test.pb.go
new file mode 100644
index 000000000..0fe144f48
--- /dev/null
+++ b/internal/streamrpcs/testdata/test.pb.go
@@ -0,0 +1,166 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.26.0
+// protoc v3.17.3
+// source: streamrpc/testdata/test.proto
+
+package testdata
+
+import (
+ gitalypb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type StreamRequest struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Fail bool `protobuf:"varint,1,opt,name=fail,proto3" json:"fail,omitempty"`
+ StringField string `protobuf:"bytes,2,opt,name=string_field,json=stringField,proto3" json:"string_field,omitempty"`
+}
+
+func (x *StreamRequest) Reset() {
+ *x = StreamRequest{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_streamrpc_testdata_test_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *StreamRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*StreamRequest) ProtoMessage() {}
+
+func (x *StreamRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_streamrpc_testdata_test_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use StreamRequest.ProtoReflect.Descriptor instead.
+func (*StreamRequest) Descriptor() ([]byte, []int) {
+ return file_streamrpc_testdata_test_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *StreamRequest) GetFail() bool {
+ if x != nil {
+ return x.Fail
+ }
+ return false
+}
+
+func (x *StreamRequest) GetStringField() string {
+ if x != nil {
+ return x.StringField
+ }
+ return ""
+}
+
+var File_streamrpc_testdata_test_proto protoreflect.FileDescriptor
+
+var file_streamrpc_testdata_test_proto_rawDesc = []byte{
+ 0x0a, 0x1d, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74,
+ 0x64, 0x61, 0x74, 0x61, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
+ 0x0e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x1a,
+ 0x0f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
+ 0x22, 0x46, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x61, 0x69, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x04, 0x66, 0x61, 0x69, 0x6c, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f,
+ 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x74, 0x72,
+ 0x69, 0x6e, 0x67, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x32, 0x4a, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74,
+ 0x12, 0x42, 0x0a, 0x06, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x1d, 0x2e, 0x74, 0x65, 0x73,
+ 0x74, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x74, 0x72, 0x65,
+ 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x67, 0x69, 0x74, 0x61,
+ 0x6c, 0x79, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0x00,
+ 0x28, 0x01, 0x30, 0x01, 0x42, 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x63,
+ 0x6f, 0x6d, 0x2f, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2d, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x69,
+ 0x74, 0x61, 0x6c, 0x79, 0x2f, 0x76, 0x31, 0x34, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61,
+ 0x6c, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74,
+ 0x64, 0x61, 0x74, 0x61, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_streamrpc_testdata_test_proto_rawDescOnce sync.Once
+ file_streamrpc_testdata_test_proto_rawDescData = file_streamrpc_testdata_test_proto_rawDesc
+)
+
+func file_streamrpc_testdata_test_proto_rawDescGZIP() []byte {
+ file_streamrpc_testdata_test_proto_rawDescOnce.Do(func() {
+ file_streamrpc_testdata_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_streamrpc_testdata_test_proto_rawDescData)
+ })
+ return file_streamrpc_testdata_test_proto_rawDescData
+}
+
+var file_streamrpc_testdata_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_streamrpc_testdata_test_proto_goTypes = []interface{}{
+ (*StreamRequest)(nil), // 0: test.streamrpc.StreamRequest
+ (*gitalypb.StreamToken)(nil), // 1: gitaly.StreamToken
+}
+var file_streamrpc_testdata_test_proto_depIdxs = []int32{
+ 0, // 0: test.streamrpc.Test.Stream:input_type -> test.streamrpc.StreamRequest
+ 1, // 1: test.streamrpc.Test.Stream:output_type -> gitaly.StreamToken
+ 1, // [1:2] is the sub-list for method output_type
+ 0, // [0:1] is the sub-list for method input_type
+ 0, // [0:0] is the sub-list for extension type_name
+ 0, // [0:0] is the sub-list for extension extendee
+ 0, // [0:0] is the sub-list for field type_name
+}
+
+func init() { file_streamrpc_testdata_test_proto_init() }
+func file_streamrpc_testdata_test_proto_init() {
+ if File_streamrpc_testdata_test_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_streamrpc_testdata_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*StreamRequest); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_streamrpc_testdata_test_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 1,
+ },
+ GoTypes: file_streamrpc_testdata_test_proto_goTypes,
+ DependencyIndexes: file_streamrpc_testdata_test_proto_depIdxs,
+ MessageInfos: file_streamrpc_testdata_test_proto_msgTypes,
+ }.Build()
+ File_streamrpc_testdata_test_proto = out.File
+ file_streamrpc_testdata_test_proto_rawDesc = nil
+ file_streamrpc_testdata_test_proto_goTypes = nil
+ file_streamrpc_testdata_test_proto_depIdxs = nil
+}
diff --git a/internal/streamrpcs/testdata/test.proto b/internal/streamrpcs/testdata/test.proto
new file mode 100644
index 000000000..cba0a3afa
--- /dev/null
+++ b/internal/streamrpcs/testdata/test.proto
@@ -0,0 +1,16 @@
+syntax = "proto3";
+
+package test.streamrpc;
+
+option go_package = "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata";
+
+import "streamrpc.proto";
+
+service Test {
+ rpc Stream(stream StreamRequest) returns (stream gitaly.StreamToken) {}
+}
+
+message StreamRequest {
+ bool fail = 1;
+ string string_field = 2;
+}
diff --git a/internal/streamrpcs/testdata/test_grpc.pb.go b/internal/streamrpcs/testdata/test_grpc.pb.go
new file mode 100644
index 000000000..4f34e5230
--- /dev/null
+++ b/internal/streamrpcs/testdata/test_grpc.pb.go
@@ -0,0 +1,134 @@
+// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
+
+package testdata
+
+import (
+ context "context"
+ gitalypb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ grpc "google.golang.org/grpc"
+ codes "google.golang.org/grpc/codes"
+ status "google.golang.org/grpc/status"
+)
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+// Requires gRPC-Go v1.32.0 or later.
+const _ = grpc.SupportPackageIsVersion7
+
+// TestClient is the client API for Test service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
+type TestClient interface {
+ Stream(ctx context.Context, opts ...grpc.CallOption) (Test_StreamClient, error)
+}
+
+type testClient struct {
+ cc grpc.ClientConnInterface
+}
+
+func NewTestClient(cc grpc.ClientConnInterface) TestClient {
+ return &testClient{cc}
+}
+
+func (c *testClient) Stream(ctx context.Context, opts ...grpc.CallOption) (Test_StreamClient, error) {
+ stream, err := c.cc.NewStream(ctx, &Test_ServiceDesc.Streams[0], "/test.streamrpc.Test/Stream", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &testStreamClient{stream}
+ return x, nil
+}
+
+type Test_StreamClient interface {
+ Send(*StreamRequest) error
+ Recv() (*gitalypb.StreamToken, error)
+ grpc.ClientStream
+}
+
+type testStreamClient struct {
+ grpc.ClientStream
+}
+
+func (x *testStreamClient) Send(m *StreamRequest) error {
+ return x.ClientStream.SendMsg(m)
+}
+
+func (x *testStreamClient) Recv() (*gitalypb.StreamToken, error) {
+ m := new(gitalypb.StreamToken)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// TestServer is the server API for Test service.
+// All implementations must embed UnimplementedTestServer
+// for forward compatibility
+type TestServer interface {
+ Stream(Test_StreamServer) error
+ mustEmbedUnimplementedTestServer()
+}
+
+// UnimplementedTestServer must be embedded to have forward compatible implementations.
+type UnimplementedTestServer struct {
+}
+
+func (UnimplementedTestServer) Stream(Test_StreamServer) error {
+ return status.Errorf(codes.Unimplemented, "method Stream not implemented")
+}
+func (UnimplementedTestServer) mustEmbedUnimplementedTestServer() {}
+
+// UnsafeTestServer may be embedded to opt out of forward compatibility for this service.
+// Use of this interface is not recommended, as added methods to TestServer will
+// result in compilation errors.
+type UnsafeTestServer interface {
+ mustEmbedUnimplementedTestServer()
+}
+
+func RegisterTestServer(s grpc.ServiceRegistrar, srv TestServer) {
+ s.RegisterService(&Test_ServiceDesc, srv)
+}
+
+func _Test_Stream_Handler(srv interface{}, stream grpc.ServerStream) error {
+ return srv.(TestServer).Stream(&testStreamServer{stream})
+}
+
+type Test_StreamServer interface {
+ Send(*gitalypb.StreamToken) error
+ Recv() (*StreamRequest, error)
+ grpc.ServerStream
+}
+
+type testStreamServer struct {
+ grpc.ServerStream
+}
+
+func (x *testStreamServer) Send(m *gitalypb.StreamToken) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func (x *testStreamServer) Recv() (*StreamRequest, error) {
+ m := new(StreamRequest)
+ if err := x.ServerStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// Test_ServiceDesc is the grpc.ServiceDesc for Test service.
+// It's only intended for direct use with grpc.RegisterService,
+// and not to be introspected or modified (even as a copy)
+var Test_ServiceDesc = grpc.ServiceDesc{
+ ServiceName: "test.streamrpc.Test",
+ HandlerType: (*TestServer)(nil),
+ Methods: []grpc.MethodDesc{},
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "Stream",
+ Handler: _Test_Stream_Handler,
+ ServerStreams: true,
+ ClientStreams: true,
+ },
+ },
+ Metadata: "streamrpc/testdata/test.proto",
+}
diff --git a/proto/go/gitalypb/protolist.go b/proto/go/gitalypb/protolist.go
index a15916f70..4ee2a1be2 100644
--- a/proto/go/gitalypb/protolist.go
+++ b/proto/go/gitalypb/protolist.go
@@ -23,6 +23,7 @@ var GitalyProtos = []string{
"shared.proto",
"smarthttp.proto",
"ssh.proto",
+ "streamrpc.proto",
"transaction.proto",
"wiki.proto",
}
diff --git a/proto/go/gitalypb/streamrpc.pb.go b/proto/go/gitalypb/streamrpc.pb.go
new file mode 100644
index 000000000..c2a498e98
--- /dev/null
+++ b/proto/go/gitalypb/streamrpc.pb.go
@@ -0,0 +1,155 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.26.0
+// protoc v3.17.3
+// source: streamrpc.proto
+
+package gitalypb
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type StreamToken struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Cookie string `protobuf:"bytes,1,opt,name=cookie,proto3" json:"cookie,omitempty"`
+ Token string `protobuf:"bytes,2,opt,name=token,proto3" json:"token,omitempty"`
+}
+
+func (x *StreamToken) Reset() {
+ *x = StreamToken{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_streamrpc_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *StreamToken) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*StreamToken) ProtoMessage() {}
+
+func (x *StreamToken) ProtoReflect() protoreflect.Message {
+ mi := &file_streamrpc_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use StreamToken.ProtoReflect.Descriptor instead.
+func (*StreamToken) Descriptor() ([]byte, []int) {
+ return file_streamrpc_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *StreamToken) GetCookie() string {
+ if x != nil {
+ return x.Cookie
+ }
+ return ""
+}
+
+func (x *StreamToken) GetToken() string {
+ if x != nil {
+ return x.Token
+ }
+ return ""
+}
+
+var File_streamrpc_proto protoreflect.FileDescriptor
+
+var file_streamrpc_proto_rawDesc = []byte{
+ 0x0a, 0x0f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x12, 0x06, 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x1a, 0x0a, 0x6c, 0x69, 0x6e, 0x74, 0x2e,
+ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x3b, 0x0a, 0x0b, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x54,
+ 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x18, 0x01,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x12, 0x14, 0x0a, 0x05,
+ 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b,
+ 0x65, 0x6e, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
+ 0x2f, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2d, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x69, 0x74, 0x61,
+ 0x6c, 0x79, 0x2f, 0x76, 0x31, 0x34, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f,
+ 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_streamrpc_proto_rawDescOnce sync.Once
+ file_streamrpc_proto_rawDescData = file_streamrpc_proto_rawDesc
+)
+
+func file_streamrpc_proto_rawDescGZIP() []byte {
+ file_streamrpc_proto_rawDescOnce.Do(func() {
+ file_streamrpc_proto_rawDescData = protoimpl.X.CompressGZIP(file_streamrpc_proto_rawDescData)
+ })
+ return file_streamrpc_proto_rawDescData
+}
+
+var file_streamrpc_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_streamrpc_proto_goTypes = []interface{}{
+ (*StreamToken)(nil), // 0: gitaly.StreamToken
+}
+var file_streamrpc_proto_depIdxs = []int32{
+ 0, // [0:0] is the sub-list for method output_type
+ 0, // [0:0] is the sub-list for method input_type
+ 0, // [0:0] is the sub-list for extension type_name
+ 0, // [0:0] is the sub-list for extension extendee
+ 0, // [0:0] is the sub-list for field type_name
+}
+
+func init() { file_streamrpc_proto_init() }
+func file_streamrpc_proto_init() {
+ if File_streamrpc_proto != nil {
+ return
+ }
+ file_lint_proto_init()
+ if !protoimpl.UnsafeEnabled {
+ file_streamrpc_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*StreamToken); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_streamrpc_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 0,
+ },
+ GoTypes: file_streamrpc_proto_goTypes,
+ DependencyIndexes: file_streamrpc_proto_depIdxs,
+ MessageInfos: file_streamrpc_proto_msgTypes,
+ }.Build()
+ File_streamrpc_proto = out.File
+ file_streamrpc_proto_rawDesc = nil
+ file_streamrpc_proto_goTypes = nil
+ file_streamrpc_proto_depIdxs = nil
+}
diff --git a/proto/streamrpc.proto b/proto/streamrpc.proto
new file mode 100644
index 000000000..26eae00ae
--- /dev/null
+++ b/proto/streamrpc.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package gitaly;
+
+option go_package = "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb";
+
+import "lint.proto";
+
+message StreamToken {
+ string cookie = 1;
+ string token = 2;
+}
diff --git a/ruby/proto/gitaly/streamrpc_pb.rb b/ruby/proto/gitaly/streamrpc_pb.rb
new file mode 100644
index 000000000..f8f7563de
--- /dev/null
+++ b/ruby/proto/gitaly/streamrpc_pb.rb
@@ -0,0 +1,18 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: streamrpc.proto
+
+require 'google/protobuf'
+
+require 'lint_pb'
+Google::Protobuf::DescriptorPool.generated_pool.build do
+ add_file("streamrpc.proto", :syntax => :proto3) do
+ add_message "gitaly.StreamToken" do
+ optional :cookie, :string, 1
+ optional :token, :string, 2
+ end
+ end
+end
+
+module Gitaly
+ StreamToken = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gitaly.StreamToken").msgclass
+end