diff options
author | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-07-07 14:34:47 +0300 |
---|---|---|
committer | Quang-Minh Nguyen <qmnguyen@gitlab.com> | 2021-07-08 09:42:43 +0300 |
commit | 54a06f5fe50bde8b90064c35a7f6667113ac2045 (patch) | |
tree | 0943b6eba0bb07385fa54dbebb0cf5dee00bfd22 | |
parent | 6896481abe7e8a5c888d45b261d38fec0a420175 (diff) |
Implement StreamRPC Proxyqmnguyen0711/1127-add-streamrpc-proxying-support-to-praefect
-rw-r--r-- | cmd/praefect/main.go | 13 | ||||
-rw-r--r-- | internal/praefect/server.go | 26 | ||||
-rw-r--r-- | internal/praefect/server_factory.go | 4 | ||||
-rw-r--r-- | internal/praefect/streamrpc_proxy.go | 160 | ||||
-rw-r--r-- | internal/streamrpc/common.go | 13 | ||||
-rw-r--r-- | internal/streamrpc/handshaker.go | 4 | ||||
-rw-r--r-- | internal/streamrpc/server.go | 18 |
7 files changed, 222 insertions, 16 deletions
diff --git a/cmd/praefect/main.go b/cmd/praefect/main.go index fde28f0ef..038cffc59 100644 --- a/cmd/praefect/main.go +++ b/cmd/praefect/main.go @@ -304,11 +304,11 @@ func run(cfgs []starter.Config, conf config.Config) error { clientHandshaker := backchannel.NewClientHandshaker(logger, praefect.NewBackchannelServerFactory(logger, transaction.NewServer(transactionManager))) assignmentStore := praefect.NewDisabledAssignmentStore(conf.StorageNames()) var ( - nodeManager nodes.Manager - healthChecker praefect.HealthChecker - nodeSet praefect.NodeSet - router praefect.Router - primaryGetter praefect.PrimaryGetter + nodeManager nodes.Manager + healthChecker praefect.HealthChecker + nodeSet praefect.NodeSet + router praefect.Router + primaryGetter praefect.PrimaryGetter ) if conf.Failover.ElectionStrategy == config.ElectionStrategyPerRepository { nodeSet, err = praefect.DialNodes(ctx, conf.VirtualStorages, protoregistry.GitalyProtoPreregistered, errTracker, clientHandshaker) @@ -381,6 +381,8 @@ func run(cfgs []starter.Config, conf config.Config) error { protoregistry.GitalyProtoPreregistered, ) + streamRPCProxy = praefect.NewStreamRPCProxy(router) + repl = praefect.NewReplMgr( logger, conf.VirtualStorageNames(), @@ -404,6 +406,7 @@ func run(cfgs []starter.Config, conf config.Config) error { protoregistry.GitalyProtoPreregistered, nodeSet.Connections(), primaryGetter, + streamRPCProxy, ) ) metricsCollectors = append(metricsCollectors, transactionManager, coordinator, repl) diff --git a/internal/praefect/server.go b/internal/praefect/server.go index 42166a07b..00bd65d6e 100644 --- a/internal/praefect/server.go +++ b/internal/praefect/server.go @@ -30,10 +30,13 @@ import ( "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/service/transaction" "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/transactions" "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" + "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc" grpccorrelation "gitlab.com/gitlab-org/labkit/correlation/grpc" grpctracing "gitlab.com/gitlab-org/labkit/tracing/grpc" "google.golang.org/grpc" "google.golang.org/grpc/health" + "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" ) @@ -87,6 +90,7 @@ func NewGRPCServer( assignmentStore AssignmentStore, conns Connections, primaryGetter PrimaryGetter, + streamRPCProxy *StreamRPCProxy, grpcOpts ...grpc.ServerOption, ) *grpc.Server { streamInterceptors := []grpc.StreamServerInterceptor{ @@ -105,21 +109,29 @@ func NewGRPCServer( // converted to errors and logged panichandler.StreamPanicHandler, } + unaryInterceptors := grpc_middleware.ChainUnaryServer( + append( + commonUnaryServerInterceptors(logger), + middleware.MethodTypeUnaryInterceptor(registry), + auth.UnaryServerInterceptor(conf.Auth), + )..., + ) if conf.Failover.ElectionStrategy == config.ElectionStrategyPerRepository { streamInterceptors = append(streamInterceptors, RepositoryExistsStreamInterceptor(rs)) } + lm := listenmux.New(insecure.NewCredentials()) + lm.Register(streamrpc.NewServerHandshaker( + streamRPCProxy, + unaryInterceptors, + )) + grpcOpts = append(grpcOpts, proxyRequiredOpts(director)...) + grpcOpts = append(grpcOpts, grpc.Creds(lm)) grpcOpts = append(grpcOpts, []grpc.ServerOption{ grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(streamInterceptors...)), - grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( - append( - commonUnaryServerInterceptors(logger), - middleware.MethodTypeUnaryInterceptor(registry), - auth.UnaryServerInterceptor(conf.Auth), - )..., - )), + grpc.UnaryInterceptor(unaryInterceptors), grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ MinTime: 20 * time.Second, PermitWithoutStream: true, diff --git a/internal/praefect/server_factory.go b/internal/praefect/server_factory.go index e27661b0e..1ca3fa072 100644 --- a/internal/praefect/server_factory.go +++ b/internal/praefect/server_factory.go @@ -30,6 +30,7 @@ func NewServerFactory( registry *protoregistry.Registry, conns Connections, primaryGetter PrimaryGetter, + streamRPCProxy *StreamRPCProxy, ) *ServerFactory { return &ServerFactory{ conf: conf, @@ -43,6 +44,7 @@ func NewServerFactory( registry: registry, conns: conns, primaryGetter: primaryGetter, + streamRPCProxy: streamRPCProxy, } } @@ -61,6 +63,7 @@ type ServerFactory struct { secure, insecure []*grpc.Server conns Connections primaryGetter PrimaryGetter + streamRPCProxy *StreamRPCProxy } // Serve starts serving on the provided listener with newly created grpc.Server @@ -132,6 +135,7 @@ func (s *ServerFactory) createGRPC(grpcOpts ...grpc.ServerOption) *grpc.Server { s.assignmentStore, s.conns, s.primaryGetter, + s.streamRPCProxy, grpcOpts..., ) } diff --git a/internal/praefect/streamrpc_proxy.go b/internal/praefect/streamrpc_proxy.go new file mode 100644 index 000000000..2bb1848ed --- /dev/null +++ b/internal/praefect/streamrpc_proxy.go @@ -0,0 +1,160 @@ +package praefect + +import ( + "time" + "net" + "fmt" + "context" + "encoding/json" + "io" + "strings" + "sync/atomic" + "golang.org/x/sync/errgroup" + "github.com/golang/protobuf/proto" + "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc" + "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry" + "gitlab.com/gitlab-org/gitaly/v14/internal/helper" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc" +) + +// StreamRPCProxy should implement streamrpc.StreamRPCHandler interface +type StreamRPCProxy struct { + interceptor grpc.UnaryServerInterceptor + router Router +} + +func NewStreamRPCProxy(router Router) *StreamRPCProxy { + return &StreamRPCProxy{ + router: router, + } +} + +func (proxy *StreamRPCProxy) SetInterceptor(interceptor grpc.UnaryServerInterceptor) { + proxy.interceptor = interceptor +} + +func (proxy *StreamRPCProxy) Interceptor() grpc.UnaryServerInterceptor { + return proxy.interceptor; +} + +type proxySession struct { + c net.Conn + deadline time.Time +} + +func (proxy *StreamRPCProxy) Handle(c net.Conn) { + defer c.Close() + + deadline := time.Now().Add(streamrpc.DefaultHandshakeTimeout) + req, err := streamrpc.RecvFrame(c, deadline) + if err != nil { + return + } + + session := &proxySession{ + c: c, + deadline: deadline, + } + + if err := proxy.handleSession(session, req); err != nil { + _ = session.reject(err) + } + +} +func (proxy *StreamRPCProxy) handleSession(session *proxySession, reqBytes []byte) error { + req := &streamrpc.HandshakeRequest{} + if err := json.Unmarshal(reqBytes, req); err != nil { + return err + } + method, err := protoregistry.GitalyProtoPreregistered.LookupMethod(req.Method) + if err != nil { + return err + } + message, err := method.UnmarshalRequestProto(req.Message) + if err != nil { + return err + } + if method.Scope != protoregistry.ScopeRepository { + return fmt.Errorf("StreamRPC Proxy only supports registry scope at the moment") + } + if method.Operation != protoregistry.OpAccessor { + return fmt.Errorf("StreamRPC Proxy only supports accessor operation at the moment") + } + + targetRepo, err := method.TargetRepo(message) + if err != nil { + return helper.ErrInvalidArgument(fmt.Errorf("repo scoped: %w", err)) + } + + ctx, cancel := proxyContext(session, req) + defer cancel() + + // Trigger all Praefect server interceptor + _, err = proxy.interceptor( + ctx, message, &grpc.UnaryServerInfo{ FullMethod: method.FullMethodName() }, + func(ctx context.Context, req interface{}) (interface{}, error) { + // The coordinator also validates target repo. I'm skipping it now + // Also, other things like loggings? + node, err := proxy.router.RouteRepositoryAccessor( + ctx, + targetRepo.StorageName, + targetRepo.GetRelativePath(), + false, // ForcePrimary. Okay, depending on the call, i'm skipping it now^ + ) + if err != nil { + return nil, err + } + return nil, session.forward(ctx, node, method, message) + }, + ) + + return err +} + +func (session *proxySession) forward(ctx context.Context, node RouterNode, method protoregistry.MethodInfo, message proto.Message) error { + parts := strings.Split(node.Connection.Target(), ":") + if len(parts) != 2 { + return fmt.Errorf("Invalid endpoing: %s", node.Connection.Target()) + } + schema := parts[0] + addr := parts[1] + return streamrpc.Call( + ctx, // Okay, the context is messed up here + streamrpc.DialNet(schema, addr), + method.FullMethodName(), + message, + func(dst net.Conn) error { + defer dst.Close() + + // After the secondary call to the upstream passes the handshaking + // phase, we signal the client the call suceeds. Then we copy all + // data from clients to write to the connection of upstream + if err := streamrpc.SendFrame(session.c, nil, session.deadline); err != nil { + return fmt.Errorf("accept session: %w", err) + } + + // Okay, a really really naive proxy implementation. Just forward data, no error handled + go io.Copy(dst, session.c) + if _, err := io.Copy(session.c, dst); err != nil { + return err + } + return session.c.Close() + }, + ) +} + +func proxyContext(session *proxySession, req *streamrpc.HandshakeRequest) (context.Context, func()) { + ctx := context.Background() + ctx = metadata.NewIncomingContext(ctx, req.Metadata) + return context.WithCancel(ctx) +} + +func (s *proxySession) reject(err error) error { + buf, err := json.Marshal(&streamrpc.HandshakeResponse{Error: err.Error()}) + if err != nil { + return fmt.Errorf("mashal response: %w", err) + } + + return streamrpc.SendFrame(s.c, buf, s.deadline) +} diff --git a/internal/streamrpc/common.go b/internal/streamrpc/common.go index 7ee91b245..4c220fa58 100644 --- a/internal/streamrpc/common.go +++ b/internal/streamrpc/common.go @@ -14,11 +14,14 @@ type request struct { Message []byte Metadata map[string][]string } +type HandshakeRequest request // Should rename the above one. Clean up later type response struct{ Error string } +type HandshakeResponse response // Should rename the above one. Clean up later const ( defaultHandshakeTimeout = 10 * time.Second + DefaultHandshakeTimeout = defaultHandshakeTimeout // Should rename the above one. Clean up later // The frames exchanged during the handshake have a uint32 length prefix // so their theoretical maximum size is 4GB. We don't want to allow that @@ -31,6 +34,11 @@ var ( errFrameTooLarge = errors.New("frame too large") ) +// Should sendFrame. Clean up later +func SendFrame(c net.Conn, frame []byte, deadline time.Time) error { + return sendFrame(c, frame, deadline) +} + func sendFrame(c net.Conn, frame []byte, deadline time.Time) error { if len(frame) > maxFrameSize { return errFrameTooLarge @@ -45,6 +53,11 @@ func sendFrame(c net.Conn, frame []byte, deadline time.Time) error { return errAsync(deadline, func() (int, error) { return c.Write(frame) }) } +// Should recvFrame. Clean up later +func RecvFrame(c net.Conn, deadline time.Time) ([]byte, error) { + return recvFrame(c, deadline) +} + func recvFrame(c net.Conn, deadline time.Time) ([]byte, error) { header := make([]byte, 4) if err := errAsync(deadline, func() (int, error) { return io.ReadFull(c, header) }); err != nil { diff --git a/internal/streamrpc/handshaker.go b/internal/streamrpc/handshaker.go index bfe826a43..8c2051624 100644 --- a/internal/streamrpc/handshaker.go +++ b/internal/streamrpc/handshaker.go @@ -88,14 +88,14 @@ func DialTLS(network, address string, cfg *tls.Config) DialFunc { // ServerHandshaker implements the server side handshake of the multiplexed connection. type ServerHandshaker struct { - server *Server + server StreamRPCHandler } // NewServerHandshaker returns an implementation of streamrpc server // handshaker. The provided TransportCredentials are handshaked prior to // initializing the multiplexing session. This handshaker Gitaly's unary server // interceptors into the interceptor chain of input StreamRPC server. -func NewServerHandshaker(server *Server, interceptorChain grpc.UnaryServerInterceptor) *ServerHandshaker { +func NewServerHandshaker(server StreamRPCHandler, interceptorChain grpc.UnaryServerInterceptor) *ServerHandshaker { WithServerInterceptor(interceptorChain)(server) return &ServerHandshaker{ diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go index 9a4a83a18..69678604c 100644 --- a/internal/streamrpc/server.go +++ b/internal/streamrpc/server.go @@ -15,6 +15,12 @@ import ( var _ grpc.ServiceRegistrar = &Server{} +type StreamRPCHandler interface { + SetInterceptor(grpc.UnaryServerInterceptor) + Interceptor() grpc.UnaryServerInterceptor + Handle(net.Conn) +} + // Server handles network connections and routes them to StreamRPC handlers. type Server struct { methods map[string]*method @@ -28,11 +34,11 @@ type method struct { // ServerOption is an abstraction that lets you pass 0 or more server // options to NewServer. -type ServerOption func(*Server) +type ServerOption func(StreamRPCHandler) // WithServerInterceptor adds a unary gRPC server interceptor. func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption { - return func(s *Server) { s.interceptor = interceptor } + return func(s StreamRPCHandler) { s.SetInterceptor(interceptor) } } // NewServer returns a new StreamRPC server. You can pass the result to @@ -47,6 +53,14 @@ func NewServer(opts ...ServerOption) *Server { return s } +func (s *Server) SetInterceptor(interceptor grpc.UnaryServerInterceptor) { + s.interceptor = interceptor; +} + +func (s *Server) Interceptor() grpc.UnaryServerInterceptor { + return s.interceptor; +} + // RegisterService implements grpc.ServiceRegistrar. It makes it possible // to pass a *Server to grpc-go foopb.RegisterFooServer functions as the // first argument. |