Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2021-07-01 19:12:36 +0300
committerJacob Vosmaer <jacob@gitlab.com>2021-07-01 19:12:36 +0300
commit4ebb87a3334cbd51a203d3fe462c7cd588d347da (patch)
treeea98a318af0788b2d0f9e6caf201805b60848cef
parentbca44e6c708e26c67137febd5fe2a932d9406e2b (diff)
Sketch streamrpc "not found" handlerjv-streamrpc-praefect
-rw-r--r--internal/streamrpc/server.go45
1 files changed, 41 insertions, 4 deletions
diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go
index 9a4a83a18..de79957ab 100644
--- a/internal/streamrpc/server.go
+++ b/internal/streamrpc/server.go
@@ -19,6 +19,7 @@ var _ grpc.ServiceRegistrar = &Server{}
type Server struct {
methods map[string]*method
interceptor grpc.UnaryServerInterceptor
+ notFound NotFoundHandler
}
type method struct {
@@ -40,6 +41,10 @@ func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption
func NewServer(opts ...ServerOption) *Server {
s := &Server{
methods: make(map[string]*method),
+ interceptor: func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
+ return handler(ctx, req)
+ },
+ notFound: defaultNotFoundHandler{},
}
for _, o := range opts {
o(s)
@@ -87,14 +92,14 @@ func (s *Server) handleSession(session *serverSession, reqBytes []byte) error {
return err
}
+ ctx, cancel := serverContext(session, req)
+ defer cancel()
+
method, ok := s.methods[req.Method]
if !ok {
- return fmt.Errorf("method not found: %s", req.Method)
+ return s.handleNotFound(ctx, req.Method, req.Message)
}
- ctx, cancel := serverContext(session, req)
- defer cancel()
-
if _, err := method.Handler(
method.implementation,
ctx,
@@ -107,6 +112,23 @@ func (s *Server) handleSession(session *serverSession, reqBytes []byte) error {
return nil
}
+func (s *Server) handleNotFound(ctx context.Context, method string, reqBytes []byte) error {
+ msg, err := s.notFound.Decode(method, reqBytes)
+ if err != nil {
+ return err
+ }
+
+ _, err = s.interceptor(
+ ctx,
+ msg,
+ &grpc.UnaryServerInfo{FullMethod: method},
+ func(ctx context.Context, req interface{}) (interface{}, error) {
+ return nil, s.notFound.Handle(ctx, method, req)
+ },
+ )
+ return err
+}
+
func serverContext(session *serverSession, req *request) (context.Context, func()) {
ctx := context.Background()
ctx = context.WithValue(ctx, theSessionKey, session)
@@ -164,3 +186,18 @@ func (ss *serverSession) reject(err error) error {
return sendFrame(ss.c, buf, ss.deadline)
}
+
+type NotFoundHandler interface {
+ Decode(method string, buf []byte) (proto.Message, error)
+ Handle(ctx context.Context, method string, req interface{}) error
+}
+
+type defaultNotFoundHandler struct{}
+
+func (defaultNotFoundHandler) Decode(method string, _ []byte) (proto.Message, error) {
+ return nil, fmt.Errorf("method not found: %s", method)
+}
+
+func (defaultNotFoundHandler) Handle(context.Context, string, interface{}) error {
+ panic("never reached")
+}