Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Vosmaer <jacob@gitlab.com>2021-06-17 14:35:51 +0300
committerJacob Vosmaer <jacob@gitlab.com>2021-07-12 15:34:41 +0300
commit8a925b40a5e35848600600ee72441224f99af0fa (patch)
treea6c0ea5eee006c42f095dbaa31e0cda20d4ca773
parent357a4b1dab81459b3d9323f10737b1698586cd5c (diff)
Add StreamRPC library code
Changelog: other
-rw-r--r--Makefile3
-rw-r--r--doc/README.md1
-rw-r--r--doc/stream_rpc.md101
-rw-r--r--internal/streamrpc/client.go145
-rw-r--r--internal/streamrpc/common.go82
-rw-r--r--internal/streamrpc/frame_test.go139
-rw-r--r--internal/streamrpc/protocol_test.go117
-rw-r--r--internal/streamrpc/rpc_test.go315
-rw-r--r--internal/streamrpc/server.go166
-rw-r--r--internal/streamrpc/testdata/test.pb.go167
-rw-r--r--internal/streamrpc/testdata/test.proto16
-rw-r--r--internal/streamrpc/testdata/test_grpc.pb.go102
12 files changed, 1353 insertions, 1 deletions
diff --git a/Makefile b/Makefile
index 399684425..ec9896dbb 100644
--- a/Makefile
+++ b/Makefile
@@ -336,7 +336,8 @@ proto: ${PROTOC} ${PROTOC_GEN_GO} ${PROTOC_GEN_GO_GRPC} ${SOURCE_DIR}/.ruby-bund
${SOURCE_DIR}/internal/praefect/mock/mock.proto \
${SOURCE_DIR}/internal/middleware/cache/testdata/stream.proto \
${SOURCE_DIR}/internal/helper/chunk/testdata/test.proto \
- ${SOURCE_DIR}/internal/middleware/limithandler/testdata/test.proto
+ ${SOURCE_DIR}/internal/middleware/limithandler/testdata/test.proto \
+ ${SOURCE_DIR}/internal/streamrpc/testdata/test.proto
${PROTOC} ${SHARED_PROTOC_OPTS} -I ${SOURCE_DIR}/proto --go_out=${SOURCE_DIR}/proto --go-grpc_out=${SOURCE_DIR}/proto ${SOURCE_DIR}/proto/go/internal/linter/testdata/*.proto
.PHONY: lint-proto
diff --git a/doc/README.md b/doc/README.md
index 96fffbad1..fd96df56a 100644
--- a/doc/README.md
+++ b/doc/README.md
@@ -38,6 +38,7 @@ For configuration please read [praefects configuration documentation](doc/config
- [Tips for reading Git source code](reading_git_source.md)
- [Serverside Git Usage](serverside_git_usage.md)
- [Object Pools](object_pools.md)
+- [StreamRPC](stream_rpc.md)
#### RFCs
diff --git a/doc/stream_rpc.md b/doc/stream_rpc.md
new file mode 100644
index 000000000..f78c6723f
--- /dev/null
+++ b/doc/stream_rpc.md
@@ -0,0 +1,101 @@
+# StreamRPC
+
+StreamRPC is a remote procedure call (RPC) protocol implemented by
+Gitaly. It is used for RPC's that transfer a high volume of byte stream
+data, such as the server side of `git fetch`.
+
+For background on why we created StreamRPC, see
+https://gitlab.com/groups/gitlab-com/gl-infra/-/epics/463.
+
+## Design goals
+
+1. Give RPC handlers direct access to the underlying network socket or
+TLS stream
+1. Interoperate with existing Gitaly gRPC middlewares (logging,
+authentication, metrics etc.)
+1. Allow for efficient proxying in Praefect
+
+## Semantics
+
+A StreamRPC call has two phases: the handshake phase and the stream
+phase. The structure of the handshake phase is described by the
+StreamRPC protocol. The stream phase has no inherent structure as far
+as StreamRPC is concerned; it is up to the RPC client and RPC handler
+what they do with the stream.
+
+The handshake phase consists of two steps:
+
+1. Client sends request
+1. Server sends response
+
+To allow for a clean transition from the handshake phase to the stream
+phase, the handshake phase uses frames with length prefixes. Length
+prefixes make it possible to implement the handshake without buffered
+IO. When the transition to the stream phase happens, we do not have to
+hand over buffered data.
+
+All length prefixes in the handshake phase are big-endian uint32
+values (4 bytes). Length prefixes do not include the length of the
+prefix itself, so a blob `foobar` would be encoded as
+`\x00\x00\x00\x06foobar`.
+
+### Request
+
+The request consists of a length prefix and a JSON object.
+
+The JSON object has three fields:
+
+1. `Method` (`string`): this is a gRPC style method name, including the
+gRPC service name.
+1. `Metadata` (`map[string][]string`): this contains gRPC metadata, which
+can be compared to HTTP headers. It is used for authentication,
+correlation ID's, etc.
+1. `Message` (`string`): this field contains a base64-encoded (RFC 4648)
+Protobuf message. This is here because Praefect, and some of our
+middlewares, try to inspect the first request of each RPC to see what
+repository etc. it targets. By having this request as part of the
+protocol, we can support Praefect and gRPC middlewares in a natural
+way.
+
+### Response
+
+The response is a length-prefixed empty string OR a JSON object.
+
+The server accepts the request and transitions to the stream phase if
+and only if the frame is empty. That is, the accepting response is
+`\x00\x00\x00\x00`.
+
+In case of a rejection, the server returns a JSON object with an error
+message in the `Error` field.
+
+If the server rejects the request it will close the connection.
+
+## Relation to gRPC
+
+StreamRPC is designed to be embedded into a gRPC service (Gitaly).
+StreamRPC RPC's are defined using Protobuf just like regular Gitaly
+gRPC RPC's. From the point of view of gRPC middleware, StreamRPC RPC's
+are unary methods which return an empty response message.
+
+```protobuf
+import "google/protobuf/empty.proto";
+
+service ExampleService {
+ rpc ExampleStream(ExampleStreamRequest) returns (google.protobuf.Empty) {
+ option (op_type) = {
+ op: ACCESSOR
+ };
+ }
+}
+
+message ExampleStreamRequest {
+ Repository repository = 1 [(target_repository)=true];
+}
+```
+
+The server handler may return an error after the stream phase, which
+will be logged on the server, but this error cannot be transmitted to
+the client. This is because the stream phase lasts until the
+connection is closed. There is no way for the server to transmit the
+error "after the stream phase", because the connection is then already
+closed.
diff --git a/internal/streamrpc/client.go b/internal/streamrpc/client.go
new file mode 100644
index 000000000..b26d43f33
--- /dev/null
+++ b/internal/streamrpc/client.go
@@ -0,0 +1,145 @@
+package streamrpc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/metadata"
+)
+
+// DialFunc is an abstraction that allows Call to transparently handle
+// unencrypted connections and TLS connections.
+type DialFunc func(time.Duration) (net.Conn, error)
+
+func doCall(dial DialFunc, request []byte, callback func(net.Conn) error) error {
+ deadline := time.Now().Add(defaultHandshakeTimeout)
+
+ c, err := dial(time.Until(deadline))
+ if err != nil {
+ return fmt.Errorf("dial: %w", err)
+ }
+ defer c.Close()
+
+ if err := sendFrame(c, request, deadline); err != nil {
+ return fmt.Errorf("send request: %w", err)
+ }
+
+ responseBytes, err := recvFrame(c, deadline)
+ if err != nil {
+ return fmt.Errorf("receive response: %w", err)
+ }
+
+ if len(responseBytes) > 0 {
+ var resp response
+ if err := json.Unmarshal(responseBytes, &resp); err != nil {
+ return fmt.Errorf("unmarshal response: %w", err)
+ }
+
+ return &RequestRejectedError{resp.Error}
+ }
+
+ return callback(c)
+}
+
+// RequestRejectedError is returned by Call if the server explicitly
+// rejected the request (as opposed to e.g. an IO timeout).
+type RequestRejectedError struct{ message string }
+
+func (r *RequestRejectedError) Error() string { return r.message }
+
+type callOptions struct {
+ creds credentials.PerRPCCredentials
+ interceptor grpc.UnaryClientInterceptor
+}
+
+func (opts *callOptions) addCredentials(ctx context.Context) (context.Context, error) {
+ headers, err := opts.creds.GetRequestMetadata(ctx)
+ if err != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ ctx = metadata.AppendToOutgoingContext(ctx, k, v)
+ }
+ return ctx, nil
+}
+
+// CallOption is an abstraction that lets us pass 0 or more options to a call.
+type CallOption func(*callOptions)
+
+type nullCredentials struct{}
+
+func (nullCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
+ return nil, nil
+}
+
+func (nullCredentials) RequireTransportSecurity() bool { return false }
+
+var _ credentials.PerRPCCredentials = nullCredentials{}
+
+func nullClientInterceptor(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+ return invoker(ctx, method, req, resp, cc, opts...)
+}
+
+func defaultCallOptions() *callOptions {
+ return &callOptions{
+ creds: nullCredentials{},
+ interceptor: nullClientInterceptor,
+ }
+}
+
+// WithCredentials adds gRPC per-request credentials to an outgoing call.
+func WithCredentials(creds credentials.PerRPCCredentials) CallOption {
+ return func(o *callOptions) { o.creds = creds }
+}
+
+// WithClientInterceptor adds a gRPC unary client interceptor to an outgoing call.
+func WithClientInterceptor(interceptor grpc.UnaryClientInterceptor) CallOption {
+ return func(o *callOptions) { o.interceptor = interceptor }
+}
+
+// Call makes a StreamRPC call. The dial function determines the remote
+// address. The method argument is the full name of the StreamRPC method
+// we are calling (e.g. "/foo.BarService/BazStream"). Msg is the request
+// message. If the server accepts the call, callback is called with a
+// connection.
+func Call(ctx context.Context, dial DialFunc, method string, msg proto.Message, callback func(net.Conn) error, opts ...CallOption) (err error) {
+ callOpts := defaultCallOptions()
+ for _, o := range opts {
+ o(callOpts)
+ }
+
+ invoke := func(ctx context.Context, method string, msg, _ interface{}, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
+ ctx, err := callOpts.addCredentials(ctx)
+ if err != nil {
+ return err
+ }
+
+ msgBytes, err := proto.Marshal(msg.(proto.Message))
+ if err != nil {
+ return err
+ }
+
+ req := &request{
+ Method: method,
+ Message: msgBytes,
+ }
+ if md, ok := metadata.FromOutgoingContext(ctx); ok {
+ req.Metadata = md
+ }
+
+ reqBytes, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+
+ return doCall(dial, reqBytes, callback)
+ }
+
+ return callOpts.interceptor(ctx, method, msg, nil, nil, invoke)
+}
diff --git a/internal/streamrpc/common.go b/internal/streamrpc/common.go
new file mode 100644
index 000000000..88c3b6148
--- /dev/null
+++ b/internal/streamrpc/common.go
@@ -0,0 +1,82 @@
+package streamrpc
+
+import (
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "os"
+ "time"
+)
+
+type request struct {
+ Method string
+ Message []byte
+ Metadata map[string][]string
+}
+
+type response struct{ Error string }
+
+const (
+ defaultHandshakeTimeout = 10 * time.Second
+
+ // The frames exchanged during the handshake have a uint32 length prefix
+ // so their theoretical maximum size is 4GB. We don't want to allow that
+ // so we enforce a lower limit. This number was chosen because it is
+ // close to the default grpc-go maximum message size.
+ maxFrameSize = (1 << 20) - 1
+)
+
+var (
+ errFrameTooLarge = errors.New("frame too large")
+)
+
+func sendFrame(c net.Conn, frame []byte, deadline time.Time) error {
+ if len(frame) > maxFrameSize {
+ return errFrameTooLarge
+ }
+
+ header := make([]byte, 4)
+ binary.BigEndian.PutUint32(header, uint32(len(frame)))
+ buffers := net.Buffers([][]byte{header, frame})
+
+ return errAsync(deadline, func() error { _, err := buffers.WriteTo(c); return err })
+}
+
+func recvFrame(c net.Conn, deadline time.Time) ([]byte, error) {
+ header := make([]byte, 4)
+ if err := errAsync(deadline, func() error { _, err := io.ReadFull(c, header); return err }); err != nil {
+ return nil, err
+ }
+
+ size := binary.BigEndian.Uint32(header)
+ if size > maxFrameSize {
+ return nil, errFrameTooLarge
+ }
+ frame := make([]byte, size)
+ if err := errAsync(deadline, func() error { _, err := io.ReadFull(c, frame); return err }); err != nil {
+ return nil, err
+ }
+
+ return frame, nil
+}
+
+// errAsync is a hack to work around the fact that grpc-go calls
+// SetDeadline on connections _after_ handing them over to us. Because of
+// this race, we cannot use SetDeadline which would have been nicer.
+//
+// https://github.com/grpc/grpc-go/blob/v1.38.0/server.go#L853
+func errAsync(deadline time.Time, f func() error) error {
+ tm := time.NewTimer(time.Until(deadline))
+ defer tm.Stop()
+
+ errC := make(chan error, 1)
+ go func() { errC <- f() }()
+
+ select {
+ case <-tm.C:
+ return os.ErrDeadlineExceeded
+ case err := <-errC:
+ return err
+ }
+}
diff --git a/internal/streamrpc/frame_test.go b/internal/streamrpc/frame_test.go
new file mode 100644
index 000000000..cb665a47f
--- /dev/null
+++ b/internal/streamrpc/frame_test.go
@@ -0,0 +1,139 @@
+package streamrpc
+
+import (
+ "io"
+ "io/ioutil"
+ "net"
+ "os"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSendFrame(t *testing.T) {
+ largeString := strings.Repeat("x", 0xfffff)
+
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ err error
+ }{
+ {desc: "empty", out: "\x00\x00\x00\x00"},
+ {desc: "not empty", in: "hello", out: "\x00\x00\x00\x05hello"},
+ {desc: "very large", in: largeString, out: "\x00\x0f\xff\xff" + largeString},
+ {desc: "too large", in: "z" + largeString, err: errFrameTooLarge},
+ }
+
+ type result struct {
+ data string
+ err error
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, server := socketPair(t)
+ ch := make(chan result, 1)
+ go func() {
+ out, err := ioutil.ReadAll(server)
+ ch <- result{data: string(out), err: err}
+ }()
+
+ err := sendFrame(client, []byte(tc.in), time.Now().Add(10*time.Second))
+ if tc.err != nil {
+ require.Equal(t, tc.err, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.NoError(t, client.Close())
+
+ res := <-ch
+ require.NoError(t, res.err)
+ require.Equal(t, tc.out, res.data)
+ })
+ }
+}
+
+func TestSendFrame_timeout(t *testing.T) {
+ client, _ := socketPair(t)
+
+ // Ensure frame is bigger than write buffer, so that sendFrame will
+ // block. Otherwise we cannot observe the timeout behavior.
+ frame := make([]byte, 10*1024)
+ require.NoError(t, client.(*net.UnixConn).SetWriteBuffer(1024))
+
+ err := sendFrame(client, frame, time.Now())
+ require.ErrorIs(t, err, os.ErrDeadlineExceeded)
+}
+
+func TestRecvFrame(t *testing.T) {
+ largeString := strings.Repeat("x", 0xfffff)
+
+ testCases := []struct {
+ desc string
+ out string
+ in string
+ err error
+ }{
+ {desc: "empty", in: "\x00\x00\x00\x00", out: ""},
+ {desc: "not empty", in: "\x00\x00\x00\x05hello", out: "hello"},
+ {desc: "very large", in: "\x00\x0f\xff\xff" + largeString, out: largeString},
+ {desc: "too large", in: "\x00\x10\x00\x00" + "z" + largeString, err: errFrameTooLarge},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ client, server := socketPair(t)
+ ch := make(chan error, 1)
+ go func() {
+ ch <- func() error {
+ n, err := server.Write([]byte(tc.in))
+ if n != len(tc.in) {
+ return io.ErrShortWrite
+ }
+ return err
+ }()
+ }()
+
+ out, err := recvFrame(client, time.Now().Add(10*time.Second))
+ if tc.err != nil {
+ require.Equal(t, tc.err, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, tc.out, string(out))
+
+ require.NoError(t, <-ch)
+ })
+ }
+}
+
+func TestRecvFrame_timeout(t *testing.T) {
+ client, _ := socketPair(t)
+ _, err := recvFrame(client, time.Now())
+ require.ErrorIs(t, err, os.ErrDeadlineExceeded)
+}
+
+func socketPair(t *testing.T) (net.Conn, net.Conn) {
+ t.Helper()
+
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ require.NoError(t, err)
+
+ conns := make([]net.Conn, 2)
+ for i, fd := range fds[:] {
+ f := os.NewFile(uintptr(fd), "socket pair")
+ c, err := net.FileConn(f)
+ require.NoError(t, err)
+ require.NoError(t, f.Close())
+ t.Cleanup(func() { c.Close() })
+ conns[i] = c
+ }
+
+ return conns[0], conns[1]
+}
diff --git a/internal/streamrpc/protocol_test.go b/internal/streamrpc/protocol_test.go
new file mode 100644
index 000000000..5d169c179
--- /dev/null
+++ b/internal/streamrpc/protocol_test.go
@@ -0,0 +1,117 @@
+package streamrpc
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ testpb "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/protobuf/types/known/emptypb"
+)
+
+// TestProtocol exclusively uses hard-coded strings to prevent breaking
+// changes in the wire protocol.
+func TestProtocol(t *testing.T) {
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ }{
+ {
+ desc: "successful request",
+ in: "\x00\x00\x00\x28" + `{"Method":"/test.streamrpc.Test/Stream"}`,
+ out: strings.Join([]string{
+ "\x00\x00\x00\x00", // Server accepts
+ "\n", // Handler prints request field (empty) followed by newline
+ }, ""),
+ },
+ {
+ desc: "unknown method",
+ in: "\x00\x00\x00\x1b" + `{"Method":"does not exist"}`,
+ out: strings.Join([]string{
+ "\x00\x00\x00\x2c", // Server rejects by sending non-empty error message
+ `{"Error":"method not found: does not exist"}`,
+ }, ""),
+ },
+ {
+ desc: "request with message and metadata",
+ in: strings.Join([]string{
+ "\x00\x00\x00\x73",
+ `{`,
+ `"Method":"/test.streamrpc.Test/Stream",`,
+ `"Message":"EgtoZWxsbyB3b3JsZA==",`, // &testpb.StreamRequest{StringField: "hello world"}
+ `"Metadata":{"k1":["v1","v2"],"k2":["v3"]}`,
+ `}`,
+ }, ""),
+ out: strings.Join([]string{
+ "\x00\x00\x00\x00", // Server accepts
+ "k1: v1\nk1: v2\nk2: v3\n", // Server echoes metadata key-value pairs
+ "hello world\n", // Server echoes field from request message
+ }, ""),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ c, err := AcceptConnection(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var mdKeys []string
+ var md metadata.MD
+ if ctxMD, ok := metadata.FromIncomingContext(ctx); ok {
+ md = ctxMD
+ }
+
+ for k := range md {
+ mdKeys = append(mdKeys, k)
+ }
+
+ // Direct go map iteration is non-deterministic. Sort the keys to make it
+ // deterministic.
+ sort.Strings(mdKeys)
+
+ // Echo back metadata so tests can see it was received correctly
+ for _, k := range mdKeys {
+ for _, v := range md[k] {
+ if _, err := fmt.Fprintf(c, "%s: %s\n", k, v); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ // Echo back string field so tests can see request was received correctly
+ if _, err := fmt.Fprintln(c, in.StringField); err != nil {
+ return nil, err
+ }
+
+ return nil, nil
+ },
+ )
+
+ c, err := dial(10 * time.Second)
+ require.NoError(t, err)
+ defer c.Close()
+ require.NoError(t, c.SetDeadline(time.Now().Add(10*time.Second)))
+
+ n, err := c.Write([]byte(tc.in))
+ require.NoError(t, err)
+ require.Equal(t, len(tc.in), n)
+
+ out, err := ioutil.ReadAll(c)
+ require.NoError(t, err)
+ require.Equal(t, tc.out, string(out))
+ })
+ }
+}
diff --git a/internal/streamrpc/rpc_test.go b/internal/streamrpc/rpc_test.go
new file mode 100644
index 000000000..669222780
--- /dev/null
+++ b/internal/streamrpc/rpc_test.go
@@ -0,0 +1,315 @@
+package streamrpc
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ testpb "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+ "google.golang.org/protobuf/types/known/emptypb"
+)
+
+func TestCall(t *testing.T) {
+ const (
+ testKey = "test key"
+ testValue = "test value"
+ blobSize = 1024 * 1024
+ )
+
+ var receivedValues []string
+ var receivedField string
+
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ receivedField = in.StringField
+
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ receivedValues = md[testKey]
+ }
+
+ c, err := AcceptConnection(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = io.CopyN(c, c, blobSize)
+ return nil, err
+ },
+ )
+
+ in := make([]byte, blobSize)
+ _, err := rand.Read(in)
+ require.NoError(t, err)
+
+ var out []byte
+ require.NotEqual(t, in, out)
+
+ ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue)
+ require.NoError(t, Call(
+ ctx,
+ dial,
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{StringField: "hello world"},
+ func(c net.Conn) error {
+ errC := make(chan error, 1)
+ go func() {
+ var err error
+ out, err = ioutil.ReadAll(c)
+ errC <- err
+ }()
+
+ if _, err := io.Copy(c, bytes.NewReader(in)); err != nil {
+ return err
+ }
+ if err := <-errC; err != nil {
+ return err
+ }
+
+ return c.Close()
+ },
+ ))
+
+ require.Equal(t, "hello world", receivedField, "request propagates")
+ require.Equal(t, []string{testValue}, receivedValues, "grpc metadata stored in client ctx propagates")
+ require.Equal(t, in, out, "byte stream works")
+}
+
+func TestCall_serverError(t *testing.T) {
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ return nil, errors.New("this is the server error")
+ },
+ )
+
+ callError := Call(
+ context.Background(),
+ dial,
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{},
+ func(c net.Conn) error { panic("never reached") },
+ )
+
+ require.Equal(t, &RequestRejectedError{"this is the server error"}, callError)
+}
+
+func TestCall_clientMiddleware(t *testing.T) {
+ const (
+ testKey = "test key"
+ testValue = "test value"
+ )
+
+ var receivedValues []string
+ var receivedField string
+
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ _, err := AcceptConnection(ctx)
+ return nil, err
+ },
+ )
+
+ var middlewareMethod string
+ ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue)
+
+ const testMethod = "/test.streamrpc.Test/Stream"
+ require.NoError(t, Call(
+ ctx,
+ dial,
+ testMethod,
+ &testpb.StreamRequest{StringField: "hello world"},
+ func(c net.Conn) error { return nil },
+ WithClientInterceptor(func(ctx context.Context, method string, req, _ interface{}, _ *grpc.ClientConn, invoker grpc.UnaryInvoker, _ ...grpc.CallOption) error {
+ middlewareMethod = method
+ receivedField = req.(*testpb.StreamRequest).StringField
+ if md, ok := metadata.FromOutgoingContext(ctx); ok {
+ receivedValues = md[testKey]
+ }
+ return invoker(ctx, method, req, nil, nil)
+ }),
+ ))
+
+ require.Equal(t, testMethod, middlewareMethod, "client middleware sees correct method")
+ require.Equal(t, "hello world", receivedField, "client middleware sees request")
+ require.Equal(t, []string{testValue}, receivedValues, "client middleware sees context metadata")
+}
+
+func TestCall_clientMiddlewareReject(t *testing.T) {
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ panic("never reached")
+ },
+ )
+
+ middlewareError := errors.New("middleware says no")
+
+ err := Call(
+ context.Background(),
+ dial,
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{StringField: "hello world"},
+ func(c net.Conn) error { return nil },
+ WithClientInterceptor(func(ctx context.Context, method string, req, _ interface{}, _ *grpc.ClientConn, invoker grpc.UnaryInvoker, _ ...grpc.CallOption) error {
+ return middlewareError
+ }),
+ )
+
+ require.Equal(t, middlewareError, err)
+}
+
+func TestCall_serverMiddleware(t *testing.T) {
+ const (
+ testKey = "test key"
+ testValue = "test value"
+ testMethod = "/test.streamrpc.Test/Stream"
+ )
+
+ var (
+ receivedField string
+ middlewareMethod string
+ receivedValues []string
+ )
+
+ interceptorDone := make(chan struct{})
+
+ dial := startServer(
+ t,
+ NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ defer close(interceptorDone)
+ middlewareMethod = info.FullMethod
+ receivedField = req.(*testpb.StreamRequest).StringField
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ receivedValues = md[testKey]
+ }
+ return handler(ctx, req)
+ })),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ _, err := AcceptConnection(ctx)
+ return nil, err
+ },
+ )
+
+ ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue)
+ require.NoError(t, Call(
+ ctx,
+ dial,
+ testMethod,
+ &testpb.StreamRequest{StringField: "hello world"},
+ func(c net.Conn) error { return nil },
+ ))
+
+ <-interceptorDone
+ require.Equal(t, testMethod, middlewareMethod, "server middleware sees correct method")
+ require.Equal(t, "hello world", receivedField, "server middleware sees request")
+ require.Equal(t, []string{testValue}, receivedValues, "server middleware sees context metadata")
+}
+
+func TestCall_serverMiddlewareReject(t *testing.T) {
+ dial := startServer(
+ t,
+ NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
+ return nil, errors.New("middleware says no")
+ })),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ panic("never reached")
+ },
+ )
+
+ err := Call(
+ context.Background(),
+ dial,
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{},
+ func(c net.Conn) error { return nil },
+ )
+
+ require.Equal(t, &RequestRejectedError{message: "middleware says no"}, err)
+}
+
+type testCredentials struct {
+ values map[string]string
+}
+
+func (tc *testCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
+ out := make(map[string]string)
+ for k, v := range tc.values {
+ out[k] = v
+ }
+ return out, nil
+}
+
+func (*testCredentials) RequireTransportSecurity() bool { return false }
+
+func TestCall_credentials(t *testing.T) {
+ receivedValues := make(map[string]string)
+ interceptorDone := make(chan struct{})
+
+ dial := startServer(
+ t,
+ NewServer(),
+ func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ defer close(interceptorDone)
+
+ if md, ok := metadata.FromIncomingContext(ctx); ok {
+ receivedValues["key 1"] = strings.Join(md["key 1"], ",")
+ receivedValues["key 2"] = strings.Join(md["key 2"], ",")
+ }
+
+ _, err := AcceptConnection(ctx)
+ return nil, err
+ },
+ )
+
+ inputs := map[string]string{
+ "key 1": "value a",
+ "key 2": "value b",
+ }
+
+ require.NoError(t, Call(
+ context.Background(),
+ dial,
+ "/test.streamrpc.Test/Stream",
+ &testpb.StreamRequest{},
+ func(c net.Conn) error { return nil },
+ WithCredentials(&testCredentials{inputs}),
+ ))
+
+ <-interceptorDone
+ require.Equal(t, inputs, receivedValues)
+}
+
+func startServer(t *testing.T, s *Server, th testHandler) DialFunc {
+ t.Helper()
+ testpb.RegisterTestServer(s, &server{testHandler: th})
+ client, server := socketPair(t)
+ go func() { _ = s.Handle(server) }()
+ return func(time.Duration) (net.Conn, error) { return client, nil }
+}
+
+type testHandler func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error)
+
+type server struct {
+ testpb.UnimplementedTestServer
+ testHandler
+}
+
+func (s *server) Stream(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) {
+ return s.testHandler(ctx, in)
+}
diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go
new file mode 100644
index 000000000..502764fd4
--- /dev/null
+++ b/internal/streamrpc/server.go
@@ -0,0 +1,166 @@
+package streamrpc
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+var _ grpc.ServiceRegistrar = &Server{}
+
+// Server handles network connections and routes them to StreamRPC handlers.
+type Server struct {
+ methods map[string]*method
+ interceptor grpc.UnaryServerInterceptor
+}
+
+type method struct {
+ *grpc.MethodDesc
+ implementation interface{}
+}
+
+// ServerOption is an abstraction that lets you pass 0 or more server
+// options to NewServer.
+type ServerOption func(*Server)
+
+// WithServerInterceptor adds a unary gRPC server interceptor.
+func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption {
+ return func(s *Server) { s.interceptor = interceptor }
+}
+
+// NewServer returns a new StreamRPC server. You can pass the result to
+// grpc-go RegisterFooServer functions.
+func NewServer(opts ...ServerOption) *Server {
+ s := &Server{
+ methods: make(map[string]*method),
+ }
+ for _, o := range opts {
+ o(s)
+ }
+ return s
+}
+
+// RegisterService implements grpc.ServiceRegistrar. It makes it possible
+// to pass a *Server to grpc-go foopb.RegisterFooServer functions as the
+// first argument.
+func (s *Server) RegisterService(sd *grpc.ServiceDesc, impl interface{}) {
+ for i := range sd.Methods {
+ m := &sd.Methods[i]
+ s.methods["/"+sd.ServiceName+"/"+m.MethodName] = &method{
+ MethodDesc: m,
+ implementation: impl,
+ }
+ }
+}
+
+// Handle handles an incoming network connection with the StreamRPC
+// protocol. It is intended to be called from a net.Listener.Accept loop
+// (or something equivalent).
+func (s *Server) Handle(c net.Conn) error {
+ defer c.Close()
+
+ deadline := time.Now().Add(defaultHandshakeTimeout)
+ req, err := recvFrame(c, deadline)
+ if err != nil {
+ return err
+ }
+
+ session := &serverSession{
+ c: c,
+ deadline: deadline,
+ }
+ if err := s.handleSession(session, req); err != nil {
+ return session.reject(err)
+ }
+
+ return nil
+}
+
+func (s *Server) handleSession(session *serverSession, reqBytes []byte) error {
+ req := &request{}
+ if err := json.Unmarshal(reqBytes, req); err != nil {
+ return err
+ }
+
+ method, ok := s.methods[req.Method]
+ if !ok {
+ return fmt.Errorf("method not found: %s", req.Method)
+ }
+
+ ctx, cancel := serverContext(session, req)
+ defer cancel()
+
+ if _, err := method.Handler(
+ method.implementation,
+ ctx,
+ func(msg interface{}) error { return proto.Unmarshal(req.Message, msg.(proto.Message)) },
+ s.interceptor,
+ ); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func serverContext(session *serverSession, req *request) (context.Context, func()) {
+ ctx := context.Background()
+ ctx = context.WithValue(ctx, sessionKey{}, session)
+ ctx = metadata.NewIncomingContext(ctx, req.Metadata)
+ return context.WithCancel(ctx)
+}
+
+type sessionKey struct{}
+
+// AcceptConnection completes the StreamRPC handshake on the server side.
+// It notifies the client that the server has accepted the stream, and
+// returns the connection.
+func AcceptConnection(ctx context.Context) (net.Conn, error) {
+ session, ok := ctx.Value(sessionKey{}).(*serverSession)
+ if !ok {
+ return nil, errors.New("context has no serverSession")
+ }
+ return session.Accept()
+}
+
+// serverSession wraps an incoming connection whose handshake has not
+// been completed yet.
+type serverSession struct {
+ c net.Conn
+ accepted bool
+ deadline time.Time
+}
+
+// Accept completes the handshake on the connection wrapped by ss and
+// unwraps the connection.
+func (ss *serverSession) Accept() (net.Conn, error) {
+ if ss.accepted {
+ return nil, errors.New("connection already accepted")
+ }
+
+ ss.accepted = true
+ if err := sendFrame(ss.c, nil, ss.deadline); err != nil {
+ return nil, fmt.Errorf("accept session: %w", err)
+ }
+
+ return ss.c, nil
+}
+
+func (ss *serverSession) reject(err error) error {
+ if ss.accepted {
+ return nil
+ }
+
+ buf, err := json.Marshal(&response{Error: err.Error()})
+ if err != nil {
+ return fmt.Errorf("mashal response: %w", err)
+ }
+
+ return sendFrame(ss.c, buf, ss.deadline)
+}
diff --git a/internal/streamrpc/testdata/test.pb.go b/internal/streamrpc/testdata/test.pb.go
new file mode 100644
index 000000000..cc267eefd
--- /dev/null
+++ b/internal/streamrpc/testdata/test.pb.go
@@ -0,0 +1,167 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// protoc-gen-go v1.26.0
+// protoc v3.17.3
+// source: streamrpc/testdata/test.proto
+
+package testdata
+
+import (
+ protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+ protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ emptypb "google.golang.org/protobuf/types/known/emptypb"
+ reflect "reflect"
+ sync "sync"
+)
+
+const (
+ // Verify that this generated code is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+ // Verify that runtime/protoimpl is sufficiently up-to-date.
+ _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type StreamRequest struct {
+ state protoimpl.MessageState
+ sizeCache protoimpl.SizeCache
+ unknownFields protoimpl.UnknownFields
+
+ Fail bool `protobuf:"varint,1,opt,name=fail,proto3" json:"fail,omitempty"`
+ StringField string `protobuf:"bytes,2,opt,name=string_field,json=stringField,proto3" json:"string_field,omitempty"`
+}
+
+func (x *StreamRequest) Reset() {
+ *x = StreamRequest{}
+ if protoimpl.UnsafeEnabled {
+ mi := &file_streamrpc_testdata_test_proto_msgTypes[0]
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ ms.StoreMessageInfo(mi)
+ }
+}
+
+func (x *StreamRequest) String() string {
+ return protoimpl.X.MessageStringOf(x)
+}
+
+func (*StreamRequest) ProtoMessage() {}
+
+func (x *StreamRequest) ProtoReflect() protoreflect.Message {
+ mi := &file_streamrpc_testdata_test_proto_msgTypes[0]
+ if protoimpl.UnsafeEnabled && x != nil {
+ ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+ if ms.LoadMessageInfo() == nil {
+ ms.StoreMessageInfo(mi)
+ }
+ return ms
+ }
+ return mi.MessageOf(x)
+}
+
+// Deprecated: Use StreamRequest.ProtoReflect.Descriptor instead.
+func (*StreamRequest) Descriptor() ([]byte, []int) {
+ return file_streamrpc_testdata_test_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *StreamRequest) GetFail() bool {
+ if x != nil {
+ return x.Fail
+ }
+ return false
+}
+
+func (x *StreamRequest) GetStringField() string {
+ if x != nil {
+ return x.StringField
+ }
+ return ""
+}
+
+var File_streamrpc_testdata_test_proto protoreflect.FileDescriptor
+
+var file_streamrpc_testdata_test_proto_rawDesc = []byte{
+ 0x0a, 0x1d, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74,
+ 0x64, 0x61, 0x74, 0x61, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
+ 0x0e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x1a,
+ 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
+ 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x46, 0x0a, 0x0d,
+ 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a,
+ 0x04, 0x66, 0x61, 0x69, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x66, 0x61, 0x69,
+ 0x6c, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c,
+ 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46,
+ 0x69, 0x65, 0x6c, 0x64, 0x32, 0x49, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x06,
+ 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x1d, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x73, 0x74,
+ 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
+ 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42,
+ 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x69,
+ 0x74, 0x6c, 0x61, 0x62, 0x2d, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x2f,
+ 0x76, 0x31, 0x34, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x73, 0x74, 0x72,
+ 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x62,
+ 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
+}
+
+var (
+ file_streamrpc_testdata_test_proto_rawDescOnce sync.Once
+ file_streamrpc_testdata_test_proto_rawDescData = file_streamrpc_testdata_test_proto_rawDesc
+)
+
+func file_streamrpc_testdata_test_proto_rawDescGZIP() []byte {
+ file_streamrpc_testdata_test_proto_rawDescOnce.Do(func() {
+ file_streamrpc_testdata_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_streamrpc_testdata_test_proto_rawDescData)
+ })
+ return file_streamrpc_testdata_test_proto_rawDescData
+}
+
+var file_streamrpc_testdata_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_streamrpc_testdata_test_proto_goTypes = []interface{}{
+ (*StreamRequest)(nil), // 0: test.streamrpc.StreamRequest
+ (*emptypb.Empty)(nil), // 1: google.protobuf.Empty
+}
+var file_streamrpc_testdata_test_proto_depIdxs = []int32{
+ 0, // 0: test.streamrpc.Test.Stream:input_type -> test.streamrpc.StreamRequest
+ 1, // 1: test.streamrpc.Test.Stream:output_type -> google.protobuf.Empty
+ 1, // [1:2] is the sub-list for method output_type
+ 0, // [0:1] is the sub-list for method input_type
+ 0, // [0:0] is the sub-list for extension type_name
+ 0, // [0:0] is the sub-list for extension extendee
+ 0, // [0:0] is the sub-list for field type_name
+}
+
+func init() { file_streamrpc_testdata_test_proto_init() }
+func file_streamrpc_testdata_test_proto_init() {
+ if File_streamrpc_testdata_test_proto != nil {
+ return
+ }
+ if !protoimpl.UnsafeEnabled {
+ file_streamrpc_testdata_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*StreamRequest); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ type x struct{}
+ out := protoimpl.TypeBuilder{
+ File: protoimpl.DescBuilder{
+ GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+ RawDescriptor: file_streamrpc_testdata_test_proto_rawDesc,
+ NumEnums: 0,
+ NumMessages: 1,
+ NumExtensions: 0,
+ NumServices: 1,
+ },
+ GoTypes: file_streamrpc_testdata_test_proto_goTypes,
+ DependencyIndexes: file_streamrpc_testdata_test_proto_depIdxs,
+ MessageInfos: file_streamrpc_testdata_test_proto_msgTypes,
+ }.Build()
+ File_streamrpc_testdata_test_proto = out.File
+ file_streamrpc_testdata_test_proto_rawDesc = nil
+ file_streamrpc_testdata_test_proto_goTypes = nil
+ file_streamrpc_testdata_test_proto_depIdxs = nil
+}
diff --git a/internal/streamrpc/testdata/test.proto b/internal/streamrpc/testdata/test.proto
new file mode 100644
index 000000000..7614a07a9
--- /dev/null
+++ b/internal/streamrpc/testdata/test.proto
@@ -0,0 +1,16 @@
+syntax = "proto3";
+
+package test.streamrpc;
+
+option go_package = "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata";
+
+import "google/protobuf/empty.proto";
+
+service Test {
+ rpc Stream(StreamRequest) returns (google.protobuf.Empty) {}
+}
+
+message StreamRequest {
+ bool fail = 1;
+ string string_field = 2;
+}
diff --git a/internal/streamrpc/testdata/test_grpc.pb.go b/internal/streamrpc/testdata/test_grpc.pb.go
new file mode 100644
index 000000000..53fb22792
--- /dev/null
+++ b/internal/streamrpc/testdata/test_grpc.pb.go
@@ -0,0 +1,102 @@
+// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
+
+package testdata
+
+import (
+ context "context"
+ grpc "google.golang.org/grpc"
+ codes "google.golang.org/grpc/codes"
+ status "google.golang.org/grpc/status"
+ emptypb "google.golang.org/protobuf/types/known/emptypb"
+)
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+// Requires gRPC-Go v1.32.0 or later.
+const _ = grpc.SupportPackageIsVersion7
+
+// TestClient is the client API for Test service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
+type TestClient interface {
+ Stream(ctx context.Context, in *StreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
+}
+
+type testClient struct {
+ cc grpc.ClientConnInterface
+}
+
+func NewTestClient(cc grpc.ClientConnInterface) TestClient {
+ return &testClient{cc}
+}
+
+func (c *testClient) Stream(ctx context.Context, in *StreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
+ out := new(emptypb.Empty)
+ err := c.cc.Invoke(ctx, "/test.streamrpc.Test/Stream", in, out, opts...)
+ if err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// TestServer is the server API for Test service.
+// All implementations must embed UnimplementedTestServer
+// for forward compatibility
+type TestServer interface {
+ Stream(context.Context, *StreamRequest) (*emptypb.Empty, error)
+ mustEmbedUnimplementedTestServer()
+}
+
+// UnimplementedTestServer must be embedded to have forward compatible implementations.
+type UnimplementedTestServer struct {
+}
+
+func (UnimplementedTestServer) Stream(context.Context, *StreamRequest) (*emptypb.Empty, error) {
+ return nil, status.Errorf(codes.Unimplemented, "method Stream not implemented")
+}
+func (UnimplementedTestServer) mustEmbedUnimplementedTestServer() {}
+
+// UnsafeTestServer may be embedded to opt out of forward compatibility for this service.
+// Use of this interface is not recommended, as added methods to TestServer will
+// result in compilation errors.
+type UnsafeTestServer interface {
+ mustEmbedUnimplementedTestServer()
+}
+
+func RegisterTestServer(s grpc.ServiceRegistrar, srv TestServer) {
+ s.RegisterService(&Test_ServiceDesc, srv)
+}
+
+func _Test_Stream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
+ in := new(StreamRequest)
+ if err := dec(in); err != nil {
+ return nil, err
+ }
+ if interceptor == nil {
+ return srv.(TestServer).Stream(ctx, in)
+ }
+ info := &grpc.UnaryServerInfo{
+ Server: srv,
+ FullMethod: "/test.streamrpc.Test/Stream",
+ }
+ handler := func(ctx context.Context, req interface{}) (interface{}, error) {
+ return srv.(TestServer).Stream(ctx, req.(*StreamRequest))
+ }
+ return interceptor(ctx, in, info, handler)
+}
+
+// Test_ServiceDesc is the grpc.ServiceDesc for Test service.
+// It's only intended for direct use with grpc.RegisterService,
+// and not to be introspected or modified (even as a copy)
+var Test_ServiceDesc = grpc.ServiceDesc{
+ ServiceName: "test.streamrpc.Test",
+ HandlerType: (*TestServer)(nil),
+ Methods: []grpc.MethodDesc{
+ {
+ MethodName: "Stream",
+ Handler: _Test_Stream_Handler,
+ },
+ },
+ Streams: []grpc.StreamDesc{},
+ Metadata: "streamrpc/testdata/test.proto",
+}