diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2020-03-30 10:36:35 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2020-03-30 10:36:35 +0300 |
commit | e932dbba5a4e0052f398f951e875cf31aebc5392 (patch) | |
tree | 8996ddd8f1d4258a1379465040ed2a473daf1e59 | |
parent | 80bcbe7c3e3a76f4531e7b44ec93e538c88769ec (diff) | |
parent | 944081210ac74e2ff2e7d76e561d30e861ad8fe5 (diff) |
Merge branch 'jv-use-reference-count' into 'master'
Use reference counting in limithandler middleware
Closes #2588 and #2574
See merge request gitlab-org/gitaly!1984
3 files changed, 52 insertions, 56 deletions
diff --git a/changelogs/unreleased/jv-use-reference-count.yml b/changelogs/unreleased/jv-use-reference-count.yml new file mode 100644 index 000000000..a88801ce6 --- /dev/null +++ b/changelogs/unreleased/jv-use-reference-count.yml @@ -0,0 +1,5 @@ +--- +title: Use reference counting in limithandler middleware +merge_request: 1984 +author: +type: fixed diff --git a/internal/middleware/limithandler/concurrency_limiter.go b/internal/middleware/limithandler/concurrency_limiter.go index 66343f715..9c437ed9f 100644 --- a/internal/middleware/limithandler/concurrency_limiter.go +++ b/internal/middleware/limithandler/concurrency_limiter.go @@ -2,6 +2,7 @@ package limithandler import ( "context" + "fmt" "sync" "time" @@ -21,50 +22,56 @@ type ConcurrencyMonitor interface { // ConcurrencyLimiter contains rate limiter state type ConcurrencyLimiter struct { + semaphores map[string]*semaphoreReference + max int64 + mux *sync.Mutex + monitor ConcurrencyMonitor +} + +type semaphoreReference struct { // A weighted semaphore is like a mutex, but with a number of 'slots'. // When locking the locker requests 1 or more slots to be locked. // In this package, the number of slots is the number of concurrent requests the rate limiter lets through. // https://godoc.org/golang.org/x/sync/semaphore - semaphores map[string]*semaphore.Weighted - max int64 - mux *sync.Mutex - monitor ConcurrencyMonitor + *semaphore.Weighted + count int } // Lazy create a semaphore for the given key -func (c *ConcurrencyLimiter) getSemaphore(lockKey string) *semaphore.Weighted { +func (c *ConcurrencyLimiter) getSemaphore(lockKey string) *semaphoreReference { c.mux.Lock() defer c.mux.Unlock() - ws := c.semaphores[lockKey] - if ws != nil { - return ws + if ref := c.semaphores[lockKey]; ref != nil { + ref.count++ + return ref } - w := semaphore.NewWeighted(c.max) - c.semaphores[lockKey] = w - return w + ref := &semaphoreReference{ + Weighted: semaphore.NewWeighted(c.max), + count: 1, // The caller gets this reference so the initial value is 1 + } + c.semaphores[lockKey] = ref + return ref } -func (c *ConcurrencyLimiter) attemptCollection(lockKey string) { +func (c *ConcurrencyLimiter) putSemaphore(lockKey string) { c.mux.Lock() defer c.mux.Unlock() - ws := c.semaphores[lockKey] - if ws == nil { - return + ref := c.semaphores[lockKey] + if ref == nil { + panic("semaphore should be in the map") } - if !ws.TryAcquire(c.max) { - return + if ref.count <= 0 { + panic(fmt.Sprintf("bad semaphore ref count %d", ref.count)) } - // By releasing, we prevent a lockup of goroutines that have already - // acquired the semaphore, but have yet to acquire on it - ws.Release(c.max) - - // If we managed to acquire all the locks, we can remove the semaphore for this key - delete(c.semaphores, lockKey) + ref.count-- + if ref.count == 0 { + delete(c.semaphores, lockKey) + } } func (c *ConcurrencyLimiter) countSemaphores() int { @@ -83,26 +90,20 @@ func (c *ConcurrencyLimiter) Limit(ctx context.Context, lockKey string, f Limite start := time.Now() c.monitor.Queued(ctx) - w := c.getSemaphore(lockKey) - - // Attempt to cleanup the semaphore it's no longer being used - defer c.attemptCollection(lockKey) + sem := c.getSemaphore(lockKey) + defer c.putSemaphore(lockKey) - err := w.Acquire(ctx, 1) + err := sem.Acquire(ctx, 1) c.monitor.Dequeued(ctx) - if err != nil { return nil, err } + defer sem.Release(1) c.monitor.Enter(ctx, time.Since(start)) defer c.monitor.Exit(ctx) - defer w.Release(1) - - resp, err := f() - - return resp, err + return f() } // NewLimiter creates a new rate limiter @@ -112,7 +113,7 @@ func NewLimiter(max int, monitor ConcurrencyMonitor) *ConcurrencyLimiter { } return &ConcurrencyLimiter{ - semaphores: make(map[string]*semaphore.Weighted), + semaphores: make(map[string]*semaphoreReference), max: int64(max), mux: &sync.Mutex{}, monitor: monitor, diff --git a/internal/middleware/limithandler/concurrency_limiter_test.go b/internal/middleware/limithandler/concurrency_limiter_test.go index 735ff77c6..7300fbacc 100644 --- a/internal/middleware/limithandler/concurrency_limiter_test.go +++ b/internal/middleware/limithandler/concurrency_limiter_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type counter struct { @@ -73,9 +74,7 @@ func TestLimiter(t *testing.T) { concurrency int maxConcurrency int iterations int - delay time.Duration buckets int - wantMaxRange []int wantMonitorCalls bool }{ { @@ -83,9 +82,7 @@ func TestLimiter(t *testing.T) { concurrency: 1, maxConcurrency: 1, iterations: 1, - delay: 1 * time.Millisecond, buckets: 1, - wantMaxRange: []int{1, 1}, wantMonitorCalls: true, }, { @@ -93,19 +90,15 @@ func TestLimiter(t *testing.T) { concurrency: 100, maxConcurrency: 2, iterations: 10, - delay: 1 * time.Millisecond, buckets: 1, - wantMaxRange: []int{2, 3}, wantMonitorCalls: true, }, { name: "two-by-two", concurrency: 100, maxConcurrency: 2, - delay: 1000 * time.Nanosecond, iterations: 4, buckets: 2, - wantMaxRange: []int{4, 5}, wantMonitorCalls: true, }, { @@ -113,9 +106,7 @@ func TestLimiter(t *testing.T) { concurrency: 10, maxConcurrency: 0, iterations: 200, - delay: 1000 * time.Nanosecond, buckets: 1, - wantMaxRange: []int{8, 10}, wantMonitorCalls: false, }, { @@ -125,17 +116,19 @@ func TestLimiter(t *testing.T) { // We use a long delay here to prevent flakiness in CI. If the delay is // too short, the first goroutines to enter the critical section will be // gone before we hit the intended maximum concurrency. - delay: 5 * time.Millisecond, iterations: 40, buckets: 50, - wantMaxRange: []int{95, 105}, wantMonitorCalls: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + expectedGaugeMax := tt.maxConcurrency * tt.buckets + if tt.maxConcurrency <= 0 { + expectedGaugeMax = tt.concurrency + } + gauge := &counter{} - start := make(chan struct{}) limiter := NewLimiter(tt.maxConcurrency, gauge) wg := sync.WaitGroup{} @@ -152,7 +145,7 @@ func TestLimiter(t *testing.T) { gauge.up() - if gauge.max >= tt.wantMaxRange[0] { + if gauge.max >= expectedGaugeMax { full.Broadcast() return } @@ -165,7 +158,6 @@ func TestLimiter(t *testing.T) { // concurrently. for c := 0; c < tt.concurrency; c++ { go func(counter int) { - <-start for i := 0; i < tt.iterations; i++ { lockKey := strconv.Itoa((i ^ counter) % tt.buckets) @@ -173,24 +165,22 @@ func TestLimiter(t *testing.T) { primePump() current := gauge.currentVal() - assert.True(t, current <= tt.wantMaxRange[1], "Expected the number of concurrent operations (%v) to not exceed the maximum concurrency (%v)", current, tt.wantMaxRange[1]) - assert.True(t, limiter.countSemaphores() <= tt.buckets, "Expected the number of semaphores (%v) to be lte number of buckets (%v)", limiter.countSemaphores(), tt.buckets) + require.True(t, current <= expectedGaugeMax, "Expected the number of concurrent operations (%v) to not exceed the maximum concurrency (%v)", current, expectedGaugeMax) + + require.True(t, limiter.countSemaphores() <= tt.buckets, "Expected the number of semaphores (%v) to be lte number of buckets (%v)", limiter.countSemaphores(), tt.buckets) gauge.down() return nil, nil }) - - time.Sleep(tt.delay) } wg.Done() }(c) } - close(start) wg.Wait() - assert.True(t, tt.wantMaxRange[0] <= gauge.max && gauge.max <= tt.wantMaxRange[1], "Expected maximum concurrency to be in the range [%v,%v] but got %v", tt.wantMaxRange[0], tt.wantMaxRange[1], gauge.max) + assert.Equal(t, expectedGaugeMax, gauge.max, "Expected maximum concurrency") assert.Equal(t, 0, gauge.current) assert.Equal(t, 0, limiter.countSemaphores()) |