diff options
Diffstat (limited to 'internal/praefect/grpc-proxy/proxy/peeker.go')
-rw-r--r-- | internal/praefect/grpc-proxy/proxy/peeker.go | 54 |
1 files changed, 28 insertions, 26 deletions
diff --git a/internal/praefect/grpc-proxy/proxy/peeker.go b/internal/praefect/grpc-proxy/proxy/peeker.go index 1ba7cfdaf..1c07a343f 100644 --- a/internal/praefect/grpc-proxy/proxy/peeker.go +++ b/internal/praefect/grpc-proxy/proxy/peeker.go @@ -2,23 +2,23 @@ package proxy import ( "errors" - "fmt" - "golang.org/x/net/context" "google.golang.org/grpc" ) -// StreamModifier abstracts away the gRPC stream being forwarded so that it can +// StreamPeeker abstracts away the gRPC stream being forwarded so that it can // be inspected and modified. -type StreamModifier interface { - // Peek allows a director to peak a messages into the stream without +type StreamPeeker interface { + // Peek allows a director to peek a messages into the stream without // removing those messages from the stream that will be forwarded to // the backend server. - Peek(ctx context.Context) (frame []byte, _ error) + Peek() (Frame, error) +} - // Modify modifies a payload in the stream. It will replace a frame with the - // payload it is given - Modify(ctx context.Context, payload []byte) error +// Frame contains a payload that can be optionally modified +type Frame interface { + Modify(payload []byte) error + Payload() []byte } type partialStream struct { @@ -42,30 +42,25 @@ func newPeeker(stream grpc.ServerStream) *peeker { // peek quanity var ErrInvalidPeekCount = errors.New("peek count must be greater than zero") -func (p peeker) Peek(ctx context.Context) ([]byte, error) { - payloads, err := p.peek(ctx, 1) +func (p peeker) Peek() (Frame, error) { + frames, err := p.peek(1) if err != nil { return nil, err } - if len(payloads) != 1 { + if len(frames) != 1 { return nil, errors.New("failed to peek 1 message") } - return payloads[0], nil -} - -func (p peeker) Modify(ctx context.Context, payload []byte) error { - return p.modify(ctx, [][]byte{payload}) + return frames[0], nil } -func (p peeker) peek(ctx context.Context, n uint) ([][]byte, error) { +func (p peeker) peek(n uint) ([]*frame, error) { if n < 1 { return nil, ErrInvalidPeekCount } p.consumedStream.frames = make([]*frame, n) - peekedFrames := make([][]byte, n) for i := 0; i < len(p.consumedStream.frames); i++ { f := &frame{} @@ -74,20 +69,27 @@ func (p peeker) peek(ctx context.Context, n uint) ([][]byte, error) { break } p.consumedStream.frames[i] = f - peekedFrames[i] = f.payload } - return peekedFrames, nil + return p.consumedStream.frames, nil } -func (p peeker) modify(ctx context.Context, payloads [][]byte) error { - if len(payloads) != len(p.consumedStream.frames) { - return fmt.Errorf("replacement frames count %d does not match consumed frames count %d", len(payloads), len(p.consumedStream.frames)) +func (f *frame) Payload() []byte { + return f.payload +} + +func (f *frame) Modify(payload []byte) error { + f.Lock() + defer f.Unlock() + + if f.consumed { + return errors.New("frame has already been consumed") } - for i, payload := range payloads { - p.consumedStream.frames[i].payload = payload + if f.payload == nil { + return errors.New("frame payload is empty") } + f.payload = payload return nil } |