1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
package limithandler
import (
"context"
"fmt"
"sync"
"time"
"golang.org/x/sync/semaphore"
)
// LimitedFunc represents a function that will be limited
type LimitedFunc func() (resp interface{}, err error)
// ConcurrencyMonitor allows the concurrency monitor to be observed
type ConcurrencyMonitor interface {
Queued(ctx context.Context)
Dequeued(ctx context.Context)
Enter(ctx context.Context, acquireTime time.Duration)
Exit(ctx context.Context)
}
// ConcurrencyLimiter contains rate limiter state
type ConcurrencyLimiter struct {
semaphores map[string]*semaphoreReference
max int64
mux *sync.Mutex
monitor ConcurrencyMonitor
}
type semaphoreReference struct {
// A weighted semaphore is like a mutex, but with a number of 'slots'.
// When locking the locker requests 1 or more slots to be locked.
// In this package, the number of slots is the number of concurrent requests the rate limiter lets through.
// https://godoc.org/golang.org/x/sync/semaphore
*semaphore.Weighted
count int
}
// Lazy create a semaphore for the given key
func (c *ConcurrencyLimiter) getSemaphore(lockKey string) *semaphoreReference {
c.mux.Lock()
defer c.mux.Unlock()
if ref := c.semaphores[lockKey]; ref != nil {
ref.count++
return ref
}
ref := &semaphoreReference{
Weighted: semaphore.NewWeighted(c.max),
count: 1, // The caller gets this reference so the initial value is 1
}
c.semaphores[lockKey] = ref
return ref
}
func (c *ConcurrencyLimiter) putSemaphore(lockKey string) {
c.mux.Lock()
defer c.mux.Unlock()
ref := c.semaphores[lockKey]
if ref == nil {
panic("semaphore should be in the map")
}
if ref.count <= 0 {
panic(fmt.Sprintf("bad semaphore ref count %d", ref.count))
}
ref.count--
if ref.count == 0 {
delete(c.semaphores, lockKey)
}
}
func (c *ConcurrencyLimiter) countSemaphores() int {
c.mux.Lock()
defer c.mux.Unlock()
return len(c.semaphores)
}
// Limit will limit the concurrency of f
func (c *ConcurrencyLimiter) Limit(ctx context.Context, lockKey string, f LimitedFunc) (interface{}, error) {
if c.max <= 0 {
return f()
}
start := time.Now()
c.monitor.Queued(ctx)
sem := c.getSemaphore(lockKey)
defer c.putSemaphore(lockKey)
err := sem.Acquire(ctx, 1)
c.monitor.Dequeued(ctx)
if err != nil {
return nil, err
}
defer sem.Release(1)
c.monitor.Enter(ctx, time.Since(start))
defer c.monitor.Exit(ctx)
return f()
}
// NewLimiter creates a new rate limiter
func NewLimiter(max int, monitor ConcurrencyMonitor) *ConcurrencyLimiter {
if monitor == nil {
monitor = &nullConcurrencyMonitor{}
}
return &ConcurrencyLimiter{
semaphores: make(map[string]*semaphoreReference),
max: int64(max),
mux: &sync.Mutex{},
monitor: monitor,
}
}
type nullConcurrencyMonitor struct{}
func (c *nullConcurrencyMonitor) Queued(ctx context.Context) {}
func (c *nullConcurrencyMonitor) Dequeued(ctx context.Context) {}
func (c *nullConcurrencyMonitor) Enter(ctx context.Context, acquireTime time.Duration) {}
func (c *nullConcurrencyMonitor) Exit(ctx context.Context) {}
|