Welcome to mirror list, hosted at ThFree Co, Russian Federation.

proxy.go « sidechannel « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: db184ff8878ea63827a2bfb59d15ffef08f9a82e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package sidechannel

import (
	"context"
	"fmt"
	"io"

	"gitlab.com/gitlab-org/gitaly/v14/internal/metadata"
	"google.golang.org/grpc"
	grpcMetadata "google.golang.org/grpc/metadata"
)

// NewUnaryProxy creates a gRPC client middleware that proxies sidechannels.
func NewUnaryProxy(registry *Registry) grpc.UnaryClientInterceptor {
	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
		if !hasSidechannelMetadata(ctx) {
			return invoker(ctx, method, req, reply, cc, opts...)
		}

		ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx))
		defer waiter.Close()

		if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
			return err
		}
		if err := waiter.Close(); err != nil && err != ErrCallbackDidNotRun {
			return fmt.Errorf("sidechannel: proxy callback: %w", err)
		}
		return nil
	}
}

// NewStreamProxy creates a gRPC client middleware that proxies sidechannels.
func NewStreamProxy(registry *Registry) grpc.StreamClientInterceptor {
	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
		if !hasSidechannelMetadata(ctx) {
			return streamer(ctx, desc, cc, method, opts...)
		}

		ctx, waiter := RegisterSidechannel(ctx, registry, proxy(ctx))
		go func() {
			<-ctx.Done()
			// The Close() error is checked and bubbled up in
			// streamWrapper.RecvMsg(). This call is just for cleanup.
			_ = waiter.Close()
		}()

		cs, err := streamer(ctx, desc, cc, method, opts...)
		if err != nil {
			return nil, err
		}

		return &streamWrapper{ClientStream: cs, waiter: waiter}, nil
	}
}

type streamWrapper struct {
	grpc.ClientStream
	waiter *Waiter
}

func (sw *streamWrapper) RecvMsg(m interface{}) error {
	if err := sw.ClientStream.RecvMsg(m); err != io.EOF {
		return err
	}

	if err := sw.waiter.Close(); err != nil && err != ErrCallbackDidNotRun {
		return fmt.Errorf("sidechannel: proxy callback: %w", err)
	}

	return io.EOF
}

func hasSidechannelMetadata(ctx context.Context) bool {
	md, ok := grpcMetadata.FromOutgoingContext(ctx)
	return ok && len(md.Get(sidechannelMetadataKey)) > 0
}

func proxy(ctx context.Context) func(*ClientConn) error {
	return func(upstream *ClientConn) error {
		downstream, err := OpenSidechannel(metadata.OutgoingToIncoming(ctx))
		if err != nil {
			return err
		}
		defer downstream.Close()

		fromDownstream := make(chan error, 1)
		go func() {
			fromDownstream <- func() error {
				if _, err := io.Copy(upstream, downstream); err != nil {
					return fmt.Errorf("copy to upstream: %w", err)
				}

				if err := upstream.CloseWrite(); err != nil {
					return fmt.Errorf("closewrite upstream: %w", err)
				}

				return nil
			}()
		}()

		fromUpstream := make(chan error, 1)
		go func() {
			fromUpstream <- func() error {
				if _, err := io.Copy(downstream, upstream); err != nil {
					return fmt.Errorf("copy to downstream: %w", err)
				}

				return nil
			}()
		}()

	waitForUpstream:
		for {
			select {
			case err := <-fromUpstream:
				if err != nil {
					return err
				}

				break waitForUpstream
			case err := <-fromDownstream:
				if err != nil {
					return err
				}
			case <-ctx.Done():
				return ctx.Err()
			}
		}

		if err := downstream.Close(); err != nil {
			return fmt.Errorf("close downstream: %w", err)
		}

		return nil
	}
}