1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
|
package limithandler
import (
"context"
"google.golang.org/grpc"
)
// GetLockKey function defines the lock key of an RPC invocation based on its context
type GetLockKey func(context.Context) string
// LimiterMiddleware contains rate limiter state
type LimiterMiddleware struct {
methodLimiters map[string]*ConcurrencyLimiter
getLockKey GetLockKey
}
type wrappedStream struct {
grpc.ServerStream
info *grpc.StreamServerInfo
limiterMiddleware *LimiterMiddleware
initial bool
}
var maxConcurrencyPerRepoPerRPC map[string]int
// UnaryInterceptor returns a Unary Interceptor
func (c *LimiterMiddleware) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
lockKey := c.getLockKey(ctx)
if lockKey == "" {
return handler(ctx, req)
}
limiter := c.methodLimiters[info.FullMethod]
if limiter == nil {
// No concurrency limiting
return handler(ctx, req)
}
return limiter.Limit(ctx, lockKey, func() (interface{}, error) {
return handler(ctx, req)
})
}
}
// StreamInterceptor returns a Stream Interceptor
func (c *LimiterMiddleware) StreamInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &wrappedStream{stream, info, c, true}
return handler(srv, wrapper)
}
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
return err
}
// Only perform limiting on the first request of a stream
if !w.initial {
return nil
}
w.initial = false
ctx := w.Context()
lockKey := w.limiterMiddleware.getLockKey(ctx)
if lockKey == "" {
return nil
}
limiter := w.limiterMiddleware.methodLimiters[w.info.FullMethod]
if limiter == nil {
// No concurrency limiting
return nil
}
ready := make(chan struct{})
go limiter.Limit(ctx, lockKey, func() (interface{}, error) {
close(ready)
<-ctx.Done()
return nil, nil
})
select {
case <-ctx.Done():
return ctx.Err()
case <-ready:
// It's our turn!
return nil
}
}
// New creates a new rate limiter
func New(getLockKey GetLockKey) LimiterMiddleware {
return LimiterMiddleware{
methodLimiters: createLimiterConfig(),
getLockKey: getLockKey,
}
}
func createLimiterConfig() map[string]*ConcurrencyLimiter {
result := make(map[string]*ConcurrencyLimiter)
for fullMethodName, max := range maxConcurrencyPerRepoPerRPC {
result[fullMethodName] = NewLimiter(max, NewPromMonitor("gitaly", fullMethodName))
}
return result
}
// SetMaxRepoConcurrency Configures the max concurrency per repo per RPC
func SetMaxRepoConcurrency(config map[string]int) {
maxConcurrencyPerRepoPerRPC = config
}
|