diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/server.go')
-rw-r--r-- | vendor/google.golang.org/grpc/server.go | 484 |
1 files changed, 207 insertions, 277 deletions
diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 920da5e01..f65162168 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -19,6 +19,7 @@ package grpc import ( + "bytes" "errors" "fmt" "io" @@ -29,24 +30,24 @@ import ( "runtime" "strings" "sync" - "sync/atomic" "time" + "io/ioutil" + "golang.org/x/net/context" + "golang.org/x/net/http2" "golang.org/x/net/trace" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" - "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/channelz" - "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/internal" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/grpc/tap" + "google.golang.org/grpc/transport" ) const ( @@ -95,20 +96,16 @@ type Server struct { m map[string]*service // service name -> service info events trace.EventLog - quit chan struct{} - done chan struct{} - quitOnce sync.Once - doneOnce sync.Once - channelzRemoveOnce sync.Once - serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop - - channelzID int64 // channelz unique identification number - czData *channelzData + quit chan struct{} + done chan struct{} + quitOnce sync.Once + doneOnce sync.Once + serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop } type options struct { creds credentials.TransportCredentials - codec baseCodec + codec Codec cp Compressor dc Decompressor unaryInt UnaryServerInterceptor @@ -118,6 +115,7 @@ type options struct { maxConcurrentStreams uint32 maxReceiveMessageSize int maxSendMessageSize int + useHandlerImpl bool // use http.Handler-based server unknownStreamDesc *StreamDesc keepaliveParams keepalive.ServerParameters keepalivePolicy keepalive.EnforcementPolicy @@ -126,25 +124,19 @@ type options struct { writeBufferSize int readBufferSize int connectionTimeout time.Duration - maxHeaderListSize *uint32 } var defaultServerOptions = options{ maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize, connectionTimeout: 120 * time.Second, - writeBufferSize: defaultWriteBufSize, - readBufferSize: defaultReadBufSize, } // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. type ServerOption func(*options) -// WriteBufferSize determines how much data can be batched before doing a write on the wire. -// The corresponding memory allocation for this buffer will be twice the size to keep syscalls low. -// The default value for this buffer is 32KB. -// Zero will disable the write buffer such that each write will be on underlying connection. -// Note: A Send call may not directly translate to a write. +// WriteBufferSize lets you set the size of write buffer, this determines how much data can be batched +// before doing a write on the wire. func WriteBufferSize(s int) ServerOption { return func(o *options) { o.writeBufferSize = s @@ -153,9 +145,6 @@ func WriteBufferSize(s int) ServerOption { // ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most // for one read syscall. -// The default value for this buffer is 32KB. -// Zero will disable read buffer for a connection so data framer can access the underlying -// conn directly. func ReadBufferSize(s int) ServerOption { return func(o *options) { o.readBufferSize = s @@ -193,8 +182,6 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { } // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. -// -// This will override any lookups by content-subtype for Codecs registered with RegisterCodec. func CustomCodec(codec Codec) ServerOption { return func(o *options) { o.codec = codec @@ -226,9 +213,7 @@ func RPCDecompressor(dc Decompressor) ServerOption { } // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. -// If this is not set, gRPC uses the default limit. -// -// Deprecated: use MaxRecvMsgSize instead. +// If this is not set, gRPC uses the default limit. Deprecated: use MaxRecvMsgSize instead. func MaxMsgSize(m int) ServerOption { return MaxRecvMsgSize(m) } @@ -335,14 +320,6 @@ func ConnectionTimeout(d time.Duration) ServerOption { } } -// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size -// of header list that the server is prepared to accept. -func MaxHeaderListSize(s uint32) ServerOption { - return func(o *options) { - o.maxHeaderListSize = &s - } -} - // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -350,24 +327,23 @@ func NewServer(opt ...ServerOption) *Server { for _, o := range opt { o(&opts) } + if opts.codec == nil { + // Set the default codec. + opts.codec = protoCodec{} + } s := &Server{ - lis: make(map[net.Listener]bool), - opts: opts, - conns: make(map[io.Closer]bool), - m: make(map[string]*service), - quit: make(chan struct{}), - done: make(chan struct{}), - czData: new(channelzData), + lis: make(map[net.Listener]bool), + opts: opts, + conns: make(map[io.Closer]bool), + m: make(map[string]*service), + quit: make(chan struct{}), + done: make(chan struct{}), } s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) } - - if channelz.IsOn() { - s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") - } return s } @@ -483,26 +459,6 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti return s.opts.creds.ServerHandshake(rawConn) } -type listenSocket struct { - net.Listener - channelzID int64 -} - -func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric { - return &channelz.SocketInternalMetric{ - SocketOptions: channelz.GetSocketOption(l.Listener), - LocalAddr: l.Listener.Addr(), - } -} - -func (l *listenSocket) Close() error { - err := l.Listener.Close() - if channelz.IsOn() { - channelz.RemoveEntry(l.channelzID) - } - return err -} - // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines // read gRPC requests and then call the registered handlers to reply to them. @@ -531,19 +487,13 @@ func (s *Server) Serve(lis net.Listener) error { } }() - ls := &listenSocket{Listener: lis} - s.lis[ls] = true - - if channelz.IsOn() { - ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, "") - } + s.lis[lis] = true s.mu.Unlock() - defer func() { s.mu.Lock() - if s.lis != nil && s.lis[ls] { - ls.Close() - delete(s.lis, ls) + if s.lis != nil && s.lis[lis] { + lis.Close() + delete(s.lis, lis) } s.mu.Unlock() }() @@ -627,19 +577,27 @@ func (s *Server) handleRawConn(rawConn net.Conn) { } s.mu.Unlock() - // Finish handshaking (HTTP2) - st := s.newHTTP2Transport(conn, authInfo) - if st == nil { - return + var serve func() + c := conn.(io.Closer) + if s.opts.useHandlerImpl { + serve = func() { s.serveUsingHandler(conn) } + } else { + // Finish handshaking (HTTP2) + st := s.newHTTP2Transport(conn, authInfo) + if st == nil { + return + } + c = st + serve = func() { s.serveStreams(st) } } rawConn.SetDeadline(time.Time{}) - if !s.addConn(st) { + if !s.addConn(c) { return } go func() { - s.serveStreams(st) - s.removeConn(st) + serve() + s.removeConn(c) }() } @@ -657,8 +615,6 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr InitialConnWindowSize: s.opts.initialConnWindowSize, WriteBufferSize: s.opts.writeBufferSize, ReadBufferSize: s.opts.readBufferSize, - ChannelzParentID: s.channelzID, - MaxHeaderListSize: s.opts.maxHeaderListSize, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -669,7 +625,6 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr grpclog.Warningln("grpc: Server.Serve failed to create ServerTransport: ", err) return nil } - return st } @@ -694,6 +649,27 @@ func (s *Server) serveStreams(st transport.ServerTransport) { var _ http.Handler = (*Server)(nil) +// serveUsingHandler is called from handleRawConn when s is configured +// to handle requests via the http.Handler interface. It sets up a +// net/http.Server to handle the just-accepted conn. The http.Server +// is configured to route all incoming requests (all HTTP/2 streams) +// to ServeHTTP, which creates a new ServerTransport for each stream. +// serveUsingHandler blocks until conn closes. +// +// This codepath is only used when Server.TestingUseHandlerImpl has +// been configured. This lets the end2end tests exercise the ServeHTTP +// method as one of the environment types. +// +// conn is the *tls.Conn that's already been authenticated. +func (s *Server) serveUsingHandler(conn net.Conn) { + h2s := &http2.Server{ + MaxConcurrentStreams: s.opts.maxConcurrentStreams, + } + h2s.ServeConn(conn, &http2.ServeConnOpts{ + Handler: s, + }) +} + // ServeHTTP implements the Go standard library's http.Handler // interface by responding to the gRPC request r, by looking up // the requested gRPC method in the gRPC server s. @@ -719,7 +695,7 @@ var _ http.Handler = (*Server)(nil) // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL // and subject to change. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) + st, err := transport.NewServerHandlerTransport(w, r) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -776,73 +752,39 @@ func (s *Server) removeConn(c io.Closer) { } } -func (s *Server) channelzMetric() *channelz.ServerInternalMetric { - return &channelz.ServerInternalMetric{ - CallsStarted: atomic.LoadInt64(&s.czData.callsStarted), - CallsSucceeded: atomic.LoadInt64(&s.czData.callsSucceeded), - CallsFailed: atomic.LoadInt64(&s.czData.callsFailed), - LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&s.czData.lastCallStartedTime)), - } -} - -func (s *Server) incrCallsStarted() { - atomic.AddInt64(&s.czData.callsStarted, 1) - atomic.StoreInt64(&s.czData.lastCallStartedTime, time.Now().UnixNano()) -} - -func (s *Server) incrCallsSucceeded() { - atomic.AddInt64(&s.czData.callsSucceeded, 1) -} - -func (s *Server) incrCallsFailed() { - atomic.AddInt64(&s.czData.callsFailed, 1) -} - func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { - data, err := encode(s.getCodec(stream.ContentSubtype()), msg) - if err != nil { - grpclog.Errorln("grpc: server failed to encode response: ", err) - return err + var ( + outPayload *stats.OutPayload + ) + if s.opts.statsHandler != nil { + outPayload = &stats.OutPayload{} } - compData, err := compress(data, cp, comp) + hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp) if err != nil { - grpclog.Errorln("grpc: server failed to compress response: ", err) + grpclog.Errorln("grpc: server failed to encode response: ", err) return err } - hdr, payload := msgHeader(data, compData) - // TODO(dfawley): should we be checking len(data) instead? - if len(payload) > s.opts.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) + if len(data) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) } - err = t.Write(stream, hdr, payload, opts) - if err == nil && s.opts.statsHandler != nil { - s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + err = t.Write(stream, hdr, data, opts) + if err == nil && outPayload != nil { + outPayload.SentTime = time.Now() + s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) } return err } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { - if channelz.IsOn() { - s.incrCallsStarted() - defer func() { - if err != nil && err != io.EOF { - s.incrCallsFailed() - } else { - s.incrCallsSucceeded() - } - }() - } sh := s.opts.statsHandler if sh != nil { - beginTime := time.Now() begin := &stats.Begin{ - BeginTime: beginTime, + BeginTime: time.Now(), } sh.HandleRPC(stream.Context(), begin) defer func() { end := &stats.End{ - BeginTime: beginTime, - EndTime: time.Now(), + EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) @@ -898,32 +840,77 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } } - var inPayload *stats.InPayload - if sh != nil { - inPayload = &stats.InPayload{ - RecvTime: time.Now(), - } + p := &parser{r: stream} + pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) + if err == io.EOF { + // The entire stream is done (for unary RPC only). + return err + } + if err == io.ErrUnexpectedEOF { + err = status.Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } - d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, inPayload, decomp) if err != nil { if st, ok := status.FromError(err); ok { if e := t.WriteStatus(stream, st); e != nil { grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) } + } else { + switch st := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { + grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) + } } return err } - if channelz.IsOn() { - t.IncrMsgRecv() + if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { + if e := t.WriteStatus(stream, st); e != nil { + grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + return st.Err() + } + var inPayload *stats.InPayload + if sh != nil { + inPayload = &stats.InPayload{ + RecvTime: time.Now(), + } } df := func(v interface{}) error { - if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { + if inPayload != nil { + inPayload.WireLength = len(req) + } + if pf == compressionMade { + var err error + if dc != nil { + req, err = dc.Do(bytes.NewReader(req)) + if err != nil { + return status.Errorf(codes.Internal, err.Error()) + } + } else { + tmp, _ := decomp.Decompress(bytes.NewReader(req)) + req, err = ioutil.ReadAll(tmp) + if err != nil { + return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + } + } + } + if len(req) > s.opts.maxReceiveMessageSize { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) + } + if err := s.opts.codec.Unmarshal(req, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } if inPayload != nil { inPayload.Payload = v - inPayload.Data = d - inPayload.Length = len(d) + inPayload.Data = req + inPayload.Length = len(req) sh.HandleRPC(stream.Context(), inPayload) } if trInfo != nil { @@ -931,13 +918,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return nil } - ctx := NewContextWithServerTransportStream(stream.Context(), stream) - reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) + reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { // Convert appErr if it is not a grpc status error. - appErr = status.Error(codes.Unknown, appErr.Error()) + appErr = status.Error(convertCode(appErr), appErr.Error()) appStatus, _ = status.FromError(appErr) } if trInfo != nil { @@ -952,7 +938,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) } - opts := &transport.Options{Last: true} + opts := &transport.Options{ + Last: true, + Delay: false, + } if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { if err == io.EOF { @@ -967,15 +956,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. switch st := err.(type) { case transport.ConnectionError: // Nothing to do here. + case transport.StreamError: + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { + grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) + } default: panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) } } return err } - if channelz.IsOn() { - t.IncrMsgSent() - } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } @@ -986,27 +976,15 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { - if channelz.IsOn() { - s.incrCallsStarted() - defer func() { - if err != nil && err != io.EOF { - s.incrCallsFailed() - } else { - s.incrCallsSucceeded() - } - }() - } sh := s.opts.statsHandler if sh != nil { - beginTime := time.Now() begin := &stats.Begin{ - BeginTime: beginTime, + BeginTime: time.Now(), } sh.HandleRPC(stream.Context(), begin) defer func() { end := &stats.End{ - BeginTime: beginTime, - EndTime: time.Now(), + EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) @@ -1014,13 +992,11 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp sh.HandleRPC(stream.Context(), end) }() } - ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ - ctx: ctx, - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.getCodec(stream.ContentSubtype()), + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, @@ -1086,7 +1062,12 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { - appStatus = status.New(codes.Unknown, appErr.Error()) + switch err := appErr.(type) { + case transport.StreamError: + appStatus = status.New(err.Code, err.Desc) + default: + appStatus = status.New(convertCode(appErr), appErr.Error()) + } appErr = appStatus.Err() } if trInfo != nil { @@ -1105,6 +1086,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } return t.WriteStatus(ss.s, status.New(codes.OK, "")) + } func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { @@ -1133,27 +1115,47 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str } service := sm[:pos] method := sm[pos+1:] - - if srv, ok := s.m[service]; ok { - if md, ok := srv.md[method]; ok { - s.processUnaryRPC(t, stream, srv, md, trInfo) + srv, ok := s.m[service] + if !ok { + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) return } - if sd, ok := srv.sd[method]; ok { - s.processStreamingRPC(t, stream, srv, sd, trInfo) - return + if trInfo != nil { + trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) + trInfo.tr.SetError() } + errDesc := fmt.Sprintf("unknown service %v", service) + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { + if trInfo != nil { + trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + trInfo.tr.SetError() + } + grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err) + } + if trInfo != nil { + trInfo.tr.Finish() + } + return } - // Unknown service, or known server unknown method. - if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { - s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + // Unary RPC or Streaming RPC? + if md, ok := srv.md[method]; ok { + s.processUnaryRPC(t, stream, srv, md, trInfo) + return + } + if sd, ok := srv.sd[method]; ok { + s.processStreamingRPC(t, stream, srv, sd, trInfo) return } if trInfo != nil { - trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) + trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.SetError() } - errDesc := fmt.Sprintf("unknown service %v", service) + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } + errDesc := fmt.Sprintf("unknown method %v", method) if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) @@ -1166,42 +1168,6 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str } } -// The key to save ServerTransportStream in the context. -type streamKey struct{} - -// NewContextWithServerTransportStream creates a new context from ctx and -// attaches stream to it. -// -// This API is EXPERIMENTAL. -func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context { - return context.WithValue(ctx, streamKey{}, stream) -} - -// ServerTransportStream is a minimal interface that a transport stream must -// implement. This can be used to mock an actual transport stream for tests of -// handler code that use, for example, grpc.SetHeader (which requires some -// stream to be in context). -// -// See also NewContextWithServerTransportStream. -// -// This API is EXPERIMENTAL. -type ServerTransportStream interface { - Method() string - SetHeader(md metadata.MD) error - SendHeader(md metadata.MD) error - SetTrailer(md metadata.MD) error -} - -// ServerTransportStreamFromContext returns the ServerTransportStream saved in -// ctx. Returns nil if the given context has no stream associated with it -// (which implies it is not an RPC invocation context). -// -// This API is EXPERIMENTAL. -func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream { - s, _ := ctx.Value(streamKey{}).(ServerTransportStream) - return s -} - // Stop stops the gRPC server. It immediately closes all open // connections and listeners. // It cancels all active RPCs on the server side and the corresponding @@ -1219,12 +1185,6 @@ func (s *Server) Stop() { }) }() - s.channelzRemoveOnce.Do(func() { - if channelz.IsOn() { - channelz.RemoveEntry(s.channelzID) - } - }) - s.mu.Lock() listeners := s.lis s.lis = nil @@ -1263,17 +1223,11 @@ func (s *Server) GracefulStop() { }) }() - s.channelzRemoveOnce.Do(func() { - if channelz.IsOn() { - channelz.RemoveEntry(s.channelzID) - } - }) s.mu.Lock() if s.conns == nil { s.mu.Unlock() return } - for lis := range s.lis { lis.Close() } @@ -1302,20 +1256,10 @@ func (s *Server) GracefulStop() { s.mu.Unlock() } -// contentSubtype must be lowercase -// cannot return nil -func (s *Server) getCodec(contentSubtype string) baseCodec { - if s.opts.codec != nil { - return s.opts.codec - } - if contentSubtype == "" { - return encoding.GetCodec(proto.Name) - } - codec := encoding.GetCodec(contentSubtype) - if codec == nil { - return encoding.GetCodec(proto.Name) +func init() { + internal.TestingUseHandlerImpl = func(arg interface{}) { + arg.(*Server).opts.useHandlerImpl = true } - return codec } // SetHeader sets the header metadata. @@ -1328,8 +1272,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream := ServerTransportStreamFromContext(ctx) - if stream == nil { + stream, ok := transport.StreamFromContext(ctx) + if !ok { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetHeader(md) @@ -1338,11 +1282,15 @@ func SetHeader(ctx context.Context, md metadata.MD) error { // SendHeader sends header metadata. It may be called at most once. // The provided md and headers set by SetHeader() will be sent. func SendHeader(ctx context.Context, md metadata.MD) error { - stream := ServerTransportStreamFromContext(ctx) - if stream == nil { + stream, ok := transport.StreamFromContext(ctx) + if !ok { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } - if err := stream.SendHeader(md); err != nil { + t := stream.ServerTransport() + if t == nil { + grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) + } + if err := t.WriteHeader(stream, md); err != nil { return toRPCErr(err) } return nil @@ -1354,27 +1302,9 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream := ServerTransportStreamFromContext(ctx) - if stream == nil { + stream, ok := transport.StreamFromContext(ctx) + if !ok { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetTrailer(md) } - -// Method returns the method string for the server context. The returned -// string is in the format of "/service/method". -func Method(ctx context.Context) (string, bool) { - s := ServerTransportStreamFromContext(ctx) - if s == nil { - return "", false - } - return s.Method(), true -} - -type channelzServer struct { - s *Server -} - -func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { - return c.s.channelzMetric() -} |