Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrick Steinhardt <psteinhardt@gitlab.com>2023-09-19 09:34:10 +0300
committerPatrick Steinhardt <psteinhardt@gitlab.com>2023-09-19 09:34:10 +0300
commit76b2eed1b30847ff69b6e96a9845388405101c6e (patch)
tree919e147c228aedb002265998ab539899b0ade6ab
parent0bef3b7e4677bb820c8ab8946024688f27fc5282 (diff)
parenta3565b4f6fcc5013adb734fa0dddff046cfbd835 (diff)
Merge branch 'qmnguyen0711/implement-resizable-semaphore' into 'master'
Implement resizable semaphore data structure See merge request https://gitlab.com/gitlab-org/gitaly/-/merge_requests/6366 Merged-by: Patrick Steinhardt <psteinhardt@gitlab.com> Approved-by: Patrick Steinhardt <psteinhardt@gitlab.com> Approved-by: James Liu <jliu@gitlab.com> Reviewed-by: Patrick Steinhardt <psteinhardt@gitlab.com> Reviewed-by: Quang-Minh Nguyen <qmnguyen@gitlab.com> Reviewed-by: James Liu <jliu@gitlab.com> Co-authored-by: Quang-Minh Nguyen <qmnguyen@gitlab.com>
-rw-r--r--internal/limiter/adaptive_calculator.go2
-rw-r--r--internal/limiter/adaptive_calculator_test.go10
-rw-r--r--internal/limiter/adaptive_limit.go62
-rw-r--r--internal/limiter/adaptive_limit_test.go124
-rw-r--r--internal/limiter/resizable_semaphore.go250
-rw-r--r--internal/limiter/resizable_semaphore_test.go572
6 files changed, 1002 insertions, 18 deletions
diff --git a/internal/limiter/adaptive_calculator.go b/internal/limiter/adaptive_calculator.go
index bb7088527..ebb7526e0 100644
--- a/internal/limiter/adaptive_calculator.go
+++ b/internal/limiter/adaptive_calculator.go
@@ -262,7 +262,7 @@ func (c *AdaptiveCalculator) calibrateLimits(ctx context.Context) {
}).Debugf("Additive increase")
} else {
// Multiplicative decrease
- newLimit = int(math.Floor(float64(limit.Current()) * setting.BackoffBackoff))
+ newLimit = int(math.Floor(float64(limit.Current()) * setting.BackoffFactor))
if newLimit < setting.Min {
newLimit = setting.Min
}
diff --git a/internal/limiter/adaptive_calculator_test.go b/internal/limiter/adaptive_calculator_test.go
index 7fed9b01b..8da85704d 100644
--- a/internal/limiter/adaptive_calculator_test.go
+++ b/internal/limiter/adaptive_calculator_test.go
@@ -657,12 +657,14 @@ func (l *testLimit) Update(val int) {
l.currents = append(l.currents, val)
}
+func (*testLimit) AfterUpdate(_ AfterUpdateHook) {}
+
func (l *testLimit) Setting() AdaptiveSetting {
return AdaptiveSetting{
- Initial: l.initial,
- Max: l.max,
- Min: l.min,
- BackoffBackoff: l.backoffBackoff,
+ Initial: l.initial,
+ Max: l.max,
+ Min: l.min,
+ BackoffFactor: l.backoffBackoff,
}
}
diff --git a/internal/limiter/adaptive_limit.go b/internal/limiter/adaptive_limit.go
index 2aefe5eb0..71eddc389 100644
--- a/internal/limiter/adaptive_limit.go
+++ b/internal/limiter/adaptive_limit.go
@@ -1,30 +1,48 @@
package limiter
-import "sync/atomic"
+import (
+ "sync"
+)
// AdaptiveSetting is a struct that holds the configuration parameters for an adaptive limiter.
type AdaptiveSetting struct {
- Initial int
- Max int
- Min int
- BackoffBackoff float64
+ Initial int
+ Max int
+ Min int
+ BackoffFactor float64
}
+// AfterUpdateHook is a function hook that is triggered when the current limit changes. The callers need to register a hook to
+// the AdaptiveLimiter implementation beforehand. They are required to handle errors inside the hook function.
+type AfterUpdateHook func(newVal int)
+
// AdaptiveLimiter is an interface for managing and updating adaptive limits.
// It exposes methods to get the name, current limit value, update the limit value, and access its settings.
type AdaptiveLimiter interface {
Name() string
Current() int
Update(val int)
+ AfterUpdate(AfterUpdateHook)
Setting() AdaptiveSetting
}
-// AdaptiveLimit is an implementation of the AdaptiveLimiter interface. It uses an atomic Int32 to represent the current
-// limit value, ensuring thread-safe updates.
+// AdaptiveLimit is an implementation of the AdaptiveLimiter interface. It uses a mutex to ensure thread-safe access to the limit value.
type AdaptiveLimit struct {
- name string
- current atomic.Int32
- setting AdaptiveSetting
+ sync.Mutex
+
+ name string
+ current int
+ setting AdaptiveSetting
+ updateHooks []AfterUpdateHook
+}
+
+// NewAdaptiveLimit initializes a new AdaptiveLimit object
+func NewAdaptiveLimit(name string, setting AdaptiveSetting) *AdaptiveLimit {
+ return &AdaptiveLimit{
+ name: name,
+ current: setting.Initial,
+ setting: setting,
+ }
}
// Name returns the name of the adaptive limit
@@ -34,12 +52,30 @@ func (l *AdaptiveLimit) Name() string {
// Current returns the current limit. This function can be called without the need for synchronization.
func (l *AdaptiveLimit) Current() int {
- return int(l.current.Load())
+ l.Lock()
+ defer l.Unlock()
+
+ return l.current
}
-// Update adjusts current limit value.
+// Update adjusts the current limit value and executes all registered update hooks.
func (l *AdaptiveLimit) Update(val int) {
- l.current.Store(int32(val))
+ l.Lock()
+ defer l.Unlock()
+
+ if val != l.current {
+ l.current = val
+ for _, hook := range l.updateHooks {
+ hook(val)
+ }
+ }
+}
+
+// AfterUpdate registers a callback when the current limit is updated. Because all updates and hooks are synchronized,
+// calling l.Current() inside the update hook in the same goroutine will cause deadlock. Hence, the update hook must
+// use the newVal argument instead.
+func (l *AdaptiveLimit) AfterUpdate(hook AfterUpdateHook) {
+ l.updateHooks = append(l.updateHooks, hook)
}
// Setting returns the configuration parameters for an adaptive limiter.
diff --git a/internal/limiter/adaptive_limit_test.go b/internal/limiter/adaptive_limit_test.go
new file mode 100644
index 000000000..f6c67f092
--- /dev/null
+++ b/internal/limiter/adaptive_limit_test.go
@@ -0,0 +1,124 @@
+package limiter
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAdaptiveLimit_New(t *testing.T) {
+ t.Parallel()
+
+ setting := AdaptiveSetting{
+ Initial: 5,
+ Max: 10,
+ Min: 1,
+ BackoffFactor: 0.5,
+ }
+
+ limit := NewAdaptiveLimit("testLimit", setting)
+ require.Equal(t, limit.Name(), "testLimit")
+ require.Equal(t, limit.Current(), 5)
+ require.Equal(t, limit.Setting(), setting)
+}
+
+func TestAdaptiveLimit_Update(t *testing.T) {
+ t.Parallel()
+
+ newLimit := func() *AdaptiveLimit {
+ return NewAdaptiveLimit("testLimit", AdaptiveSetting{
+ Initial: 5,
+ Max: 10,
+ Min: 1,
+ BackoffFactor: 0.5,
+ })
+ }
+
+ t.Run("without update hooks", func(t *testing.T) {
+ limit := newLimit()
+
+ limit.Update(1)
+ require.Equal(t, 1, limit.Current())
+
+ limit.Update(2)
+ require.Equal(t, 2, limit.Current())
+
+ limit.Update(3)
+ require.Equal(t, 3, limit.Current())
+ })
+
+ t.Run("new values are different from old values", func(t *testing.T) {
+ limit := newLimit()
+
+ vals := []int{}
+ limit.AfterUpdate(func(val int) {
+ vals = append(vals, val)
+ })
+
+ limit.Update(1)
+ require.Equal(t, 1, limit.Current())
+ require.Equal(t, vals, []int{1})
+
+ limit.Update(2)
+ require.Equal(t, 2, limit.Current())
+ require.Equal(t, vals, []int{1, 2})
+
+ limit.Update(3)
+ require.Equal(t, 3, limit.Current())
+ require.Equal(t, vals, []int{1, 2, 3})
+ })
+
+ t.Run("new values are the same as old values", func(t *testing.T) {
+ limit := newLimit()
+
+ vals := []int{}
+ limit.AfterUpdate(func(val int) {
+ vals = append(vals, val)
+ })
+
+ limit.Update(1)
+ require.Equal(t, 1, limit.Current())
+ require.Equal(t, vals, []int{1})
+
+ limit.Update(1)
+ require.Equal(t, 1, limit.Current())
+ require.Equal(t, vals, []int{1})
+
+ limit.Update(2)
+ require.Equal(t, 2, limit.Current())
+ require.Equal(t, vals, []int{1, 2})
+
+ limit.Update(2)
+ require.Equal(t, 2, limit.Current())
+ require.Equal(t, vals, []int{1, 2})
+ })
+
+ t.Run("multiple update hooks", func(t *testing.T) {
+ limit := newLimit()
+
+ vals1 := []int{}
+ limit.AfterUpdate(func(val int) {
+ vals1 = append(vals1, val)
+ })
+
+ vals2 := []int{}
+ limit.AfterUpdate(func(val int) {
+ vals2 = append(vals2, val*2)
+ })
+
+ limit.Update(1)
+ require.Equal(t, 1, limit.Current())
+ require.Equal(t, vals1, []int{1})
+ require.Equal(t, vals2, []int{2})
+
+ limit.Update(2)
+ require.Equal(t, 2, limit.Current())
+ require.Equal(t, vals1, []int{1, 2})
+ require.Equal(t, vals2, []int{2, 4})
+
+ limit.Update(3)
+ require.Equal(t, 3, limit.Current())
+ require.Equal(t, vals1, []int{1, 2, 3})
+ require.Equal(t, vals2, []int{2, 4, 6})
+ })
+}
diff --git a/internal/limiter/resizable_semaphore.go b/internal/limiter/resizable_semaphore.go
new file mode 100644
index 000000000..b09ea43b4
--- /dev/null
+++ b/internal/limiter/resizable_semaphore.go
@@ -0,0 +1,250 @@
+package limiter
+
+import (
+ "container/list"
+ "context"
+ "sync"
+)
+
+// resizableSemaphore struct models a semaphore with a dynamically adjustable size. It bounds the concurrent access to
+// resources, allowing a certain level of concurrency. When the concurrency reaches the semaphore's capacity, the callers
+// are blocked until a resource becomes available again. The size of the semaphore can be adjusted atomically at any time.
+// When the semaphore gets resized to a size smaller than the current number of resources acquired, the semaphore is
+// considered to be full. The "leftover" acquirers can still keep the resource until they release the semaphore. The
+// semaphore cannot be acquired until the amount of acquirers fall under the size again.
+//
+// Internally, it uses a doubly-linked list to manage waiters when the semaphore is full. Callers acquire the semaphore
+// by invoking `Acquire()`, and release them by calling `Release()`. This struct ensures that the available slots are
+// properly managed, and also handles the semaphore's current count and size. It processes resize requests and manages
+// try requests and responses, ensuring smooth operation.
+//
+// This implementation is heavily inspired by "golang.org/x/sync/semaphore" package's implementation.
+//
+// Note: This struct is not intended to serve as a general-purpose data structure but is specifically designed for
+// flexible concurrency control with resizable capacity.
+type resizableSemaphore struct {
+ sync.Mutex
+ // current represents the current concurrency access to the resources.
+ current uint
+ // leftover accounts for the number of extra acquirers when the size shrinks down.
+ leftover uint
+ // size is the maximum capacity of the semaphore. It represents the maximum number of concurrent accesses allowed
+ // to the resource at the current time.
+ size uint
+ // waiters is a FIFO list of waiters waiting for the resource.
+ waiters *list.List
+}
+
+// waiter is a wrapper to be put into the waiting queue. When there is an available resource, the front waiter is pulled
+// out and ready channel is closed.
+type waiter struct {
+ ready chan struct{}
+}
+
+// NewResizableSemaphore creates a new resizableSemaphore with the specified initial size.
+func NewResizableSemaphore(size uint) *resizableSemaphore {
+ return &resizableSemaphore{
+ size: size,
+ waiters: list.New(),
+ }
+}
+
+// Acquire allows the caller to acquire the semaphore. If the semaphore is full, the caller is blocked until there
+// is an available slot or the context is canceled. If the context is canceled, context's error is returned. Otherwise,
+// this function returns nil after acquired.
+func (s *resizableSemaphore) Acquire(ctx context.Context) error {
+ s.Lock()
+ if s.count() < s.size {
+ select {
+ case <-ctx.Done():
+ s.Unlock()
+ return ctx.Err()
+ default:
+ s.current++
+ s.Unlock()
+ return nil
+ }
+ }
+
+ w := &waiter{ready: make(chan struct{})}
+ element := s.waiters.PushBack(w)
+ s.Unlock()
+
+ select {
+ case <-ctx.Done():
+ return s.stopWaiter(element, w, ctx.Err())
+ case <-w.ready:
+ return nil
+ }
+}
+
+func (s *resizableSemaphore) stopWaiter(element *list.Element, w *waiter, err error) error {
+ s.Lock()
+ defer s.Unlock()
+
+ select {
+ case <-w.ready:
+ // If the waiter is ready at the same time as the context cancellation, act as if this
+ // waiter is not aware of the cancellation. At this point, the linked list item is
+ // properly removed from the queue and the waiter is considered to acquire the
+ // semaphore. Otherwise, there might be a race that makes Acquire() returns an error
+ // even after the acquisition.
+ err = nil
+ default:
+ isFront := s.waiters.Front() == element
+ s.waiters.Remove(element)
+ // If we're at the front and there are extra slots left, notify next waiters in the
+ // queue. If all waiters in the queue have the same context, this action is not
+ // necessary because the rest will return an error anyway. Unfortunately, as we accept
+ // ctx as an argument of Acquire(), it's possible for waiters to have different
+ // contexts. Hence, we need to scan the waiter list, just in case.
+ if isFront {
+ s.notifyWaiters()
+ }
+ }
+ return err
+}
+
+// notifyWaiters scans from the head of the s.waiters linked list, removing waiters until there are no free slots. This
+// function must only be called after the mutex of s is acquired.
+func (s *resizableSemaphore) notifyWaiters() {
+ for {
+ element := s.waiters.Front()
+ if element == nil {
+ break
+ }
+
+ if s.count() >= s.size {
+ return
+ }
+
+ w := element.Value.(*waiter)
+ s.current++
+ s.waiters.Remove(element)
+ close(w.ready)
+ }
+}
+
+// TryAcquire attempts to acquire the semaphore without blocking. On success, it returns nil. On failure, it returns
+// ErrMaxQueueSize and leaves the semaphore unchanged.
+func (s *resizableSemaphore) TryAcquire() error {
+ s.Lock()
+ defer s.Unlock()
+
+ // Technically, if the number of waiters is less than the number of available slots, the caller
+ // of this function should be put to at the end of the queue. However, the queue always moved
+ // up when a slot is available or when a waiter's context is cancelled. Thus, as soon as there
+ // are waiters in the queue, there is no chance for this caller to acquire the semaphore
+ // without waiting.
+ if s.count() < s.size && s.waiters.Len() == 0 {
+ s.current++
+ return nil
+ }
+ return ErrMaxQueueSize
+}
+
+// Release releases the semaphore.
+func (s *resizableSemaphore) Release() {
+ s.Lock()
+ defer s.Unlock()
+ // Deduct leftover first, because we want to release the remaining extra slots that were acquired before
+ // the semaphore was shrunk. The semaphore can be acquired again when current < size and leftover = 0.
+ // ┌────────── size ────────────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅
+ // └─────────── current ───────────┴─ leftover ┘
+ if s.leftover > 0 {
+ s.leftover--
+ } else {
+ s.current--
+ }
+ s.notifyWaiters()
+}
+
+// Count returns the number of concurrent accesses allowed by the semaphore.
+func (s *resizableSemaphore) Count() int {
+ s.Lock()
+ defer s.Unlock()
+ return int(s.count())
+}
+
+func (s *resizableSemaphore) count() uint {
+ return s.current + s.leftover
+}
+
+// Resize modifies the maximum number of concurrent accesses allowed by the semaphore.
+func (s *resizableSemaphore) Resize(newSize uint) {
+ s.Lock()
+ defer s.Unlock()
+
+ if newSize == s.size {
+ return
+ }
+
+ s.size = newSize
+ currentCount := s.count()
+ if newSize > currentCount {
+ // Case 1: The semaphore is full. There is no leftover. The current and leftover stays intact.
+ // ┌─────────────── New size ────────────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □
+ // └───────── current ─────────┘
+ // └────── Previous size ─────┘
+ //
+ // Case 2: The semaphore is full. The leftover might exceed the previous size.
+ // ┌────────────────── New size ──────────────────────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅ □ □ □ □ □
+ // └─────────── current ───────────┴─ leftover ┘
+ // └─────── Previous size ────────┘
+ //
+ // Case 3: The semaphore is not full. It's not feasible to have leftover because leftover is deducted
+ // before current. If the semaphore's size grows after shrinking down, the leftover is properly
+ // restructured.
+ // ┌─────────────── New size ────────────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □
+ // └────────── current ─────────┘ │
+ // └────────── Previous size ─────────┘
+ //
+ // Case 4: The semaphore is not full but the new size is less than the previous size. The queue stays
+ // idle. It's not necessary to notify the waiters,
+ // ┌─────────── New size ───────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □
+ // └────────── current ─────────┘ │
+ // └────────── Previous size ─────────┘
+ //
+ // In either case, the new size covers the current and leftover. No need to continue accounting for
+ // leftover. We also need to notify the waiters to move up the queues if there are available slots. If
+ // there isn't (case 4), no need to notify the waiters, but the function returns immediately.
+ s.leftover = 0
+ s.current = currentCount
+ s.notifyWaiters()
+ } else {
+ // Case 1: The semaphore is full. There is no leftover.
+ // ┌──────── New size ──────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■
+ // └───────────── current ───────────────┘
+ // └─────────── Previous size ──────────┘
+ //
+ // Case 2: The semaphore is full. There are some leftovers.
+ // ┌──────── New size ──────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⧅ ⧅ ⧅ ⧅ ⧅ ⧅
+ // └───────────── current ───────────────┴─ leftover ┘
+ // └─────────── Previous size ──────────┘
+ //
+ // Case 3: The semaphore is not full. Similar to case 3 above, there shouldn't be any leftover.
+ // ┌────── New size ────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □
+ // └────────── current ──────────┘ │
+ // └────────── Previous size ───────────┘
+ //
+ // Case 4: The new size is equal to the current count. The semaphore is either saturated or not, the
+ // leftover is reset.
+ // ┌────────── New size ──────────┐
+ // ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □
+ // └────────── current ──────────┘ │
+ // └────────── Previous size ───────────┘
+ //
+ // In all of the above cases, the semaphore sets the current to the new size and convert the rest to
+ // leftover. There isn't any new slot, hence no need to notify the waiters.
+ s.current = newSize
+ s.leftover = currentCount - newSize
+ }
+}
diff --git a/internal/limiter/resizable_semaphore_test.go b/internal/limiter/resizable_semaphore_test.go
new file mode 100644
index 000000000..af9aa5733
--- /dev/null
+++ b/internal/limiter/resizable_semaphore_test.go
@@ -0,0 +1,572 @@
+package limiter
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper"
+)
+
+func TestResizableSemaphore_New(t *testing.T) {
+ t.Parallel()
+
+ semaphore := NewResizableSemaphore(5)
+ require.Equal(t, 0, semaphore.Count())
+}
+
+// TestResizableSemaphore_ContextCanceled ensures that when consumers successfully acquire the semaphore with a given
+// context, and that context is subsequently canceled, the consumers are still able to release the semaphore. It also
+// tests that consumers waiting to acquire the semaphore with the same canceled context are not able to do so once slots
+// become available.
+func TestResizableSemaphore_ContextCanceled(t *testing.T) {
+ t.Parallel()
+
+ t.Run("context is canceled when the semaphore is empty", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(testhelper.Context(t))
+ cancel()
+
+ semaphore := NewResizableSemaphore(5)
+
+ require.Equal(t, context.Canceled, semaphore.Acquire(ctx))
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("context is canceled when the semaphore is not full", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(testhelper.Context(t))
+ testResizableSemaphoreCanceledWhenNotFull(t, ctx, cancel, context.Canceled)
+ })
+
+ t.Run("context is canceled when the semaphore is full", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(testhelper.Context(t))
+ testResizableSemaphoreCanceledWhenFull(t, ctx, cancel, context.Canceled)
+ })
+
+ t.Run("context's deadline exceeded when the semaphore is empty", func(t *testing.T) {
+ ctx, cancel := context.WithDeadline(testhelper.Context(t), time.Now().Add(-1*time.Hour))
+ defer cancel()
+
+ semaphore := NewResizableSemaphore(5)
+
+ require.Equal(t, context.DeadlineExceeded, semaphore.Acquire(ctx))
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("context's deadline exceeded when the semaphore is not full", func(t *testing.T) {
+ ctx, cancel, simulateTimeout := testhelper.ContextWithSimulatedTimeout(testhelper.Context(t))
+ defer cancel()
+
+ testResizableSemaphoreCanceledWhenNotFull(t, ctx, simulateTimeout, context.DeadlineExceeded)
+ })
+
+ t.Run("context's deadline exceeded when the semaphore is full", func(t *testing.T) {
+ ctx, cancel, simulateTimeout := testhelper.ContextWithSimulatedTimeout(testhelper.Context(t))
+ defer cancel()
+
+ testResizableSemaphoreCanceledWhenFull(t, ctx, simulateTimeout, context.DeadlineExceeded)
+ })
+}
+
+func testResizableSemaphoreCanceledWhenNotFull(t *testing.T, ctx context.Context, stopContext context.CancelFunc, expectedErr error) {
+ semaphore := NewResizableSemaphore(5)
+
+ // 3 goroutines acquired semaphore
+ beforeCallRelease := make(chan struct{})
+ var acquireWg, releaseWg sync.WaitGroup
+ for i := 0; i < 3; i++ {
+ acquireWg.Add(1)
+ releaseWg.Add(1)
+ go func() {
+ require.Nil(t, semaphore.Acquire(ctx))
+ acquireWg.Done()
+
+ <-beforeCallRelease
+
+ semaphore.Release()
+ releaseWg.Done()
+ }()
+ }
+ acquireWg.Wait()
+
+ // Now cancel the context
+ stopContext()
+ require.Equal(t, expectedErr, semaphore.Acquire(ctx))
+
+ // The first 3 goroutines can call Release() even if the context is cancelled
+ close(beforeCallRelease)
+ releaseWg.Wait()
+
+ require.Equal(t, expectedErr, semaphore.Acquire(ctx))
+ require.Equal(t, 0, semaphore.Count())
+}
+
+func testResizableSemaphoreCanceledWhenFull(t *testing.T, ctx context.Context, stopContext context.CancelFunc, expectedErr error) {
+ semaphore := NewResizableSemaphore(5)
+
+ // Try to acquire a token of the empty sempahore
+ require.Nil(t, semaphore.TryAcquire())
+ semaphore.Release()
+
+ // 5 goroutines acquired semaphore
+ beforeCallRelease := make(chan struct{})
+ var acquireWg1, releaseWg1 sync.WaitGroup
+ for i := 0; i < 5; i++ {
+ acquireWg1.Add(1)
+ releaseWg1.Add(1)
+ go func() {
+ require.Nil(t, semaphore.Acquire(ctx))
+ acquireWg1.Done()
+
+ <-beforeCallRelease
+
+ semaphore.Release()
+ releaseWg1.Done()
+ }()
+ }
+ acquireWg1.Wait()
+
+ // Another 5 waits for sempahore
+ var acquireWg2 sync.WaitGroup
+ for i := 0; i < 5; i++ {
+ acquireWg2.Add(1)
+ go func() {
+ // This goroutine is block until the context is cancel, which returns canceled error
+ require.Equal(t, expectedErr, semaphore.Acquire(ctx))
+ acquireWg2.Done()
+ }()
+ }
+
+ // Try to acquire a token of the full semaphore
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+
+ // Cancel the context
+ stopContext()
+ acquireWg2.Wait()
+
+ // The first 5 goroutines can call Release() even if the context is cancelled
+ close(beforeCallRelease)
+ releaseWg1.Wait()
+
+ // The last 5 goroutines exits immediately, Acquire() returns error
+ acquireWg2.Wait()
+
+ require.Equal(t, 0, semaphore.Count())
+
+ // Now the context is cancelled
+ require.Equal(t, expectedErr, semaphore.Acquire(ctx))
+}
+
+func TestResizableSemaphore_Acquire(t *testing.T) {
+ t.Parallel()
+
+ t.Run("acquire less than the capacity", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(5)
+
+ waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 3)
+ require.Equal(t, 3, semaphore.Count())
+
+ require.Nil(t, semaphore.Acquire(ctx))
+ require.Equal(t, 4, semaphore.Count())
+
+ require.Nil(t, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+
+ close(waitBeforeRelease)
+ waitRelease()
+
+ // Still 2 left
+ require.Equal(t, 2, semaphore.Count())
+ semaphore.Release()
+ semaphore.Release()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("acquire more than the capacity", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(5)
+
+ waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 5, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+
+ close(waitBeforeRelease)
+ waitRelease()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("semaphore is full then available again", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(5)
+ for i := 0; i < 5; i++ {
+ require.NoError(t, semaphore.Acquire(ctx))
+ }
+
+ waitChan := make(chan error)
+ go func() {
+ for i := 0; i < 5; i++ {
+ waitChan <- semaphore.Acquire(ctx)
+ }
+ }()
+
+ // The semaphore is full now
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+
+ for i := 0; i < 5; i++ {
+ // Release one token
+ semaphore.Release()
+ // The waiting channel is unlocked
+ require.Nil(t, <-waitChan)
+ }
+
+ // Release another token
+ semaphore.Release()
+ require.Equal(t, 4, semaphore.Count())
+
+ // Now TryAcquire can pull out a token
+ require.Nil(t, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized up when empty", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+
+ semaphore := NewResizableSemaphore(5)
+ semaphore.Resize(10)
+
+ waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 9)
+ require.Equal(t, 9, semaphore.Count())
+
+ require.Nil(t, semaphore.Acquire(ctx))
+ require.Equal(t, 10, semaphore.Count())
+
+ close(waitBeforeRelease)
+ waitRelease()
+
+ // Still 1 left
+ semaphore.Release()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized up when not empty", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(7)
+
+ waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 5, semaphore.Count())
+
+ semaphore.Resize(15)
+ require.Equal(t, 5, semaphore.Count())
+
+ waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 10, semaphore.Count())
+
+ require.Nil(t, semaphore.Acquire(ctx))
+ require.Equal(t, 11, semaphore.Count())
+
+ close(waitBeforeRelease1)
+ close(waitBeforeRelease2)
+ waitRelease1()
+ waitRelease2()
+
+ // Still 1 left
+ semaphore.Release()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized up when full", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(5)
+
+ waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5)
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+
+ semaphore.Resize(10)
+
+ waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 10, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 10, semaphore.Count())
+
+ var count atomic.Int32
+ for i := 0; i < 10; i++ {
+ go func() {
+ require.Nil(t, semaphore.Acquire(ctx))
+ count.Add(1)
+ }()
+ }
+
+ semaphore.Resize(15)
+ // Poll until 5 acquires
+ for count.Load() != 5 {
+ time.Sleep(1 * time.Millisecond)
+ }
+ // Resize to 20 to fit the rest 5
+ semaphore.Resize(20)
+ // Wait for the rest to finish
+ for count.Load() != 10 {
+ time.Sleep(1 * time.Millisecond)
+ }
+
+ close(waitBeforeRelease1)
+ close(waitBeforeRelease2)
+ waitRelease1()
+ waitRelease2()
+
+ require.Equal(t, 10, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized down when empty", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(10)
+ semaphore.Resize(5)
+
+ waitBeforeRelease, waitRelease := acquireSemaphore(t, ctx, semaphore, 4)
+ require.Equal(t, 4, semaphore.Count())
+
+ require.Nil(t, semaphore.Acquire(ctx))
+ require.Equal(t, 5, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 5, semaphore.Count())
+
+ close(waitBeforeRelease)
+ waitRelease()
+
+ // Still 1 left
+ semaphore.Release()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized down when not empty", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(20)
+
+ waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 5, semaphore.Count())
+
+ semaphore.Resize(15)
+ waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 10, semaphore.Count())
+
+ require.Nil(t, semaphore.Acquire(ctx))
+ require.Equal(t, 11, semaphore.Count())
+
+ close(waitBeforeRelease1)
+ close(waitBeforeRelease2)
+ waitRelease1()
+ waitRelease2()
+
+ // Still 1 left
+ semaphore.Release()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized down lower than the current length", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(10)
+
+ waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 5)
+ require.Equal(t, 5, semaphore.Count())
+
+ semaphore.Resize(3)
+ require.Equal(t, 5, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+
+ close(waitBeforeRelease1)
+ waitRelease1()
+ require.Equal(t, 0, semaphore.Count())
+
+ waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 3)
+ require.Equal(t, 3, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ require.Equal(t, 3, semaphore.Count())
+
+ close(waitBeforeRelease2)
+ waitRelease2()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized down when full", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(10)
+
+ waitBeforeRelease1, waitRelease1 := acquireSemaphore(t, ctx, semaphore, 10)
+ require.Equal(t, 10, semaphore.Count())
+
+ semaphore.Resize(5)
+ require.Equal(t, 10, semaphore.Count())
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+
+ close(waitBeforeRelease1)
+ waitRelease1()
+
+ require.Equal(t, 0, semaphore.Count())
+
+ waitBeforeRelease2, waitRelease2 := acquireSemaphore(t, ctx, semaphore, 5)
+
+ require.Equal(t, 5, semaphore.Count())
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+
+ close(waitBeforeRelease2)
+ waitRelease2()
+
+ require.Equal(t, 0, semaphore.Count())
+ })
+
+ t.Run("the semaphore is resized up and down consecutively", func(t *testing.T) {
+ ctx := testhelper.Context(t)
+ semaphore := NewResizableSemaphore(10)
+
+ for i := 0; i < 5; i++ {
+ require.NoError(t, semaphore.Acquire(ctx))
+ }
+ require.Equal(t, 5, semaphore.Count())
+
+ semaphore.Resize(7)
+ require.Equal(t, 5, semaphore.Count())
+
+ require.NoError(t, semaphore.Acquire(ctx))
+ require.Equal(t, 6, semaphore.Count())
+
+ // Resize down to 3, current = 3, leftover = 3
+ semaphore.Resize(3)
+ require.Equal(t, 6, semaphore.Count())
+
+ // Cannot acquire
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ semaphore.Release()
+ require.Equal(t, 5, semaphore.Count())
+ semaphore.Release()
+ require.Equal(t, 4, semaphore.Count())
+
+ // Resize down again. Current = 2, leftover = 2
+ semaphore.Resize(2)
+ require.Equal(t, 4, semaphore.Count())
+
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+ semaphore.Release()
+ require.Equal(t, 3, semaphore.Count())
+ semaphore.Release()
+ require.Equal(t, 2, semaphore.Count())
+
+ // Leftover is used up, but still cannot acquire
+ require.Equal(t, ErrMaxQueueSize, semaphore.TryAcquire())
+
+ // Acquireable now
+ semaphore.Release()
+ require.Equal(t, 1, semaphore.Count())
+ require.NoError(t, semaphore.Acquire(ctx))
+ require.Equal(t, 2, semaphore.Count())
+ })
+}
+
+func BenchmarkResizableSemaphore(b *testing.B) {
+ for _, numIterations := range []uint{100, 1000, 10_000} {
+ n := numIterations
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ b.Run("acquire then release immediately", func(b *testing.B) {
+ b.ResetTimer()
+ ctx := testhelper.Context(b)
+ semaphore := NewResizableSemaphore(2)
+
+ for i := uint(0); i < n; i++ {
+ b.StartTimer()
+ err := semaphore.Acquire(ctx)
+ semaphore.Release()
+ b.StopTimer()
+ require.NoError(b, err)
+ }
+
+ require.Equal(b, 0, semaphore.Count())
+ })
+
+ b.Run("acquire then release after done", func(b *testing.B) {
+ b.ResetTimer()
+ ctx := testhelper.Context(b)
+ semaphore := NewResizableSemaphore(n)
+
+ for i := uint(0); i < n; i++ {
+ b.StartTimer()
+ err := semaphore.Acquire(ctx)
+ b.StopTimer()
+ require.NoError(b, err)
+ }
+ for i := uint(0); i < n; i++ {
+ b.StartTimer()
+ semaphore.Release()
+ b.StopTimer()
+ }
+ require.Equal(b, 0, semaphore.Count())
+ })
+
+ b.Run("acquire after waiting", func(b *testing.B) {
+ b.ResetTimer()
+ ctx := testhelper.Context(b)
+ semaphore := NewResizableSemaphore(n)
+
+ for i := uint(0); i < n; i++ {
+ require.NoError(b, semaphore.Acquire(ctx))
+ }
+ // All of the following acquisitions are blocked
+ waitChan := make(chan error)
+ go func() {
+ for i := uint(0); i < n; i++ {
+ waitChan <- semaphore.Acquire(ctx)
+ }
+ }()
+ for i := uint(0); i < n; i++ {
+ // Measure the time since the last release and to the waiter acquires the
+ // semaphore.
+ b.StartTimer()
+ semaphore.Release()
+ <-waitChan
+ b.StopTimer()
+ }
+ require.Equal(b, n, semaphore.Count())
+ })
+ })
+ }
+}
+
+// acquireSemaphore attempts to acquire semaphore n times using ctx. It returns a channel which can be closed as a
+// signal to the consumer to release the semaphore, and a WaitGroup which blocks until all consumers are finished.
+func acquireSemaphore(t *testing.T, ctx context.Context, semaphore *resizableSemaphore, n int) (chan struct{}, func()) {
+ var acquireWg, releaseWg sync.WaitGroup
+ waitBeforeRelease := make(chan struct{})
+
+ for i := 0; i < n; i++ {
+ acquireWg.Add(1)
+ releaseWg.Add(1)
+ go func() {
+ require.Nil(t, semaphore.Acquire(ctx))
+ acquireWg.Done()
+
+ <-waitBeforeRelease
+ semaphore.Release()
+ releaseWg.Done()
+ }()
+ }
+ acquireWg.Wait()
+
+ return waitBeforeRelease, releaseWg.Wait
+}