diff options
author | Patrick Steinhardt <psteinhardt@gitlab.com> | 2023-09-19 09:34:10 +0300 |
---|---|---|
committer | Patrick Steinhardt <psteinhardt@gitlab.com> | 2023-09-19 09:34:10 +0300 |
commit | 76b2eed1b30847ff69b6e96a9845388405101c6e (patch) | |
tree | 919e147c228aedb002265998ab539899b0ade6ab | |
parent | 0bef3b7e4677bb820c8ab8946024688f27fc5282 (diff) | |
parent | a3565b4f6fcc5013adb734fa0dddff046cfbd835 (diff) |
Merge branch 'qmnguyen0711/implement-resizable-semaphore' into 'master'
Implement resizable semaphore data structure
See merge request https://gitlab.com/gitlab-org/gitaly/-/merge_requests/6366
Merged-by: Patrick Steinhardt <psteinhardt@gitlab.com>
Approved-by: Patrick Steinhardt <psteinhardt@gitlab.com>
Approved-by: James Liu <jliu@gitlab.com>
Reviewed-by: Patrick Steinhardt <psteinhardt@gitlab.com>
Reviewed-by: Quang-Minh Nguyen <qmnguyen@gitlab.com>
Reviewed-by: James Liu <jliu@gitlab.com>
Co-authored-by: Quang-Minh Nguyen <qmnguyen@gitlab.com>
-rw-r--r-- | internal/limiter/adaptive_calculator.go | 2 | ||||
-rw-r--r-- | internal/limiter/adaptive_calculator_test.go | 10 | ||||
-rw-r--r-- | internal/limiter/adaptive_limit.go | 62 | ||||
-rw-r--r-- | internal/limiter/adaptive_limit_test.go | 124 | ||||
-rw-r--r-- | internal/limiter/resizable_semaphore.go | 250 | ||||
-rw-r--r-- | internal/limiter/resizable_semaphore_test.go | 572 |
6 files changed, 1002 insertions, 18 deletions
diff --git a/internal/limiter/adaptive_calculator.go b/internal/limiter/adaptive_calculator.go index bb7088527..ebb7526e0 100644 --- a/internal/limiter/adaptive_calculator.go +++ b/internal/limiter/adaptive_calculator.go @@ -262,7 +262,7 @@ func (c *AdaptiveCalculator) calibrateLimits(ctx context.Context) { }).Debugf("Additive increase") } else { // Multiplicative decrease - newLimit = int(math.Floor(float64(limit.Current()) * setting.BackoffBackoff)) + newLimit = int(math.Floor(float64(limit.Current()) * setting.BackoffFactor)) if newLimit < setting.Min { newLimit = setting.Min } diff --git a/internal/limiter/adaptive_calculator_test.go b/internal/limiter/adaptive_calculator_test.go index 7fed9b01b..8da85704d 100644 --- a/internal/limiter/adaptive_calculator_test.go +++ b/internal/limiter/adaptive_calculator_test.go @@ -657,12 +657,14 @@ func (l *testLimit) Update(val int) { l.currents = append(l.currents, val) } +func (*testLimit) AfterUpdate(_ AfterUpdateHook) {} + func (l *testLimit) Setting() AdaptiveSetting { return AdaptiveSetting{ - Initial: l.initial, - Max: l.max, - Min: l.min, - BackoffBackoff: l.backoffBackoff, + Initial: l.initial, + Max: l.max, + Min: l.min, + BackoffFactor: l.backoffBackoff, } } diff --git a/internal/limiter/adaptive_limit.go b/internal/limiter/adaptive_limit.go index 2aefe5eb0..71eddc389 100644 --- a/internal/limiter/adaptive_limit.go +++ b/internal/limiter/adaptive_limit.go @@ -1,30 +1,48 @@ package limiter -import "sync/atomic" +import ( + "sync" +) // AdaptiveSetting is a struct that holds the configuration parameters for an adaptive limiter. type AdaptiveSetting struct { - Initial int - Max int - Min int - BackoffBackoff float64 + Initial int + Max int + Min int + BackoffFactor float64 } +// AfterUpdateHook is a function hook that is triggered when the current limit changes. The callers need to register a hook to +// the AdaptiveLimiter implementation beforehand. They are required to handle errors inside the hook function. +type AfterUpdateHook func(newVal int) + // AdaptiveLimiter is an interface for managing and updating adaptive limits. // It exposes methods to get the name, current limit value, update the limit value, and access its settings. type AdaptiveLimiter interface { Name() string Current() int Update(val int) + AfterUpdate(AfterUpdateHook) Setting() AdaptiveSetting } -// AdaptiveLimit is an implementation of the AdaptiveLimiter interface. It uses an atomic Int32 to represent the current -// limit value, ensuring thread-safe updates. +// AdaptiveLimit is an implementation of the AdaptiveLimiter interface. It uses a mutex to ensure thread-safe access to the limit value. type AdaptiveLimit struct { - name string - current atomic.Int32 - setting AdaptiveSetting + sync.Mutex + + name string + current int + setting AdaptiveSetting + updateHooks []AfterUpdateHook +} + +// NewAdaptiveLimit initializes a new AdaptiveLimit object +func NewAdaptiveLimit(name string, setting AdaptiveSetting) *AdaptiveLimit { + return &AdaptiveLimit{ + name: name, + current: setting.Initial, + setting: setting, + } } // Name returns the name of the adaptive limit @@ -34,12 +52,30 @@ func (l *AdaptiveLimit) Name() string { // Current returns the current limit. This function can be called without the need for synchronization. func (l *AdaptiveLimit) Current() int { - return int(l.current.Load()) + l.Lock() + defer l.Unlock() + + return l.current } -// Update adjusts current limit value. +// Update adjusts the current limit value and executes all registered update hooks. func (l *AdaptiveLimit) Update(val int) { - l.current.Store(int32(val)) + l.Lock() + defer l.Unlock() + + if val != l.current { + l.current = val + for _, hook := range l.updateHooks { + hook(val) + } + } +} + +// AfterUpdate registers a callback when the current limit is updated. Because all updates and hooks are synchronized, +// calling l.Current() inside the update hook in the same goroutine will cause deadlock. Hence, the update hook must +// use the newVal argument instead. +func (l *AdaptiveLimit) AfterUpdate(hook AfterUpdateHook) { + l.updateHooks = append(l.updateHooks, hook) } // Setting returns the configuration parameters for an adaptive limiter. diff --git a/internal/limiter/adaptive_limit_test.go b/internal/limiter/adaptive_limit_test.go new file mode 100644 index 000000000..f6c67f092 --- /dev/null +++ b/internal/limiter/adaptive_limit_test.go @@ -0,0 +1,124 @@ +package limiter + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdaptiveLimit_New(t *testing.T) { + t.Parallel() + + setting := AdaptiveSetting{ + Initial: 5, + Max: 10, + Min: 1, + BackoffFactor: 0.5, + } + + limit := NewAdaptiveLimit("testLimit", setting) + require.Equal(t, limit.Name(), "testLimit") + require.Equal(t, limit.Current(), 5) + require.Equal(t, limit.Setting(), setting) +} + +func TestAdaptiveLimit_Update(t *testing.T) { + t.Parallel() + + newLimit := func() *AdaptiveLimit { + return NewAdaptiveLimit("testLimit", AdaptiveSetting{ + Initial: 5, + Max: 10, + Min: 1, + BackoffFactor: 0.5, + }) + } + + t.Run("without update hooks", func(t *testing.T) { + limit := newLimit() + + limit.Update(1) + require.Equal(t, 1, limit.Current()) + + limit.Update(2) + require.Equal(t, 2, limit.Current()) + + limit.Update(3) + require.Equal(t, 3, limit.Current()) + }) + + t.Run("new values are different from old values", func(t *testing.T) { + limit := newLimit() + + vals := []int{} + limit.AfterUpdate(func(val int) { + vals = append(vals, val) + }) + + limit.Update(1) + require.Equal(t, 1, limit.Current()) + require.Equal(t, vals, []int{1}) + + limit.Update(2) + require.Equal(t, 2, limit.Current()) + require.Equal(t, vals, []int{1, 2}) + + limit.Update(3) + require.Equal(t, 3, limit.Current()) + require.Equal(t, vals, []int{1, 2, 3}) + }) + + t.Run("new values are the same as old values", func(t *testing.T) { + limit := newLimit() + + vals := []int{} + limit.AfterUpdate(func(val int) { + vals = append(vals, val) + }) + + limit.Update(1) + require.Equal(t, 1, limit.Current()) + require.Equal(t, vals, []int{1}) + + limit.Update(1) + require.Equal(t, 1, limit.Current()) + require.Equal(t, vals, []int{1}) + + limit.Update(2) + require.Equal(t, 2, limit.Current()) + require.Equal(t, vals, []int{1, 2}) + + limit.Update(2) + require.Equal(t, 2, limit.Current()) + require.Equal(t, vals, []int{1, 2}) + }) + + t.Run("multiple update hooks", func(t *testing.T) { + limit := newLimit() + + vals1 := []int{} + limit.AfterUpdate(func(val int) { + vals1 = append(vals1, val) + }) + + vals2 := []int{} + limit.AfterUpdate(func(val int) { + vals2 = append(vals2, val*2) + }) + + limit.Update(1) + require.Equal(t, 1, limit.Current()) + require.Equal(t, vals1, []int{1}) + require.Equal(t, vals2, []int{2}) + + limit.Update(2) + require.Equal(t, 2, limit.Current()) + require.Equal(t, vals1, []int{1, 2}) + require.Equal(t, vals2, []int{2, 4}) + + limit.Update(3) + require.Equal(t, 3, limit.Current()) + require.Equal(t, vals1, []int{1, 2, 3}) + require.Equal(t, vals2, []int{2, 4, 6}) + }) +} diff --git a/internal/limiter/resizable_semaphore.go b/internal/limiter/resizable_semaphore.go new file mode 100644 index 000000000..b09ea43b4 --- /dev/null +++ b/internal/limiter/resizable_semaphore.go @@ -0,0 +1,250 @@ +package limiter + +import ( + "container/list" + "context" + "sync" +) + +// resizableSemaphore struct models a semaphore with a dynamically adjustable size. It bounds the concurrent access to +// resources, allowing a certain level of concurrency. When the concurrency reaches the semaphore's capacity, the callers +// are blocked until a resource becomes available again. The size of the semaphore can be adjusted atomically at any time. +// When the semaphore gets resized to a size smaller than the current number of resources acquired, the semaphore is +// considered to be full. The "leftover" acquirers can still keep the resource until they release the semaphore. The +// semaphore cannot be acquired until the amount of acquirers fall under the size again. +// +// Internally, it uses a doubly-linked list to manage waiters when the semaphore is full. Callers acquire the semaphore +// by invoking `Acquire()`, and release them by calling `Release()`. This struct ensures that the available slots are +// properly managed, and also handles the semaphore's current count and size. It processes resize requests and manages +// try requests and responses, ensuring smooth operation. +// +// This implementation is heavily inspired by "golang.org/x/sync/semaphore" package's implementation. +// +// Note: This struct is not intended to serve as a general-purpose data structure but is specifically designed for +// flexible concurrency control with resizable capacity. +type resizableSemaphore struct { + sync.Mutex + // current represents the current concurrency access to the resources. + current uint + // leftover accounts for the number of extra acquirers when the size shrinks down. + leftover uint + // size is the maximum capacity of the semaphore. It represents the maximum number of concurrent accesses allowed + // to the resource at the current time. + size uint + // waiters is a FIFO list of waiters waiting for the resource. + waiters *list.List +} + +// waiter is a wrapper to be put into the waiting queue. When there is an available resource, the front waiter is pulled +// out and ready channel is closed. +type waiter struct { + ready chan struct{} +} + +// NewResizableSemaphore creates a new resizableSemaphore with the specified initial size. +func NewResizableSemaphore(size uint) *resizableSemaphore { + return &resizableSemaphore{ + size: size, + waiters: list.New(), + } +} + +// Acquire allows the caller to acquire the semaphore. If the semaphore is full, the caller is blocked until there +// is an available slot or the context is canceled. If the context is canceled, context's error is returned. Otherwise, +// this function returns nil after acquired. +func (s *resizableSemaphore) Acquire(ctx context.Context) error { + s.Lock() + if s.count() < s.size { + select { + case <-ctx.Done(): + s.Unlock() + return ctx.Err() + default: + s.current++ + s.Unlock() + return nil + } + } + + w := &waiter{ready: make(chan struct{})} + element := s.waiters.PushBack(w) + s.Unlock() + + select { + case <-ctx.Done(): + return s.stopWaiter(element, w, ctx.Err()) + case <-w.ready: + return nil + } +} + +func (s *resizableSemaphore) stopWaiter(element *list.Element, w *waiter, err error) error { + s.Lock() + defer s.Unlock() + + select { + case <-w.ready: + // If the waiter is ready at the same time as the context cancellation, act as if this + // waiter is not aware of the cancellation. At this point, the linked list item is + // properly removed from the queue and the waiter is considered to acquire the + // semaphore. Otherwise, there might be a race that makes Acquire() returns an error + // even after the acquisition. + err = nil + default: + isFront := s.waiters.Front() == element + s.waiters.Remove(element) + // If we're at the front and there are extra slots left, notify next waiters in the + // queue. If all waiters in the queue have the same context, this action is not + // necessary because the rest will return an error anyway. Unfortunately, as we accept + // ctx as an argument of Acquire(), it's possible for waiters to have different + // contexts. Hence, we need to scan the waiter list, just in case. + if isFront { + s.notifyWaiters() + } + } + return err +} + +// notifyWaiters scans from the head of the s.waiters linked list, removing waiters until there are no free slots. This +// function must only be called after the mutex of s is acquired. +func (s *resizableSemaphore) notifyWaiters() { + for { + element := s.waiters.Front() + if element == nil { + break + } + + if s.count() >= s.size { + return + } + + w := element.Value.(*waiter) + s.current++ + s.waiters.Remove(element) + close(w.ready) + } +} + +// TryAcquire attempts to acquire the semaphore without blocking. On success, it returns nil. On failure, it returns +// ErrMaxQueueSize and leaves the semaphore unchanged. +func (s *resizableSemaphore) TryAcquire() error { + s.Lock() + defer s.Unlock() + + // Technically, if the number of waiters is less than the number of available slots, the caller + // of this function should be put to at the end of the queue. However, the queue always moved + // up when a slot is available or when a waiter's context is cancelled. Thus, as soon as there + // are waiters in the queue, there is no chance for this caller to acquire the semaphore + // without waiting. + if s.count() < s.size && s.waiters.Len() == 0 { + s.current++ + return nil + } + return ErrMaxQueueSize +} + +// Release releases the semaphore. +func (s *resizableSemaphore) Release() { + s.Lock() + defer s.Unlock() + // Deduct leftover first, because we want to release the remaining extra slots that were acquired before + // the semaphore was shrunk. The semaphore can be acquired again when current < size and leftover = 0. + // ┌────────── size ────────────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅ + // └─────────── current ───────────┴─ leftover ┘ + if s.leftover > 0 { + s.leftover-- + } else { + s.current-- + } + s.notifyWaiters() +} + +// Count returns the number of concurrent accesses allowed by the semaphore. +func (s *resizableSemaphore) Count() int { + s.Lock() + defer s.Unlock() + return int(s.count()) +} + +func (s *resizableSemaphore) count() uint { + return s.current + s.leftover +} + +// Resize modifies the maximum number of concurrent accesses allowed by the semaphore. +func (s *resizableSemaphore) Resize(newSize uint) { + s.Lock() + defer s.Unlock() + + if newSize == s.size { + return + } + + s.size = newSize + currentCount := s.count() + if newSize > currentCount { + // Case 1: The semaphore is full. There is no leftover. The current and leftover stays intact. + // ┌─────────────── New size ────────────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ + // └───────── current ─────────┘ + // └────── Previous size ─────┘ + // + // Case 2: The semaphore is full. The leftover might exceed the previous size. + // ┌────────────────── New size ──────────────────────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅ □ □ □ □ □ + // └─────────── current ───────────┴─ leftover ┘ + // └─────── Previous size ────────┘ + // + // Case 3: The semaphore is not full. It's not feasible to have leftover because leftover is deducted + // before current. If the semaphore's size grows after shrinking down, the leftover is properly + // restructured. + // ┌─────────────── New size ────────────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ + // └────────── current ─────────┘ │ + // └────────── Previous size ─────────┘ + // + // Case 4: The semaphore is not full but the new size is less than the previous size. The queue stays + // idle. It's not necessary to notify the waiters, + // ┌─────────── New size ───────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ + // └────────── current ─────────┘ │ + // └────────── Previous size ─────────┘ + // + // In either case, the new size covers the current and leftover. No need to continue accounting for + // leftover. We also need to notify the waiters to move up the queues if there are available slots. If + // there isn't (case 4), no need to notify the waiters, but the function returns immediately. + s.leftover = 0 + s.current = currentCount + s.notifyWaiters() + } else { + // Case 1: The semaphore is full. There is no leftover. + // ┌──────── New size ──────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ + // └───────────── current ───────────────┘ + // └─────────── Previous size ──────────┘ + // + // Case 2: The semaphore is full. There are some leftovers. + // ┌──────── New size ──────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅ + // └───────────── current ───────────────┴─ leftover ┘ + // └─────────── Previous size ──────────┘ + // + // Case 3: The semaphore is not full. Similar to case 3 above, there shouldn't be any leftover. + // ┌────── New size ────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ + // └────────── current ──────────┘ │ + // └────────── Previous size ───────────┘ + // + // Case 4: The new size is equal to the current count. The semaphore is either saturated or not, the + // leftover is reset. + // ┌────────── New size ──────────┐ + // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ + // └────────── current ──────────┘ │ + // └────────── Previous size ───────────┘ + // + // In all of the above cases, the semaphore sets the current to the new size and convert the rest to + // leftover. There isn't any new slot, hence no need to notify the waiters. + s.current = newSize + s.leftover = currentCount - newSize + } +} diff --git a/internal/limiter/resizable_semaphore_test.go b/internal/limiter/resizable_semaphore_test.go new file mode 100644 index 000000000..af9aa5733 --- /dev/null +++ b/internal/limiter/resizable_semaphore_test.go @@ -0,0 +1,572 @@ +package limiter + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper" +) + +func TestResizableSemaphore_New(t *testing.T) { + t.Parallel() + + semaphore := NewResizableSemaphore(5) + require.Equal(t, 0, semaphore.Count()) +} + +// TestResizableSemaphore_ContextCanceled ensures that when consumers successfully acquire the semaphore with a given +// context, and that context is subsequently canceled, the consumers are still able to release the semaphore. It also +// tests that consumers waiting to acquire the semaphore with the same canceled context are not able to do so once slots +// become available. +func TestResizableSemaphore_ContextCanceled(t *testing.T) { + t.Parallel() + + t.Run("context is canceled when the semaphore is empty", func(t *testing.T) { + ctx, cancel := context.WithCancel(testhelper.Context(t)) + cancel() + + semaphore := NewResizableSemaphore(5) + + require.Equal(t, context.Canceled, semaphore.Acquire(ctx)) + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("context is canceled when the semaphore is not full", func(t *testing.T) { + ctx, cancel := context.WithCancel(testhelper.Context(t)) + testResizableSemaphoreCanceledWhenNotFull(t, ctx, cancel, context.Canceled) + }) + + t.Run("context is canceled when the semaphore is full", func(t *testing.T) { + ctx, cancel := context.WithCancel(testhelper.Context(t)) + testResizableSemaphoreCanceledWhenFull(t, ctx, cancel, context.Canceled) + }) + + t.Run("context's deadline exceeded when the semaphore is empty", func(t *testing.T) { + ctx, cancel := context.WithDeadline(testhelper.Context(t), time.Now().Add(-1*time.Hour)) + defer cancel() + + semaphore := NewResizableSemaphore(5) + + require.Equal(t, context.DeadlineExceeded, semaphore.Acquire(ctx)) + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("context's deadline exceeded when the semaphore is not full", func(t *testing.T) { + ctx, cancel, simulateTimeout := testhelper.ContextWithSimulatedTimeout(testhelper.Context(t)) + defer cancel() + + testResizableSemaphoreCanceledWhenNotFull(t, ctx, simulateTimeout, context.DeadlineExceeded) + }) + + t.Run("context's deadline exceeded when the semaphore is full", func(t *testing.T) { + ctx, cancel, simulateTimeout := testhelper.ContextWithSimulatedTimeout(testhelper.Context(t)) + defer cancel() + + testResizableSemaphoreCanceledWhenFull(t, ctx, simulateTimeout, context.DeadlineExceeded) + }) +} + +func testResizableSemaphoreCanceledWhenNotFull(t *testing.T, ctx context.Context, stopContext context.CancelFunc, expectedErr error) { + semaphore := NewResizableSemaphore(5) + + // 3 goroutines acquired semaphore + beforeCallRelease := make(chan struct{}) + var acquireWg, releaseWg sync.WaitGroup + for i := 0; i < 3; i++ { + acquireWg.Add(1) + releaseWg.Add(1) + go func() { + require.Nil(t, semaphore.Acquire(ctx)) + acquireWg.Done() + + <-beforeCallRelease + + semaphore.Release() + releaseWg.Done() + }() + } + acquireWg.Wait() + + // Now cancel the context + stopContext() + require.Equal(t, expectedErr, semaphore.Acquire(ctx)) + + // The first 3 goroutines can call Release() even if the context is cancelled + close(beforeCallRelease) + releaseWg.Wait() + + require.Equal(t, expectedErr, semaphore.Acquire(ctx)) + require.Equal(t, 0, semaphore.Count()) +} + +func testResizableSemaphoreCanceledWhenFull(t *testing.T, ctx context.Context, stopContext context.CancelFunc, expectedErr error) { + semaphore := NewResizableSemaphore(5) + + // Try to acquire a token of the empty sempahore + require.Nil(t, semaphore.TryAcquire()) + semaphore.Release() + + // 5 goroutines acquired semaphore + beforeCallRelease := make(chan struct{}) + var acquireWg1, releaseWg1 sync.WaitGroup + for i := 0; i < 5; i++ { + acquireWg1.Add(1) + releaseWg1.Add(1) + go func() { + require.Nil(t, semaphore.Acquire(ctx)) + acquireWg1.Done() + + <-beforeCallRelease + + semaphore.Release() + releaseWg1.Done() + }() + } + acquireWg1.Wait() + + // Another 5 waits for sempahore + var acquireWg2 sync.WaitGroup + for i := 0; i < 5; i++ { + acquireWg2.Add(1) + go func() { + // This goroutine is block until the context is cancel, which returns canceled error + require.Equal(t, expectedErr, semaphore.Acquire(ctx)) + acquireWg2.Done() + }() + } + + // Try to acquire a token of the full semaphore + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + + // Cancel the context + stopContext() + acquireWg2.Wait() + + // The first 5 goroutines can call Release() even if the context is cancelled + close(beforeCallRelease) + releaseWg1.Wait() + + // The last 5 goroutines exits immediately, Acquire() returns error + acquireWg2.Wait() + + require.Equal(t, 0, semaphore.Count()) + + // Now the context is cancelled + require.Equal(t, expectedErr, semaphore.Acquire(ctx)) +} + +func TestResizableSemaphore_Acquire(t *testing.T) { + t.Parallel() + + t.Run("acquire less than the capacity", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(5) + + waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 3) + require.Equal(t, 3, semaphore.Count()) + + require.Nil(t, semaphore.Acquire(ctx)) + require.Equal(t, 4, semaphore.Count()) + + require.Nil(t, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + + close(waitBeforeRelease) + waitRelease() + + // Still 2 left + require.Equal(t, 2, semaphore.Count()) + semaphore.Release() + semaphore.Release() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("acquire more than the capacity", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(5) + + waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 5, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + + close(waitBeforeRelease) + waitRelease() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("semaphore is full then available again", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(5) + for i := 0; i < 5; i++ { + require.NoError(t, semaphore.Acquire(ctx)) + } + + waitChan := make(chan error) + go func() { + for i := 0; i < 5; i++ { + waitChan <- semaphore.Acquire(ctx) + } + }() + + // The semaphore is full now + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + + for i := 0; i < 5; i++ { + // Release one token + semaphore.Release() + // The waiting channel is unlocked + require.Nil(t, <-waitChan) + } + + // Release another token + semaphore.Release() + require.Equal(t, 4, semaphore.Count()) + + // Now TryAcquire can pull out a token + require.Nil(t, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + }) + + t.Run("the semaphore is resized up when empty", func(t *testing.T) { + ctx := testhelper.Context(t) + + semaphore := NewResizableSemaphore(5) + semaphore.Resize(10) + + waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 9) + require.Equal(t, 9, semaphore.Count()) + + require.Nil(t, semaphore.Acquire(ctx)) + require.Equal(t, 10, semaphore.Count()) + + close(waitBeforeRelease) + waitRelease() + + // Still 1 left + semaphore.Release() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized up when not empty", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(7) + + waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 5, semaphore.Count()) + + semaphore.Resize(15) + require.Equal(t, 5, semaphore.Count()) + + waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 10, semaphore.Count()) + + require.Nil(t, semaphore.Acquire(ctx)) + require.Equal(t, 11, semaphore.Count()) + + close(waitBeforeRelease1) + close(waitBeforeRelease2) + waitRelease1() + waitRelease2() + + // Still 1 left + semaphore.Release() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized up when full", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(5) + + waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + + semaphore.Resize(10) + + waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 10, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 10, semaphore.Count()) + + var count atomic.Int32 + for i := 0; i < 10; i++ { + go func() { + require.Nil(t, semaphore.Acquire(ctx)) + count.Add(1) + }() + } + + semaphore.Resize(15) + // Poll until 5 acquires + for count.Load() != 5 { + time.Sleep(1 * time.Millisecond) + } + // Resize to 20 to fit the rest 5 + semaphore.Resize(20) + // Wait for the rest to finish + for count.Load() != 10 { + time.Sleep(1 * time.Millisecond) + } + + close(waitBeforeRelease1) + close(waitBeforeRelease2) + waitRelease1() + waitRelease2() + + require.Equal(t, 10, semaphore.Count()) + }) + + t.Run("the semaphore is resized down when empty", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(10) + semaphore.Resize(5) + + waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 4) + require.Equal(t, 4, semaphore.Count()) + + require.Nil(t, semaphore.Acquire(ctx)) + require.Equal(t, 5, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 5, semaphore.Count()) + + close(waitBeforeRelease) + waitRelease() + + // Still 1 left + semaphore.Release() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized down when not empty", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(20) + + waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 5, semaphore.Count()) + + semaphore.Resize(15) + waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 10, semaphore.Count()) + + require.Nil(t, semaphore.Acquire(ctx)) + require.Equal(t, 11, semaphore.Count()) + + close(waitBeforeRelease1) + close(waitBeforeRelease2) + waitRelease1() + waitRelease2() + + // Still 1 left + semaphore.Release() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized down lower than the current length", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(10) + + waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5) + require.Equal(t, 5, semaphore.Count()) + + semaphore.Resize(3) + require.Equal(t, 5, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + + close(waitBeforeRelease1) + waitRelease1() + require.Equal(t, 0, semaphore.Count()) + + waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 3) + require.Equal(t, 3, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + require.Equal(t, 3, semaphore.Count()) + + close(waitBeforeRelease2) + waitRelease2() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized down when full", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(10) + + waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 10) + require.Equal(t, 10, semaphore.Count()) + + semaphore.Resize(5) + require.Equal(t, 10, semaphore.Count()) + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + + close(waitBeforeRelease1) + waitRelease1() + + require.Equal(t, 0, semaphore.Count()) + + waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5) + + require.Equal(t, 5, semaphore.Count()) + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + + close(waitBeforeRelease2) + waitRelease2() + + require.Equal(t, 0, semaphore.Count()) + }) + + t.Run("the semaphore is resized up and down consecutively", func(t *testing.T) { + ctx := testhelper.Context(t) + semaphore := NewResizableSemaphore(10) + + for i := 0; i < 5; i++ { + require.NoError(t, semaphore.Acquire(ctx)) + } + require.Equal(t, 5, semaphore.Count()) + + semaphore.Resize(7) + require.Equal(t, 5, semaphore.Count()) + + require.NoError(t, semaphore.Acquire(ctx)) + require.Equal(t, 6, semaphore.Count()) + + // Resize down to 3, current = 3, leftover = 3 + semaphore.Resize(3) + require.Equal(t, 6, semaphore.Count()) + + // Cannot acquire + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + semaphore.Release() + require.Equal(t, 5, semaphore.Count()) + semaphore.Release() + require.Equal(t, 4, semaphore.Count()) + + // Resize down again. Current = 2, leftover = 2 + semaphore.Resize(2) + require.Equal(t, 4, semaphore.Count()) + + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + semaphore.Release() + require.Equal(t, 3, semaphore.Count()) + semaphore.Release() + require.Equal(t, 2, semaphore.Count()) + + // Leftover is used up, but still cannot acquire + require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire()) + + // Acquireable now + semaphore.Release() + require.Equal(t, 1, semaphore.Count()) + require.NoError(t, semaphore.Acquire(ctx)) + require.Equal(t, 2, semaphore.Count()) + }) +} + +func BenchmarkResizableSemaphore(b *testing.B) { + for _, numIterations := range []uint{100, 1000, 10_000} { + n := numIterations + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + b.Run("acquire then release immediately", func(b *testing.B) { + b.ResetTimer() + ctx := testhelper.Context(b) + semaphore := NewResizableSemaphore(2) + + for i := uint(0); i < n; i++ { + b.StartTimer() + err := semaphore.Acquire(ctx) + semaphore.Release() + b.StopTimer() + require.NoError(b, err) + } + + require.Equal(b, 0, semaphore.Count()) + }) + + b.Run("acquire then release after done", func(b *testing.B) { + b.ResetTimer() + ctx := testhelper.Context(b) + semaphore := NewResizableSemaphore(n) + + for i := uint(0); i < n; i++ { + b.StartTimer() + err := semaphore.Acquire(ctx) + b.StopTimer() + require.NoError(b, err) + } + for i := uint(0); i < n; i++ { + b.StartTimer() + semaphore.Release() + b.StopTimer() + } + require.Equal(b, 0, semaphore.Count()) + }) + + b.Run("acquire after waiting", func(b *testing.B) { + b.ResetTimer() + ctx := testhelper.Context(b) + semaphore := NewResizableSemaphore(n) + + for i := uint(0); i < n; i++ { + require.NoError(b, semaphore.Acquire(ctx)) + } + // All of the following acquisitions are blocked + waitChan := make(chan error) + go func() { + for i := uint(0); i < n; i++ { + waitChan <- semaphore.Acquire(ctx) + } + }() + for i := uint(0); i < n; i++ { + // Measure the time since the last release and to the waiter acquires the + // semaphore. + b.StartTimer() + semaphore.Release() + <-waitChan + b.StopTimer() + } + require.Equal(b, n, semaphore.Count()) + }) + }) + } +} + +// acquireSemaphore attempts to acquire semaphore n times using ctx. It returns a channel which can be closed as a +// signal to the consumer to release the semaphore, and a WaitGroup which blocks until all consumers are finished. +func acquireSemaphore(t *testing.T, ctx context.Context, semaphore *resizableSemaphore, n int) (chan struct{}, func()) { + var acquireWg, releaseWg sync.WaitGroup + waitBeforeRelease := make(chan struct{}) + + for i := 0; i < n; i++ { + acquireWg.Add(1) + releaseWg.Add(1) + go func() { + require.Nil(t, semaphore.Acquire(ctx)) + acquireWg.Done() + + <-waitBeforeRelease + semaphore.Release() + releaseWg.Done() + }() + } + acquireWg.Wait() + + return waitBeforeRelease, releaseWg.Wait +} |