diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2021-07-01 19:12:36 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2021-07-01 19:12:36 +0300 |
commit | 4ebb87a3334cbd51a203d3fe462c7cd588d347da (patch) | |
tree | ea98a318af0788b2d0f9e6caf201805b60848cef | |
parent | bca44e6c708e26c67137febd5fe2a932d9406e2b (diff) |
Sketch streamrpc "not found" handlerjv-streamrpc-praefect
-rw-r--r-- | internal/streamrpc/server.go | 45 |
1 files changed, 41 insertions, 4 deletions
diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go index 9a4a83a18..de79957ab 100644 --- a/internal/streamrpc/server.go +++ b/internal/streamrpc/server.go @@ -19,6 +19,7 @@ var _ grpc.ServiceRegistrar = &Server{} type Server struct { methods map[string]*method interceptor grpc.UnaryServerInterceptor + notFound NotFoundHandler } type method struct { @@ -40,6 +41,10 @@ func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption func NewServer(opts ...ServerOption) *Server { s := &Server{ methods: make(map[string]*method), + interceptor: func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(ctx, req) + }, + notFound: defaultNotFoundHandler{}, } for _, o := range opts { o(s) @@ -87,14 +92,14 @@ func (s *Server) handleSession(session *serverSession, reqBytes []byte) error { return err } + ctx, cancel := serverContext(session, req) + defer cancel() + method, ok := s.methods[req.Method] if !ok { - return fmt.Errorf("method not found: %s", req.Method) + return s.handleNotFound(ctx, req.Method, req.Message) } - ctx, cancel := serverContext(session, req) - defer cancel() - if _, err := method.Handler( method.implementation, ctx, @@ -107,6 +112,23 @@ func (s *Server) handleSession(session *serverSession, reqBytes []byte) error { return nil } +func (s *Server) handleNotFound(ctx context.Context, method string, reqBytes []byte) error { + msg, err := s.notFound.Decode(method, reqBytes) + if err != nil { + return err + } + + _, err = s.interceptor( + ctx, + msg, + &grpc.UnaryServerInfo{FullMethod: method}, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, s.notFound.Handle(ctx, method, req) + }, + ) + return err +} + func serverContext(session *serverSession, req *request) (context.Context, func()) { ctx := context.Background() ctx = context.WithValue(ctx, theSessionKey, session) @@ -164,3 +186,18 @@ func (ss *serverSession) reject(err error) error { return sendFrame(ss.c, buf, ss.deadline) } + +type NotFoundHandler interface { + Decode(method string, buf []byte) (proto.Message, error) + Handle(ctx context.Context, method string, req interface{}) error +} + +type defaultNotFoundHandler struct{} + +func (defaultNotFoundHandler) Decode(method string, _ []byte) (proto.Message, error) { + return nil, fmt.Errorf("method not found: %s", method) +} + +func (defaultNotFoundHandler) Handle(context.Context, string, interface{}) error { + panic("never reached") +} |