Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-28 09:34:06 +0300
committerQuang-Minh Nguyen <qmnguyen@gitlab.com>2021-07-28 09:34:24 +0300
commit261b3a0a6759a7c1859d3f10d0370c8861c42afc (patch)
treeed147648317277079aba63793aa77728ab8d302e
parenta9540f961bde8fa51872532d4f9dda3e4b4bb0ac (diff)
Combine listenmux and streamrpc into a publicly exposed TestStream RPCqmnguyen0711/1151-combine-listenmux-and-streamrpc-into-a-publicly-exposed-test-rpc
- Add a new RPC TestStream to gitaly/proto, which uses the StreamRPC protocol - Put all the pieces together so that Gitaly can accept StreamRPC calls and handle the TestStream RPC over TCP, TLS, Unix sockets - The first iteration should include metrics, logging and authentication Changelog: added
-rw-r--r--cmd/gitaly-ssh/auth_test.go4
-rw-r--r--cmd/gitaly/main.go37
-rw-r--r--internal/gitaly/server/auth_test.go82
-rw-r--r--internal/gitaly/server/server.go88
-rw-r--r--internal/gitaly/server/server_factory.go29
-rw-r--r--internal/gitaly/server/server_factory_test.go198
-rw-r--r--internal/gitaly/service/setup/register.go16
-rw-r--r--internal/gitaly/service/teststream/server.go37
-rw-r--r--internal/gitaly/service/teststream/server_test.go115
-rw-r--r--internal/streamrpc/handshaker.go137
-rw-r--r--internal/streamrpc/rpc_test.go34
-rw-r--r--internal/streamrpc/server.go11
-rw-r--r--proto/go/gitalypb/protolist.go1
-rw-r--r--proto/go/gitalypb/teststream.pb.go173
-rw-r--r--proto/go/gitalypb/teststream_grpc.pb.go102
-rw-r--r--proto/teststream.proto22
-rw-r--r--ruby/proto/gitaly.rb2
-rw-r--r--ruby/proto/gitaly/teststream_pb.rb20
-rw-r--r--ruby/proto/gitaly/teststream_services_pb.rb22
19 files changed, 954 insertions, 176 deletions
diff --git a/cmd/gitaly-ssh/auth_test.go b/cmd/gitaly-ssh/auth_test.go
index 777b0757d..cce031c4c 100644
--- a/cmd/gitaly-ssh/auth_test.go
+++ b/cmd/gitaly-ssh/auth_test.go
@@ -23,6 +23,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/setup"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/transaction"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitlab"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testserver"
@@ -152,7 +153,8 @@ func runServer(t *testing.T, secure bool, cfg config.Cfg, connectionType string,
hookManager := hook.NewManager(locator, txManager, gitlab.NewMockClient(), cfg)
gitCmdFactory := git.NewExecCommandFactory(cfg)
diskCache := cache.New(cfg, locator)
- srv, err := server.New(secure, cfg, testhelper.DiscardTestEntry(t), registry, diskCache)
+ streamRPCServer := streamrpc.NewServer()
+ srv, err := server.New(secure, cfg, testhelper.DiscardTestEntry(t), registry, diskCache, streamRPCServer)
require.NoError(t, err)
setup.RegisterAll(srv, &service.Dependencies{
Cfg: cfg,
diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go
index dd4243003..be48a6c20 100644
--- a/cmd/gitaly/main.go
+++ b/cmd/gitaly/main.go
@@ -29,6 +29,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitlab"
glog "gitlab.com/gitlab-org/gitaly/v14/internal/log"
"gitlab.com/gitlab-org/gitaly/v14/internal/storage"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
"gitlab.com/gitlab-org/gitaly/v14/internal/tempdir"
"gitlab.com/gitlab-org/gitaly/v14/internal/version"
"gitlab.com/gitlab-org/labkit/monitoring"
@@ -196,6 +197,19 @@ func run(cfg config.Cfg) error {
}
defer rubySrv.Stop()
+ deps := &service.Dependencies{
+ Cfg: cfg,
+ RubyServer: rubySrv,
+ GitalyHookManager: hookManager,
+ TransactionManager: transactionManager,
+ StorageLocator: locator,
+ ClientPool: conns,
+ GitCmdFactory: gitCmdFactory,
+ Linguist: ling,
+ CatfileCache: catfileCache,
+ DiskCache: diskCache,
+ }
+
for _, c := range []starter.Config{
{Name: starter.Unix, Addr: cfg.SocketPath, HandoverOnUpgrade: true},
{Name: starter.Unix, Addr: cfg.GitalyInternalSocketPath(), HandoverOnUpgrade: false},
@@ -206,32 +220,23 @@ func run(cfg config.Cfg) error {
continue
}
- var srv *grpc.Server
+ var grpcSrv *grpc.Server
+ var srpcSrv *streamrpc.Server
if c.HandoverOnUpgrade {
- srv, err = gitalyServerFactory.CreateExternal(c.IsSecure())
+ grpcSrv, srpcSrv, err = gitalyServerFactory.CreateExternal(c.IsSecure())
if err != nil {
return fmt.Errorf("create external gRPC server: %w", err)
}
} else {
- srv, err = gitalyServerFactory.CreateInternal()
+ grpcSrv, srpcSrv, err = gitalyServerFactory.CreateInternal()
if err != nil {
return fmt.Errorf("create internal gRPC server: %w", err)
}
}
- setup.RegisterAll(srv, &service.Dependencies{
- Cfg: cfg,
- RubyServer: rubySrv,
- GitalyHookManager: hookManager,
- TransactionManager: transactionManager,
- StorageLocator: locator,
- ClientPool: conns,
- GitCmdFactory: gitCmdFactory,
- Linguist: ling,
- CatfileCache: catfileCache,
- DiskCache: diskCache,
- })
- b.RegisterStarter(starter.New(c, srv))
+ setup.RegisterAll(grpcSrv, deps)
+ setup.RegisterAll(srpcSrv, deps)
+ b.RegisterStarter(starter.New(c, grpcSrv))
}
if addr := cfg.PrometheusListenAddr; addr != "" {
diff --git a/internal/gitaly/server/auth_test.go b/internal/gitaly/server/auth_test.go
index 35959828e..6dee3c821 100644
--- a/internal/gitaly/server/auth_test.go
+++ b/internal/gitaly/server/auth_test.go
@@ -25,8 +25,10 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/hook"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/setup"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/teststream"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/transaction"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitlab"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
@@ -86,30 +88,54 @@ func TestTLSSanity(t *testing.T) {
func TestAuthFailures(t *testing.T) {
testCases := []struct {
- desc string
- opts []grpc.DialOption
- code codes.Code
+ desc string
+ creds credentials.PerRPCCredentials
+ code codes.Code
+ message string
}{
- {desc: "no auth", opts: nil, code: codes.Unauthenticated},
{
- desc: "invalid auth",
- opts: []grpc.DialOption{grpc.WithPerRPCCredentials(brokenAuth{})},
- code: codes.Unauthenticated,
+ desc: "no auth",
+ creds: nil,
+ code: codes.Unauthenticated,
+ message: "rpc error: code = Unauthenticated desc = authentication required",
},
{
- desc: "wrong secret",
- opts: []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2("foobar"))},
- code: codes.PermissionDenied,
+ desc: "invalid auth",
+ creds: brokenAuth{},
+ code: codes.Unauthenticated,
+ message: "rpc error: code = Unauthenticated desc = authentication required",
+ },
+ {
+ desc: "wrong secret",
+ creds: gitalyauth.RPCCredentialsV2("foobar"),
+ code: codes.PermissionDenied,
+ message: "rpc error: code = PermissionDenied desc = permission denied",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- serverSocketPath := runServer(t, config.Cfg{Auth: auth.Config{Token: "quxbaz"}})
- connOpts := append(tc.opts, grpc.WithInsecure())
+ cfg, repo, _ := testcfg.BuildWithRepo(t, testcfg.WithBase(config.Cfg{
+ Auth: auth.Config{Token: "quxbaz"},
+ }))
+ serverSocketPath := runServer(t, cfg)
+
+ // Make a healthcheck gRPC call
+ connOpts := []grpc.DialOption{grpc.WithInsecure()}
+ if tc.creds != nil {
+ connOpts = append(connOpts, grpc.WithPerRPCCredentials(tc.creds))
+ }
conn, err := dial(serverSocketPath, connOpts)
require.NoError(t, err, tc.desc)
t.Cleanup(func() { conn.Close() })
testhelper.RequireGrpcError(t, healthCheck(conn), tc.code)
+
+ // // Make a streamRPC call
+ var callOpts []streamrpc.CallOption
+ if tc.creds != nil {
+ callOpts = append(callOpts, streamrpc.WithCredentials(tc.creds))
+ }
+ _, _, err = checkStreamRPC(t, streamrpc.DialNet(serverSocketPath), repo, callOpts...)
+ require.EqualError(t, err, tc.message)
})
}
}
@@ -119,40 +145,54 @@ func TestAuthSuccess(t *testing.T) {
testCases := []struct {
desc string
- opts []grpc.DialOption
+ creds credentials.PerRPCCredentials
required bool
token string
}{
{desc: "no auth, not required"},
{
desc: "v2 correct auth, not required",
- opts: []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token))},
+ creds: gitalyauth.RPCCredentialsV2(token),
token: token,
},
{
desc: "v2 incorrect auth, not required",
- opts: []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2("incorrect"))},
+ creds: gitalyauth.RPCCredentialsV2("incorrect"),
token: token,
},
{
desc: "v2 correct auth, required",
- opts: []grpc.DialOption{grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(token))},
+ creds: gitalyauth.RPCCredentialsV2(token),
token: token,
required: true,
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
- cfg := testcfg.Build(t, testcfg.WithBase(config.Cfg{
+ cfg, repo, _ := testcfg.BuildWithRepo(t, testcfg.WithBase(config.Cfg{
Auth: auth.Config{Token: tc.token, Transitioning: !tc.required},
}))
serverSocketPath := runServer(t, cfg)
- connOpts := append(tc.opts, grpc.WithInsecure())
+
+ // Make a healthcheck gRPC call
+ connOpts := []grpc.DialOption{grpc.WithInsecure()}
+ if tc.creds != nil {
+ connOpts = append(connOpts, grpc.WithPerRPCCredentials(tc.creds))
+ }
conn, err := dial(serverSocketPath, connOpts)
require.NoError(t, err, tc.desc)
t.Cleanup(func() { conn.Close() })
assert.NoError(t, healthCheck(conn), tc.desc)
+
+ // // Make a streamRPC call
+ var callOpts []streamrpc.CallOption
+ if tc.creds != nil {
+ callOpts = append(callOpts, streamrpc.WithCredentials(tc.creds))
+ }
+ in, out, err := checkStreamRPC(t, streamrpc.DialNet(serverSocketPath), repo, callOpts...)
+ require.NoError(t, err)
+ require.Equal(t, in, out)
})
}
}
@@ -201,8 +241,10 @@ func runServer(t *testing.T, cfg config.Cfg) string {
gitCmdFactory := git.NewExecCommandFactory(cfg)
catfileCache := catfile.NewCache(cfg)
diskCache := cache.New(cfg, locator)
+ streamRPCServer := streamrpc.NewServer()
+ gitalypb.RegisterTestStreamServiceServer(streamRPCServer, teststream.NewServer(locator))
- srv, err := New(false, cfg, testhelper.DiscardTestEntry(t), registry, diskCache)
+ srv, err := New(false, cfg, testhelper.DiscardTestEntry(t), registry, diskCache, streamRPCServer)
require.NoError(t, err)
setup.RegisterAll(srv, &service.Dependencies{
@@ -236,7 +278,7 @@ func runSecureServer(t *testing.T, cfg config.Cfg) string {
conns := client.NewPool()
t.Cleanup(func() { conns.Close() })
- srv, err := New(true, cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
+ srv, err := New(true, cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)), streamrpc.NewServer())
require.NoError(t, err)
healthpb.RegisterHealthServer(srv, health.NewServer())
diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go
index ffdf55c10..82b48f958 100644
--- a/internal/gitaly/server/server.go
+++ b/internal/gitaly/server/server.go
@@ -28,6 +28,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/middleware/panichandler"
"gitlab.com/gitlab-org/gitaly/v14/internal/middleware/sentryhandler"
"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc"
grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc"
"google.golang.org/grpc"
@@ -74,6 +75,7 @@ func New(
logrusEntry *log.Entry,
registry *backchannel.Registry,
cacheInvalidator diskcache.Invalidator,
+ streamRPCServer *streamrpc.Server,
) (*grpc.Server, error) {
ctxTagOpts := []grpcmwtags.Option{
grpcmwtags.WithFieldExtractorForInitialReq(fieldextractors.FieldExtractor),
@@ -96,53 +98,63 @@ func New(
})
}
+ serverStreamInterceptorChain := grpcmw.ChainStreamServer(
+ grpcmwtags.StreamServerInterceptor(ctxTagOpts...),
+ grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler
+ metadatahandler.StreamInterceptor,
+ grpcprometheus.StreamServerInterceptor,
+ commandstatshandler.StreamInterceptor,
+ grpcmwlogrus.StreamServerInterceptor(logrusEntry,
+ grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
+ grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
+ sentryhandler.StreamLogHandler,
+ cancelhandler.Stream, // Should be below LogHandler
+ auth.StreamServerInterceptor(cfg.Auth),
+ lh.StreamInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued
+ grpctracing.StreamServerTracingInterceptor(),
+ cache.StreamInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered),
+ // Panic handler should remain last so that application panics will be
+ // converted to errors and logged
+ panichandler.StreamPanicHandler,
+ )
+
+ serverUnaryInterceptorChain := grpcmw.ChainUnaryServer(
+ grpcmwtags.UnaryServerInterceptor(ctxTagOpts...),
+ grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler
+ metadatahandler.UnaryInterceptor,
+ grpcprometheus.UnaryServerInterceptor,
+ commandstatshandler.UnaryInterceptor,
+ grpcmwlogrus.UnaryServerInterceptor(logrusEntry,
+ grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
+ grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
+ sentryhandler.UnaryLogHandler,
+ cancelhandler.Unary, // Should be below LogHandler
+ auth.UnaryServerInterceptor(cfg.Auth),
+ lh.UnaryInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued
+ grpctracing.UnaryServerTracingInterceptor(),
+ cache.UnaryInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered),
+ // Panic handler should remain last so that application panics will be
+ // converted to errors and logged
+ panichandler.UnaryPanicHandler,
+ )
+
+ streamRPCServer.UseInterceptor(serverUnaryInterceptorChain)
+
lm := listenmux.New(transportCredentials)
lm.Register(backchannel.NewServerHandshaker(
logrusEntry,
registry,
[]grpc.DialOption{client.UnaryInterceptor()},
))
+ lm.Register(streamrpc.NewServerHandshaker(
+ streamRPCServer,
+ gitalylog.Default(),
+ ))
opts := []grpc.ServerOption{
grpc.Creds(lm),
- grpc.StreamInterceptor(grpcmw.ChainStreamServer(
- grpcmwtags.StreamServerInterceptor(ctxTagOpts...),
- grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler
- metadatahandler.StreamInterceptor,
- grpcprometheus.StreamServerInterceptor,
- commandstatshandler.StreamInterceptor,
- grpcmwlogrus.StreamServerInterceptor(logrusEntry,
- grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
- sentryhandler.StreamLogHandler,
- cancelhandler.Stream, // Should be below LogHandler
- auth.StreamServerInterceptor(cfg.Auth),
- lh.StreamInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued
- grpctracing.StreamServerTracingInterceptor(),
- cache.StreamInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered),
- // Panic handler should remain last so that application panics will be
- // converted to errors and logged
- panichandler.StreamPanicHandler,
- )),
- grpc.UnaryInterceptor(grpcmw.ChainUnaryServer(
- grpcmwtags.UnaryServerInterceptor(ctxTagOpts...),
- grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler
- metadatahandler.UnaryInterceptor,
- grpcprometheus.UnaryServerInterceptor,
- commandstatshandler.UnaryInterceptor,
- grpcmwlogrus.UnaryServerInterceptor(logrusEntry,
- grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat),
- grpcmwlogrus.WithMessageProducer(commandstatshandler.CommandStatsMessageProducer)),
- sentryhandler.UnaryLogHandler,
- cancelhandler.Unary, // Should be below LogHandler
- auth.UnaryServerInterceptor(cfg.Auth),
- lh.UnaryInterceptor(), // Should be below auth handler to prevent v2 hmac tokens from timing out while queued
- grpctracing.UnaryServerTracingInterceptor(),
- cache.UnaryInvalidator(cacheInvalidator, protoregistry.GitalyProtoPreregistered),
- // Panic handler should remain last so that application panics will be
- // converted to errors and logged
- panichandler.UnaryPanicHandler,
- )),
+ grpc.StreamInterceptor(serverStreamInterceptorChain),
+ grpc.UnaryInterceptor(serverUnaryInterceptorChain),
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 20 * time.Second,
PermitWithoutStream: true,
diff --git a/internal/gitaly/server/server_factory.go b/internal/gitaly/server/server_factory.go
index 2b2d9e73d..bae8fad63 100644
--- a/internal/gitaly/server/server_factory.go
+++ b/internal/gitaly/server/server_factory.go
@@ -15,6 +15,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/maintenance"
"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
)
@@ -133,26 +134,30 @@ func (s *GitalyServerFactory) GracefulStop() {
}
}
-// CreateExternal creates a new external gRPC server. The external servers are closed
+// CreateExternal creates a new external gRPC server and StreamRPC server. The external servers are closed
// before the internal servers when gracefully shutting down.
-func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, error) {
- server, err := New(secure, s.cfg, s.logger, s.registry, s.cacheInvalidator)
+func (s *GitalyServerFactory) CreateExternal(secure bool) (*grpc.Server, *streamrpc.Server, error) {
+ streamRPCServer := streamrpc.NewServer()
+ grpcServer, err := New(secure, s.cfg, s.logger, s.registry, s.cacheInvalidator, streamRPCServer)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- s.externalServers = append(s.externalServers, server)
- return server, nil
+ s.externalServers = append(s.externalServers, grpcServer)
+
+ return grpcServer, streamRPCServer, nil
}
-// CreateInternal creates a new internal gRPC server. Internal servers are closed
+// CreateInternal creates a new internal gRPC server and StreamRPC server. Internal servers are closed
// after the external ones when gracefully shutting down.
-func (s *GitalyServerFactory) CreateInternal() (*grpc.Server, error) {
- server, err := New(false, s.cfg, s.logger, s.registry, s.cacheInvalidator)
+func (s *GitalyServerFactory) CreateInternal() (*grpc.Server, *streamrpc.Server, error) {
+ streamRPCServer := streamrpc.NewServer()
+ grpcServer, err := New(false, s.cfg, s.logger, s.registry, s.cacheInvalidator, streamRPCServer)
if err != nil {
- return nil, err
+ return nil, nil, err
}
- s.internalServers = append(s.internalServers, server)
- return server, nil
+ s.internalServers = append(s.internalServers, grpcServer)
+
+ return grpcServer, streamRPCServer, nil
}
diff --git a/internal/gitaly/server/server_factory_test.go b/internal/gitaly/server/server_factory_test.go
index c851c32d7..34c9c46c3 100644
--- a/internal/gitaly/server/server_factory_test.go
+++ b/internal/gitaly/server/server_factory_test.go
@@ -1,10 +1,14 @@
package server
import (
+ "bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
+ "io"
+ "io/ioutil"
+ "math/rand"
"net"
"os"
"testing"
@@ -16,8 +20,11 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap/starter"
"gitlab.com/gitlab-org/gitaly/v14/internal/cache"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/teststream"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@@ -30,71 +37,24 @@ func TestGitalyServerFactory(t *testing.T) {
ctx, cancel := testhelper.Context()
defer cancel()
- checkHealth := func(t *testing.T, sf *GitalyServerFactory, schema, addr string) healthpb.HealthClient {
- t.Helper()
-
- var cc *grpc.ClientConn
- if schema == starter.TLS {
- listener, err := net.Listen(starter.TCP, addr)
- require.NoError(t, err)
- t.Cleanup(func() { listener.Close() })
-
- srv, err := sf.CreateExternal(true)
- require.NoError(t, err)
- healthpb.RegisterHealthServer(srv, health.NewServer())
- go srv.Serve(listener)
-
- certPool, err := x509.SystemCertPool()
- require.NoError(t, err)
-
- pem := testhelper.MustReadFile(t, sf.cfg.TLS.CertPath)
- require.True(t, certPool.AppendCertsFromPEM(pem))
-
- creds := credentials.NewTLS(&tls.Config{
- RootCAs: certPool,
- MinVersion: tls.VersionTLS12,
- })
-
- cc, err = grpc.DialContext(ctx, listener.Addr().String(), grpc.WithTransportCredentials(creds))
- require.NoError(t, err)
- } else {
- listener, err := net.Listen(schema, addr)
- require.NoError(t, err)
- t.Cleanup(func() { listener.Close() })
-
- srv, err := sf.CreateExternal(false)
- require.NoError(t, err)
- healthpb.RegisterHealthServer(srv, health.NewServer())
- go srv.Serve(listener)
-
- endpoint, err := starter.ComposeEndpoint(schema, listener.Addr().String())
- require.NoError(t, err)
-
- cc, err = client.Dial(endpoint, nil)
- require.NoError(t, err)
- }
-
- t.Cleanup(func() { cc.Close() })
-
- healthClient := healthpb.NewHealthClient(cc)
+ t.Run("insecure over TCP", func(t *testing.T) {
+ cfg, repo, _ := testcfg.BuildWithRepo(t)
+ sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
- resp, err := healthClient.Check(ctx, &healthpb.HealthCheckRequest{})
- require.NoError(t, err)
- require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status)
- return healthClient
- }
+ check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0")
+ })
- t.Run("insecure", func(t *testing.T) {
- cfg := testcfg.Build(t)
+ t.Run("insecure over Unix Socket", func(t *testing.T) {
+ cfg, repo, _ := testcfg.BuildWithRepo(t)
sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
- checkHealth(t, sf, starter.TCP, "localhost:0")
+ check(t, ctx, sf, cfg, repo, starter.Unix, testhelper.GetTemporaryGitalySocketFileName(t))
})
t.Run("secure", func(t *testing.T) {
certFile, keyFile := testhelper.GenerateCerts(t)
- cfg := testcfg.Build(t, testcfg.WithBase(config.Cfg{TLS: config.TLS{
+ cfg, repo, _ := testcfg.BuildWithRepo(t, testcfg.WithBase(config.Cfg{TLS: config.TLS{
CertPath: certFile,
KeyPath: keyFile,
}}))
@@ -102,20 +62,20 @@ func TestGitalyServerFactory(t *testing.T) {
sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
t.Cleanup(sf.Stop)
- checkHealth(t, sf, starter.TLS, "localhost:0")
+ check(t, ctx, sf, cfg, repo, starter.TLS, "localhost:0")
})
t.Run("all services must be stopped", func(t *testing.T) {
- cfg := testcfg.Build(t)
+ cfg, repo, _ := testcfg.BuildWithRepo(t)
sf := NewGitalyServerFactory(cfg, testhelper.DiscardTestEntry(t), backchannel.NewRegistry(), cache.New(cfg, config.NewLocator(cfg)))
t.Cleanup(sf.Stop)
- tcpHealthClient := checkHealth(t, sf, starter.TCP, "localhost:0")
+ tcpHealthClient := check(t, ctx, sf, cfg, repo, starter.TCP, "localhost:0")
socket := testhelper.GetTemporaryGitalySocketFileName(t)
t.Cleanup(func() { require.NoError(t, os.RemoveAll(socket)) })
- socketHealthClient := checkHealth(t, sf, starter.Unix, socket)
+ socketHealthClient := check(t, ctx, sf, cfg, repo, starter.Unix, socket)
sf.GracefulStop() // stops all started servers(listeners)
@@ -185,7 +145,7 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
}{
{
createServer: func() *grpc.Server {
- server, err := sf.CreateInternal()
+ server, _, err := sf.CreateInternal()
require.NoError(t, err)
return server
},
@@ -195,7 +155,7 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
},
{
createServer: func() *grpc.Server {
- server, err := sf.CreateExternal(false)
+ server, _, err := sf.CreateExternal(false)
require.NoError(t, err)
return server
},
@@ -287,3 +247,117 @@ func TestGitalyServerFactory_closeOrder(t *testing.T) {
// wait until the graceful shutdown completes
<-shutdownCompeleted
}
+
+func check(t *testing.T, ctx context.Context, sf *GitalyServerFactory, cfg config.Cfg, repo *gitalypb.Repository, schema, addr string) healthpb.HealthClient {
+ t.Helper()
+
+ var grpcConn *grpc.ClientConn
+ var streamRPCDial streamrpc.DialFunc
+
+ if schema == starter.TLS {
+ listener, err := net.Listen(starter.TCP, addr)
+ require.NoError(t, err)
+ t.Cleanup(func() { listener.Close() })
+
+ grpcSrv, srpcSrv, err := sf.CreateExternal(true)
+ require.NoError(t, err)
+ healthpb.RegisterHealthServer(grpcSrv, health.NewServer())
+ registerStreamRPCServers(t, srpcSrv, cfg)
+ go grpcSrv.Serve(listener)
+
+ certPool, err := x509.SystemCertPool()
+ require.NoError(t, err)
+
+ pem := testhelper.MustReadFile(t, sf.cfg.TLS.CertPath)
+ require.True(t, certPool.AppendCertsFromPEM(pem))
+
+ tlsConf := &tls.Config{
+ RootCAs: certPool,
+ MinVersion: tls.VersionTLS12,
+ }
+ creds := credentials.NewTLS(tlsConf)
+
+ streamRPCDial = streamrpc.DialTLS(listener.Addr().String(), tlsConf)
+ grpcConn, err = grpc.DialContext(ctx, listener.Addr().String(), grpc.WithTransportCredentials(creds))
+ require.NoError(t, err)
+ } else {
+ listener, err := net.Listen(schema, addr)
+ require.NoError(t, err)
+ t.Cleanup(func() { listener.Close() })
+
+ grpcSrv, srpcSrv, err := sf.CreateExternal(false)
+ require.NoError(t, err)
+ healthpb.RegisterHealthServer(grpcSrv, health.NewServer())
+ registerStreamRPCServers(t, srpcSrv, cfg)
+ go grpcSrv.Serve(listener)
+
+ endpoint, err := starter.ComposeEndpoint(schema, listener.Addr().String())
+ require.NoError(t, err)
+
+ streamRPCDial = streamrpc.DialNet(endpoint)
+ grpcConn, err = client.Dial(endpoint, nil)
+ require.NoError(t, err)
+ }
+
+ // Make a healthcheck gRPC call
+ t.Cleanup(func() { grpcConn.Close() })
+ healthClient := healthpb.NewHealthClient(grpcConn)
+
+ resp, err := healthClient.Check(ctx, &healthpb.HealthCheckRequest{})
+ require.NoError(t, err)
+ require.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status)
+
+ // Make a streamRPC call
+ in, out, err := checkStreamRPC(t, streamRPCDial, repo)
+ require.NoError(t, err)
+ require.Equal(t, in, out, "byte stream works")
+
+ return healthClient
+}
+
+func registerStreamRPCServers(t *testing.T, srv *streamrpc.Server, cfg config.Cfg) {
+ gitalypb.RegisterTestStreamServiceServer(srv, teststream.NewServer(config.NewLocator(cfg)))
+}
+
+func checkStreamRPC(t *testing.T, dial streamrpc.DialFunc, repo *gitalypb.Repository, opts ...streamrpc.CallOption) ([]byte, []byte, error) {
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ const size = 1024 * 1024
+
+ in := make([]byte, size)
+ _, err := rand.Read(in)
+ require.NoError(t, err)
+
+ var out []byte
+ require.NotEqual(t, in, out)
+
+ err = streamrpc.Call(
+ ctx,
+ dial,
+ "/gitaly.TestStreamService/TestStream",
+ &gitalypb.TestStreamRequest{
+ Repository: repo,
+ Size: size,
+ },
+ func(c net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ out, err = ioutil.ReadAll(c)
+ errC <- err
+ }()
+
+ if _, err := io.Copy(c, bytes.NewReader(in)); err != nil {
+ return err
+ }
+ if err := <-errC; err != nil {
+ return err
+ }
+
+ return nil
+ },
+ opts...,
+ )
+ return in, out, err
+}
diff --git a/internal/gitaly/service/setup/register.go b/internal/gitaly/service/setup/register.go
index 31859d843..827906f3a 100644
--- a/internal/gitaly/service/setup/register.go
+++ b/internal/gitaly/service/setup/register.go
@@ -21,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/server"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/smarthttp"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/ssh"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/teststream"
"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service/wiki"
"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
"google.golang.org/grpc"
@@ -51,8 +52,8 @@ var (
)
)
-// RegisterAll will register all the known gRPC services on the provided gRPC service instance.
-func RegisterAll(srv *grpc.Server, deps *service.Dependencies) {
+// RegisterAll will register all the known gRPC + StreamRPC services
+func RegisterAll(srv grpc.ServiceRegistrar, deps *service.Dependencies) {
gitalypb.RegisterBlobServiceServer(srv, blob.NewServer(
deps.GetCfg(),
deps.GetLocator(),
@@ -143,6 +144,13 @@ func RegisterAll(srv *grpc.Server, deps *service.Dependencies) {
gitalypb.RegisterInternalGitalyServer(srv, internalgitaly.NewServer(deps.GetCfg().Storages))
healthpb.RegisterHealthServer(srv, health.NewServer())
- reflection.Register(srv)
- grpcprometheus.Register(srv)
+
+ gitalypb.RegisterTestStreamServiceServer(srv, teststream.NewServer(
+ deps.GetLocator(),
+ ))
+
+ if gs, ok := srv.(*grpc.Server); ok {
+ reflection.Register(gs)
+ grpcprometheus.Register(gs)
+ }
}
diff --git a/internal/gitaly/service/teststream/server.go b/internal/gitaly/service/teststream/server.go
new file mode 100644
index 000000000..51256b226
--- /dev/null
+++ b/internal/gitaly/service/teststream/server.go
@@ -0,0 +1,37 @@
+package teststream
+
+import (
+ "context"
+ "io"
+
+ "gitlab.com/gitlab-org/gitaly/v14/internal/storage"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ emptypb "google.golang.org/protobuf/types/known/emptypb"
+)
+
+type server struct {
+ gitalypb.UnimplementedTestStreamServiceServer
+ locator storage.Locator
+}
+
+func (s *server) TestStream(ctx context.Context, request *gitalypb.TestStreamRequest) (*emptypb.Empty, error) {
+ if _, err := s.locator.GetRepoPath(request.Repository); err != nil {
+ return nil, err
+ }
+
+ c, err := streamrpc.AcceptConnection(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = io.CopyN(c, c, request.Size)
+ return nil, err
+}
+
+// NewServer creates a new instance of a grpc ServerServiceServer
+func NewServer(locator storage.Locator) gitalypb.TestStreamServiceServer {
+ return &server{
+ locator: locator,
+ }
+}
diff --git a/internal/gitaly/service/teststream/server_test.go b/internal/gitaly/service/teststream/server_test.go
new file mode 100644
index 000000000..8051cf787
--- /dev/null
+++ b/internal/gitaly/service/teststream/server_test.go
@@ -0,0 +1,115 @@
+package teststream
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/service"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testserver"
+ "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
+ "google.golang.org/grpc"
+)
+
+func TestTestStreamPingPong(t *testing.T) {
+ const size = 1024 * 1024
+
+ addr, repo := runGitalyServer(t)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ in := make([]byte, size)
+ _, err := rand.Read(in)
+ require.NoError(t, err)
+
+ var out []byte
+ require.NotEqual(t, in, out)
+ require.NoError(t, streamrpc.Call(
+ ctx,
+ streamrpc.DialNet(addr),
+ "/gitaly.TestStreamService/TestStream",
+ &gitalypb.TestStreamRequest{
+ Repository: repo,
+ Size: size,
+ },
+ func(c net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ out, err = ioutil.ReadAll(c)
+ errC <- err
+ }()
+
+ if _, err := io.Copy(c, bytes.NewReader(in)); err != nil {
+ return err
+ }
+ if err := <-errC; err != nil {
+ return err
+ }
+
+ return nil
+ },
+ ))
+
+ require.Equal(t, in, out, "byte stream works")
+}
+
+func TestTestStreamPingPongWithInvalidRepo(t *testing.T) {
+ addr, repo := runGitalyServer(t)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ err := streamrpc.Call(
+ ctx,
+ streamrpc.DialNet(addr),
+ "/gitaly.TestStreamService/TestStream",
+ &gitalypb.TestStreamRequest{
+ Repository: &gitalypb.Repository{
+ StorageName: repo.StorageName,
+ RelativePath: "@hashed/94/00/notexist.git",
+ GlRepository: repo.GlRepository,
+ GlProjectPath: repo.GlProjectPath,
+ },
+ Size: 1024 * 1024,
+ },
+ func(c net.Conn) error {
+ panic("Should not reach here")
+ },
+ )
+
+ require.Error(t, err)
+ require.Contains(
+ t, err.Error(),
+ "rpc error: code = NotFound desc = GetRepoPath: not a git repository",
+ )
+}
+
+func runGitalyServer(t *testing.T) (string, *gitalypb.Repository) {
+ t.Helper()
+ testhelper.Configure()
+
+ cfg, repo, _ := testcfg.BuildWithRepo(t)
+
+ addr := testserver.RunGitalyServer(
+ t, cfg, nil,
+ func(srv grpc.ServiceRegistrar, deps *service.Dependencies) {
+ gitalypb.RegisterTestStreamServiceServer(srv, NewServer(deps.GetLocator()))
+ },
+ // TODO: At the moment, stream RPC doesn't work well with Praefect,
+ // hence we have to disable Praefect. We can remove this option after
+ // https://gitlab.com/gitlab-com/gl-infra/scalability/-/issues/1127 is
+ // done
+ testserver.WithDisablePraefect(),
+ )
+
+ return addr, repo
+}
diff --git a/internal/streamrpc/handshaker.go b/internal/streamrpc/handshaker.go
new file mode 100644
index 000000000..54548af86
--- /dev/null
+++ b/internal/streamrpc/handshaker.go
@@ -0,0 +1,137 @@
+package streamrpc
+
+import (
+ "crypto/tls"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/v14/internal/bootstrap/starter"
+ "google.golang.org/grpc/credentials"
+)
+
+// The magic bytes used for classification by listenmux
+var magicBytes = []byte("streamrpc00")
+
+// DialNet lets Call initiate unencrypted connections. They tend to be used
+// with Gitaly's listenmux multiplexer only. After the connection is
+// established, streamrpc's 11-byte magic bytes are written into the wire.
+// Listemmux peeks into these magic bytes and redirects the request to
+// StreamRPC server.
+// Please visit internal/listenmux/mux.go for more information
+func DialNet(address string) DialFunc {
+ return func(t time.Duration) (net.Conn, error) {
+ endpoint, err := starter.ParseEndpoint(address)
+ if err != nil {
+ return nil, err
+ }
+
+ // Dial-only deadline
+ deadline := time.Now().Add(t)
+
+ dialer := &net.Dialer{Deadline: deadline}
+ conn, err := dialer.Dial(endpoint.Name, endpoint.Addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if err = conn.SetDeadline(deadline); err != nil {
+ return nil, err
+ }
+ // Write the magic bytes on the connection so the server knows we're
+ // about to initiate a multiplexing session.
+ if _, err := conn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write backchannel magic bytes: %w", err)
+ }
+
+ // Reset deadline of tls connection for later stages
+ if err = conn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return conn, nil
+ }
+}
+
+// DialTLS lets Call initiate TLS connections. Similar to DialNet, the
+// connections are used for listenmux multiplexer. There are 3 steps involving:
+// - TCP handshake
+// - TLS handshake
+// - Write streamrpc magic bytes
+func DialTLS(address string, cfg *tls.Config) DialFunc {
+ return func(t time.Duration) (net.Conn, error) {
+ // Dial-only deadline
+ deadline := time.Now().Add(t)
+
+ dialer := &net.Dialer{Deadline: deadline}
+ tlsConn, err := tls.DialWithDialer(dialer, "tcp", address, cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ err = tlsConn.SetDeadline(deadline)
+ if err != nil {
+ return nil, err
+ }
+ // Write the magic bytes on the connection so the server knows we're
+ // about to initiate a multiplexing session.
+ if _, err := tlsConn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("streamrpc client: write backchannel magic bytes: %w", err)
+ }
+
+ // Reset deadline of tls connection for later stages
+ if err = tlsConn.SetDeadline(time.Time{}); err != nil {
+ return nil, err
+ }
+
+ return tlsConn, nil
+ }
+}
+
+// ServerHandshaker implements the server side handshake of the multiplexed connection.
+type ServerHandshaker struct {
+ server *Server
+ logger logrus.FieldLogger
+}
+
+// NewServerHandshaker returns an implementation of streamrpc server
+// handshaker. The provided TransportCredentials are handshaked prior to
+// initializing the multiplexing session. This handshaker Gitaly's unary server
+// interceptors into the interceptor chain of input StreamRPC server.
+func NewServerHandshaker(server *Server, logger logrus.FieldLogger) *ServerHandshaker {
+ return &ServerHandshaker{
+ server: server,
+ logger: logger,
+ }
+}
+
+// Magic is used by listenmux to retrieve the magic string for
+// streamrpc connections.
+func (s *ServerHandshaker) Magic() string { return string(magicBytes) }
+
+// Handshake "steals" the request from Gitaly's main gRPC server during
+// connection handshaking phase. Listenmux depends on the first 11-byte magic
+// bytes sent by the client, and invoke StreamRPC handshaker accordingly. The
+// request is then handled by stream RPC server, and skipped by Gitaly gRPC
+// server.
+func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error) {
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ return nil, nil, err
+ }
+
+ go func() {
+ if err := s.server.Handle(conn); err != nil {
+ s.logger.WithError(err).Error("streamrpc: handle call")
+ }
+ }()
+ // At this point, the connection is already closed. If the
+ // TransportCredentials continues its code path, gRPC constructs a HTTP2
+ // server transport to handle the connection. Eventually, it fails and logs
+ // several warnings and errors even though the stream RPC call is
+ // successful.
+ // Fortunately, gRPC has credentials.ErrConnDispatched, indicating that the
+ // connection is already dispatched out of gRPC. gRPC should leave it alone
+ // and exit in peace.
+ return nil, nil, credentials.ErrConnDispatched
+}
diff --git a/internal/streamrpc/rpc_test.go b/internal/streamrpc/rpc_test.go
index 850838465..c93448036 100644
--- a/internal/streamrpc/rpc_test.go
+++ b/internal/streamrpc/rpc_test.go
@@ -185,18 +185,20 @@ func TestCall_serverMiddleware(t *testing.T) {
)
interceptorDone := make(chan struct{})
+ server := NewServer()
+ server.UseInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ defer close(interceptorDone)
+ middlewareMethod = info.FullMethod
+ receivedField = req.(*testpb.StreamRequest).StringField
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ receivedValues = md[testKey]
+ }
+ return handler(ctx, req)
+ })
dial := startServer(
t,
- NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
- defer close(interceptorDone)
- middlewareMethod = info.FullMethod
- receivedField = req.(*testpb.StreamRequest).StringField
- if md, ok := metadata.FromIncomingContext(ctx); ok {
- receivedValues = md[testKey]
- }
- return handler(ctx, req)
- })),
+ server,
func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
_, err := AcceptConnection(ctx)
return nil, err
@@ -219,15 +221,11 @@ func TestCall_serverMiddleware(t *testing.T) {
}
func TestCall_serverMiddlewareReject(t *testing.T) {
- dial := startServer(
- t,
- NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
- return nil, errors.New("middleware says no")
- })),
- func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
- panic("never reached")
- },
- )
+ server := NewServer()
+ server.UseInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ return nil, errors.New("middleware says no")
+ })
+ dial := startServer(t, server, func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { panic("never reached") })
err := Call(
context.Background(),
diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go
index 502764fd4..2c4771af5 100644
--- a/internal/streamrpc/server.go
+++ b/internal/streamrpc/server.go
@@ -30,11 +30,6 @@ type method struct {
// options to NewServer.
type ServerOption func(*Server)
-// WithServerInterceptor adds a unary gRPC server interceptor.
-func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption {
- return func(s *Server) { s.interceptor = interceptor }
-}
-
// NewServer returns a new StreamRPC server. You can pass the result to
// grpc-go RegisterFooServer functions.
func NewServer(opts ...ServerOption) *Server {
@@ -60,6 +55,12 @@ func (s *Server) RegisterService(sd *grpc.ServiceDesc, impl interface{}) {
}
}
+// UseInterceptor adds a unary gRPC server interceptor for the StreamRPC
+// server to use.
+func (s *Server) UseInterceptor(interceptor grpc.UnaryServerInterceptor) {
+ s.interceptor = interceptor
+}
+
// Handle handles an incoming network connection with the StreamRPC
// protocol. It is intended to be called from a net.Listener.Accept loop
// (or something equivalent).
diff --git a/proto/go/gitalypb/protolist.go b/proto/go/gitalypb/protolist.go
index a15916f70..9d26e24a7 100644
--- a/proto/go/gitalypb/protolist.go
+++ b/proto/go/gitalypb/protolist.go
@@ -23,6 +23,7 @@ var GitalyProtos = []string{
"shared.proto",
"smarthttp.proto",
"ssh.proto",
+ "teststream.proto",
"transaction.proto",
"wiki.proto",
}
diff --git a/proto/go/gitalypb/teststream.pb.go b/proto/go/gitalypb/teststream.pb.go
new file mode 100644
index 000000000..d6a68abce
--- /dev/null
+++ b/proto/go/gitalypb/teststream.pb.go
@@ -0,0 +1,173 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.26.0
+// protoc v3.17.3
+// source: teststream.proto
+
+package gitalypb
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ emptypb "google.golang.org/protobuf/types/known/emptypb"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type TestStreamRequest struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Repository *Repository `protobuf:"bytes,1,opt,name=repository,proto3" json:"repository,omitempty"`
+ Size int64 `protobuf:"varint,2,opt,name=size,proto3" json:"size,omitempty"`
+}
+
+func (x *TestStreamRequest) Reset() {
+ *x = TestStreamRequest{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_teststream_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *TestStreamRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*TestStreamRequest) ProtoMessage() {}
+
+func (x *TestStreamRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_teststream_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use TestStreamRequest.ProtoReflect.Descriptor instead.
+func (*TestStreamRequest) Descriptor() ([]byte, []int) {
+ return file_teststream_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *TestStreamRequest) GetRepository() *Repository {
+ if x != nil {
+ return x.Repository
+ }
+ return nil
+}
+
+func (x *TestStreamRequest) GetSize() int64 {
+ if x != nil {
+ return x.Size
+ }
+ return 0
+}
+
+var File_teststream_proto protoreflect.FileDescriptor
+
+var file_teststream_proto_rawDesc = []byte{
+ 0x0a, 0x10, 0x74, 0x65, 0x73, 0x74, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x2e, 0x70, 0x72, 0x6f,
+ 0x74, 0x6f, 0x12, 0x06, 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x1a, 0x0a, 0x6c, 0x69, 0x6e, 0x74,
+ 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x0c, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, 0x2e, 0x70,
+ 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f,
+ 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x22, 0x61, 0x0a, 0x11, 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52,
+ 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69,
+ 0x74, 0x6f, 0x72, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x67, 0x69, 0x74,
+ 0x61, 0x6c, 0x79, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, 0x79, 0x42, 0x04,
+ 0x98, 0xc6, 0x2c, 0x01, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, 0x79,
+ 0x12, 0x12, 0x0a, 0x04, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04,
+ 0x73, 0x69, 0x7a, 0x65, 0x32, 0x5c, 0x0a, 0x11, 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x65,
+ 0x61, 0x6d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x47, 0x0a, 0x0a, 0x54, 0x65, 0x73,
+ 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x19, 0x2e, 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79,
+ 0x2e, 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
+ 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x06, 0xfa, 0x97, 0x28, 0x02,
+ 0x08, 0x02, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
+ 0x2f, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2d, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x69, 0x74, 0x61,
+ 0x6c, 0x79, 0x2f, 0x76, 0x31, 0x34, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f,
+ 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_teststream_proto_rawDescOnce sync.Once
+ file_teststream_proto_rawDescData = file_teststream_proto_rawDesc
+)
+
+func file_teststream_proto_rawDescGZIP() []byte {
+ file_teststream_proto_rawDescOnce.Do(func() {
+ file_teststream_proto_rawDescData = protoimpl.X.CompressGZIP(file_teststream_proto_rawDescData)
+ })
+ return file_teststream_proto_rawDescData
+}
+
+var file_teststream_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_teststream_proto_goTypes = []interface{}{
+ (*TestStreamRequest)(nil), // 0: gitaly.TestStreamRequest
+ (*Repository)(nil), // 1: gitaly.Repository
+ (*emptypb.Empty)(nil), // 2: google.protobuf.Empty
+}
+var file_teststream_proto_depIdxs = []int32{
+ 1, // 0: gitaly.TestStreamRequest.repository:type_name -> gitaly.Repository
+ 0, // 1: gitaly.TestStreamService.TestStream:input_type -> gitaly.TestStreamRequest
+ 2, // 2: gitaly.TestStreamService.TestStream:output_type -> google.protobuf.Empty
+ 2, // [2:3] is the sub-list for method output_type
+ 1, // [1:2] is the sub-list for method input_type
+ 1, // [1:1] is the sub-list for extension type_name
+ 1, // [1:1] is the sub-list for extension extendee
+ 0, // [0:1] is the sub-list for field type_name
+}
+
+func init() { file_teststream_proto_init() }
+func file_teststream_proto_init() {
+ if File_teststream_proto != nil {
+ return
+ }
+ file_lint_proto_init()
+ file_shared_proto_init()
+ if !protoimpl.UnsafeEnabled {
+ file_teststream_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*TestStreamRequest); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_teststream_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 1,
+ },
+ GoTypes: file_teststream_proto_goTypes,
+ DependencyIndexes: file_teststream_proto_depIdxs,
+ MessageInfos: file_teststream_proto_msgTypes,
+ }.Build()
+ File_teststream_proto = out.File
+ file_teststream_proto_rawDesc = nil
+ file_teststream_proto_goTypes = nil
+ file_teststream_proto_depIdxs = nil
+}
diff --git a/proto/go/gitalypb/teststream_grpc.pb.go b/proto/go/gitalypb/teststream_grpc.pb.go
new file mode 100644
index 000000000..e8b4c5fcb
--- /dev/null
+++ b/proto/go/gitalypb/teststream_grpc.pb.go
@@ -0,0 +1,102 @@
+// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
+
+package gitalypb
+
+import (
+ context "context"
+ grpc "google.golang.org/grpc"
+ codes "google.golang.org/grpc/codes"
+ status "google.golang.org/grpc/status"
+ emptypb "google.golang.org/protobuf/types/known/emptypb"
+)
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+// Requires gRPC-Go v1.32.0 or later.
+const _ = grpc.SupportPackageIsVersion7
+
+// TestStreamServiceClient is the client API for TestStreamService service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
+type TestStreamServiceClient interface {
+ TestStream(ctx context.Context, in *TestStreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
+}
+
+type testStreamServiceClient struct {
+ cc grpc.ClientConnInterface
+}
+
+func NewTestStreamServiceClient(cc grpc.ClientConnInterface) TestStreamServiceClient {
+ return &testStreamServiceClient{cc}
+}
+
+func (c *testStreamServiceClient) TestStream(ctx context.Context, in *TestStreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
+ out := new(emptypb.Empty)
+ err := c.cc.Invoke(ctx, "/gitaly.TestStreamService/TestStream", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// TestStreamServiceServer is the server API for TestStreamService service.
+// All implementations must embed UnimplementedTestStreamServiceServer
+// for forward compatibility
+type TestStreamServiceServer interface {
+ TestStream(context.Context, *TestStreamRequest) (*emptypb.Empty, error)
+ mustEmbedUnimplementedTestStreamServiceServer()
+}
+
+// UnimplementedTestStreamServiceServer must be embedded to have forward compatible implementations.
+type UnimplementedTestStreamServiceServer struct {
+}
+
+func (UnimplementedTestStreamServiceServer) TestStream(context.Context, *TestStreamRequest) (*emptypb.Empty, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method TestStream not implemented")
+}
+func (UnimplementedTestStreamServiceServer) mustEmbedUnimplementedTestStreamServiceServer() {}
+
+// UnsafeTestStreamServiceServer may be embedded to opt out of forward compatibility for this service.
+// Use of this interface is not recommended, as added methods to TestStreamServiceServer will
+// result in compilation errors.
+type UnsafeTestStreamServiceServer interface {
+ mustEmbedUnimplementedTestStreamServiceServer()
+}
+
+func RegisterTestStreamServiceServer(s grpc.ServiceRegistrar, srv TestStreamServiceServer) {
+ s.RegisterService(&TestStreamService_ServiceDesc, srv)
+}
+
+func _TestStreamService_TestStream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(TestStreamRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestStreamServiceServer).TestStream(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/gitaly.TestStreamService/TestStream",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestStreamServiceServer).TestStream(ctx, req.(*TestStreamRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+// TestStreamService_ServiceDesc is the grpc.ServiceDesc for TestStreamService service.
+// It's only intended for direct use with grpc.RegisterService,
+// and not to be introspected or modified (even as a copy)
+var TestStreamService_ServiceDesc = grpc.ServiceDesc{
+ ServiceName: "gitaly.TestStreamService",
+ HandlerType: (*TestStreamServiceServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "TestStream",
+ Handler: _TestStreamService_TestStream_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "teststream.proto",
+}
diff --git a/proto/teststream.proto b/proto/teststream.proto
new file mode 100644
index 000000000..734047887
--- /dev/null
+++ b/proto/teststream.proto
@@ -0,0 +1,22 @@
+syntax = "proto3";
+
+package gitaly;
+
+option go_package = "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb";
+
+import "lint.proto";
+import "shared.proto";
+import "google/protobuf/empty.proto";
+
+service TestStreamService {
+ rpc TestStream(TestStreamRequest) returns (google.protobuf.Empty) {
+ option (op_type) = {
+ op: ACCESSOR
+ };
+ }
+}
+
+message TestStreamRequest {
+ Repository repository = 1 [(target_repository)=true];
+ int64 size = 2;
+}
diff --git a/ruby/proto/gitaly.rb b/ruby/proto/gitaly.rb
index 9c80cea63..9cfff81cf 100644
--- a/ruby/proto/gitaly.rb
+++ b/ruby/proto/gitaly.rb
@@ -37,6 +37,8 @@ require 'gitaly/smarthttp_services_pb'
require 'gitaly/ssh_services_pb'
+require 'gitaly/teststream_services_pb'
+
require 'gitaly/transaction_services_pb'
require 'gitaly/wiki_services_pb'
diff --git a/ruby/proto/gitaly/teststream_pb.rb b/ruby/proto/gitaly/teststream_pb.rb
new file mode 100644
index 000000000..d75050f04
--- /dev/null
+++ b/ruby/proto/gitaly/teststream_pb.rb
@@ -0,0 +1,20 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: teststream.proto
+
+require 'google/protobuf'
+
+require 'lint_pb'
+require 'shared_pb'
+require 'google/protobuf/empty_pb'
+Google::Protobuf::DescriptorPool.generated_pool.build do
+ add_file("teststream.proto", :syntax => :proto3) do
+ add_message "gitaly.TestStreamRequest" do
+ optional :repository, :message, 1, "gitaly.Repository"
+ optional :size, :int64, 2
+ end
+ end
+end
+
+module Gitaly
+ TestStreamRequest = ::Google::Protobuf::DescriptorPool.generated_pool.lookup("gitaly.TestStreamRequest").msgclass
+end
diff --git a/ruby/proto/gitaly/teststream_services_pb.rb b/ruby/proto/gitaly/teststream_services_pb.rb
new file mode 100644
index 000000000..c74774b03
--- /dev/null
+++ b/ruby/proto/gitaly/teststream_services_pb.rb
@@ -0,0 +1,22 @@
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# Source: teststream.proto for package 'gitaly'
+
+require 'grpc'
+require 'teststream_pb'
+
+module Gitaly
+ module TestStreamService
+ class Service
+
+ include ::GRPC::GenericService
+
+ self.marshal_class_method = :encode
+ self.unmarshal_class_method = :decode
+ self.service_name = 'gitaly.TestStreamService'
+
+ rpc :TestStream, ::Gitaly::TestStreamRequest, ::Google::Protobuf::Empty
+ end
+
+ Stub = Service.rpc_stub_class
+ end
+end