Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Okstad <pokstad@gitlab.com>2019-05-27 05:16:38 +0300
committerPaul Okstad <pokstad@gitlab.com>2019-05-27 05:16:38 +0300
commitad7fc2fb5d510d51c42efc7dbe70a667cc9f6069 (patch)
treeb0eb09a3cc49b501cd13f4a303caa8f3b917ace7
parent39346e71cb51c31d06c29113e5703de411dfc400 (diff)
troubleshooting interceptor peekingpo-inforef-cache
-rw-r--r--internal/interceptor/cache.go44
-rw-r--r--internal/interceptor/cache_test.go35
-rw-r--r--internal/praefect/protoregistry/request_factory.go6
3 files changed, 67 insertions, 18 deletions
diff --git a/internal/interceptor/cache.go b/internal/interceptor/cache.go
index d17ab20d1..5a5082eac 100644
--- a/internal/interceptor/cache.go
+++ b/internal/interceptor/cache.go
@@ -3,6 +3,8 @@ package interceptor
import (
"errors"
"fmt"
+ "log"
+ "reflect"
"github.com/golang/protobuf/proto"
"github.com/sirupsen/logrus"
@@ -17,6 +19,10 @@ type RepoCache interface {
InvalidateRepo(repo *gitalypb.Repository) error
}
+type RequestFactory interface {
+ NewRequest() (proto.Message, error)
+}
+
// StreamInvalidator will invalidate any mutating RPC that targets a repository
// in a gRPC stream based RPC
func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServerInterceptor {
@@ -26,13 +32,15 @@ func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServ
logrus.Errorf("unable to lookup method information for %+v", info)
}
- peeker := &StreamPeeker{ServerStream: ss}
+ peeker := &StreamPeeker{
+ ServerStream: ss,
+ reqFactory: mInfo,
+ }
switch op := mInfo.Operation; op {
case protoregistry.OpAccessor:
break
case protoregistry.OpMutator:
- fmt.Printf("👹")
peekedMsg, err := peeker.PeekReq()
if err != nil {
logrus.Errorf("cache invalidator interceptor unable to peek into stream: %s", err)
@@ -59,9 +67,11 @@ func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServ
type StreamPeeker struct {
grpc.ServerStream
- peeked bool // did you peek?
- peekedMsg interface{} // what did you peek?
- peekedErr error // what did you screw up when you peeked?
+ peeked bool // did you peek?
+ peekedMsg proto.Message // what did you peek?
+ peekedErr error // what did you screw up when you peeked?
+
+ reqFactory RequestFactory
}
// PeekMsg will peek one message into the stream to obtain the client's first
@@ -71,15 +81,19 @@ func (sp *StreamPeeker) PeekReq() (proto.Message, error) {
if sp.peeked {
return nil, errors.New("already peeked")
}
+
sp.peeked = true
- sp.peekedErr = sp.ServerStream.RecvMsg(sp.peekedMsg)
- pbMsg, ok := sp.peekedMsg.(proto.Message)
- if !ok {
- return nil, errors.New("peeked message is not protobuf")
+ var err error
+ sp.peekedMsg, err = sp.reqFactory.NewRequest()
+ if err != nil {
+ return nil, err
}
- return pbMsg, sp.peekedErr
+ sp.peekedErr = sp.ServerStream.RecvMsg(sp.peekedMsg)
+ log.Printf("👽: %#v", sp.peekedMsg)
+
+ return sp.peekedMsg, sp.peekedErr
}
// RecvMsg overrides the embedded grpc.ServerStream's method of the same name.
@@ -89,6 +103,16 @@ func (sp *StreamPeeker) RecvMsg(m interface{}) error {
if sp.peeked {
sp.peeked = false
m = sp.peekedMsg
+ log.Printf("Forwarding peeked msg: %#v", sp.peekedMsg)
+
+ mv := reflect.ValueOf(m)
+ if mv.Kind() != reflect.Ptr || mv.IsNil() {
+ return fmt.Errorf("receievd message of wrong type: %s", mv.Type())
+ }
+ mv.Elem().Set(reflect.ValueOf(sp.peekedMsg).Elem())
+
+ log.Printf("🤖: %#v", m)
+
return sp.peekedErr
}
diff --git a/internal/interceptor/cache_test.go b/internal/interceptor/cache_test.go
index ae1fa45e4..fd75c6fa5 100644
--- a/internal/interceptor/cache_test.go
+++ b/internal/interceptor/cache_test.go
@@ -2,6 +2,7 @@ package interceptor_test
import (
"context"
+ "log"
"net"
"testing"
"time"
@@ -19,7 +20,6 @@ import (
//go:generate make testdata/stream.pb.go
func TestStreamInvalidator(t *testing.T) {
-
cache, repoQ := newMockCache()
reg := protoregistry.New()
@@ -60,16 +60,28 @@ func TestStreamInvalidator(t *testing.T) {
}()
for i := 0; i < len(expectedInvalidations); i++ {
- t.Logf("waiting for repo invalidation #%d", i)
+ expect := expectedInvalidations[i]
select {
- case repo := <-repoQ:
- require.Equal(t, expectedInvalidations[i], repo)
+ case actual := <-repoQ:
+ requireReposEqual(t, actual, expect)
case <-ctx.Done():
- break
+ require.Fail(t, "test timed out")
}
}
+ cancel()
+}
+
+// requireReposEqual only compares "important" fields of a repo and ignores
+// XXX_* fields
+func requireReposEqual(t testing.TB, expect, actual *gitalypb.Repository) {
+ require.Equal(t, expect.GitAlternateObjectDirectories, actual.GitAlternateObjectDirectories)
+ require.Equal(t, expect.GitObjectDirectory, actual.GitObjectDirectory)
+ require.Equal(t, expect.GlProjectPath, actual.GlProjectPath)
+ require.Equal(t, expect.GlRepository, actual.GlRepository)
+ require.Equal(t, expect.RelativePath, actual.RelativePath)
+ require.Equal(t, expect.StorageName, actual.StorageName)
}
// mockCache allows us to relay back via channel which repos are being
@@ -107,6 +119,7 @@ func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testda
}()
cleanup := func() {
+ srvr.Stop()
require.NoError(t, <-errQ)
}
@@ -121,11 +134,19 @@ func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testda
return testdata.NewTestServiceClient(cc), cleanup
}
-type testSvc struct{}
+type testSvc struct {
+ clientStreamRepoMutatorQ chan<- *testdata.Request
+}
-func (ts *testSvc) ClientStreamRepoMutator(*testdata.Request, testdata.TestService_ClientStreamRepoMutatorServer) error {
+func (ts *testSvc) ClientStreamRepoMutator(req *testdata.Request, cli testdata.TestService_ClientStreamRepoMutatorServer) error {
+ log.Printf("req: %#v", req)
+ req = new(testdata.Request)
+ cli.RecvMsg(req)
+ log.Printf("req: %#v", req)
+ //req <- clientStreamRepoMutatorQ
return nil
}
+
func (ts *testSvc) ClientStreamRepoAccessor(*testdata.Request, testdata.TestService_ClientStreamRepoAccessorServer) error {
return nil
}
diff --git a/internal/praefect/protoregistry/request_factory.go b/internal/praefect/protoregistry/request_factory.go
index 4e53f4b23..20e702b20 100644
--- a/internal/praefect/protoregistry/request_factory.go
+++ b/internal/praefect/protoregistry/request_factory.go
@@ -13,7 +13,9 @@ import (
// message type for an RPC method. This is useful in gRPC components that treat
// messages generically, like a stream interceptor.
func requestFactory(mdp *descriptor.MethodDescriptorProto) (func() (proto.Message, error), error) {
- reqTypeName := strings.TrimPrefix(mdp.GetInputType(), ".") // not sure why this has a leading dot
+ // not sure why this has a leading dot
+ reqTypeName := strings.TrimPrefix(mdp.GetInputType(), ".")
+
reqType := proto.MessageType(reqTypeName)
if reqType == nil {
return nil, fmt.Errorf("unable to retrieve protobuf message type for %s", reqTypeName)
@@ -21,10 +23,12 @@ func requestFactory(mdp *descriptor.MethodDescriptorProto) (func() (proto.Messag
factory := func() (proto.Message, error) {
newReq := reflect.New(reqType.Elem())
+
val, ok := newReq.Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("method request factory does not return proto message: %#v", newReq)
}
+
return val, nil
}