diff options
author | Paul Okstad <pokstad@gitlab.com> | 2019-07-16 16:22:51 +0300 |
---|---|---|
committer | Jacob Vosmaer <jacob@gitlab.com> | 2019-07-16 16:22:51 +0300 |
commit | 45882789b1abca4efa53f51f5500ecf8e5ead506 (patch) | |
tree | 3ec80942886115eb7a85dca8abc860ddb54c0cac | |
parent | b52d9ceab4603724b748f4911e2ad77e6501b5a0 (diff) |
Cache invalidation via gRPC interceptor
-rw-r--r-- | changelogs/unreleased/po-cache-invalidator.yml | 5 | ||||
-rw-r--r-- | internal/middleware/cache/Makefile | 3 | ||||
-rw-r--r-- | internal/middleware/cache/cache.go | 165 | ||||
-rw-r--r-- | internal/middleware/cache/cache_test.go | 183 | ||||
-rw-r--r-- | internal/middleware/cache/testdata/stream.pb.go | 281 | ||||
-rw-r--r-- | internal/middleware/cache/testdata/stream.proto | 28 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry.go | 38 | ||||
-rw-r--r-- | internal/praefect/protoregistry/protoregistry_test.go | 3 | ||||
-rw-r--r-- | internal/praefect/protoregistry/targetrepo_test.go | 2 |
9 files changed, 687 insertions, 21 deletions
diff --git a/changelogs/unreleased/po-cache-invalidator.yml b/changelogs/unreleased/po-cache-invalidator.yml new file mode 100644 index 000000000..f47d4c482 --- /dev/null +++ b/changelogs/unreleased/po-cache-invalidator.yml @@ -0,0 +1,5 @@ +--- +title: Cache invalidation via gRPC interceptor +merge_request: 1268 +author: +type: added diff --git a/internal/middleware/cache/Makefile b/internal/middleware/cache/Makefile new file mode 100644 index 000000000..7eee3b0e2 --- /dev/null +++ b/internal/middleware/cache/Makefile @@ -0,0 +1,3 @@ +testdata/%.pb.go: testdata/%.proto + go mod download gitlab.com/gitlab-org/gitaly-proto + protoc --go_out=paths=source_relative,plugins=grpc:./testdata -I${shell go list -m -f '{{ .Dir }}' gitlab.com/gitlab-org/gitaly-proto} -I$(shell pwd)/testdata $(shell pwd)/testdata/*.proto diff --git a/internal/middleware/cache/cache.go b/internal/middleware/cache/cache.go new file mode 100644 index 000000000..6ad16a59c --- /dev/null +++ b/internal/middleware/cache/cache.go @@ -0,0 +1,165 @@ +package cache + +import ( + "fmt" + "sync" + + "github.com/golang/protobuf/proto" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb" + diskcache "gitlab.com/gitlab-org/gitaly/internal/cache" + "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry" + "google.golang.org/grpc" +) + +var ( + rpcTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "gitaly_cacheinvalidator_rpc_total", + Help: "Total number of RPCs encountered by cache invalidator", + }, + ) + rpcOpTypes = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitaly_cacheinvalidator_optype_total", + Help: "Total number of operation types encountered by cache invalidator", + }, + []string{"type"}, + ) +) + +func init() { + prometheus.MustRegister(rpcTotal) + prometheus.MustRegister(rpcOpTypes) +} + +func countRPCType(mInfo protoregistry.MethodInfo) { + rpcTotal.Inc() + + switch mInfo.Operation { + case protoregistry.OpAccessor: + rpcOpTypes.WithLabelValues("accessor").Inc() + case protoregistry.OpMutator: + rpcOpTypes.WithLabelValues("mutator").Inc() + default: + rpcOpTypes.WithLabelValues("unknown").Inc() + } +} + +// Invalidator is able to invalidate parts of the cache pertinent to a +// specific repository. Before a repo mutating operation, StartLease should +// be called. Once the operation is complete, the returned LeaseEnder should +// be invoked to end the lease. +type Invalidator interface { + StartLease(repo *gitalypb.Repository) (diskcache.LeaseEnder, error) +} + +// StreamInvalidator will invalidate any mutating RPC that targets a +// repository in a gRPC stream based RPC +func StreamInvalidator(ci Invalidator, reg *protoregistry.Registry) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + mInfo, err := reg.LookupMethod(info.FullMethod) + countRPCType(mInfo) + if err != nil { + logrus.WithField("FullMethodName", info.FullMethod).Errorf("unable to lookup method information for %+v", info) + } + + if mInfo.Operation == protoregistry.OpAccessor { + return handler(srv, ss) + } + + handler, callback := invalidateCache(ci, mInfo, handler) + peeker := newStreamPeeker(ss, callback) + return handler(srv, peeker) + } +} + +type recvMsgCallback func(interface{}, error) error + +func invalidateCache(ci Invalidator, mInfo protoregistry.MethodInfo, handler grpc.StreamHandler) (grpc.StreamHandler, recvMsgCallback) { + var le struct { + sync.RWMutex + diskcache.LeaseEnder + } + + // ensures that the lease ender is invoked after the original handler + wrappedHandler := func(srv interface{}, stream grpc.ServerStream) error { + defer func() { + le.RLock() + defer le.RUnlock() + + if le.LeaseEnder == nil { + return + } + if err := le.EndLease(stream.Context()); err != nil { + logrus.Errorf("unable to end lease: %q", err) + } + }() + return handler(srv, stream) + } + + // starts the cache lease and sets the lease ender iff the request's target + // repository can be determined from the first request message + peekerCallback := func(firstReq interface{}, err error) error { + if err != nil { + return err + } + + pbFirstReq, ok := firstReq.(proto.Message) + if !ok { + return fmt.Errorf("cache invalidation expected protobuf request, but got %T", firstReq) + } + + target, err := mInfo.TargetRepo(pbFirstReq) + if err != nil { + return err + } + + le.Lock() + defer le.Unlock() + + le.LeaseEnder, err = ci.StartLease(target) + if err != nil { + return err + } + + return nil + } + + return wrappedHandler, peekerCallback +} + +// streamPeeker allows a stream interceptor to insert peeking logic to perform +// an action when the first RecvMsg +type streamPeeker struct { + grpc.ServerStream + + // onFirstRecvCallback is called the first time the server stream's RecvMsg + // is invoked. It passes the results of the stream's RecvMsg as the + // callback's parameters. + onFirstRecvOnce sync.Once + onFirstRecvCallback recvMsgCallback +} + +// newStreamPeeker returns a wrapped stream that allows a callback to be called +// on the first invocation of RecvMsg. +func newStreamPeeker(stream grpc.ServerStream, callback recvMsgCallback) grpc.ServerStream { + return &streamPeeker{ + ServerStream: stream, + onFirstRecvCallback: callback, + } +} + +// RecvMsg overrides the embedded grpc.ServerStream's method of the same name so +// that the callback is called on the first call. +func (sp *streamPeeker) RecvMsg(m interface{}) error { + err := sp.ServerStream.RecvMsg(m) + sp.onFirstRecvOnce.Do(func() { + err := sp.onFirstRecvCallback(m, err) + if err != nil { + logrus.Errorf("unable to invalidate cache: %q", err) + } + }) + return err +} diff --git a/internal/middleware/cache/cache_test.go b/internal/middleware/cache/cache_test.go new file mode 100644 index 000000000..c04cb4d6a --- /dev/null +++ b/internal/middleware/cache/cache_test.go @@ -0,0 +1,183 @@ +package cache_test + +import ( + "context" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb" + diskcache "gitlab.com/gitlab-org/gitaly/internal/cache" + "gitlab.com/gitlab-org/gitaly/internal/middleware/cache" + "gitlab.com/gitlab-org/gitaly/internal/middleware/cache/testdata" + "gitlab.com/gitlab-org/gitaly/internal/praefect/protoregistry" + "google.golang.org/grpc" +) + +//go:generate make testdata/stream.pb.go +func TestStreamInvalidator(t *testing.T) { + mCache := newMockCache() + + reg := protoregistry.New() + require.NoError(t, reg.RegisterFiles(streamFileDesc(t))) + + srvr := grpc.NewServer( + grpc.StreamInterceptor( + grpc.StreamServerInterceptor( + cache.StreamInvalidator(mCache, reg), + ), + ), + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + svc := &testSvc{} + + cli, cleanup := newTestSvc(t, ctx, srvr, svc) + defer cleanup() + + repo1 := &gitalypb.Repository{ + GitAlternateObjectDirectories: []string{"1"}, + GitObjectDirectory: "1", + GlProjectPath: "1", + GlRepository: "1", + RelativePath: "1", + StorageName: "1", + } + + repo2 := &gitalypb.Repository{ + GitAlternateObjectDirectories: []string{"2"}, + GitObjectDirectory: "2", + GlProjectPath: "2", + GlRepository: "2", + RelativePath: "2", + StorageName: "2", + } + + repo3 := &gitalypb.Repository{ + GitAlternateObjectDirectories: []string{"3"}, + GitObjectDirectory: "3", + GlProjectPath: "3", + GlRepository: "3", + RelativePath: "3", + StorageName: "3", + } + + expectedSvcRequests := []gitalypb.Repository{*repo1, *repo2, *repo3} + expectedInvalidations := []gitalypb.Repository{*repo2, *repo3} + + // Should NOT trigger cache invalidation + c, err := cli.ClientStreamRepoAccessor(ctx, &testdata.Request{ + Destination: repo1, + }) + assert.NoError(t, err) + _, err = c.Recv() // make client call synchronous by waiting for close + assert.Equal(t, err, io.EOF) + + // Should trigger cache invalidation + c, err = cli.ClientStreamRepoMutator(ctx, &testdata.Request{ + Destination: repo2, + }) + assert.NoError(t, err) + _, err = c.Recv() // make client call synchronous by waiting for close + assert.Equal(t, err, io.EOF) + + // Should trigger cache invalidation + c, err = cli.ClientStreamRepoMutator(ctx, &testdata.Request{ + Destination: repo3, + }) + assert.NoError(t, err) + _, err = c.Recv() // make client call synchronous by waiting for close + assert.Equal(t, err, io.EOF) + + require.Equal(t, expectedInvalidations, mCache.(*mockCache).invalidatedRepos) + require.Equal(t, expectedSvcRequests, svc.repoRequests) + require.Equal(t, 2, mCache.(*mockCache).endedLeases.count) +} + +// mockCache allows us to relay back via channel which repos are being +// invalidated in the cache +type mockCache struct { + invalidatedRepos []gitalypb.Repository + endedLeases *struct { + sync.RWMutex + count int + } +} + +func newMockCache() cache.Invalidator { + return &mockCache{ + endedLeases: &struct { + sync.RWMutex + count int + }{}, + } +} + +func (mc *mockCache) EndLease(_ context.Context) error { + mc.endedLeases.Lock() + defer mc.endedLeases.Unlock() + mc.endedLeases.count++ + + return nil +} + +func (mc *mockCache) StartLease(repo *gitalypb.Repository) (diskcache.LeaseEnder, error) { + mc.invalidatedRepos = append(mc.invalidatedRepos, *repo) + return mc, nil +} + +func streamFileDesc(t testing.TB) *descriptor.FileDescriptorProto { + fdp, err := protoregistry.ExtractFileDescriptor(proto.FileDescriptor("stream.proto")) + require.NoError(t, err) + return fdp +} + +func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testdata.TestServiceServer) (testdata.TestServiceClient, func()) { + testdata.RegisterTestServiceServer(srvr, svc) + + lis, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + errQ := make(chan error) + + go func() { + errQ <- srvr.Serve(lis) + }() + + cleanup := func() { + srvr.Stop() + require.NoError(t, <-errQ) + } + + cc, err := grpc.DialContext( + ctx, + lis.Addr().String(), + grpc.WithBlock(), + grpc.WithInsecure(), + ) + require.NoError(t, err) + + return testdata.NewTestServiceClient(cc), cleanup +} + +type testSvc struct { + repoRequests []gitalypb.Repository +} + +func (ts *testSvc) ClientStreamRepoMutator(req *testdata.Request, _ testdata.TestService_ClientStreamRepoMutatorServer) error { + ts.repoRequests = append(ts.repoRequests, *req.GetDestination()) + return nil +} + +func (ts *testSvc) ClientStreamRepoAccessor(req *testdata.Request, _ testdata.TestService_ClientStreamRepoAccessorServer) error { + ts.repoRequests = append(ts.repoRequests, *req.GetDestination()) + return nil +} diff --git a/internal/middleware/cache/testdata/stream.pb.go b/internal/middleware/cache/testdata/stream.pb.go new file mode 100644 index 000000000..cbce32ff3 --- /dev/null +++ b/internal/middleware/cache/testdata/stream.pb.go @@ -0,0 +1,281 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: stream.proto + +package testdata + +import ( + context "context" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + gitalypb "gitlab.com/gitlab-org/gitaly-proto/go/gitalypb" + grpc "google.golang.org/grpc" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type Request struct { + Destination *gitalypb.Repository `protobuf:"bytes,1,opt,name=destination,proto3" json:"destination,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Request) Reset() { *m = Request{} } +func (m *Request) String() string { return proto.CompactTextString(m) } +func (*Request) ProtoMessage() {} +func (*Request) Descriptor() ([]byte, []int) { + return fileDescriptor_bb17ef3f514bfe54, []int{0} +} + +func (m *Request) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Request.Unmarshal(m, b) +} +func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Request.Marshal(b, m, deterministic) +} +func (m *Request) XXX_Merge(src proto.Message) { + xxx_messageInfo_Request.Merge(m, src) +} +func (m *Request) XXX_Size() int { + return xxx_messageInfo_Request.Size(m) +} +func (m *Request) XXX_DiscardUnknown() { + xxx_messageInfo_Request.DiscardUnknown(m) +} + +var xxx_messageInfo_Request proto.InternalMessageInfo + +func (m *Request) GetDestination() *gitalypb.Repository { + if m != nil { + return m.Destination + } + return nil +} + +type Response struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Response) Reset() { *m = Response{} } +func (m *Response) String() string { return proto.CompactTextString(m) } +func (*Response) ProtoMessage() {} +func (*Response) Descriptor() ([]byte, []int) { + return fileDescriptor_bb17ef3f514bfe54, []int{1} +} + +func (m *Response) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Response.Unmarshal(m, b) +} +func (m *Response) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Response.Marshal(b, m, deterministic) +} +func (m *Response) XXX_Merge(src proto.Message) { + xxx_messageInfo_Response.Merge(m, src) +} +func (m *Response) XXX_Size() int { + return xxx_messageInfo_Response.Size(m) +} +func (m *Response) XXX_DiscardUnknown() { + xxx_messageInfo_Response.DiscardUnknown(m) +} + +var xxx_messageInfo_Response proto.InternalMessageInfo + +func init() { + proto.RegisterType((*Request)(nil), "testdata.Request") + proto.RegisterType((*Response)(nil), "testdata.Response") +} + +func init() { proto.RegisterFile("stream.proto", fileDescriptor_bb17ef3f514bfe54) } + +var fileDescriptor_bb17ef3f514bfe54 = []byte{ + // 257 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0xd0, 0xb1, 0x4e, 0xc3, 0x30, + 0x10, 0x06, 0x60, 0xb9, 0x12, 0x25, 0xb8, 0x5d, 0xf0, 0x42, 0x95, 0x09, 0x65, 0xca, 0x82, 0x5d, + 0x0a, 0x7b, 0x55, 0x18, 0x51, 0x97, 0x94, 0x89, 0xed, 0xea, 0x9c, 0x52, 0x4b, 0x8e, 0x1d, 0x7c, + 0x57, 0x50, 0x9f, 0xa4, 0xcf, 0xc0, 0x2b, 0x76, 0x42, 0x34, 0x54, 0xaa, 0x98, 0xd8, 0xac, 0x5f, + 0xd6, 0x77, 0xff, 0x9d, 0x1c, 0x13, 0x27, 0x84, 0x56, 0x77, 0x29, 0x72, 0x54, 0x19, 0x23, 0x71, + 0x0d, 0x0c, 0xf9, 0x98, 0x36, 0x90, 0xb0, 0xee, 0xf3, 0x62, 0x2e, 0x2f, 0x2b, 0x7c, 0xdf, 0x22, + 0xb1, 0x7a, 0x94, 0xa3, 0x1a, 0x89, 0x5d, 0x00, 0x76, 0x31, 0x4c, 0xc4, 0xad, 0x28, 0x47, 0x33, + 0xa5, 0x1b, 0xc7, 0xe0, 0x77, 0xba, 0xc2, 0x2e, 0x92, 0xe3, 0x98, 0x76, 0xd5, 0xf9, 0xb7, 0x42, + 0xca, 0xac, 0x42, 0xea, 0x62, 0x20, 0x9c, 0x7d, 0x09, 0x39, 0x7a, 0x45, 0xe2, 0x15, 0xa6, 0x0f, + 0x67, 0x51, 0x2d, 0xe5, 0xcd, 0xb3, 0x77, 0x18, 0x78, 0x75, 0xac, 0xf2, 0x43, 0x2c, 0xb7, 0x0c, + 0x1c, 0x93, 0xba, 0xd6, 0xa7, 0x42, 0xfa, 0x77, 0x7e, 0xae, 0xce, 0xa3, 0x5e, 0x2c, 0xae, 0x0e, + 0xfb, 0xf2, 0x22, 0x13, 0xb9, 0xb8, 0x9f, 0x0a, 0xf5, 0x22, 0x27, 0x7f, 0xb9, 0x85, 0xb5, 0x48, + 0xf4, 0x7f, 0x6f, 0x78, 0xd8, 0x97, 0x83, 0x6c, 0x30, 0x15, 0x4f, 0x8b, 0xb7, 0x79, 0xe3, 0xd8, + 0xc3, 0x5a, 0xdb, 0xd8, 0x9a, 0xfe, 0x79, 0x17, 0x53, 0x63, 0xfa, 0x7d, 0x8d, 0x0b, 0x8c, 0x29, + 0x80, 0x37, 0xad, 0xab, 0x6b, 0x8f, 0x9f, 0x90, 0xd0, 0x58, 0xb0, 0x1b, 0x34, 0x27, 0x75, 0x3d, + 0x3c, 0x9e, 0xf0, 0xe1, 0x3b, 0x00, 0x00, 0xff, 0xff, 0x77, 0x8a, 0x5d, 0x72, 0x6a, 0x01, 0x00, + 0x00, +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// TestServiceClient is the client API for TestService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type TestServiceClient interface { + ClientStreamRepoMutator(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoMutatorClient, error) + ClientStreamRepoAccessor(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoAccessorClient, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) ClientStreamRepoMutator(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoMutatorClient, error) { + stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[0], "/testdata.TestService/ClientStreamRepoMutator", opts...) + if err != nil { + return nil, err + } + x := &testServiceClientStreamRepoMutatorClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestService_ClientStreamRepoMutatorClient interface { + Recv() (*Response, error) + grpc.ClientStream +} + +type testServiceClientStreamRepoMutatorClient struct { + grpc.ClientStream +} + +func (x *testServiceClientStreamRepoMutatorClient) Recv() (*Response, error) { + m := new(Response) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) ClientStreamRepoAccessor(ctx context.Context, in *Request, opts ...grpc.CallOption) (TestService_ClientStreamRepoAccessorClient, error) { + stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[1], "/testdata.TestService/ClientStreamRepoAccessor", opts...) + if err != nil { + return nil, err + } + x := &testServiceClientStreamRepoAccessorClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestService_ClientStreamRepoAccessorClient interface { + Recv() (*Response, error) + grpc.ClientStream +} + +type testServiceClientStreamRepoAccessorClient struct { + grpc.ClientStream +} + +func (x *testServiceClientStreamRepoAccessorClient) Recv() (*Response, error) { + m := new(Response) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// TestServiceServer is the server API for TestService service. +type TestServiceServer interface { + ClientStreamRepoMutator(*Request, TestService_ClientStreamRepoMutatorServer) error + ClientStreamRepoAccessor(*Request, TestService_ClientStreamRepoAccessorServer) error +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_ClientStreamRepoMutator_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Request) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestServiceServer).ClientStreamRepoMutator(m, &testServiceClientStreamRepoMutatorServer{stream}) +} + +type TestService_ClientStreamRepoMutatorServer interface { + Send(*Response) error + grpc.ServerStream +} + +type testServiceClientStreamRepoMutatorServer struct { + grpc.ServerStream +} + +func (x *testServiceClientStreamRepoMutatorServer) Send(m *Response) error { + return x.ServerStream.SendMsg(m) +} + +func _TestService_ClientStreamRepoAccessor_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Request) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestServiceServer).ClientStreamRepoAccessor(m, &testServiceClientStreamRepoAccessorServer{stream}) +} + +type TestService_ClientStreamRepoAccessorServer interface { + Send(*Response) error + grpc.ServerStream +} + +type testServiceClientStreamRepoAccessorServer struct { + grpc.ServerStream +} + +func (x *testServiceClientStreamRepoAccessorServer) Send(m *Response) error { + return x.ServerStream.SendMsg(m) +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "testdata.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "ClientStreamRepoMutator", + Handler: _TestService_ClientStreamRepoMutator_Handler, + ServerStreams: true, + }, + { + StreamName: "ClientStreamRepoAccessor", + Handler: _TestService_ClientStreamRepoAccessor_Handler, + ServerStreams: true, + }, + }, + Metadata: "stream.proto", +} diff --git a/internal/middleware/cache/testdata/stream.proto b/internal/middleware/cache/testdata/stream.proto new file mode 100644 index 000000000..f10aa08af --- /dev/null +++ b/internal/middleware/cache/testdata/stream.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package testdata; + +import "shared.proto"; + +option go_package = "gitlab.com/gitlab-org/gitaly/internal/middleware/cache/testdata"; + +message Request { + gitaly.Repository destination = 1; +} + +message Response{} + +service TestService { + rpc ClientStreamRepoMutator(Request) returns (stream Response) { + option (gitaly.op_type) = { + op: MUTATOR + target_repository_field: "1" + }; + } + + rpc ClientStreamRepoAccessor(Request) returns (stream Response) { + option (gitaly.op_type) = { + op: ACCESSOR + }; + } +} diff --git a/internal/praefect/protoregistry/protoregistry.go b/internal/praefect/protoregistry/protoregistry.go index 83f8a16bd..c0418893d 100644 --- a/internal/praefect/protoregistry/protoregistry.go +++ b/internal/praefect/protoregistry/protoregistry.go @@ -21,7 +21,7 @@ var GitalyProtoFileDescriptors []*descriptor.FileDescriptorProto func init() { for _, protoName := range gitalypb.GitalyProtos { gz := proto.FileDescriptor(protoName) - fd, err := extractFile(gz) + fd, err := ExtractFileDescriptor(gz) if err != nil { panic(err) } @@ -49,6 +49,7 @@ type MethodInfo struct { Operation OpType targetRepo []int requestName string // protobuf message name for input type + } // TargetRepo returns the target repository for a protobuf message if it exists @@ -66,33 +67,35 @@ func (mi MethodInfo) TargetRepo(msg proto.Message) (*gitalypb.Repository, error) // Registry contains info about RPC methods type Registry struct { sync.RWMutex - protos map[string]map[string]MethodInfo + protos map[string]MethodInfo } // New creates a new ProtoRegistry func New() *Registry { return &Registry{ - protos: make(map[string]map[string]MethodInfo), + protos: make(map[string]MethodInfo), } } -// RegisterFiles takes one or more descriptor.FileDescriptorProto and populates the registry with its info +// RegisterFiles takes one or more descriptor.FileDescriptorProto and populates +// the registry with its info func (pr *Registry) RegisterFiles(protos ...*descriptor.FileDescriptorProto) error { pr.Lock() defer pr.Unlock() for _, p := range protos { - for _, serviceDescriptorProto := range p.GetService() { - for _, methodDescriptorProto := range serviceDescriptorProto.GetMethod() { - mi, err := parseMethodInfo(methodDescriptorProto) + for _, svc := range p.GetService() { + for _, method := range svc.GetMethod() { + mi, err := parseMethodInfo(method) if err != nil { return err } - if _, ok := pr.protos[serviceDescriptorProto.GetName()]; !ok { - pr.protos[serviceDescriptorProto.GetName()] = make(map[string]MethodInfo) - } - pr.protos[serviceDescriptorProto.GetName()][methodDescriptorProto.GetName()] = mi + fullMethodName := fmt.Sprintf( + "/%s.%s/%s", + p.GetPackage(), svc.GetName(), method.GetName(), + ) + pr.protos[fullMethodName] = mi } } } @@ -189,23 +192,20 @@ func parseOID(rawFieldOID string) ([]int, error) { } // LookupMethod looks up an MethodInfo by service and method name -func (pr *Registry) LookupMethod(service, method string) (MethodInfo, error) { +func (pr *Registry) LookupMethod(fullMethodName string) (MethodInfo, error) { pr.RLock() defer pr.RUnlock() - if _, ok := pr.protos[service]; !ok { - return MethodInfo{}, fmt.Errorf("service not found: %v", service) - } - methodInfo, ok := pr.protos[service][method] + methodInfo, ok := pr.protos[fullMethodName] if !ok { - return MethodInfo{}, fmt.Errorf("method not found: %v", method) + return MethodInfo{}, fmt.Errorf("full method name not found: %v", fullMethodName) } return methodInfo, nil } -// extractFile extracts a FileDescriptorProto from a gzip'd buffer. +// ExtractFileDescriptor extracts a FileDescriptorProto from a gzip'd buffer. // https://github.com/golang/protobuf/blob/9eb2c01ac278a5d89ce4b2be68fe4500955d8179/descriptor/descriptor.go#L50 -func extractFile(gz []byte) (*descriptor.FileDescriptorProto, error) { +func ExtractFileDescriptor(gz []byte) (*descriptor.FileDescriptorProto, error) { r, err := gzip.NewReader(bytes.NewReader(gz)) if err != nil { return nil, fmt.Errorf("failed to open gzip reader: %v", err) diff --git a/internal/praefect/protoregistry/protoregistry_test.go b/internal/praefect/protoregistry/protoregistry_test.go index 56cffe255..3d896d6d7 100644 --- a/internal/praefect/protoregistry/protoregistry_test.go +++ b/internal/praefect/protoregistry/protoregistry_test.go @@ -1,6 +1,7 @@ package protoregistry_test import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -187,7 +188,7 @@ func TestPopulatesProtoRegistry(t *testing.T) { for serviceName, methods := range expectedResults { for methodName, opType := range methods { - methodInfo, err := r.LookupMethod(serviceName, methodName) + methodInfo, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", serviceName, methodName)) require.NoError(t, err) assert.Equalf(t, opType, methodInfo.Operation, "expect %s:%s to have the correct op type", serviceName, methodName) } diff --git a/internal/praefect/protoregistry/targetrepo_test.go b/internal/praefect/protoregistry/targetrepo_test.go index fc016807d..f2c1f394e 100644 --- a/internal/praefect/protoregistry/targetrepo_test.go +++ b/internal/praefect/protoregistry/targetrepo_test.go @@ -63,7 +63,7 @@ func TestProtoRegistryTargetRepo(t *testing.T) { for _, tc := range testcases { desc := fmt.Sprintf("%s:%s %s", tc.svc, tc.method, tc.desc) t.Run(desc, func(t *testing.T) { - info, err := r.LookupMethod(tc.svc, tc.method) + info, err := r.LookupMethod(fmt.Sprintf("/gitaly.%s/%s", tc.svc, tc.method)) require.NoError(t, err) actualTarget, actualErr := info.TargetRepo(tc.pbMsg) |