From ad7fc2fb5d510d51c42efc7dbe70a667cc9f6069 Mon Sep 17 00:00:00 2001 From: Paul Okstad Date: Sun, 26 May 2019 19:16:38 -0700 Subject: troubleshooting interceptor peeking --- internal/interceptor/cache.go | 44 +++++++++++++++++----- internal/interceptor/cache_test.go | 35 +++++++++++++---- internal/praefect/protoregistry/request_factory.go | 6 ++- 3 files changed, 67 insertions(+), 18 deletions(-) diff --git a/internal/interceptor/cache.go b/internal/interceptor/cache.go index d17ab20d1..5a5082eac 100644 --- a/internal/interceptor/cache.go +++ b/internal/interceptor/cache.go @@ -3,6 +3,8 @@ package interceptor import ( "errors" "fmt" + "log" + "reflect" "github.com/golang/protobuf/proto" "github.com/sirupsen/logrus" @@ -17,6 +19,10 @@ type RepoCache interface { InvalidateRepo(repo *gitalypb.Repository) error } +type RequestFactory interface { + NewRequest() (proto.Message, error) +} + // StreamInvalidator will invalidate any mutating RPC that targets a repository // in a gRPC stream based RPC func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServerInterceptor { @@ -26,13 +32,15 @@ func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServ logrus.Errorf("unable to lookup method information for %+v", info) } - peeker := &StreamPeeker{ServerStream: ss} + peeker := &StreamPeeker{ + ServerStream: ss, + reqFactory: mInfo, + } switch op := mInfo.Operation; op { case protoregistry.OpAccessor: break case protoregistry.OpMutator: - fmt.Printf("👹") peekedMsg, err := peeker.PeekReq() if err != nil { logrus.Errorf("cache invalidator interceptor unable to peek into stream: %s", err) @@ -59,9 +67,11 @@ func StreamInvalidator(c RepoCache, reg *protoregistry.Registry) grpc.StreamServ type StreamPeeker struct { grpc.ServerStream - peeked bool // did you peek? - peekedMsg interface{} // what did you peek? - peekedErr error // what did you screw up when you peeked? + peeked bool // did you peek? + peekedMsg proto.Message // what did you peek? + peekedErr error // what did you screw up when you peeked? + + reqFactory RequestFactory } // PeekMsg will peek one message into the stream to obtain the client's first @@ -71,15 +81,19 @@ func (sp *StreamPeeker) PeekReq() (proto.Message, error) { if sp.peeked { return nil, errors.New("already peeked") } + sp.peeked = true - sp.peekedErr = sp.ServerStream.RecvMsg(sp.peekedMsg) - pbMsg, ok := sp.peekedMsg.(proto.Message) - if !ok { - return nil, errors.New("peeked message is not protobuf") + var err error + sp.peekedMsg, err = sp.reqFactory.NewRequest() + if err != nil { + return nil, err } - return pbMsg, sp.peekedErr + sp.peekedErr = sp.ServerStream.RecvMsg(sp.peekedMsg) + log.Printf("👽: %#v", sp.peekedMsg) + + return sp.peekedMsg, sp.peekedErr } // RecvMsg overrides the embedded grpc.ServerStream's method of the same name. @@ -89,6 +103,16 @@ func (sp *StreamPeeker) RecvMsg(m interface{}) error { if sp.peeked { sp.peeked = false m = sp.peekedMsg + log.Printf("Forwarding peeked msg: %#v", sp.peekedMsg) + + mv := reflect.ValueOf(m) + if mv.Kind() != reflect.Ptr || mv.IsNil() { + return fmt.Errorf("receievd message of wrong type: %s", mv.Type()) + } + mv.Elem().Set(reflect.ValueOf(sp.peekedMsg).Elem()) + + log.Printf("🤖: %#v", m) + return sp.peekedErr } diff --git a/internal/interceptor/cache_test.go b/internal/interceptor/cache_test.go index ae1fa45e4..fd75c6fa5 100644 --- a/internal/interceptor/cache_test.go +++ b/internal/interceptor/cache_test.go @@ -2,6 +2,7 @@ package interceptor_test import ( "context" + "log" "net" "testing" "time" @@ -19,7 +20,6 @@ import ( //go:generate make testdata/stream.pb.go func TestStreamInvalidator(t *testing.T) { - cache, repoQ := newMockCache() reg := protoregistry.New() @@ -60,16 +60,28 @@ func TestStreamInvalidator(t *testing.T) { }() for i := 0; i < len(expectedInvalidations); i++ { - t.Logf("waiting for repo invalidation #%d", i) + expect := expectedInvalidations[i] select { - case repo := <-repoQ: - require.Equal(t, expectedInvalidations[i], repo) + case actual := <-repoQ: + requireReposEqual(t, actual, expect) case <-ctx.Done(): - break + require.Fail(t, "test timed out") } } + cancel() +} + +// requireReposEqual only compares "important" fields of a repo and ignores +// XXX_* fields +func requireReposEqual(t testing.TB, expect, actual *gitalypb.Repository) { + require.Equal(t, expect.GitAlternateObjectDirectories, actual.GitAlternateObjectDirectories) + require.Equal(t, expect.GitObjectDirectory, actual.GitObjectDirectory) + require.Equal(t, expect.GlProjectPath, actual.GlProjectPath) + require.Equal(t, expect.GlRepository, actual.GlRepository) + require.Equal(t, expect.RelativePath, actual.RelativePath) + require.Equal(t, expect.StorageName, actual.StorageName) } // mockCache allows us to relay back via channel which repos are being @@ -107,6 +119,7 @@ func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testda }() cleanup := func() { + srvr.Stop() require.NoError(t, <-errQ) } @@ -121,11 +134,19 @@ func newTestSvc(t testing.TB, ctx context.Context, srvr *grpc.Server, svc testda return testdata.NewTestServiceClient(cc), cleanup } -type testSvc struct{} +type testSvc struct { + clientStreamRepoMutatorQ chan<- *testdata.Request +} -func (ts *testSvc) ClientStreamRepoMutator(*testdata.Request, testdata.TestService_ClientStreamRepoMutatorServer) error { +func (ts *testSvc) ClientStreamRepoMutator(req *testdata.Request, cli testdata.TestService_ClientStreamRepoMutatorServer) error { + log.Printf("req: %#v", req) + req = new(testdata.Request) + cli.RecvMsg(req) + log.Printf("req: %#v", req) + //req <- clientStreamRepoMutatorQ return nil } + func (ts *testSvc) ClientStreamRepoAccessor(*testdata.Request, testdata.TestService_ClientStreamRepoAccessorServer) error { return nil } diff --git a/internal/praefect/protoregistry/request_factory.go b/internal/praefect/protoregistry/request_factory.go index 4e53f4b23..20e702b20 100644 --- a/internal/praefect/protoregistry/request_factory.go +++ b/internal/praefect/protoregistry/request_factory.go @@ -13,7 +13,9 @@ import ( // message type for an RPC method. This is useful in gRPC components that treat // messages generically, like a stream interceptor. func requestFactory(mdp *descriptor.MethodDescriptorProto) (func() (proto.Message, error), error) { - reqTypeName := strings.TrimPrefix(mdp.GetInputType(), ".") // not sure why this has a leading dot + // not sure why this has a leading dot + reqTypeName := strings.TrimPrefix(mdp.GetInputType(), ".") + reqType := proto.MessageType(reqTypeName) if reqType == nil { return nil, fmt.Errorf("unable to retrieve protobuf message type for %s", reqTypeName) @@ -21,10 +23,12 @@ func requestFactory(mdp *descriptor.MethodDescriptorProto) (func() (proto.Messag factory := func() (proto.Message, error) { newReq := reflect.New(reqType.Elem()) + val, ok := newReq.Interface().(proto.Message) if !ok { return nil, fmt.Errorf("method request factory does not return proto message: %#v", newReq) } + return val, nil } -- cgit v1.2.3