diff options
author | John Cai <jcai@gitlab.com> | 2019-07-26 01:34:14 +0300 |
---|---|---|
committer | John Cai <jcai@gitlab.com> | 2019-07-27 00:03:27 +0300 |
commit | d6f7ded9fe3e12799593508f2c553d4c945311d2 (patch) | |
tree | a4cbac0cc22de9b2ca3a0ed6db3c7e35369409d3 | |
parent | 52880019875947e465913a376b86ca484b34f98c (diff) |
Make Frame interface to allow modificationsjc-replace-frames-with-locking
-rw-r--r-- | internal/praefect/coordinator.go | 2 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/codec.go | 5 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/director.go | 2 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/examples_test.go | 2 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/handler.go | 4 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/handler_test.go | 2 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/peeker.go | 54 | ||||
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/peeker_test.go | 14 |
8 files changed, 47 insertions, 38 deletions
diff --git a/internal/praefect/coordinator.go b/internal/praefect/coordinator.go index c238604f3..8f64022cb 100644 --- a/internal/praefect/coordinator.go +++ b/internal/praefect/coordinator.go @@ -69,7 +69,7 @@ func (c *Coordinator) GetStorageNode(storage string) (Node, error) { } // streamDirector determines which downstream servers receive requests -func (c *Coordinator) streamDirector(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (context.Context, *grpc.ClientConn, error) { +func (c *Coordinator) streamDirector(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (context.Context, *grpc.ClientConn, error) { // For phase 1, we need to route messages based on the storage location // to the appropriate Gitaly node. c.log.Debugf("Stream director received method %s", fullMethodName) diff --git a/internal/praefect/grpc-proxy/proxy/codec.go b/internal/praefect/grpc-proxy/proxy/codec.go index 24d5f5cea..117aa1a0e 100644 --- a/internal/praefect/grpc-proxy/proxy/codec.go +++ b/internal/praefect/grpc-proxy/proxy/codec.go @@ -6,6 +6,7 @@ package proxy import ( "fmt" + "sync" "github.com/golang/protobuf/proto" "google.golang.org/grpc" @@ -33,7 +34,9 @@ type rawCodec struct { } type frame struct { - payload []byte + payload []byte + consumed bool + sync.Mutex } func (c *rawCodec) Marshal(v interface{}) ([]byte, error) { diff --git a/internal/praefect/grpc-proxy/proxy/director.go b/internal/praefect/grpc-proxy/proxy/director.go index 10a63b228..37f2be2d1 100644 --- a/internal/praefect/grpc-proxy/proxy/director.go +++ b/internal/praefect/grpc-proxy/proxy/director.go @@ -21,4 +21,4 @@ import ( // are invoked. So decisions around authorization, monitoring etc. are better to be handled there. // // See the rather rich example. -type StreamDirector func(ctx context.Context, fullMethodName string, peeker StreamModifier) (context.Context, *grpc.ClientConn, error) +type StreamDirector func(ctx context.Context, fullMethodName string, peeker StreamPeeker) (context.Context, *grpc.ClientConn, error) diff --git a/internal/praefect/grpc-proxy/proxy/examples_test.go b/internal/praefect/grpc-proxy/proxy/examples_test.go index 2c2090363..e312f3a5c 100644 --- a/internal/praefect/grpc-proxy/proxy/examples_test.go +++ b/internal/praefect/grpc-proxy/proxy/examples_test.go @@ -39,7 +39,7 @@ func ExampleTransparentHandler() { // Provide sa simple example of a director that shields internal services and dials a staging or production backend. // This is a *very naive* implementation that creates a new connection on every request. Consider using pooling. func ExampleStreamDirector() { - director = func(ctx context.Context, fullMethodName string, _ proxy.StreamModifier) (context.Context, *grpc.ClientConn, error) { + director = func(ctx context.Context, fullMethodName string, _ proxy.StreamPeeker) (context.Context, *grpc.ClientConn, error) { // Make sure we never forward internal services. if strings.HasPrefix(fullMethodName, "/com.example.internal.") { return nil, nil, grpc.Errorf(codes.Unimplemented, "Unknown method") diff --git a/internal/praefect/grpc-proxy/proxy/handler.go b/internal/praefect/grpc-proxy/proxy/handler.go index daf12d4b1..f18cb2e61 100644 --- a/internal/praefect/grpc-proxy/proxy/handler.go +++ b/internal/praefect/grpc-proxy/proxy/handler.go @@ -160,6 +160,10 @@ func (s *handler) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientSt // number of frames can be peeked break } + frame.Lock() + defer frame.Unlock() + defer func() { frame.consumed = true }() + if err := dst.SendMsg(frame); err != nil { ret <- err return diff --git a/internal/praefect/grpc-proxy/proxy/handler_test.go b/internal/praefect/grpc-proxy/proxy/handler_test.go index 0fff36ed4..0a4fabd20 100644 --- a/internal/praefect/grpc-proxy/proxy/handler_test.go +++ b/internal/praefect/grpc-proxy/proxy/handler_test.go @@ -207,7 +207,7 @@ func (s *ProxyHappySuite) SetupSuite() { // Setup of the proxy's Director. s.serverClientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithCodec(proxy.Codec())) require.NoError(s.T(), err, "must not error on deferred client Dial") - director := func(ctx context.Context, fullName string, _ proxy.StreamModifier) (context.Context, *grpc.ClientConn, error) { + director := func(ctx context.Context, fullName string, _ proxy.StreamPeeker) (context.Context, *grpc.ClientConn, error) { md, ok := metadata.FromIncomingContext(ctx) if ok { if _, exists := md[rejectingMdKey]; exists { diff --git a/internal/praefect/grpc-proxy/proxy/peeker.go b/internal/praefect/grpc-proxy/proxy/peeker.go index 1ba7cfdaf..1c07a343f 100644 --- a/internal/praefect/grpc-proxy/proxy/peeker.go +++ b/internal/praefect/grpc-proxy/proxy/peeker.go @@ -2,23 +2,23 @@ package proxy import ( "errors" - "fmt" - "golang.org/x/net/context" "google.golang.org/grpc" ) -// StreamModifier abstracts away the gRPC stream being forwarded so that it can +// StreamPeeker abstracts away the gRPC stream being forwarded so that it can // be inspected and modified. -type StreamModifier interface { - // Peek allows a director to peak a messages into the stream without +type StreamPeeker interface { + // Peek allows a director to peek a messages into the stream without // removing those messages from the stream that will be forwarded to // the backend server. - Peek(ctx context.Context) (frame []byte, _ error) + Peek() (Frame, error) +} - // Modify modifies a payload in the stream. It will replace a frame with the - // payload it is given - Modify(ctx context.Context, payload []byte) error +// Frame contains a payload that can be optionally modified +type Frame interface { + Modify(payload []byte) error + Payload() []byte } type partialStream struct { @@ -42,30 +42,25 @@ func newPeeker(stream grpc.ServerStream) *peeker { // peek quanity var ErrInvalidPeekCount = errors.New("peek count must be greater than zero") -func (p peeker) Peek(ctx context.Context) ([]byte, error) { - payloads, err := p.peek(ctx, 1) +func (p peeker) Peek() (Frame, error) { + frames, err := p.peek(1) if err != nil { return nil, err } - if len(payloads) != 1 { + if len(frames) != 1 { return nil, errors.New("failed to peek 1 message") } - return payloads[0], nil -} - -func (p peeker) Modify(ctx context.Context, payload []byte) error { - return p.modify(ctx, [][]byte{payload}) + return frames[0], nil } -func (p peeker) peek(ctx context.Context, n uint) ([][]byte, error) { +func (p peeker) peek(n uint) ([]*frame, error) { if n < 1 { return nil, ErrInvalidPeekCount } p.consumedStream.frames = make([]*frame, n) - peekedFrames := make([][]byte, n) for i := 0; i < len(p.consumedStream.frames); i++ { f := &frame{} @@ -74,20 +69,27 @@ func (p peeker) peek(ctx context.Context, n uint) ([][]byte, error) { break } p.consumedStream.frames[i] = f - peekedFrames[i] = f.payload } - return peekedFrames, nil + return p.consumedStream.frames, nil } -func (p peeker) modify(ctx context.Context, payloads [][]byte) error { - if len(payloads) != len(p.consumedStream.frames) { - return fmt.Errorf("replacement frames count %d does not match consumed frames count %d", len(payloads), len(p.consumedStream.frames)) +func (f *frame) Payload() []byte { + return f.payload +} + +func (f *frame) Modify(payload []byte) error { + f.Lock() + defer f.Unlock() + + if f.consumed { + return errors.New("frame has already been consumed") } - for i, payload := range payloads { - p.consumedStream.frames[i].payload = payload + if f.payload == nil { + return errors.New("frame payload is empty") } + f.payload = payload return nil } diff --git a/internal/praefect/grpc-proxy/proxy/peeker_test.go b/internal/praefect/grpc-proxy/proxy/peeker_test.go index 4874b8952..1173f52d4 100644 --- a/internal/praefect/grpc-proxy/proxy/peeker_test.go +++ b/internal/praefect/grpc-proxy/proxy/peeker_test.go @@ -28,14 +28,14 @@ func TestStreamPeeking(t *testing.T) { pingReqSent := &testservice.PingRequest{Value: "hi"} // director will peek into stream before routing traffic - director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (context.Context, *grpc.ClientConn, error) { + director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (context.Context, *grpc.ClientConn, error) { t.Logf("director routing method %s to backend", fullMethodName) - peekedMsg, err := peeker.Peek(ctx) + peekedMsg, err := peeker.Peek() require.NoError(t, err) peekedRequest := new(testservice.PingRequest) - err = proto.Unmarshal(peekedMsg, peekedRequest) + err = proto.Unmarshal(peekedMsg.Payload(), peekedRequest) require.NoError(t, err) require.Equal(t, pingReqSent, peekedRequest) @@ -87,14 +87,14 @@ func TestStreamInjecting(t *testing.T) { newValue := "bye" // director will peek into stream and change some frames - director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamModifier) (context.Context, *grpc.ClientConn, error) { + director := func(ctx context.Context, fullMethodName string, peeker proxy.StreamPeeker) (context.Context, *grpc.ClientConn, error) { t.Logf("modifying request for method %s", fullMethodName) - peekedMsg, err := peeker.Peek(ctx) + peekedMsg, err := peeker.Peek() require.NoError(t, err) peekedRequest := new(testservice.PingRequest) - require.NoError(t, proto.Unmarshal(peekedMsg, peekedRequest)) + require.NoError(t, proto.Unmarshal(peekedMsg.Payload(), peekedRequest)) require.Equal(t, "hi", peekedRequest.GetValue()) peekedRequest.Value = newValue @@ -102,7 +102,7 @@ func TestStreamInjecting(t *testing.T) { newPayload, err := proto.Marshal(peekedRequest) require.NoError(t, err) - require.NoError(t, peeker.Modify(ctx, newPayload)) + require.NoError(t, peekedMsg.Modify(newPayload)) return ctx, backendCC, nil } |