Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--changelogs/unreleased/jc-mult-node-write.yml5
-rw-r--r--internal/praefect/coordinator.go124
-rw-r--r--internal/praefect/coordinator_test.go32
-rw-r--r--internal/praefect/grpc-proxy/proxy/director.go30
-rw-r--r--internal/praefect/grpc-proxy/proxy/examples_test.go7
-rw-r--r--internal/praefect/grpc-proxy/proxy/handler.go185
-rw-r--r--internal/praefect/grpc-proxy/proxy/handler_test.go15
-rw-r--r--internal/praefect/grpc-proxy/proxy/peeker.go28
-rw-r--r--internal/praefect/grpc-proxy/proxy/peeker_test.go11
-rw-r--r--internal/praefect/server_test.go212
10 files changed, 497 insertions, 152 deletions
diff --git a/changelogs/unreleased/jc-mult-node-write.yml b/changelogs/unreleased/jc-mult-node-write.yml
new file mode 100644
index 000000000..0bf485bb3
--- /dev/null
+++ b/changelogs/unreleased/jc-mult-node-write.yml
@@ -0,0 +1,5 @@
+---
+title: Multi node write
+merge_request: 2208
+author:
+type: added
diff --git a/internal/praefect/coordinator.go b/internal/praefect/coordinator.go
index 953918c41..698060961 100644
--- a/internal/praefect/coordinator.go
+++ b/internal/praefect/coordinator.go
@@ -72,7 +72,6 @@ type grpcCall struct {
fullMethodName string
methodInfo protoregistry.MethodInfo
msg proto.Message
- peeker proxy.StreamModifier
targetRepo *gitalypb.Repository
}
@@ -145,31 +144,32 @@ func (c *Coordinator) accessorStreamParameters(ctx context.Context, call grpcCal
}
storage := node.GetStorage()
- if err := c.rewriteStorageForRepositoryMessage(call.methodInfo, call.msg, call.peeker, storage); err != nil {
+ b, err := rewrittenRepositoryMessage(call.methodInfo, call.msg, storage)
+ if err != nil {
return nil, fmt.Errorf("accessor call: rewrite storage: %w", err)
}
metrics.ReadDistribution.WithLabelValues(virtualStorage, storage).Inc()
- return proxy.NewStreamParameters(ctx, node.GetConnection(), nil, nil), nil
+ return proxy.NewStreamParameters(proxy.Destination{
+ Ctx: helper.IncomingToOutgoing(ctx),
+ Conn: node.GetConnection(),
+ Msg: b,
+ }, nil, nil, nil), nil
}
-func (c *Coordinator) injectTransaction(ctx context.Context, node nodes.Node) (context.Context, func(), error) {
- // We currently only handle single-node-transactions for the primary,
- // so we just blindly call this single node "primary".
- nodeName := "primary"
-
- transactionID, cancel, err := c.txMgr.RegisterTransaction(ctx, []string{nodeName})
+func (c *Coordinator) injectTransaction(ctx context.Context, transactionID uint64, storageName string) (context.Context, error) {
+ ctx, err := metadata.InjectTransaction(ctx, transactionID, storageName)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- ctx, err = metadata.InjectTransaction(ctx, transactionID, nodeName)
- if err != nil {
- return nil, nil, err
- }
+ return ctx, nil
+}
- return ctx, cancel, nil
+var transactionRPCs = map[string]struct{}{
+ "/gitaly.SmartHTTPService/PostReceivePack": {},
+ "/gitaly.SSHService/SSHReceivePack": {},
}
func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall, targetRepo *gitalypb.Repository) (*proxy.StreamParameters, error) {
@@ -184,7 +184,8 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall
return nil, helper.ErrPreconditionFailed(ReadOnlyStorageError(call.targetRepo.GetStorageName()))
}
- if err = c.rewriteStorageForRepositoryMessage(call.methodInfo, call.msg, call.peeker, shard.Primary.GetStorage()); err != nil {
+ primaryMessage, err := rewrittenRepositoryMessage(call.methodInfo, call.msg, shard.Primary.GetStorage())
+ if err != nil {
return nil, fmt.Errorf("mutator call: rewrite storage: %w", err)
}
@@ -195,18 +196,57 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall
var finalizers []func()
- if featureflag.IsEnabled(ctx, featureflag.ReferenceTransactions) {
- var transactionCleanup func()
- ctx, transactionCleanup, err = c.injectTransaction(ctx, shard.Primary)
+ primaryDest := proxy.Destination{
+ Ctx: helper.IncomingToOutgoing(ctx),
+ Conn: shard.Primary.GetConnection(),
+ Msg: primaryMessage,
+ }
+
+ var secondaryDests []proxy.Destination
+
+ if _, ok := transactionRPCs[call.fullMethodName]; ok && featureflag.IsEnabled(ctx, featureflag.ReferenceTransactions) {
+ var nodeStorages []string
+
+ for _, node := range append(shard.Secondaries, shard.Primary) {
+ nodeStorages = append(nodeStorages, node.GetStorage())
+ }
+
+ transactionID, transactionCleanup, err := c.txMgr.RegisterTransaction(ctx, nodeStorages)
+ if err != nil {
+ return nil, fmt.Errorf("registering transactions: %w", err)
+ }
+
+ injectedCtx, err := c.injectTransaction(ctx, transactionID, shard.Primary.GetStorage())
if err != nil {
return nil, err
}
+
+ primaryDest.Ctx = helper.IncomingToOutgoing(injectedCtx)
+
finalizers = append(finalizers, transactionCleanup)
- }
- finalizers = append(finalizers, c.createReplicaJobs(ctx, virtualStorage, call.targetRepo, shard.Primary, shard.Secondaries, change, params))
+ for _, secondary := range shard.Secondaries {
+ secondaryMsg, err := rewrittenRepositoryMessage(call.methodInfo, call.msg, secondary.GetStorage())
+ if err != nil {
+ return nil, err
+ }
- return proxy.NewStreamParameters(ctx, shard.Primary.GetConnection(), func() {
+ injectedCtx, err := c.injectTransaction(ctx, transactionID, secondary.GetStorage())
+ if err != nil {
+ return nil, err
+ }
+
+ secondaryDests = append(secondaryDests, proxy.Destination{
+ Ctx: helper.IncomingToOutgoing(injectedCtx),
+ Conn: secondary.GetConnection(),
+ Msg: secondaryMsg,
+ })
+ }
+ } else {
+ finalizers = append(finalizers, c.createReplicaJobs(ctx, virtualStorage, call.targetRepo, shard.Primary, shard.Secondaries, change, params))
+ }
+
+ return proxy.NewStreamParameters(primaryDest, secondaryDests, func() {
for _, finalizer := range finalizers {
finalizer()
}
@@ -214,7 +254,7 @@ func (c *Coordinator) mutatorStreamParameters(ctx context.Context, call grpcCall
}
// streamDirector determines which downstream servers receive requests
-func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (*proxy.StreamParameters, error) {
+func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (*proxy.StreamParameters, error) {
// For phase 1, we need to route messages based on the storage location
// to the appropriate Gitaly node.
ctxlogrus.Extract(ctx).Debugf("Stream director received method %s", fullMethodName)
@@ -224,7 +264,12 @@ func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string,
return nil, err
}
- m, err := protoMessageFromPeeker(mi, peeker)
+ payload, err := peeker.Peek()
+ if err != nil {
+ return nil, err
+ }
+
+ m, err := protoMessage(mi, payload)
if err != nil {
return nil, err
}
@@ -243,8 +288,8 @@ func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string,
fullMethodName: fullMethodName,
methodInfo: mi,
msg: m,
- peeker: peeker,
- targetRepo: targetRepo},
+ targetRepo: targetRepo,
+ },
)
if err != nil {
if errors.Is(err, nodes.ErrVirtualStorageNotExist) {
@@ -266,13 +311,17 @@ func (c *Coordinator) StreamDirector(ctx context.Context, fullMethodName string,
return nil, err
}
- return proxy.NewStreamParameters(ctx, shard.Primary.GetConnection(), func() {}, nil), nil
+ return proxy.NewStreamParameters(proxy.Destination{
+ Ctx: helper.IncomingToOutgoing(ctx),
+ Conn: shard.Primary.GetConnection(),
+ Msg: payload,
+ }, nil, func() {}, nil), nil
}
-func (c *Coordinator) rewriteStorageForRepositoryMessage(mi protoregistry.MethodInfo, m proto.Message, peeker proxy.StreamModifier, storage string) error {
+func rewrittenRepositoryMessage(mi protoregistry.MethodInfo, m proto.Message, storage string) ([]byte, error) {
targetRepo, err := mi.TargetRepo(m)
if err != nil {
- return helper.ErrInvalidArgument(err)
+ return nil, helper.ErrInvalidArgument(err)
}
// rewrite storage name
@@ -280,7 +329,7 @@ func (c *Coordinator) rewriteStorageForRepositoryMessage(mi protoregistry.Method
additionalRepo, ok, err := mi.AdditionalRepo(m)
if err != nil {
- return helper.ErrInvalidArgument(err)
+ return nil, helper.ErrInvalidArgument(err)
}
if ok {
@@ -289,22 +338,13 @@ func (c *Coordinator) rewriteStorageForRepositoryMessage(mi protoregistry.Method
b, err := proxy.NewCodec().Marshal(m)
if err != nil {
- return err
- }
-
- if err = peeker.Modify(b); err != nil {
- return err
+ return nil, err
}
- return nil
+ return b, nil
}
-func protoMessageFromPeeker(mi protoregistry.MethodInfo, peeker proxy.StreamModifier) (proto.Message, error) {
- frame, err := peeker.Peek()
- if err != nil {
- return nil, err
- }
-
+func protoMessage(mi protoregistry.MethodInfo, frame []byte) (proto.Message, error) {
m, err := mi.UnmarshalRequestProto(frame)
if err != nil {
return nil, err
diff --git a/internal/praefect/coordinator_test.go b/internal/praefect/coordinator_test.go
index 33109c7cf..b3a3f7994 100644
--- a/internal/praefect/coordinator_test.go
+++ b/internal/praefect/coordinator_test.go
@@ -201,16 +201,16 @@ func TestStreamDirectorMutator(t *testing.T) {
peeker := &mockPeeker{frame}
streamParams, err := coordinator.StreamDirector(correlation.ContextWithCorrelation(ctx, "my-correlation-id"), fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, primaryAddress, streamParams.Conn().Target())
+ require.Equal(t, primaryAddress, streamParams.Primary().Conn.Target())
- md, ok := metadata.FromOutgoingContext(streamParams.Context())
+ md, ok := metadata.FromOutgoingContext(streamParams.Primary().Ctx)
require.True(t, ok)
require.Contains(t, md, "praefect-server")
mi, err := coordinator.registry.LookupMethod(fullMethod)
require.NoError(t, err)
- m, err := protoMessageFromPeeker(mi, peeker)
+ m, err := mi.UnmarshalRequestProto(streamParams.Primary().Msg)
require.NoError(t, err)
rewrittenTargetRepo, err := mi.TargetRepo(m)
@@ -294,9 +294,9 @@ func TestStreamDirectorAccessor(t *testing.T) {
peeker := &mockPeeker{frame: frame}
streamParams, err := coordinator.StreamDirector(correlation.ContextWithCorrelation(ctx, "my-correlation-id"), fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, gitalyAddress, streamParams.Conn().Target())
+ require.Equal(t, gitalyAddress, streamParams.Primary().Conn.Target())
- md, ok := metadata.FromOutgoingContext(streamParams.Context())
+ md, ok := metadata.FromOutgoingContext(streamParams.Primary().Ctx)
require.True(t, ok)
require.Contains(t, md, "praefect-server")
@@ -305,7 +305,7 @@ func TestStreamDirectorAccessor(t *testing.T) {
require.Equal(t, protoregistry.ScopeRepository, mi.Scope, "method must be repository scoped")
require.Equal(t, protoregistry.OpAccessor, mi.Operation, "method must be an accessor")
- m, err := protoMessageFromPeeker(mi, peeker)
+ m, err := mi.UnmarshalRequestProto(streamParams.Primary().Msg)
require.NoError(t, err)
rewrittenTargetRepo, err := mi.TargetRepo(m)
@@ -373,9 +373,9 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
peeker := &mockPeeker{frame: frame}
streamParams, err := coordinator.StreamDirector(correlation.ContextWithCorrelation(ctx, "my-correlation-id"), fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, secondaryNodeConf.Address, streamParams.Conn().Target(), "must be redirected to secondary")
+ require.Equal(t, secondaryNodeConf.Address, streamParams.Primary().Conn.Target(), "must be redirected to secondary")
- md, ok := metadata.FromOutgoingContext(streamParams.Context())
+ md, ok := metadata.FromOutgoingContext(streamParams.Primary().Ctx)
require.True(t, ok)
require.Contains(t, md, "praefect-server")
@@ -383,7 +383,7 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
require.NoError(t, err)
require.Equal(t, protoregistry.OpAccessor, mi.Operation, "method must be an accessor")
- m, err := protoMessageFromPeeker(mi, peeker)
+ m, err := protoMessage(mi, streamParams.Primary().Msg)
require.NoError(t, err)
rewrittenTargetRepo, err := mi.TargetRepo(m)
@@ -413,9 +413,9 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
peeker := &mockPeeker{frame: frame}
streamParams, err := coordinator.StreamDirector(correlation.ContextWithCorrelation(ctx, "my-correlation-id"), fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, primaryNodeConf.Address, streamParams.Conn().Target(), "must be redirected to primary")
+ require.Equal(t, primaryNodeConf.Address, streamParams.Primary().Conn.Target(), "must be redirected to primary")
- md, ok := metadata.FromOutgoingContext(streamParams.Context())
+ md, ok := metadata.FromOutgoingContext(streamParams.Primary().Ctx)
require.True(t, ok)
require.Contains(t, md, "praefect-server")
@@ -423,7 +423,7 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
require.NoError(t, err)
require.Equal(t, protoregistry.OpAccessor, mi.Operation, "method must be an accessor")
- m, err := protoMessageFromPeeker(mi, peeker)
+ m, err := protoMessage(mi, streamParams.Primary().Msg)
require.NoError(t, err)
rewrittenTargetRepo, err := mi.TargetRepo(m)
@@ -440,9 +440,9 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
peeker := &mockPeeker{frame: frame}
streamParams, err := coordinator.StreamDirector(correlation.ContextWithCorrelation(ctx, "my-correlation-id"), fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, primaryNodeConf.Address, streamParams.Conn().Target(), "must be redirected to primary")
+ require.Equal(t, primaryNodeConf.Address, streamParams.Primary().Conn.Target(), "must be redirected to primary")
- md, ok := metadata.FromOutgoingContext(streamParams.Context())
+ md, ok := metadata.FromOutgoingContext(streamParams.Primary().Ctx)
require.True(t, ok)
require.Contains(t, md, "praefect-server")
@@ -450,7 +450,7 @@ func TestCoordinatorStreamDirector_distributesReads(t *testing.T) {
require.NoError(t, err)
require.Equal(t, protoregistry.OpMutator, mi.Operation, "method must be a mutator")
- m, err := protoMessageFromPeeker(mi, peeker)
+ m, err := protoMessage(mi, streamParams.Primary().Msg)
require.NoError(t, err)
rewrittenTargetRepo, err := mi.TargetRepo(m)
@@ -546,7 +546,7 @@ func TestAbsentCorrelationID(t *testing.T) {
peeker := &mockPeeker{frame}
streamParams, err := coordinator.StreamDirector(ctx, fullMethod, peeker)
require.NoError(t, err)
- require.Equal(t, primaryAddress, streamParams.Conn().Target())
+ require.Equal(t, primaryAddress, streamParams.Primary().Conn.Target())
replEventWait.Add(1) // expected only one event to be created
// must be run as it adds replication events to the queue
diff --git a/internal/praefect/grpc-proxy/proxy/director.go b/internal/praefect/grpc-proxy/proxy/director.go
index 50a0ee63f..9c613dd11 100644
--- a/internal/praefect/grpc-proxy/proxy/director.go
+++ b/internal/praefect/grpc-proxy/proxy/director.go
@@ -6,7 +6,6 @@ package proxy
import (
"context"
- "gitlab.com/gitlab-org/gitaly/internal/helper"
"google.golang.org/grpc"
)
@@ -23,35 +22,40 @@ import (
// are invoked. So decisions around authorization, monitoring etc. are better to be handled there.
//
// See the rather rich example.
-type StreamDirector func(ctx context.Context, fullMethodName string, peeker StreamModifier) (*StreamParameters, error)
+type StreamDirector func(ctx context.Context, fullMethodName string, peeker StreamPeeker) (*StreamParameters, error)
// StreamParameters encapsulates streaming parameters the praefect coordinator returns to the
// proxy handler
type StreamParameters struct {
- ctx context.Context
- conn *grpc.ClientConn
+ primary Destination
reqFinalizer func()
callOptions []grpc.CallOption
+ secondaries []Destination
+}
+
+// Destination contains a client connection as well as a rewritten protobuf message
+type Destination struct {
+ Ctx context.Context
+ Conn *grpc.ClientConn
+ Msg []byte
}
// NewStreamParameters returns a new instance of StreamParameters
-func NewStreamParameters(ctx context.Context, conn *grpc.ClientConn, reqFinalizer func(), callOpts []grpc.CallOption) *StreamParameters {
+func NewStreamParameters(primary Destination, secondaries []Destination, reqFinalizer func(), callOpts []grpc.CallOption) *StreamParameters {
return &StreamParameters{
- ctx: helper.IncomingToOutgoing(ctx),
- conn: conn,
+ primary: primary,
+ secondaries: secondaries,
reqFinalizer: reqFinalizer,
callOptions: callOpts,
}
}
-// Context returns the outgoing context
-func (s *StreamParameters) Context() context.Context {
- return s.ctx
+func (s *StreamParameters) Primary() Destination {
+ return s.primary
}
-// Conn returns a grpc client connection
-func (s *StreamParameters) Conn() *grpc.ClientConn {
- return s.conn
+func (s *StreamParameters) Secondaries() []Destination {
+ return s.secondaries
}
// RequestFinalizer calls the request finalizer
diff --git a/internal/praefect/grpc-proxy/proxy/examples_test.go b/internal/praefect/grpc-proxy/proxy/examples_test.go
index 0f4050884..eb5508de5 100644
--- a/internal/praefect/grpc-proxy/proxy/examples_test.go
+++ b/internal/praefect/grpc-proxy/proxy/examples_test.go
@@ -11,6 +11,7 @@ import (
"context"
"strings"
+ "gitlab.com/gitlab-org/gitaly/internal/helper"
"gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/proxy"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -40,7 +41,7 @@ func ExampleTransparentHandler() {
// Provide sa simple example of a director that shields internal services and dials a staging or production backend.
// This is a *very naive* implementation that creates a new connection on every request. Consider using pooling.
func ExampleStreamDirector() {
- director = func(ctx context.Context, fullMethodName string, _ proxy.StreamModifier) (*proxy.StreamParameters, error) {
+ director = func(ctx context.Context, fullMethodName string, _ proxy.StreamPeeker) (*proxy.StreamParameters, error) {
// Make sure we never forward internal services.
if strings.HasPrefix(fullMethodName, "/com.example.internal.") {
return nil, status.Errorf(codes.Unimplemented, "Unknown method")
@@ -51,10 +52,10 @@ func ExampleStreamDirector() {
if val, exists := md[":authority"]; exists && val[0] == "staging.api.example.com" {
// Make sure we use DialContext so the dialing can be cancelled/time out together with the context.
conn, err := grpc.DialContext(ctx, "api-service.staging.svc.local", grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())))
- return proxy.NewStreamParameters(ctx, conn, nil, nil), err
+ return proxy.NewStreamParameters(proxy.Destination{Conn: conn, Ctx: helper.IncomingToOutgoing(ctx)}, nil, nil, nil), err
} else if val, exists := md[":authority"]; exists && val[0] == "api.example.com" {
conn, err := grpc.DialContext(ctx, "api-service.prod.svc.local", grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())))
- return proxy.NewStreamParameters(ctx, conn, nil, nil), err
+ return proxy.NewStreamParameters(proxy.Destination{Conn: conn, Ctx: helper.IncomingToOutgoing(ctx)}, nil, nil, nil), err
}
}
return nil, status.Errorf(codes.Unimplemented, "Unknown method")
diff --git a/internal/praefect/grpc-proxy/proxy/handler.go b/internal/praefect/grpc-proxy/proxy/handler.go
index 5ea376e07..f226488c7 100644
--- a/internal/praefect/grpc-proxy/proxy/handler.go
+++ b/internal/praefect/grpc-proxy/proxy/handler.go
@@ -9,9 +9,11 @@ package proxy
import (
"context"
+ "errors"
"io"
"gitlab.com/gitlab-org/gitaly/internal/middleware/sentryhandler"
+ "golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -88,6 +90,12 @@ type handler struct {
director StreamDirector
}
+type streamAndMsg struct {
+ grpc.ClientStream
+ msg []byte
+ cancel func()
+}
+
// handler is where the real magic of proxying happens.
// It is invoked like any gRPC server stream and uses the gRPC server framing to get and receive bytes from the wire,
// forwarding it to a ClientStream established against the relevant ClientConn.
@@ -108,38 +116,66 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
defer params.RequestFinalizer()
- clientCtx, clientCancel := context.WithCancel(params.Context())
+ clientCtx, clientCancel := context.WithCancel(params.Primary().Ctx)
defer clientCancel()
// TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For.
- clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, params.Conn(), fullMethodName, params.CallOptions()...)
+
+ primaryClientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, params.Primary().Conn, fullMethodName, params.CallOptions()...)
if err != nil {
return err
}
+
+ primaryStream := streamAndMsg{
+ ClientStream: primaryClientStream,
+ msg: params.Primary().Msg,
+ cancel: clientCancel,
+ }
+
+ var secondaryStreams []streamAndMsg
+ for _, conn := range params.Secondaries() {
+ clientCtx, clientCancel := context.WithCancel(conn.Ctx)
+ defer clientCancel()
+
+ secondaryClientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, conn.Conn, fullMethodName, params.CallOptions()...)
+ if err != nil {
+ return err
+ }
+ secondaryStreams = append(secondaryStreams, streamAndMsg{
+ ClientStream: secondaryClientStream,
+ msg: conn.Msg,
+ cancel: clientCancel,
+ })
+ }
+
// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
// Channels do not have to be closed, it is just a control flow mechanism, see
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
- s2cErrChan := s.forwardServerToClient(serverStream, clientStream, peeker.consumedStream)
- c2sErrChan := s.forwardClientToServer(clientStream, serverStream)
- // We don't know which side is going to stop sending first, so we need a select between the two.
- for i := 0; i < 2; i++ {
+ s2cErrChan := s.forwardServerToClients(serverStream, append(secondaryStreams, primaryStream))
+ c2sErrChan := s.forwardClientToServer(primaryClientStream, serverStream)
+ secondaryErrChan := receiveSecondaryStreams(secondaryStreams)
+
+ // We don't know whether the server, primary, or secondaries will stop sending first, so we need a select between them
+ for {
select {
- case s2cErr := <-s2cErrChan:
- if s2cErr == io.EOF {
- // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
- // the clientStream>serverStream may continue pumping though.
- clientStream.CloseSend()
- } else {
- // however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
+ case s2cErr, ok := <-s2cErrChan:
+ if !ok {
+ continue
+ }
+ if s2cErr != nil {
+ // we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
// to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and
// exit with an error to the stack
- clientCancel()
+
+ for _, stream := range append(secondaryStreams, primaryStream) {
+ stream.cancel()
+ }
return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
}
case c2sErr := <-c2sErrChan:
// This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two
// cases we may have received Trailers as part of the call. In case of other errors (stream closed) the trailers
// will be nil.
- trailer := clientStream.Trailer()
+ trailer := primaryClientStream.Trailer()
serverStream.SetTrailer(trailer)
// c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error.
if c2sErr != io.EOF {
@@ -149,10 +185,48 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error
}
return c2sErr
}
+
+ secondaryErr, ok := <-secondaryErrChan
+ if !ok {
+ return status.Error(codes.Internal, "failed proxying to secondary")
+ }
+ if secondaryErr != nil {
+ return status.Errorf(codes.Internal, "failed proxying to secondary: %v", secondaryErr)
+ }
+
return nil
}
}
- return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
+}
+
+// receiveSecondaryStreams reads from the client streams of the secondaries and drops the message
+// but returns an error to the channel if it encounters a non io.EOF error
+func receiveSecondaryStreams(srcs []streamAndMsg) chan error {
+ ret := make(chan error, 1)
+
+ go func() {
+ var g errgroup.Group
+ defer close(ret)
+
+ for _, src := range srcs {
+ src := src // rescoping for goroutine
+ g.Go(func() error {
+ for {
+ if err := src.RecvMsg(&frame{}); err != nil {
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+
+ src.cancel()
+ return err
+ }
+ }
+ })
+ }
+
+ ret <- g.Wait()
+ }()
+ return ret
}
func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
@@ -187,41 +261,68 @@ func (s *handler) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerSt
return ret
}
-func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream, consumedStream *partialStream) chan error {
- ret := make(chan error, 1)
- go func() {
- // send any consumed/peeked frames first
- for _, frame := range consumedStream.frames {
- if frame == nil {
- // It is possible for peeked frames to be empty. This most likely
- // occurs when the server stream returns an error before the desired
- // number of frames can be peeked
+func forwardConsumedToClient(dst grpc.ClientStream, frameChan <-chan *frame) error {
+ for f := range frameChan {
+ if err := dst.SendMsg(f); err != nil {
+ if errors.Is(err, io.EOF) {
break
}
- if err := dst.SendMsg(frame); err != nil {
- ret <- err
- return
- }
+ return err
}
+ }
+
+ // all messages redirected
+ return dst.CloseSend()
+}
+
+func (s *handler) forwardServerToClients(src grpc.ServerStream, dsts []streamAndMsg) chan error {
+ ret := make(chan error, 1)
+ go func() {
+ var g errgroup.Group
+ defer close(ret)
- // we may have encountered an error earlier while peeking
- if consumedStream.err != nil {
- ret <- consumedStream.err
- return
+ frameChans := make([]chan<- *frame, 0, len(dsts))
+
+ for _, dst := range dsts {
+ dst := dst
+ frameChan := make(chan *frame, 16)
+ frameChan <- &frame{payload: dst.msg} // send re-written message
+ frameChans = append(frameChans, frameChan)
+
+ g.Go(func() error { return forwardConsumedToClient(dst, frameChan) })
}
- // resume two-way stream after peeked messages
- f := &frame{}
- for i := 0; ; i++ {
- if err := src.RecvMsg(f); err != nil {
- ret <- err // this can be io.EOF which is happy case
- break
- }
- if err := dst.SendMsg(f); err != nil {
+ for {
+ if err := consumeServerAndForward(src, frameChans); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+
ret <- err
- break
+ return
}
}
+
+ ret <- g.Wait()
}()
return ret
}
+
+func consumeServerAndForward(src grpc.ServerStream, frameChans []chan<- *frame) error {
+ f := &frame{}
+
+ if err := src.RecvMsg(f); err != nil {
+ for _, frameChan := range frameChans {
+ // signal no more data to redirect
+ close(frameChan)
+ }
+
+ return err // this can be io.EOF which is happy case
+ }
+
+ for _, frameChan := range frameChans {
+ frameChan <- f
+ }
+
+ return nil
+}
diff --git a/internal/praefect/grpc-proxy/proxy/handler_test.go b/internal/praefect/grpc-proxy/proxy/handler_test.go
index 267ed1e85..f9ef180a1 100644
--- a/internal/praefect/grpc-proxy/proxy/handler_test.go
+++ b/internal/praefect/grpc-proxy/proxy/handler_test.go
@@ -25,6 +25,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gitlab.com/gitlab-org/gitaly/client"
+ "gitlab.com/gitlab-org/gitaly/internal/helper"
"gitlab.com/gitlab-org/gitaly/internal/helper/fieldextractors"
"gitlab.com/gitlab-org/gitaly/internal/middleware/sentryhandler"
"gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/proxy"
@@ -252,15 +253,21 @@ func (s *ProxyHappySuite) SetupSuite() {
// Setup of the proxy's Director.
s.serverClientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())))
require.NoError(s.T(), err, "must not error on deferred client Dial")
- director := func(ctx context.Context, fullName string, _ proxy.StreamModifier) (*proxy.StreamParameters, error) {
+ director := func(ctx context.Context, fullName string, peeker proxy.StreamPeeker) (*proxy.StreamParameters, error) {
+ payload, err := peeker.Peek()
+ if err != nil {
+ return nil, err
+ }
+
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if _, exists := md[rejectingMdKey]; exists {
- return proxy.NewStreamParameters(ctx, nil, nil, nil), status.Errorf(codes.PermissionDenied, "testing rejection")
+ return proxy.NewStreamParameters(proxy.Destination{Ctx: helper.IncomingToOutgoing(ctx), Msg: payload}, nil, nil, nil), status.Errorf(codes.PermissionDenied, "testing rejection")
}
}
+
// Explicitly copy the metadata, otherwise the tests will fail.
- return proxy.NewStreamParameters(ctx, s.serverClientConn, nil, nil), nil
+ return proxy.NewStreamParameters(proxy.Destination{Ctx: helper.IncomingToOutgoing(ctx), Conn: s.serverClientConn, Msg: payload}, nil, nil, nil), nil
}
s.proxy = grpc.NewServer(
@@ -326,7 +333,7 @@ func TestRegisterStreamHandlers(t *testing.T) {
server := grpc.NewServer(
grpc.CustomCodec(proxy.NewCodec()),
- grpc.UnknownServiceHandler(proxy.TransparentHandler(func(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (*proxy.StreamParameters, error) {
+ grpc.UnknownServiceHandler(proxy.TransparentHandler(func(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (*proxy.StreamParameters, error) {
return nil, directorCalledError
})),
)
diff --git a/internal/praefect/grpc-proxy/proxy/peeker.go b/internal/praefect/grpc-proxy/proxy/peeker.go
index 1d1e02df5..459468dbe 100644
--- a/internal/praefect/grpc-proxy/proxy/peeker.go
+++ b/internal/praefect/grpc-proxy/proxy/peeker.go
@@ -2,28 +2,21 @@ package proxy
import (
"errors"
- "fmt"
"google.golang.org/grpc"
)
// StreamModifier abstracts away the gRPC stream being forwarded so that it can
// be inspected and modified.
-type StreamModifier interface {
+type StreamPeeker interface {
// Peek allows a director to peek one message into the request stream without
// removing those messages from the stream that will be forwarded to
// the backend server.
Peek() (frame []byte, _ error)
-
- // Modify replaces the peeked payload in the stream with the provided frame.
- // If no payload was peeked, an error will be returned.
- // Note: Modify cannot be called after the director returns.
- Modify(payload []byte) error
}
type partialStream struct {
frames []*frame // frames encountered in partial stream
- err error // error returned by partial stream
}
type peeker struct {
@@ -55,10 +48,6 @@ func (p peeker) Peek() ([]byte, error) {
return payloads[0], nil
}
-func (p peeker) Modify(payload []byte) error {
- return p.modify([][]byte{payload})
-}
-
func (p peeker) peek(n uint) ([][]byte, error) {
if n < 1 {
return nil, ErrInvalidPeekCount
@@ -70,8 +59,7 @@ func (p peeker) peek(n uint) ([][]byte, error) {
for i := 0; i < len(p.consumedStream.frames); i++ {
f := &frame{}
if err := p.srcStream.RecvMsg(f); err != nil {
- p.consumedStream.err = err
- break
+ return nil, err
}
p.consumedStream.frames[i] = f
peekedFrames[i] = f.payload
@@ -79,15 +67,3 @@ func (p peeker) peek(n uint) ([][]byte, error) {
return peekedFrames, nil
}
-
-func (p peeker) modify(payloads [][]byte) error {
- if len(payloads) != len(p.consumedStream.frames) {
- return fmt.Errorf("replacement frames count %d does not match consumed frames count %d", len(payloads), len(p.consumedStream.frames))
- }
-
- for i, payload := range payloads {
- p.consumedStream.frames[i].payload = payload
- }
-
- return nil
-}
diff --git a/internal/praefect/grpc-proxy/proxy/peeker_test.go b/internal/praefect/grpc-proxy/proxy/peeker_test.go
index b5d882ff4..823cb50f6 100644
--- a/internal/praefect/grpc-proxy/proxy/peeker_test.go
+++ b/internal/praefect/grpc-proxy/proxy/peeker_test.go
@@ -9,6 +9,7 @@ import (
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/helper"
"gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/proxy"
testservice "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/testdata"
)
@@ -26,7 +27,7 @@ func TestStreamPeeking(t *testing.T) {
pingReqSent := &testservice.PingRequest{Value: "hi"}
// director will peek into stream before routing traffic
- director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (*proxy.StreamParameters, error) {
+ director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (*proxy.StreamParameters, error) {
t.Logf("director routing method %s to backend", fullMethodName)
peekedMsg, err := peeker.Peek()
@@ -37,7 +38,7 @@ func TestStreamPeeking(t *testing.T) {
require.NoError(t, err)
require.True(t, proto.Equal(pingReqSent, peekedRequest), "expected to be the same")
- return proxy.NewStreamParameters(ctx, backendCC, nil, nil), nil
+ return proxy.NewStreamParameters(proxy.Destination{Ctx: helper.IncomingToOutgoing(ctx), Conn: backendCC, Msg: peekedMsg}, nil, nil, nil), nil
}
pingResp := &testservice.PingResponse{
@@ -85,7 +86,7 @@ func TestStreamInjecting(t *testing.T) {
newValue := "bye"
// director will peek into stream and change some frames
- director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (*proxy.StreamParameters, error) {
+ director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (*proxy.StreamParameters, error) {
t.Logf("modifying request for method %s", fullMethodName)
peekedMsg, err := peeker.Peek()
@@ -100,9 +101,7 @@ func TestStreamInjecting(t *testing.T) {
newPayload, err := proto.Marshal(peekedRequest)
require.NoError(t, err)
- require.NoError(t, peeker.Modify(newPayload))
-
- return proxy.NewStreamParameters(ctx, backendCC, nil, nil), nil
+ return proxy.NewStreamParameters(proxy.Destination{Ctx: helper.IncomingToOutgoing(ctx), Conn: backendCC, Msg: newPayload}, nil, nil, nil), nil
}
pingResp := &testservice.PingResponse{
diff --git a/internal/praefect/server_test.go b/internal/praefect/server_test.go
index 9f59145ec..2ea5a5ef8 100644
--- a/internal/praefect/server_test.go
+++ b/internal/praefect/server_test.go
@@ -1,8 +1,12 @@
package praefect
import (
+ "bytes"
"context"
+ "errors"
+ "io"
"io/ioutil"
+ "net"
"os"
"path/filepath"
"strings"
@@ -19,13 +23,21 @@ import (
"gitlab.com/gitlab-org/gitaly/internal/git"
"gitlab.com/gitlab-org/gitaly/internal/helper"
"gitlab.com/gitlab-org/gitaly/internal/helper/text"
+ "gitlab.com/gitlab-org/gitaly/internal/metadata/featureflag"
"gitlab.com/gitlab-org/gitaly/internal/praefect/config"
"gitlab.com/gitlab-org/gitaly/internal/praefect/datastore"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/grpc-proxy/proxy"
"gitlab.com/gitlab-org/gitaly/internal/praefect/mock"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/nodes"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/transactions"
"gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper/promtest"
"gitlab.com/gitlab-org/gitaly/internal/version"
"gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
"google.golang.org/grpc/health/grpc_health_v1"
+ "google.golang.org/grpc/reflection"
)
func TestServerRouteServerAccessor(t *testing.T) {
@@ -635,6 +647,206 @@ func TestRepoRename(t *testing.T) {
defer func() { require.NoError(t, os.RemoveAll(expNewPath2)) }()
}
+type mockSmartHttp struct {
+ m sync.Mutex
+ methodsCalled map[string]int
+}
+
+func (m *mockSmartHttp) InfoRefsUploadPack(req *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsUploadPackServer) error {
+ m.m.Lock()
+ defer m.m.Unlock()
+ if m.methodsCalled == nil {
+ m.methodsCalled = make(map[string]int)
+ }
+
+ m.methodsCalled["InfoRefsUploadPack"] += 1
+
+ stream.Send(&gitalypb.InfoRefsResponse{})
+ return nil
+}
+
+func (m *mockSmartHttp) InfoRefsReceivePack(req *gitalypb.InfoRefsRequest, stream gitalypb.SmartHTTPService_InfoRefsReceivePackServer) error {
+ m.m.Lock()
+ defer m.m.Unlock()
+ if m.methodsCalled == nil {
+ m.methodsCalled = make(map[string]int)
+ }
+
+ m.methodsCalled["InfoRefsReceivePack"] += 1
+
+ stream.Send(&gitalypb.InfoRefsResponse{})
+ return nil
+}
+
+func (m *mockSmartHttp) PostUploadPack(stream gitalypb.SmartHTTPService_PostUploadPackServer) error {
+ m.m.Lock()
+ defer m.m.Unlock()
+ if m.methodsCalled == nil {
+ m.methodsCalled = make(map[string]int)
+ }
+
+ m.methodsCalled["PostUploadPack"] += 1
+
+ stream.Send(&gitalypb.PostUploadPackResponse{})
+ return nil
+}
+
+func (m *mockSmartHttp) PostReceivePack(stream gitalypb.SmartHTTPService_PostReceivePackServer) error {
+ m.m.Lock()
+ defer m.m.Unlock()
+ if m.methodsCalled == nil {
+ m.methodsCalled = make(map[string]int)
+ }
+
+ m.methodsCalled["PostReceivePack"] += 1
+
+ var err error
+ var req *gitalypb.PostReceivePackRequest
+ for {
+ req, err = stream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ return helper.ErrInternal(err)
+ }
+
+ if err := stream.Send(&gitalypb.PostReceivePackResponse{Data: req.GetData()}); err != nil {
+ return helper.ErrInternal(err)
+ }
+ }
+
+ return nil
+}
+
+func (m *mockSmartHttp) Called(method string) int {
+ m.m.Lock()
+ defer m.m.Unlock()
+
+ return m.methodsCalled[method]
+}
+
+func newGrpcServer(t *testing.T, srv gitalypb.SmartHTTPServiceServer) (string, *grpc.Server) {
+ grpcSrv := testhelper.NewTestGrpcServer(t, nil, nil)
+ socketPath := testhelper.GetTemporaryGitalySocketFileName()
+
+ gitalypb.RegisterSmartHTTPServiceServer(grpcSrv, srv)
+ reflection.Register(grpcSrv)
+
+ listener, err := net.Listen("unix", socketPath)
+ require.NoError(t, err)
+
+ go func() { grpcSrv.Serve(listener) }()
+
+ return socketPath, grpcSrv
+}
+
+func TestProxyWrites(t *testing.T) {
+ smartHttp0, smartHttp1, smartHttp2 := &mockSmartHttp{}, &mockSmartHttp{}, &mockSmartHttp{}
+
+ socket0, srv0 := newGrpcServer(t, smartHttp0)
+ defer srv0.Stop()
+ socket1, srv1 := newGrpcServer(t, smartHttp1)
+ defer srv1.Stop()
+ socket2, srv2 := newGrpcServer(t, smartHttp2)
+ defer srv2.Stop()
+
+ conf := config.Config{
+ VirtualStorages: []*config.VirtualStorage{
+ {
+ Name: "default",
+ Nodes: []*config.Node{
+ {
+ DefaultPrimary: true,
+ Storage: "praefect-internal-0",
+ Address: "unix://" + socket0,
+ },
+ {
+ Storage: "praefect-internal-1",
+ Address: "unix://" + socket1,
+ },
+ {
+ Storage: "praefect-internal-2",
+ Address: "unix://" + socket2,
+ },
+ },
+ },
+ },
+ }
+
+ queue := datastore.NewMemoryReplicationEventQueue(conf)
+ entry := testhelper.DiscardTestEntry(t)
+
+ nodeMgr, err := nodes.NewManager(entry, conf, nil, queue, promtest.NewMockHistogramVec())
+ require.NoError(t, err)
+ txMgr := transactions.NewManager()
+
+ coordinator := NewCoordinator(queue, nodeMgr, txMgr, conf, protoregistry.GitalyProtoPreregistered)
+
+ server := grpc.NewServer(
+ grpc.CustomCodec(proxy.NewCodec()),
+ grpc.UnknownServiceHandler(proxy.TransparentHandler(coordinator.StreamDirector)),
+ )
+
+ socket := testhelper.GetTemporaryGitalySocketFileName()
+ listener, err := net.Listen("unix", socket)
+ require.NoError(t, err)
+
+ go server.Serve(listener)
+ defer server.Stop()
+
+ client, _ := newSmartHTTPClient(t, "unix://"+socket)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ ctx = featureflag.OutgoingCtxWithFeatureFlag(ctx, featureflag.ReferenceTransactions)
+
+ testRepo, _, cleanup := testhelper.NewTestRepo(t)
+ defer cleanup()
+
+ stream, err := client.PostReceivePack(ctx)
+ require.NoError(t, err)
+
+ payload := "some pack data"
+ for i := 0; i < 10; i++ {
+ require.NoError(t, stream.Send(&gitalypb.PostReceivePackRequest{
+ Repository: testRepo,
+ Data: []byte(payload),
+ }))
+ }
+
+ require.NoError(t, stream.CloseSend())
+
+ var receivedData bytes.Buffer
+ for {
+ resp, err := stream.Recv()
+ if err != nil {
+ if err == io.EOF {
+ break
+ }
+ require.FailNowf(t, "unexpected non io.EOF error: %v", err.Error())
+ }
+
+ _, err = receivedData.Write(resp.GetData())
+ require.NoError(t, err)
+ }
+
+ assert.Equal(t, 1, smartHttp0.Called("PostReceivePack"))
+ assert.Equal(t, 1, smartHttp1.Called("PostReceivePack"))
+ assert.Equal(t, 1, smartHttp2.Called("PostReceivePack"))
+ assert.Equal(t, bytes.Repeat([]byte(payload), 10), receivedData.Bytes())
+}
+
+func newSmartHTTPClient(t *testing.T, serverSocketPath string) (gitalypb.SmartHTTPServiceClient, *grpc.ClientConn) {
+ t.Helper()
+
+ conn, err := grpc.Dial(serverSocketPath, grpc.WithInsecure())
+ require.NoError(t, err)
+
+ return gitalypb.NewSmartHTTPServiceClient(conn), conn
+}
+
func tempStoragePath(t testing.TB) string {
p, err := ioutil.TempDir("", t.Name())
require.NoError(t, err)