Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Okstad <pokstad@gitlab.com>2019-07-16 16:22:51 +0300
committerJacob Vosmaer <jacob@gitlab.com>2019-07-16 16:22:51 +0300
commit45882789b1abca4efa53f51f5500ecf8e5ead506 (patch)
tree3ec80942886115eb7a85dca8abc860ddb54c0cac
parentb52d9ceab4603724b748f4911e2ad77e6501b5a0 (diff)
Cache invalidation via gRPC interceptor
-rw-r--r--changelogs/unreleased/po-cache-invalidator.yml5
-rw-r--r--internal/middleware/cache/Makefile3
-rw-r--r--internal/middleware/cache/cache.go165
-rw-r--r--internal/middleware/cache/cache_test.go183
-rw-r--r--internal/middleware/cache/testdata/stream.pb.go281
-rw-r--r--internal/middleware/cache/testdata/stream.proto28
-rw-r--r--internal/praefect/protoregistry/protoregistry.go38
-rw-r--r--internal/praefect/protoregistry/protoregistry_test.go3
-rw-r--r--internal/praefect/protoregistry/targetrepo_test.go2
9 files changed, 687 insertions, 21 deletions
diff --git a/changelogs/unreleased/po-cache-invalidator.yml b/changelogs/unreleased/po-cache-invalidator.yml
new file mode 100644
index 000000000..f47d4c482
--- /dev/null
+++ b/changelogs/unreleased/po-cache-invalidator.yml
@@ -0,0 +1,5 @@
+---
+title: Cache invalidation via gRPC interceptor
+merge_request: 1268
+author:
+type: added
diff --git a/internal/middleware/cache/Makefile b/internal/middleware/cache/Makefile
new file mode 100644
index 000000000..7eee3b0e2
--- /dev/null
+++ b/internal/middleware/cache/Makefile
@@ -0,0 +1,3 @@
+testdata/%.pb.go: testdata/%.proto
+ go mod download gitlab.com/gitlab-org/gitaly-proto
+ protoc --go_out=paths=source_relative,plugins=grpc:./testdata -I${shell go list -m -f '{{ .Dir }}' gitlab.com/gitlab-org/gitaly-proto} -I$(shell pwd)/testdata $(shell pwd)/testdata/*.proto
diff --git a/internal/middleware/cache/cache.go b/internal/middleware/cache/cache.go
new file mode 100644
index 000000000..6ad16a59c
--- /dev/null
+++ b/internal/middleware/cache/cache.go
@@ -0,0 +1,165 @@
+package cache
+
+import (
+ "fmt"
+ "sync"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb"
+ diskcache "gitlab.com/gitlab-org/gitaly/internal/cache"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry"
+ "google.golang.org/grpc"
+)
+
+var (
+ rpcTotal = prometheus.NewCounter(
+ prometheus.CounterOpts{
+ Name: "gitaly_cacheinvalidator_rpc_total",
+ Help: "Total number of RPCs encountered by cache invalidator",
+ },
+ )
+ rpcOpTypes = prometheus.NewCounterVec(
+ prometheus.CounterOpts{
+ Name: "gitaly_cacheinvalidator_optype_total",
+ Help: "Total number of operation types encountered by cache invalidator",
+ },
+ []string{"type"},
+ )
+)
+
+func init() {
+ prometheus.MustRegister(rpcTotal)
+ prometheus.MustRegister(rpcOpTypes)
+}
+
+func countRPCType(mInfo protoregistry.MethodInfo) {
+ rpcTotal.Inc()
+
+ switch mInfo.Operation {
+ case protoregistry.OpAccessor:
+ rpcOpTypes.WithLabelValues("accessor").Inc()
+ case protoregistry.OpMutator:
+ rpcOpTypes.WithLabelValues("mutator").Inc()
+ default:
+ rpcOpTypes.WithLabelValues("unknown").Inc()
+ }
+}
+
+// Invalidator is able to invalidate parts of the cache pertinent to a
+// specific repository. Before a repo mutating operation, StartLease should
+// be called. Once the operation is complete, the returned LeaseEnder should
+// be invoked to end the lease.
+type Invalidator interface {
+ StartLease(repo *gitalypb.Repository) (diskcache.LeaseEnder, error)
+}
+
+// StreamInvalidator will invalidate any mutating RPC that targets a
+// repository in a gRPC stream based RPC
+func StreamInvalidator(ci Invalidator, reg *protoregistry.Registry) grpc.StreamServerInterceptor {
+ return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ mInfo, err := reg.LookupMethod(info.FullMethod)
+ countRPCType(mInfo)
+ if err != nil {
+ logrus.WithField("FullMethodName", info.FullMethod).Errorf("unable to lookup method information for %+v", info)
+ }
+
+ if mInfo.Operation == protoregistry.OpAccessor {
+ return handler(srv, ss)
+ }
+
+ handler, callback := invalidateCache(ci, mInfo, handler)
+ peeker := newStreamPeeker(ss, callback)
+ return handler(srv, peeker)
+ }
+}
+
+type recvMsgCallback func(interface{}, error) error
+
+func invalidateCache(ci Invalidator, mInfo protoregistry.MethodInfo, handler grpc.StreamHandler) (grpc.StreamHandler, recvMsgCallback) {
+ var le struct {
+ sync.RWMutex
+ diskcache.LeaseEnder
+ }
+
+ // ensures that the lease ender is invoked after the original handler
+ wrappedHandler := func(srv interface{}, stream grpc.ServerStream) error {
+ defer func() {
+ le.RLock()
+ defer le.RUnlock()
+
+ if le.LeaseEnder == nil {
+ return
+ }
+ if err := le.EndLease(stream.Context()); err != nil {
+ logrus.Errorf("unable to end lease: %q", err)
+ }
+ }()
+ return handler(srv, stream)
+ }
+
+ // starts the cache lease and sets the lease ender iff the request's target
+ // repository can be determined from the first request message
+ peekerCallback := func(firstReq interface{}, err error) error {
+ if err != nil {
+ return err
+ }
+
+ pbFirstReq, ok := firstReq.(proto.Message)
+ if !ok {
+ return fmt.Errorf("cache invalidation expected protobuf request, but got %T", firstReq)
+ }
+
+ target, err := mInfo.TargetRepo(pbFirstReq)
+ if err != nil {
+ return err
+ }
+
+ le.Lock()
+ defer le.Unlock()
+
+ le.LeaseEnder, err = ci.StartLease(target)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ return wrappedHandler, peekerCallback
+}
+
+// streamPeeker allows a stream interceptor to insert peeking logic to perform
+// an action when the first RecvMsg
+type streamPeeker struct {
+ grpc.ServerStream
+
+ // onFirstRecvCallback is called the first time the server stream's RecvMsg
+ // is invoked. It passes the results of the stream's RecvMsg as the
+ // callback's parameters.
+ onFirstRecvOnce sync.Once
+ onFirstRecvCallback recvMsgCallback
+}
+
+// newStreamPeeker returns a wrapped stream that allows a callback to be called
+// on the first invocation of RecvMsg.
+func newStreamPeeker(stream grpc.ServerStream, callback recvMsgCallback) grpc.ServerStream {
+ return &streamPeeker{
+ ServerStream: stream,
+ onFirstRecvCallback: callback,
+ }
+}
+
+// RecvMsg overrides the embedded grpc.ServerStream's method of the same name so
+// that the callback is called on the first call.
+func (sp *streamPeeker) RecvMsg(m interface{}) error {
+ err := sp.ServerStream.RecvMsg(m)
+ sp.onFirstRecvOnce.Do(func() {
+ err := sp.onFirstRecvCallback(m, err)
+ if err != nil {
+ logrus.Errorf("unable to invalidate cache: %q", err)
+ }
+ })
+ return err
+}
diff --git a/internal/middleware/cache/cache_test.go b/internal/middleware/cache/cache_test.go
new file mode 100644
index 000000000..c04cb4d6a
--- /dev/null
+++ b/internal/middleware/cache/cache_test.go
@@ -0,0 +1,183 @@
+package cache_test
+
+import (
+ "context"
+ "io"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/protoc-gen-go/descriptor"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb"
+ diskcache "gitlab.com/gitlab-org/gitaly/internal/cache"
+ "gitlab.com/gitlab-org/gitaly/internal/middleware/cache"
+ "gitlab.com/gitlab-org/gitaly/internal/middleware/cache/testdata"
+ "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry"
+ "google.golang.org/grpc"
+)
+
+//go:generate make testdata/stream.pb.go
+func TestStreamInvalidator(t *testing.T) {
+ mCache := newMockCache()
+
+ reg := protoregistry.New()
+ require.NoError(t, reg.RegisterFiles(streamFileDesc(t)))
+
+ srvr := grpc.NewServer(
+ grpc.StreamInterceptor(
+ grpc.StreamServerInterceptor(
+ cache.StreamInvalidator(mCache, reg),
+ ),
+ ),
+ )
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+
+ svc := &testSvc{}
+
+ cli, cleanup := newTestSvc(t, ctx, srvr, svc)
+ defer cleanup()
+
+ repo1 := &gitalypb.Repository{
+ GitAlternateObjectDirectories: []string{"1"},
+ GitObjectDirectory: "1",
+ GlProjectPath: "1",
+ GlRepository: "1",
+ RelativePath: "1",
+ StorageName: "1",
+ }
+
+ repo2 := &gitalypb.Repository{
+ GitAlternateObjectDirectories: []string{"2"},
+ GitObjectDirectory: "2",
+ GlProjectPath: "2",
+ GlRepository: "2",
+ RelativePath: "2",
+ StorageName: "2",
+ }
+
+ repo3 := &gitalypb.Repository{
+ GitAlternateObjectDirectories: []string{"3"},
+ GitObjectDirectory: "3",
+ GlProjectPath: "3",
+ GlRepository: "3",
+ RelativePath: "3",
+ StorageName: "3",
+ }
+
+ expectedSvcRequests := []gitalypb.Repository{*repo1, *repo2, *repo3}
+ expectedInvalidations := []gitalypb.Repository{*repo2, *repo3}
+
+ // Should NOT trigger cache invalidation
+ c, err := cli.ClientStreamRepoAccessor(ctx, &testdata.Request{
+ Destination: repo1,
+ })
+ assert.NoError(t, err)
+ _, err = c.Recv() // make client call synchronous by waiting for close
+ assert.Equal(t, err, io.EOF)
+
+ // Should trigger cache invalidation
+ c, err = cli.ClientStreamRepoMutator(ctx, &testdata.Request{
+ Destination: repo2,
+ })
+ assert.NoError(t, err)
+ _, err = c.Recv() // make client call synchronous by waiting for close
+ assert.Equal(t, err, io.EOF)
+
+ // Should trigger cache invalidation
+ c, err = cli.ClientStreamRepoMutator(ctx, &testdata.Request{
+ Destination: repo3,
+ })
+ assert.NoError(t, err)
+ _, err = c.Recv() // make client call synchronous by waiting for close
+ assert.Equal(t, err, io.EOF)
+
+ require.Equal(t, expectedInvalidations, mCache.(*mockCache).invalidatedRepos)
+ require.Equal(t, expectedSvcRequests, svc.repoRequests)
+ require.Equal(t, 2, mCache.(*mockCache).endedLeases.count)
+}
+
+// mockCache allows us to relay back via channel which repos are being
+// invalidated in the cache
+type mockCache struct {
+ invalidatedRepos []gitalypb.Repository
+ endedLeases *struct {
+ sync.RWMutex
+ count int
+ }
+}
+
+func newMockCache() cache.Invalidator {
+ return &mockCache{
+ endedLeases: &struct {
+ sync.RWMutex
+ count int
+ }{},
+ }
+}
+
+func (mc *mockCache) EndLease(_ context.Context) error {
+ mc.endedLeases.Lock()
+ defer mc.endedLeases.Unlock()
+ mc.endedLeases.count++
+
+ return nil
+}
+
+func (mc *mockCache) StartLease(repo *gitalypb.Repository) (diskcache.LeaseEnder, error) {
+ mc.invalidatedRepos = append(mc.invalidatedRepos, *repo)
+ return mc, nil
+}
+
+func streamFileDesc(t testing.TB) *descriptor.FileDescriptorProto {
+ fdp, err := protoregistry.ExtractFileDescriptor(proto.FileDescriptor("stream.proto"))
+ require.NoError(t, err)
+ return fdp
+}
+
+func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testdata.TestServiceServer) (testdata.TestServiceClient, func()) {
+ testdata.RegisterTestServiceServer(srvr, svc)
+
+ lis, err := net.Listen("tcp", ":0")
+ require.NoError(t, err)
+
+ errQ := make(chan error)
+
+ go func() {
+ errQ <- srvr.Serve(lis)
+ }()
+
+ cleanup := func() {
+ srvr.Stop()
+ require.NoError(t, <-errQ)
+ }
+
+ cc, err := grpc.DialContext(
+ ctx,
+ lis.Addr().String(),
+ grpc.WithBlock(),
+ grpc.WithInsecure(),
+ )
+ require.NoError(t, err)
+
+ return testdata.NewTestServiceClient(cc), cleanup
+}
+
+type testSvc struct {
+ repoRequests []gitalypb.Repository
+}
+
+func (ts *testSvc) ClientStreamRepoMutator(req *testdata.Request, _ testdata.TestService_ClientStreamRepoMutatorServer) error {
+ ts.repoRequests = append(ts.repoRequests, *req.GetDestination())
+ return nil
+}
+
+func (ts *testSvc) ClientStreamRepoAccessor(req *testdata.Request, _ testdata.TestService_ClientStreamRepoAccessorServer) error {
+ ts.repoRequests = append(ts.repoRequests, *req.GetDestination())
+ return nil
+}
diff --git a/internal/middleware/cache/testdata/stream.pb.go b/internal/middleware/cache/testdata/stream.pb.go
new file mode 100644
index 000000000..cbce32ff3
--- /dev/null
+++ b/internal/middleware/cache/testdata/stream.pb.go
@@ -0,0 +1,281 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// source: stream.proto
+
+package testdata
+
+import (
+ context "context"
+ fmt "fmt"
+ proto "github.com/golang/protobuf/proto"
+ gitalypb "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb"
+ grpc "google.golang.org/grpc"
+ math "math"
+)
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
+
+type Request struct {
+ Destination *gitalypb.Repository `protobuf:"bytes,1,opt,name=destination,proto3" json:"destination,omitempty"`
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Request) Reset() { *m = Request{} }
+func (m *Request) String() string { return proto.CompactTextString(m) }
+func (*Request) ProtoMessage() {}
+func (*Request) Descriptor() ([]byte, []int) {
+ return fileDescriptor_bb17ef3f514bfe54, []int{0}
+}
+
+func (m *Request) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Request.Unmarshal(m, b)
+}
+func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Request.Marshal(b, m, deterministic)
+}
+func (m *Request) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Request.Merge(m, src)
+}
+func (m *Request) XXX_Size() int {
+ return xxx_messageInfo_Request.Size(m)
+}
+func (m *Request) XXX_DiscardUnknown() {
+ xxx_messageInfo_Request.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Request proto.InternalMessageInfo
+
+func (m *Request) GetDestination() *gitalypb.Repository {
+ if m != nil {
+ return m.Destination
+ }
+ return nil
+}
+
+type Response struct {
+ XXX_NoUnkeyedLiteral struct{} `json:"-"`
+ XXX_unrecognized []byte `json:"-"`
+ XXX_sizecache int32 `json:"-"`
+}
+
+func (m *Response) Reset() { *m = Response{} }
+func (m *Response) String() string { return proto.CompactTextString(m) }
+func (*Response) ProtoMessage() {}
+func (*Response) Descriptor() ([]byte, []int) {
+ return fileDescriptor_bb17ef3f514bfe54, []int{1}
+}
+
+func (m *Response) XXX_Unmarshal(b []byte) error {
+ return xxx_messageInfo_Response.Unmarshal(m, b)
+}
+func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
+ return xxx_messageInfo_Response.Marshal(b, m, deterministic)
+}
+func (m *Response) XXX_Merge(src proto.Message) {
+ xxx_messageInfo_Response.Merge(m, src)
+}
+func (m *Response) XXX_Size() int {
+ return xxx_messageInfo_Response.Size(m)
+}
+func (m *Response) XXX_DiscardUnknown() {
+ xxx_messageInfo_Response.DiscardUnknown(m)
+}
+
+var xxx_messageInfo_Response proto.InternalMessageInfo
+
+func init() {
+ proto.RegisterType((*Request)(nil), "testdata.Request")
+ proto.RegisterType((*Response)(nil), "testdata.Response")
+}
+
+func init() { proto.RegisterFile("stream.proto", fileDescriptor_bb17ef3f514bfe54) }
+
+var fileDescriptor_bb17ef3f514bfe54 = []byte{
+ // 257 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0xd0, 0xb1, 0x4e, 0xc3, 0x30,
+ 0x10, 0x06, 0x60, 0xb9, 0x12, 0x25, 0xb8, 0x5d, 0xf0, 0x42, 0x95, 0x09, 0x65, 0xca, 0x82, 0x5d,
+ 0x0a, 0x7b, 0x55, 0x18, 0x51, 0x97, 0x94, 0x89, 0xed, 0xea, 0x9c, 0x52, 0x4b, 0x8e, 0x1d, 0x7c,
+ 0x57, 0x50, 0x9f, 0xa4, 0xcf, 0xc0, 0x2b, 0x76, 0x42, 0x34, 0x54, 0xaa, 0x98, 0xd8, 0xac, 0x5f,
+ 0xd6, 0x77, 0xff, 0x9d, 0x1c, 0x13, 0x27, 0x84, 0x56, 0x77, 0x29, 0x72, 0x54, 0x19, 0x23, 0x71,
+ 0x0d, 0x0c, 0xf9, 0x98, 0x36, 0x90, 0xb0, 0xee, 0xf3, 0x62, 0x2e, 0x2f, 0x2b, 0x7c, 0xdf, 0x22,
+ 0xb1, 0x7a, 0x94, 0xa3, 0x1a, 0x89, 0x5d, 0x00, 0x76, 0x31, 0x4c, 0xc4, 0xad, 0x28, 0x47, 0x33,
+ 0xa5, 0x1b, 0xc7, 0xe0, 0x77, 0xba, 0xc2, 0x2e, 0x92, 0xe3, 0x98, 0x76, 0xd5, 0xf9, 0xb7, 0x42,
+ 0xca, 0xac, 0x42, 0xea, 0x62, 0x20, 0x9c, 0x7d, 0x09, 0x39, 0x7a, 0x45, 0xe2, 0x15, 0xa6, 0x0f,
+ 0x67, 0x51, 0x2d, 0xe5, 0xcd, 0xb3, 0x77, 0x18, 0x78, 0x75, 0xac, 0xf2, 0x43, 0x2c, 0xb7, 0x0c,
+ 0x1c, 0x93, 0xba, 0xd6, 0xa7, 0x42, 0xfa, 0x77, 0x7e, 0xae, 0xce, 0xa3, 0x5e, 0x2c, 0xae, 0x0e,
+ 0xfb, 0xf2, 0x22, 0x13, 0xb9, 0xb8, 0x9f, 0x0a, 0xf5, 0x22, 0x27, 0x7f, 0xb9, 0x85, 0xb5, 0x48,
+ 0xf4, 0x7f, 0x6f, 0x78, 0xd8, 0x97, 0x83, 0x6c, 0x30, 0x15, 0x4f, 0x8b, 0xb7, 0x79, 0xe3, 0xd8,
+ 0xc3, 0x5a, 0xdb, 0xd8, 0x9a, 0xfe, 0x79, 0x17, 0x53, 0x63, 0xfa, 0x7d, 0x8d, 0x0b, 0x8c, 0x29,
+ 0x80, 0x37, 0xad, 0xab, 0x6b, 0x8f, 0x9f, 0x90, 0xd0, 0x58, 0xb0, 0x1b, 0x34, 0x27, 0x75, 0x3d,
+ 0x3c, 0x9e, 0xf0, 0xe1, 0x3b, 0x00, 0x00, 0xff, 0xff, 0x77, 0x8a, 0x5d, 0x72, 0x6a, 0x01, 0x00,
+ 0x00,
+}
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the grpc package it is being compiled against.
+const _ = grpc.SupportPackageIsVersion4
+
+// TestServiceClient is the client API for TestService service.
+//
+// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
+type TestServiceClient interface {
+ ClientStreamRepoMutator(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoMutatorClient, error)
+ ClientStreamRepoAccessor(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoAccessorClient, error)
+}
+
+type testServiceClient struct {
+ cc *grpc.ClientConn
+}
+
+func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient {
+ return &testServiceClient{cc}
+}
+
+func (c *testServiceClient) ClientStreamRepoMutator(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoMutatorClient, error) {
+ stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[0], "/testdata.TestService/ClientStreamRepoMutator", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &testServiceClientStreamRepoMutatorClient{stream}
+ if err := x.ClientStream.SendMsg(in); err != nil {
+ return nil, err
+ }
+ if err := x.ClientStream.CloseSend(); err != nil {
+ return nil, err
+ }
+ return x, nil
+}
+
+type TestService_ClientStreamRepoMutatorClient interface {
+ Recv() (*Response, error)
+ grpc.ClientStream
+}
+
+type testServiceClientStreamRepoMutatorClient struct {
+ grpc.ClientStream
+}
+
+func (x *testServiceClientStreamRepoMutatorClient) Recv() (*Response, error) {
+ m := new(Response)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+func (c *testServiceClient) ClientStreamRepoAccessor(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoAccessorClient, error) {
+ stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[1], "/testdata.TestService/ClientStreamRepoAccessor", opts...)
+ if err != nil {
+ return nil, err
+ }
+ x := &testServiceClientStreamRepoAccessorClient{stream}
+ if err := x.ClientStream.SendMsg(in); err != nil {
+ return nil, err
+ }
+ if err := x.ClientStream.CloseSend(); err != nil {
+ return nil, err
+ }
+ return x, nil
+}
+
+type TestService_ClientStreamRepoAccessorClient interface {
+ Recv() (*Response, error)
+ grpc.ClientStream
+}
+
+type testServiceClientStreamRepoAccessorClient struct {
+ grpc.ClientStream
+}
+
+func (x *testServiceClientStreamRepoAccessorClient) Recv() (*Response, error) {
+ m := new(Response)
+ if err := x.ClientStream.RecvMsg(m); err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+// TestServiceServer is the server API for TestService service.
+type TestServiceServer interface {
+ ClientStreamRepoMutator(*Request, TestService_ClientStreamRepoMutatorServer) error
+ ClientStreamRepoAccessor(*Request, TestService_ClientStreamRepoAccessorServer) error
+}
+
+func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) {
+ s.RegisterService(&_TestService_serviceDesc, srv)
+}
+
+func _TestService_ClientStreamRepoMutator_Handler(srv interface{}, stream grpc.ServerStream) error {
+ m := new(Request)
+ if err := stream.RecvMsg(m); err != nil {
+ return err
+ }
+ return srv.(TestServiceServer).ClientStreamRepoMutator(m, &testServiceClientStreamRepoMutatorServer{stream})
+}
+
+type TestService_ClientStreamRepoMutatorServer interface {
+ Send(*Response) error
+ grpc.ServerStream
+}
+
+type testServiceClientStreamRepoMutatorServer struct {
+ grpc.ServerStream
+}
+
+func (x *testServiceClientStreamRepoMutatorServer) Send(m *Response) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+func _TestService_ClientStreamRepoAccessor_Handler(srv interface{}, stream grpc.ServerStream) error {
+ m := new(Request)
+ if err := stream.RecvMsg(m); err != nil {
+ return err
+ }
+ return srv.(TestServiceServer).ClientStreamRepoAccessor(m, &testServiceClientStreamRepoAccessorServer{stream})
+}
+
+type TestService_ClientStreamRepoAccessorServer interface {
+ Send(*Response) error
+ grpc.ServerStream
+}
+
+type testServiceClientStreamRepoAccessorServer struct {
+ grpc.ServerStream
+}
+
+func (x *testServiceClientStreamRepoAccessorServer) Send(m *Response) error {
+ return x.ServerStream.SendMsg(m)
+}
+
+var _TestService_serviceDesc = grpc.ServiceDesc{
+ ServiceName: "testdata.TestService",
+ HandlerType: (*TestServiceServer)(nil),
+ Methods: []grpc.MethodDesc{},
+ Streams: []grpc.StreamDesc{
+ {
+ StreamName: "ClientStreamRepoMutator",
+ Handler: _TestService_ClientStreamRepoMutator_Handler,
+ ServerStreams: true,
+ },
+ {
+ StreamName: "ClientStreamRepoAccessor",
+ Handler: _TestService_ClientStreamRepoAccessor_Handler,
+ ServerStreams: true,
+ },
+ },
+ Metadata: "stream.proto",
+}
diff --git a/internal/middleware/cache/testdata/stream.proto b/internal/middleware/cache/testdata/stream.proto
new file mode 100644
index 000000000..f10aa08af
--- /dev/null
+++ b/internal/middleware/cache/testdata/stream.proto
@@ -0,0 +1,28 @@
+syntax = "proto3";
+
+package testdata;
+
+import "shared.proto";
+
+option go_package = "gitlab.com/gitlab-org/gitaly/internal/middleware/cache/testdata";
+
+message Request {
+ gitaly.Repository destination = 1;
+}
+
+message Response{}
+
+service TestService {
+ rpc ClientStreamRepoMutator(Request) returns (stream Response) {
+ option (gitaly.op_type) = {
+ op: MUTATOR
+ target_repository_field: "1"
+ };
+ }
+
+ rpc ClientStreamRepoAccessor(Request) returns (stream Response) {
+ option (gitaly.op_type) = {
+ op: ACCESSOR
+ };
+ }
+}
diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go
index 83f8a16bd..c0418893d 100644
--- a/internal/praefect/protoregistry/protoregistry.go
+++ b/internal/praefect/protoregistry/protoregistry.go
@@ -21,7 +21,7 @@ var GitalyProtoFileDescriptors []*descriptor.FileDescriptorProto
func init() {
for _, protoName := range gitalypb.GitalyProtos {
gz := proto.FileDescriptor(protoName)
- fd, err := extractFile(gz)
+ fd, err := ExtractFileDescriptor(gz)
if err != nil {
panic(err)
}
@@ -49,6 +49,7 @@ type MethodInfo struct {
Operation OpType
targetRepo []int
requestName string // protobuf message name for input type
+
}
// TargetRepo returns the target repository for a protobuf message if it exists
@@ -66,33 +67,35 @@ func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error)
// Registry contains info about RPC methods
type Registry struct {
sync.RWMutex
- protos map[string]map[string]MethodInfo
+ protos map[string]MethodInfo
}
// New creates a new ProtoRegistry
func New() *Registry {
return &Registry{
- protos: make(map[string]map[string]MethodInfo),
+ protos: make(map[string]MethodInfo),
}
}
-// RegisterFiles takes one or more descriptor.FileDescriptorProto and populates the registry with its info
+// RegisterFiles takes one or more descriptor.FileDescriptorProto and populates
+// the registry with its info
func (pr *Registry) RegisterFiles(protos ...*descriptor.FileDescriptorProto) error {
pr.Lock()
defer pr.Unlock()
for _, p := range protos {
- for _, serviceDescriptorProto := range p.GetService() {
- for _, methodDescriptorProto := range serviceDescriptorProto.GetMethod() {
- mi, err := parseMethodInfo(methodDescriptorProto)
+ for _, svc := range p.GetService() {
+ for _, method := range svc.GetMethod() {
+ mi, err := parseMethodInfo(method)
if err != nil {
return err
}
- if _, ok := pr.protos[serviceDescriptorProto.GetName()]; !ok {
- pr.protos[serviceDescriptorProto.GetName()] = make(map[string]MethodInfo)
- }
- pr.protos[serviceDescriptorProto.GetName()][methodDescriptorProto.GetName()] = mi
+ fullMethodName := fmt.Sprintf(
+ "/%s.%s/%s",
+ p.GetPackage(), svc.GetName(), method.GetName(),
+ )
+ pr.protos[fullMethodName] = mi
}
}
}
@@ -189,23 +192,20 @@ func parseOID(rawFieldOID string) ([]int, error) {
}
// LookupMethod looks up an MethodInfo by service and method name
-func (pr *Registry) LookupMethod(service, method string) (MethodInfo, error) {
+func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) {
pr.RLock()
defer pr.RUnlock()
- if _, ok := pr.protos[service]; !ok {
- return MethodInfo{}, fmt.Errorf("service not found: %v", service)
- }
- methodInfo, ok := pr.protos[service][method]
+ methodInfo, ok := pr.protos[fullMethodName]
if !ok {
- return MethodInfo{}, fmt.Errorf("method not found: %v", method)
+ return MethodInfo{}, fmt.Errorf("full method name not found: %v", fullMethodName)
}
return methodInfo, nil
}
-// extractFile extracts a FileDescriptorProto from a gzip'd buffer.
+// ExtractFileDescriptor extracts a FileDescriptorProto from a gzip'd buffer.
// https://github.com/golang/protobuf/blob/9eb2c01ac278a5d89ce4b2be68fe4500955d8179/descriptor/descriptor.go#L50
-func extractFile(gz []byte) (*descriptor.FileDescriptorProto, error) {
+func ExtractFileDescriptor(gz []byte) (*descriptor.FileDescriptorProto, error) {
r, err := gzip.NewReader(bytes.NewReader(gz))
if err != nil {
return nil, fmt.Errorf("failed to open gzip reader: %v", err)
diff --git a/internal/praefect/protoregistry/protoregistry_test.go b/internal/praefect/protoregistry/protoregistry_test.go
index 56cffe255..3d896d6d7 100644
--- a/internal/praefect/protoregistry/protoregistry_test.go
+++ b/internal/praefect/protoregistry/protoregistry_test.go
@@ -1,6 +1,7 @@
package protoregistry_test
import (
+ "fmt"
"testing"
"github.com/stretchr/testify/assert"
@@ -187,7 +188,7 @@ func TestPopulatesProtoRegistry(t *testing.T) {
for serviceName, methods := range expectedResults {
for methodName, opType := range methods {
- methodInfo, err := r.LookupMethod(serviceName, methodName)
+ methodInfo, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName))
require.NoError(t, err)
assert.Equalf(t, opType, methodInfo.Operation, "expect %s:%s to have the correct op type", serviceName, methodName)
}
diff --git a/internal/praefect/protoregistry/targetrepo_test.go b/internal/praefect/protoregistry/targetrepo_test.go
index fc016807d..f2c1f394e 100644
--- a/internal/praefect/protoregistry/targetrepo_test.go
+++ b/internal/praefect/protoregistry/targetrepo_test.go
@@ -63,7 +63,7 @@ func TestProtoRegistryTargetRepo(t *testing.T) {
for _, tc := range testcases {
desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc)
t.Run(desc, func(t *testing.T) {
- info, err := r.LookupMethod(tc.svc, tc.method)
+ info, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method))
require.NoError(t, err)
actualTarget, actualErr := info.TargetRepo(tc.pbMsg)