diff options
author | Jacob Vosmaer <jacob@gitlab.com> | 2021-06-17 14:35:51 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2021-07-12 15:34:41 +0300 |
commit | 8a925b40a5e35848600600ee72441224f99af0fa (patch) | |
tree | a6c0ea5eee006c42f095dbaa31e0cda20d4ca773 | |
parent | 357a4b1dab81459b3d9323f10737b1698586cd5c (diff) |
Add StreamRPC library code
Changelog: other
-rw-r--r-- | Makefile | 3 | ||||
-rw-r--r-- | doc/README.md | 1 | ||||
-rw-r--r-- | doc/stream_rpc.md | 101 | ||||
-rw-r--r-- | internal/streamrpc/client.go | 145 | ||||
-rw-r--r-- | internal/streamrpc/common.go | 82 | ||||
-rw-r--r-- | internal/streamrpc/frame_test.go | 139 | ||||
-rw-r--r-- | internal/streamrpc/protocol_test.go | 117 | ||||
-rw-r--r-- | internal/streamrpc/rpc_test.go | 315 | ||||
-rw-r--r-- | internal/streamrpc/server.go | 166 | ||||
-rw-r--r-- | internal/streamrpc/testdata/test.pb.go | 167 | ||||
-rw-r--r-- | internal/streamrpc/testdata/test.proto | 16 | ||||
-rw-r--r-- | internal/streamrpc/testdata/test_grpc.pb.go | 102 |
12 files changed, 1353 insertions, 1 deletions
@@ -336,7 +336,8 @@ proto: ${PROTOC} ${PROTOC_GEN_GO} ${PROTOC_GEN_GO_GRPC} ${SOURCE_DIR}/.ruby-bund ${SOURCE_DIR}/internal/praefect/mock/mock.proto \ ${SOURCE_DIR}/internal/middleware/cache/testdata/stream.proto \ ${SOURCE_DIR}/internal/helper/chunk/testdata/test.proto \ - ${SOURCE_DIR}/internal/middleware/limithandler/testdata/test.proto + ${SOURCE_DIR}/internal/middleware/limithandler/testdata/test.proto \ + ${SOURCE_DIR}/internal/streamrpc/testdata/test.proto ${PROTOC} ${SHARED_PROTOC_OPTS} -I ${SOURCE_DIR}/proto --go_out=${SOURCE_DIR}/proto --go-grpc_out=${SOURCE_DIR}/proto ${SOURCE_DIR}/proto/go/internal/linter/testdata/*.proto .PHONY: lint-proto diff --git a/doc/README.md b/doc/README.md index 96fffbad1..fd96df56a 100644 --- a/doc/README.md +++ b/doc/README.md @@ -38,6 +38,7 @@ For configuration please read [praefects configuration documentation](doc/config - [Tips for reading Git source code](reading_git_source.md) - [Serverside Git Usage](serverside_git_usage.md) - [Object Pools](object_pools.md) +- [StreamRPC](stream_rpc.md) #### RFCs diff --git a/doc/stream_rpc.md b/doc/stream_rpc.md new file mode 100644 index 000000000..f78c6723f --- /dev/null +++ b/doc/stream_rpc.md @@ -0,0 +1,101 @@ +# StreamRPC + +StreamRPC is a remote procedure call (RPC) protocol implemented by +Gitaly. It is used for RPC's that transfer a high volume of byte stream +data, such as the server side of `git fetch`. + +For background on why we created StreamRPC, see +https://gitlab.com/groups/gitlab-com/gl-infra/-/epics/463. + +## Design goals + +1. Give RPC handlers direct access to the underlying network socket or +TLS stream +1. Interoperate with existing Gitaly gRPC middlewares (logging, +authentication, metrics etc.) +1. Allow for efficient proxying in Praefect + +## Semantics + +A StreamRPC call has two phases: the handshake phase and the stream +phase. The structure of the handshake phase is described by the +StreamRPC protocol. The stream phase has no inherent structure as far +as StreamRPC is concerned; it is up to the RPC client and RPC handler +what they do with the stream. + +The handshake phase consists of two steps: + +1. Client sends request +1. Server sends response + +To allow for a clean transition from the handshake phase to the stream +phase, the handshake phase uses frames with length prefixes. Length +prefixes make it possible to implement the handshake without buffered +IO. When the transition to the stream phase happens, we do not have to +hand over buffered data. + +All length prefixes in the handshake phase are big-endian uint32 +values (4 bytes). Length prefixes do not include the length of the +prefix itself, so a blob `foobar` would be encoded as +`\x00\x00\x00\x06foobar`. + +### Request + +The request consists of a length prefix and a JSON object. + +The JSON object has three fields: + +1. `Method` (`string`): this is a gRPC style method name, including the +gRPC service name. +1. `Metadata` (`map[string][]string`): this contains gRPC metadata, which +can be compared to HTTP headers. It is used for authentication, +correlation ID's, etc. +1. `Message` (`string`): this field contains a base64-encoded (RFC 4648) +Protobuf message. This is here because Praefect, and some of our +middlewares, try to inspect the first request of each RPC to see what +repository etc. it targets. By having this request as part of the +protocol, we can support Praefect and gRPC middlewares in a natural +way. + +### Response + +The response is a length-prefixed empty string OR a JSON object. + +The server accepts the request and transitions to the stream phase if +and only if the frame is empty. That is, the accepting response is +`\x00\x00\x00\x00`. + +In case of a rejection, the server returns a JSON object with an error +message in the `Error` field. + +If the server rejects the request it will close the connection. + +## Relation to gRPC + +StreamRPC is designed to be embedded into a gRPC service (Gitaly). +StreamRPC RPC's are defined using Protobuf just like regular Gitaly +gRPC RPC's. From the point of view of gRPC middleware, StreamRPC RPC's +are unary methods which return an empty response message. + +```protobuf +import "google/protobuf/empty.proto"; + +service ExampleService { + rpc ExampleStream(ExampleStreamRequest) returns (google.protobuf.Empty) { + option (op_type) = { + op: ACCESSOR + }; + } +} + +message ExampleStreamRequest { + Repository repository = 1 [(target_repository)=true]; +} +``` + +The server handler may return an error after the stream phase, which +will be logged on the server, but this error cannot be transmitted to +the client. This is because the stream phase lasts until the +connection is closed. There is no way for the server to transmit the +error "after the stream phase", because the connection is then already +closed. diff --git a/internal/streamrpc/client.go b/internal/streamrpc/client.go new file mode 100644 index 000000000..b26d43f33 --- /dev/null +++ b/internal/streamrpc/client.go @@ -0,0 +1,145 @@ +package streamrpc + +import ( + "context" + "encoding/json" + "fmt" + "net" + "time" + + "github.com/golang/protobuf/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +// DialFunc is an abstraction that allows Call to transparently handle +// unencrypted connections and TLS connections. +type DialFunc func(time.Duration) (net.Conn, error) + +func doCall(dial DialFunc, request []byte, callback func(net.Conn) error) error { + deadline := time.Now().Add(defaultHandshakeTimeout) + + c, err := dial(time.Until(deadline)) + if err != nil { + return fmt.Errorf("dial: %w", err) + } + defer c.Close() + + if err := sendFrame(c, request, deadline); err != nil { + return fmt.Errorf("send request: %w", err) + } + + responseBytes, err := recvFrame(c, deadline) + if err != nil { + return fmt.Errorf("receive response: %w", err) + } + + if len(responseBytes) > 0 { + var resp response + if err := json.Unmarshal(responseBytes, &resp); err != nil { + return fmt.Errorf("unmarshal response: %w", err) + } + + return &RequestRejectedError{resp.Error} + } + + return callback(c) +} + +// RequestRejectedError is returned by Call if the server explicitly +// rejected the request (as opposed to e.g. an IO timeout). +type RequestRejectedError struct{ message string } + +func (r *RequestRejectedError) Error() string { return r.message } + +type callOptions struct { + creds credentials.PerRPCCredentials + interceptor grpc.UnaryClientInterceptor +} + +func (opts *callOptions) addCredentials(ctx context.Context) (context.Context, error) { + headers, err := opts.creds.GetRequestMetadata(ctx) + if err != nil { + return nil, err + } + for k, v := range headers { + ctx = metadata.AppendToOutgoingContext(ctx, k, v) + } + return ctx, nil +} + +// CallOption is an abstraction that lets us pass 0 or more options to a call. +type CallOption func(*callOptions) + +type nullCredentials struct{} + +func (nullCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return nil, nil +} + +func (nullCredentials) RequireTransportSecurity() bool { return false } + +var _ credentials.PerRPCCredentials = nullCredentials{} + +func nullClientInterceptor(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(ctx, method, req, resp, cc, opts...) +} + +func defaultCallOptions() *callOptions { + return &callOptions{ + creds: nullCredentials{}, + interceptor: nullClientInterceptor, + } +} + +// WithCredentials adds gRPC per-request credentials to an outgoing call. +func WithCredentials(creds credentials.PerRPCCredentials) CallOption { + return func(o *callOptions) { o.creds = creds } +} + +// WithClientInterceptor adds a gRPC unary client interceptor to an outgoing call. +func WithClientInterceptor(interceptor grpc.UnaryClientInterceptor) CallOption { + return func(o *callOptions) { o.interceptor = interceptor } +} + +// Call makes a StreamRPC call. The dial function determines the remote +// address. The method argument is the full name of the StreamRPC method +// we are calling (e.g. "/foo.BarService/BazStream"). Msg is the request +// message. If the server accepts the call, callback is called with a +// connection. +func Call(ctx context.Context, dial DialFunc, method string, msg proto.Message, callback func(net.Conn) error, opts ...CallOption) (err error) { + callOpts := defaultCallOptions() + for _, o := range opts { + o(callOpts) + } + + invoke := func(ctx context.Context, method string, msg, _ interface{}, _ *grpc.ClientConn, _ ...grpc.CallOption) error { + ctx, err := callOpts.addCredentials(ctx) + if err != nil { + return err + } + + msgBytes, err := proto.Marshal(msg.(proto.Message)) + if err != nil { + return err + } + + req := &request{ + Method: method, + Message: msgBytes, + } + if md, ok := metadata.FromOutgoingContext(ctx); ok { + req.Metadata = md + } + + reqBytes, err := json.Marshal(req) + if err != nil { + return err + } + + return doCall(dial, reqBytes, callback) + } + + return callOpts.interceptor(ctx, method, msg, nil, nil, invoke) +} diff --git a/internal/streamrpc/common.go b/internal/streamrpc/common.go new file mode 100644 index 000000000..88c3b6148 --- /dev/null +++ b/internal/streamrpc/common.go @@ -0,0 +1,82 @@ +package streamrpc + +import ( + "encoding/binary" + "errors" + "io" + "net" + "os" + "time" +) + +type request struct { + Method string + Message []byte + Metadata map[string][]string +} + +type response struct{ Error string } + +const ( + defaultHandshakeTimeout = 10 * time.Second + + // The frames exchanged during the handshake have a uint32 length prefix + // so their theoretical maximum size is 4GB. We don't want to allow that + // so we enforce a lower limit. This number was chosen because it is + // close to the default grpc-go maximum message size. + maxFrameSize = (1 << 20) - 1 +) + +var ( + errFrameTooLarge = errors.New("frame too large") +) + +func sendFrame(c net.Conn, frame []byte, deadline time.Time) error { + if len(frame) > maxFrameSize { + return errFrameTooLarge + } + + header := make([]byte, 4) + binary.BigEndian.PutUint32(header, uint32(len(frame))) + buffers := net.Buffers([][]byte{header, frame}) + + return errAsync(deadline, func() error { _, err := buffers.WriteTo(c); return err }) +} + +func recvFrame(c net.Conn, deadline time.Time) ([]byte, error) { + header := make([]byte, 4) + if err := errAsync(deadline, func() error { _, err := io.ReadFull(c, header); return err }); err != nil { + return nil, err + } + + size := binary.BigEndian.Uint32(header) + if size > maxFrameSize { + return nil, errFrameTooLarge + } + frame := make([]byte, size) + if err := errAsync(deadline, func() error { _, err := io.ReadFull(c, frame); return err }); err != nil { + return nil, err + } + + return frame, nil +} + +// errAsync is a hack to work around the fact that grpc-go calls +// SetDeadline on connections _after_ handing them over to us. Because of +// this race, we cannot use SetDeadline which would have been nicer. +// +// https://github.com/grpc/grpc-go/blob/v1.38.0/server.go#L853 +func errAsync(deadline time.Time, f func() error) error { + tm := time.NewTimer(time.Until(deadline)) + defer tm.Stop() + + errC := make(chan error, 1) + go func() { errC <- f() }() + + select { + case <-tm.C: + return os.ErrDeadlineExceeded + case err := <-errC: + return err + } +} diff --git a/internal/streamrpc/frame_test.go b/internal/streamrpc/frame_test.go new file mode 100644 index 000000000..cb665a47f --- /dev/null +++ b/internal/streamrpc/frame_test.go @@ -0,0 +1,139 @@ +package streamrpc + +import ( + "io" + "io/ioutil" + "net" + "os" + "strings" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSendFrame(t *testing.T) { + largeString := strings.Repeat("x", 0xfffff) + + testCases := []struct { + desc string + in string + out string + err error + }{ + {desc: "empty", out: "\x00\x00\x00\x00"}, + {desc: "not empty", in: "hello", out: "\x00\x00\x00\x05hello"}, + {desc: "very large", in: largeString, out: "\x00\x0f\xff\xff" + largeString}, + {desc: "too large", in: "z" + largeString, err: errFrameTooLarge}, + } + + type result struct { + data string + err error + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, server := socketPair(t) + ch := make(chan result, 1) + go func() { + out, err := ioutil.ReadAll(server) + ch <- result{data: string(out), err: err} + }() + + err := sendFrame(client, []byte(tc.in), time.Now().Add(10*time.Second)) + if tc.err != nil { + require.Equal(t, tc.err, err) + return + } + + require.NoError(t, err) + require.NoError(t, client.Close()) + + res := <-ch + require.NoError(t, res.err) + require.Equal(t, tc.out, res.data) + }) + } +} + +func TestSendFrame_timeout(t *testing.T) { + client, _ := socketPair(t) + + // Ensure frame is bigger than write buffer, so that sendFrame will + // block. Otherwise we cannot observe the timeout behavior. + frame := make([]byte, 10*1024) + require.NoError(t, client.(*net.UnixConn).SetWriteBuffer(1024)) + + err := sendFrame(client, frame, time.Now()) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) +} + +func TestRecvFrame(t *testing.T) { + largeString := strings.Repeat("x", 0xfffff) + + testCases := []struct { + desc string + out string + in string + err error + }{ + {desc: "empty", in: "\x00\x00\x00\x00", out: ""}, + {desc: "not empty", in: "\x00\x00\x00\x05hello", out: "hello"}, + {desc: "very large", in: "\x00\x0f\xff\xff" + largeString, out: largeString}, + {desc: "too large", in: "\x00\x10\x00\x00" + "z" + largeString, err: errFrameTooLarge}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client, server := socketPair(t) + ch := make(chan error, 1) + go func() { + ch <- func() error { + n, err := server.Write([]byte(tc.in)) + if n != len(tc.in) { + return io.ErrShortWrite + } + return err + }() + }() + + out, err := recvFrame(client, time.Now().Add(10*time.Second)) + if tc.err != nil { + require.Equal(t, tc.err, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.out, string(out)) + + require.NoError(t, <-ch) + }) + } +} + +func TestRecvFrame_timeout(t *testing.T) { + client, _ := socketPair(t) + _, err := recvFrame(client, time.Now()) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) +} + +func socketPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + require.NoError(t, err) + + conns := make([]net.Conn, 2) + for i, fd := range fds[:] { + f := os.NewFile(uintptr(fd), "socket pair") + c, err := net.FileConn(f) + require.NoError(t, err) + require.NoError(t, f.Close()) + t.Cleanup(func() { c.Close() }) + conns[i] = c + } + + return conns[0], conns[1] +} diff --git a/internal/streamrpc/protocol_test.go b/internal/streamrpc/protocol_test.go new file mode 100644 index 000000000..5d169c179 --- /dev/null +++ b/internal/streamrpc/protocol_test.go @@ -0,0 +1,117 @@ +package streamrpc + +import ( + "context" + "fmt" + "io/ioutil" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + testpb "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/emptypb" +) + +// TestProtocol exclusively uses hard-coded strings to prevent breaking +// changes in the wire protocol. +func TestProtocol(t *testing.T) { + testCases := []struct { + desc string + in string + out string + }{ + { + desc: "successful request", + in: "\x00\x00\x00\x28" + `{"Method":"/test.streamrpc.Test/Stream"}`, + out: strings.Join([]string{ + "\x00\x00\x00\x00", // Server accepts + "\n", // Handler prints request field (empty) followed by newline + }, ""), + }, + { + desc: "unknown method", + in: "\x00\x00\x00\x1b" + `{"Method":"does not exist"}`, + out: strings.Join([]string{ + "\x00\x00\x00\x2c", // Server rejects by sending non-empty error message + `{"Error":"method not found: does not exist"}`, + }, ""), + }, + { + desc: "request with message and metadata", + in: strings.Join([]string{ + "\x00\x00\x00\x73", + `{`, + `"Method":"/test.streamrpc.Test/Stream",`, + `"Message":"EgtoZWxsbyB3b3JsZA==",`, // &testpb.StreamRequest{StringField: "hello world"} + `"Metadata":{"k1":["v1","v2"],"k2":["v3"]}`, + `}`, + }, ""), + out: strings.Join([]string{ + "\x00\x00\x00\x00", // Server accepts + "k1: v1\nk1: v2\nk2: v3\n", // Server echoes metadata key-value pairs + "hello world\n", // Server echoes field from request message + }, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + c, err := AcceptConnection(ctx) + if err != nil { + return nil, err + } + + var mdKeys []string + var md metadata.MD + if ctxMD, ok := metadata.FromIncomingContext(ctx); ok { + md = ctxMD + } + + for k := range md { + mdKeys = append(mdKeys, k) + } + + // Direct go map iteration is non-deterministic. Sort the keys to make it + // deterministic. + sort.Strings(mdKeys) + + // Echo back metadata so tests can see it was received correctly + for _, k := range mdKeys { + for _, v := range md[k] { + if _, err := fmt.Fprintf(c, "%s: %s\n", k, v); err != nil { + return nil, err + } + } + } + + // Echo back string field so tests can see request was received correctly + if _, err := fmt.Fprintln(c, in.StringField); err != nil { + return nil, err + } + + return nil, nil + }, + ) + + c, err := dial(10 * time.Second) + require.NoError(t, err) + defer c.Close() + require.NoError(t, c.SetDeadline(time.Now().Add(10*time.Second))) + + n, err := c.Write([]byte(tc.in)) + require.NoError(t, err) + require.Equal(t, len(tc.in), n) + + out, err := ioutil.ReadAll(c) + require.NoError(t, err) + require.Equal(t, tc.out, string(out)) + }) + } +} diff --git a/internal/streamrpc/rpc_test.go b/internal/streamrpc/rpc_test.go new file mode 100644 index 000000000..669222780 --- /dev/null +++ b/internal/streamrpc/rpc_test.go @@ -0,0 +1,315 @@ +package streamrpc + +import ( + "bytes" + "context" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + testpb "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestCall(t *testing.T) { + const ( + testKey = "test key" + testValue = "test value" + blobSize = 1024 * 1024 + ) + + var receivedValues []string + var receivedField string + + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + receivedField = in.StringField + + if md, ok := metadata.FromIncomingContext(ctx); ok { + receivedValues = md[testKey] + } + + c, err := AcceptConnection(ctx) + if err != nil { + return nil, err + } + + _, err = io.CopyN(c, c, blobSize) + return nil, err + }, + ) + + in := make([]byte, blobSize) + _, err := rand.Read(in) + require.NoError(t, err) + + var out []byte + require.NotEqual(t, in, out) + + ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue) + require.NoError(t, Call( + ctx, + dial, + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{StringField: "hello world"}, + func(c net.Conn) error { + errC := make(chan error, 1) + go func() { + var err error + out, err = ioutil.ReadAll(c) + errC <- err + }() + + if _, err := io.Copy(c, bytes.NewReader(in)); err != nil { + return err + } + if err := <-errC; err != nil { + return err + } + + return c.Close() + }, + )) + + require.Equal(t, "hello world", receivedField, "request propagates") + require.Equal(t, []string{testValue}, receivedValues, "grpc metadata stored in client ctx propagates") + require.Equal(t, in, out, "byte stream works") +} + +func TestCall_serverError(t *testing.T) { + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + return nil, errors.New("this is the server error") + }, + ) + + callError := Call( + context.Background(), + dial, + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{}, + func(c net.Conn) error { panic("never reached") }, + ) + + require.Equal(t, &RequestRejectedError{"this is the server error"}, callError) +} + +func TestCall_clientMiddleware(t *testing.T) { + const ( + testKey = "test key" + testValue = "test value" + ) + + var receivedValues []string + var receivedField string + + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + _, err := AcceptConnection(ctx) + return nil, err + }, + ) + + var middlewareMethod string + ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue) + + const testMethod = "/test.streamrpc.Test/Stream" + require.NoError(t, Call( + ctx, + dial, + testMethod, + &testpb.StreamRequest{StringField: "hello world"}, + func(c net.Conn) error { return nil }, + WithClientInterceptor(func(ctx context.Context, method string, req, _ interface{}, _ *grpc.ClientConn, invoker grpc.UnaryInvoker, _ ...grpc.CallOption) error { + middlewareMethod = method + receivedField = req.(*testpb.StreamRequest).StringField + if md, ok := metadata.FromOutgoingContext(ctx); ok { + receivedValues = md[testKey] + } + return invoker(ctx, method, req, nil, nil) + }), + )) + + require.Equal(t, testMethod, middlewareMethod, "client middleware sees correct method") + require.Equal(t, "hello world", receivedField, "client middleware sees request") + require.Equal(t, []string{testValue}, receivedValues, "client middleware sees context metadata") +} + +func TestCall_clientMiddlewareReject(t *testing.T) { + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + panic("never reached") + }, + ) + + middlewareError := errors.New("middleware says no") + + err := Call( + context.Background(), + dial, + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{StringField: "hello world"}, + func(c net.Conn) error { return nil }, + WithClientInterceptor(func(ctx context.Context, method string, req, _ interface{}, _ *grpc.ClientConn, invoker grpc.UnaryInvoker, _ ...grpc.CallOption) error { + return middlewareError + }), + ) + + require.Equal(t, middlewareError, err) +} + +func TestCall_serverMiddleware(t *testing.T) { + const ( + testKey = "test key" + testValue = "test value" + testMethod = "/test.streamrpc.Test/Stream" + ) + + var ( + receivedField string + middlewareMethod string + receivedValues []string + ) + + interceptorDone := make(chan struct{}) + + dial := startServer( + t, + NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + defer close(interceptorDone) + middlewareMethod = info.FullMethod + receivedField = req.(*testpb.StreamRequest).StringField + if md, ok := metadata.FromIncomingContext(ctx); ok { + receivedValues = md[testKey] + } + return handler(ctx, req) + })), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + _, err := AcceptConnection(ctx) + return nil, err + }, + ) + + ctx := metadata.AppendToOutgoingContext(context.Background(), testKey, testValue) + require.NoError(t, Call( + ctx, + dial, + testMethod, + &testpb.StreamRequest{StringField: "hello world"}, + func(c net.Conn) error { return nil }, + )) + + <-interceptorDone + require.Equal(t, testMethod, middlewareMethod, "server middleware sees correct method") + require.Equal(t, "hello world", receivedField, "server middleware sees request") + require.Equal(t, []string{testValue}, receivedValues, "server middleware sees context metadata") +} + +func TestCall_serverMiddlewareReject(t *testing.T) { + dial := startServer( + t, + NewServer(WithServerInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + return nil, errors.New("middleware says no") + })), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + panic("never reached") + }, + ) + + err := Call( + context.Background(), + dial, + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{}, + func(c net.Conn) error { return nil }, + ) + + require.Equal(t, &RequestRejectedError{message: "middleware says no"}, err) +} + +type testCredentials struct { + values map[string]string +} + +func (tc *testCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + out := make(map[string]string) + for k, v := range tc.values { + out[k] = v + } + return out, nil +} + +func (*testCredentials) RequireTransportSecurity() bool { return false } + +func TestCall_credentials(t *testing.T) { + receivedValues := make(map[string]string) + interceptorDone := make(chan struct{}) + + dial := startServer( + t, + NewServer(), + func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + defer close(interceptorDone) + + if md, ok := metadata.FromIncomingContext(ctx); ok { + receivedValues["key 1"] = strings.Join(md["key 1"], ",") + receivedValues["key 2"] = strings.Join(md["key 2"], ",") + } + + _, err := AcceptConnection(ctx) + return nil, err + }, + ) + + inputs := map[string]string{ + "key 1": "value a", + "key 2": "value b", + } + + require.NoError(t, Call( + context.Background(), + dial, + "/test.streamrpc.Test/Stream", + &testpb.StreamRequest{}, + func(c net.Conn) error { return nil }, + WithCredentials(&testCredentials{inputs}), + )) + + <-interceptorDone + require.Equal(t, inputs, receivedValues) +} + +func startServer(t *testing.T, s *Server, th testHandler) DialFunc { + t.Helper() + testpb.RegisterTestServer(s, &server{testHandler: th}) + client, server := socketPair(t) + go func() { _ = s.Handle(server) }() + return func(time.Duration) (net.Conn, error) { return client, nil } +} + +type testHandler func(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) + +type server struct { + testpb.UnimplementedTestServer + testHandler +} + +func (s *server) Stream(ctx context.Context, in *testpb.StreamRequest) (*emptypb.Empty, error) { + return s.testHandler(ctx, in) +} diff --git a/internal/streamrpc/server.go b/internal/streamrpc/server.go new file mode 100644 index 000000000..502764fd4 --- /dev/null +++ b/internal/streamrpc/server.go @@ -0,0 +1,166 @@ +package streamrpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "time" + + "github.com/golang/protobuf/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +var _ grpc.ServiceRegistrar = &Server{} + +// Server handles network connections and routes them to StreamRPC handlers. +type Server struct { + methods map[string]*method + interceptor grpc.UnaryServerInterceptor +} + +type method struct { + *grpc.MethodDesc + implementation interface{} +} + +// ServerOption is an abstraction that lets you pass 0 or more server +// options to NewServer. +type ServerOption func(*Server) + +// WithServerInterceptor adds a unary gRPC server interceptor. +func WithServerInterceptor(interceptor grpc.UnaryServerInterceptor) ServerOption { + return func(s *Server) { s.interceptor = interceptor } +} + +// NewServer returns a new StreamRPC server. You can pass the result to +// grpc-go RegisterFooServer functions. +func NewServer(opts ...ServerOption) *Server { + s := &Server{ + methods: make(map[string]*method), + } + for _, o := range opts { + o(s) + } + return s +} + +// RegisterService implements grpc.ServiceRegistrar. It makes it possible +// to pass a *Server to grpc-go foopb.RegisterFooServer functions as the +// first argument. +func (s *Server) RegisterService(sd *grpc.ServiceDesc, impl interface{}) { + for i := range sd.Methods { + m := &sd.Methods[i] + s.methods["/"+sd.ServiceName+"/"+m.MethodName] = &method{ + MethodDesc: m, + implementation: impl, + } + } +} + +// Handle handles an incoming network connection with the StreamRPC +// protocol. It is intended to be called from a net.Listener.Accept loop +// (or something equivalent). +func (s *Server) Handle(c net.Conn) error { + defer c.Close() + + deadline := time.Now().Add(defaultHandshakeTimeout) + req, err := recvFrame(c, deadline) + if err != nil { + return err + } + + session := &serverSession{ + c: c, + deadline: deadline, + } + if err := s.handleSession(session, req); err != nil { + return session.reject(err) + } + + return nil +} + +func (s *Server) handleSession(session *serverSession, reqBytes []byte) error { + req := &request{} + if err := json.Unmarshal(reqBytes, req); err != nil { + return err + } + + method, ok := s.methods[req.Method] + if !ok { + return fmt.Errorf("method not found: %s", req.Method) + } + + ctx, cancel := serverContext(session, req) + defer cancel() + + if _, err := method.Handler( + method.implementation, + ctx, + func(msg interface{}) error { return proto.Unmarshal(req.Message, msg.(proto.Message)) }, + s.interceptor, + ); err != nil { + return err + } + + return nil +} + +func serverContext(session *serverSession, req *request) (context.Context, func()) { + ctx := context.Background() + ctx = context.WithValue(ctx, sessionKey{}, session) + ctx = metadata.NewIncomingContext(ctx, req.Metadata) + return context.WithCancel(ctx) +} + +type sessionKey struct{} + +// AcceptConnection completes the StreamRPC handshake on the server side. +// It notifies the client that the server has accepted the stream, and +// returns the connection. +func AcceptConnection(ctx context.Context) (net.Conn, error) { + session, ok := ctx.Value(sessionKey{}).(*serverSession) + if !ok { + return nil, errors.New("context has no serverSession") + } + return session.Accept() +} + +// serverSession wraps an incoming connection whose handshake has not +// been completed yet. +type serverSession struct { + c net.Conn + accepted bool + deadline time.Time +} + +// Accept completes the handshake on the connection wrapped by ss and +// unwraps the connection. +func (ss *serverSession) Accept() (net.Conn, error) { + if ss.accepted { + return nil, errors.New("connection already accepted") + } + + ss.accepted = true + if err := sendFrame(ss.c, nil, ss.deadline); err != nil { + return nil, fmt.Errorf("accept session: %w", err) + } + + return ss.c, nil +} + +func (ss *serverSession) reject(err error) error { + if ss.accepted { + return nil + } + + buf, err := json.Marshal(&response{Error: err.Error()}) + if err != nil { + return fmt.Errorf("mashal response: %w", err) + } + + return sendFrame(ss.c, buf, ss.deadline) +} diff --git a/internal/streamrpc/testdata/test.pb.go b/internal/streamrpc/testdata/test.pb.go new file mode 100644 index 000000000..cc267eefd --- /dev/null +++ b/internal/streamrpc/testdata/test.pb.go @@ -0,0 +1,167 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.17.3 +// source: streamrpc/testdata/test.proto + +package testdata + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type StreamRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Fail bool `protobuf:"varint,1,opt,name=fail,proto3" json:"fail,omitempty"` + StringField string `protobuf:"bytes,2,opt,name=string_field,json=stringField,proto3" json:"string_field,omitempty"` +} + +func (x *StreamRequest) Reset() { + *x = StreamRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_streamrpc_testdata_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StreamRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamRequest) ProtoMessage() {} + +func (x *StreamRequest) ProtoReflect() protoreflect.Message { + mi := &file_streamrpc_testdata_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamRequest.ProtoReflect.Descriptor instead. +func (*StreamRequest) Descriptor() ([]byte, []int) { + return file_streamrpc_testdata_test_proto_rawDescGZIP(), []int{0} +} + +func (x *StreamRequest) GetFail() bool { + if x != nil { + return x.Fail + } + return false +} + +func (x *StreamRequest) GetStringField() string { + if x != nil { + return x.StringField + } + return "" +} + +var File_streamrpc_testdata_test_proto protoreflect.FileDescriptor + +var file_streamrpc_testdata_test_proto_rawDesc = []byte{ + 0x0a, 0x1d, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74, + 0x64, 0x61, 0x74, 0x61, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x0e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x1a, + 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2f, 0x65, 0x6d, 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x46, 0x0a, 0x0d, + 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x66, 0x61, 0x69, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x66, 0x61, 0x69, + 0x6c, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, + 0x69, 0x65, 0x6c, 0x64, 0x32, 0x49, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74, 0x12, 0x41, 0x0a, 0x06, + 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x1d, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x73, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x3e, 0x5a, 0x3c, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x69, + 0x74, 0x6c, 0x61, 0x62, 0x2d, 0x6f, 0x72, 0x67, 0x2f, 0x67, 0x69, 0x74, 0x61, 0x6c, 0x79, 0x2f, + 0x76, 0x31, 0x34, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x73, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x72, 0x70, 0x63, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_streamrpc_testdata_test_proto_rawDescOnce sync.Once + file_streamrpc_testdata_test_proto_rawDescData = file_streamrpc_testdata_test_proto_rawDesc +) + +func file_streamrpc_testdata_test_proto_rawDescGZIP() []byte { + file_streamrpc_testdata_test_proto_rawDescOnce.Do(func() { + file_streamrpc_testdata_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_streamrpc_testdata_test_proto_rawDescData) + }) + return file_streamrpc_testdata_test_proto_rawDescData +} + +var file_streamrpc_testdata_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_streamrpc_testdata_test_proto_goTypes = []interface{}{ + (*StreamRequest)(nil), // 0: test.streamrpc.StreamRequest + (*emptypb.Empty)(nil), // 1: google.protobuf.Empty +} +var file_streamrpc_testdata_test_proto_depIdxs = []int32{ + 0, // 0: test.streamrpc.Test.Stream:input_type -> test.streamrpc.StreamRequest + 1, // 1: test.streamrpc.Test.Stream:output_type -> google.protobuf.Empty + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_streamrpc_testdata_test_proto_init() } +func file_streamrpc_testdata_test_proto_init() { + if File_streamrpc_testdata_test_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_streamrpc_testdata_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StreamRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_streamrpc_testdata_test_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_streamrpc_testdata_test_proto_goTypes, + DependencyIndexes: file_streamrpc_testdata_test_proto_depIdxs, + MessageInfos: file_streamrpc_testdata_test_proto_msgTypes, + }.Build() + File_streamrpc_testdata_test_proto = out.File + file_streamrpc_testdata_test_proto_rawDesc = nil + file_streamrpc_testdata_test_proto_goTypes = nil + file_streamrpc_testdata_test_proto_depIdxs = nil +} diff --git a/internal/streamrpc/testdata/test.proto b/internal/streamrpc/testdata/test.proto new file mode 100644 index 000000000..7614a07a9 --- /dev/null +++ b/internal/streamrpc/testdata/test.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package test.streamrpc; + +option go_package = "gitlab.com/gitlab-org/gitaly/v14/internal/streamrpc/testdata"; + +import "google/protobuf/empty.proto"; + +service Test { + rpc Stream(StreamRequest) returns (google.protobuf.Empty) {} +} + +message StreamRequest { + bool fail = 1; + string string_field = 2; +} diff --git a/internal/streamrpc/testdata/test_grpc.pb.go b/internal/streamrpc/testdata/test_grpc.pb.go new file mode 100644 index 000000000..53fb22792 --- /dev/null +++ b/internal/streamrpc/testdata/test_grpc.pb.go @@ -0,0 +1,102 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. + +package testdata + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// TestClient is the client API for Test service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type TestClient interface { + Stream(ctx context.Context, in *StreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) +} + +type testClient struct { + cc grpc.ClientConnInterface +} + +func NewTestClient(cc grpc.ClientConnInterface) TestClient { + return &testClient{cc} +} + +func (c *testClient) Stream(ctx context.Context, in *StreamRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, "/test.streamrpc.Test/Stream", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// TestServer is the server API for Test service. +// All implementations must embed UnimplementedTestServer +// for forward compatibility +type TestServer interface { + Stream(context.Context, *StreamRequest) (*emptypb.Empty, error) + mustEmbedUnimplementedTestServer() +} + +// UnimplementedTestServer must be embedded to have forward compatible implementations. +type UnimplementedTestServer struct { +} + +func (UnimplementedTestServer) Stream(context.Context, *StreamRequest) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Stream not implemented") +} +func (UnimplementedTestServer) mustEmbedUnimplementedTestServer() {} + +// UnsafeTestServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to TestServer will +// result in compilation errors. +type UnsafeTestServer interface { + mustEmbedUnimplementedTestServer() +} + +func RegisterTestServer(s grpc.ServiceRegistrar, srv TestServer) { + s.RegisterService(&Test_ServiceDesc, srv) +} + +func _Test_Stream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StreamRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServer).Stream(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/test.streamrpc.Test/Stream", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServer).Stream(ctx, req.(*StreamRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Test_ServiceDesc is the grpc.ServiceDesc for Test service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Test_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "test.streamrpc.Test", + HandlerType: (*TestServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Stream", + Handler: _Test_Stream_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "streamrpc/testdata/test.proto", +} |