diff options
author | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-22 08:14:39 +0300 |
---|---|---|
committer | Jaime Martinez <jmartinez@gitlab.com> | 2021-09-30 06:36:45 +0300 |
commit | 128bf6b4cbee628cdb1fe3ee26b9442a18b85ef3 (patch) | |
tree | 0ba093ed727ce943b1499a8f82af3861d08ee24f | |
parent | 49f2c8908c0d831e6a2880c8ad659116ad70d74d (diff) |
test: use goroutines
-rw-r--r-- | internal/ratelimiter/ratelimiter.go | 29 | ||||
-rw-r--r-- | internal/ratelimiter/ratelimiter_test.go | 71 |
2 files changed, 48 insertions, 52 deletions
diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index 3965f49f..f30722c0 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -10,37 +10,34 @@ import ( ) const ( - // DefaultPerDomainFrequency the maximum number of requests per second to be allowed per domain. - // The default value of 25ms equals 1 request every 25ms -> 40 rps - DefaultPerDomainFrequency = 25 * time.Millisecond - // DefaultPerDomainBurstSize is the maximum burst allowed per rate limiter - // E.g. The first 40 requests within 25ms will succeed, but the 41st will fail until the next + // DefaultPerDomainFrequency is the rate in time.Duration at which the rate.Limiter + // bucket is filled with 1 token. A token is equivalent to a request. + // The default value of 20ms, or 1 request every 20ms, equals 50 requests per second. + DefaultPerDomainFrequency = 20 * time.Millisecond + // DefaultPerDomainBurstSize is the maximum burst allowed per rate limiter. + // E.g. The first 50 requests within 20ms will succeed, but the 51st will fail until the next // refill occurs at DefaultPerDomainFrequency, allowing only 1 request per rate frequency. - DefaultPerDomainBurstSize = 40 + DefaultPerDomainBurstSize = 50 - // avg of ~18,000 unique domains per hour + // based on an avg of ~18,000 unique domains per hour // https://log.gprd.gitlab.net/app/lens#/edit/3c45a610-15c9-11ec-a012-eb2e5674cacf?_g=h@e78830b defaultDomainsItems = 20000 defaultDomainsExpirationInterval = time.Hour ) -type cache interface { - FindOrFetch(cacheNamespace, key string, fetchFn func() (interface{}, error)) (interface{}, error) -} - // Option function to configure a RateLimiter type Option func(*RateLimiter) -// RateLimiter holds a map ot domain names with counters that enable rate limiting per domain. +// RateLimiter holds an LRU cache of elements to be rate limited. Currently, it supports +// a domainsCache and each item returns a rate.Limiter. // It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per domain entry. // See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html -// Cleanup runs every cleanupTimer iteration over all domains and removing them if -// the time since counter.lastSeen is greater than the domainMaxTTL. +// It also holds a now function that can be mocked in unit tests. type RateLimiter struct { now func() time.Time perDomainFrequency time.Duration perDomainBurstSize int - domainsCache cache + domainsCache *lru.Cache // TODO: add sourceIPCache https://gitlab.com/gitlab-org/gitlab-pages/-/issues/630 } @@ -101,6 +98,6 @@ func (rl *RateLimiter) getDomainCounter(domain string) *rate.Limiter { func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { limiter := rl.getDomainCounter(domain) - // AllowN allows us to use the rl.now function so we can test this more easily. + // AllowN allows us to use the rl.now function, so we can test this more easily. return limiter.AllowN(rl.now(), 1) } diff --git a/internal/ratelimiter/ratelimiter_test.go b/internal/ratelimiter/ratelimiter_test.go index e2352209..96f50b1b 100644 --- a/internal/ratelimiter/ratelimiter_test.go +++ b/internal/ratelimiter/ratelimiter_test.go @@ -1,6 +1,7 @@ package ratelimiter import ( + "sync" "testing" "time" @@ -23,32 +24,27 @@ func TestDomainAllowed(t *testing.T) { now string domainRate time.Duration perDomainBurstPerSecond int - domain string reqNum int }{ - "one_request_per_second": { - domainRate: 1, // 1 per second + "one_request_per_nanosecond": { + domainRate: time.Nanosecond, // 1 per nanosecond perDomainBurstPerSecond: 1, reqNum: 2, - domain: "rate.gitlab.io", }, - "one_request_per_second_but_big_bucket": { - domainRate: 1, // 1 per second + "one_request_per_nanosecond_but_big_bucket": { + domainRate: time.Nanosecond, perDomainBurstPerSecond: 10, reqNum: 11, - domain: "rate.gitlab.io", }, "three_req_per_second_bucket_size_one": { domainRate: 3, // 3 per second perDomainBurstPerSecond: 1, // max burst 1 means 1 at a time reqNum: 3, - domain: "rate.gitlab.io", }, "10_requests_per_second": { domainRate: 10, perDomainBurstPerSecond: 10, reqNum: 11, - domain: "rate.gitlab.io", }, } @@ -61,48 +57,51 @@ func TestDomainAllowed(t *testing.T) { ) for i := 0; i < tc.reqNum; i++ { - got := rl.DomainAllowed(tc.domain) + got := rl.DomainAllowed("rate.gitlab.io") if i < tc.perDomainBurstPerSecond { - require.Truef(t, got, "expected true for request no. %d", i+1) + require.Truef(t, got, "expected true for request no. %d", i) } else { - require.False(t, got, "expected false for request no. %d", i+1) + // requests should fail after reaching tc.perDomainBurstPerSecond because mockNow + // always returns the same time + require.False(t, got, "expected false for request no. %d", i) } } }) } } -func TestDomainAllowedWitSleeps(t *testing.T) { +func TestSingleRateLimiterWithMultipleDomains(t *testing.T) { rate := 10 * time.Millisecond rl := New( WithPerDomainFrequency(rate), WithPerDomainBurstSize(1), ) - domain := "test.gitlab.io" + wg := sync.WaitGroup{} + wg.Add(3) - t.Run("one request every 10ms with burst 1", func(t *testing.T) { - // prove cache entries per domain - t.Parallel() + testFn := func(domain string) func(t *testing.T) { + return func(t *testing.T) { + go func() { + defer wg.Done() - for i := 0; i < 10; i++ { - got := rl.DomainAllowed(domain) - require.Truef(t, got, "expected true for request no. %d", i+1) - time.Sleep(rate) - } - }) - - t.Run("requests start failing after reaching burst", func(t *testing.T) { - // prove cache entries per domain - t.Parallel() - - for i := 0; i < 5; i++ { - got := rl.DomainAllowed(domain + ".diff") - if i < 1 { - require.Truef(t, got, "expected true for request no. %d", i) - } else { - require.False(t, got, "expected false for request no. %d", i) - } + for i := 0; i < 5; i++ { + got := rl.DomainAllowed(domain) + require.Truef(t, got, "expected true for request no. %d", i) + time.Sleep(rate) + } + }() } - }) + } + + first := "first.gitlab.io" + t.Run(first, testFn(first)) + + second := "second.gitlab.io" + t.Run(second, testFn(second)) + + third := "third.gitlab.io" + t.Run(third, testFn(third)) + + wg.Wait() } |