Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-pages.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ratelimiter/tls_test.go')
-rw-r--r--internal/ratelimiter/tls_test.go194
1 files changed, 194 insertions, 0 deletions
diff --git a/internal/ratelimiter/tls_test.go b/internal/ratelimiter/tls_test.go
new file mode 100644
index 00000000..6763514b
--- /dev/null
+++ b/internal/ratelimiter/tls_test.go
@@ -0,0 +1,194 @@
+package ratelimiter
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ "github.com/sirupsen/logrus"
+ testlog "github.com/sirupsen/logrus/hooks/test"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTLSHostnameKey(t *testing.T) {
+ info := &tls.ClientHelloInfo{ServerName: "group.gitlab.io"}
+
+ require.Equal(t, "group.gitlab.io", TLSHostnameKey(info))
+}
+
+func TestTLSClientIPKey(t *testing.T) {
+ tests := []struct {
+ addr string
+ expected string
+ }{
+ {
+ "10.1.2.3:1234",
+ "10.1.2.3",
+ },
+ {
+ "[2001:db8:3333:4444:5555:6666:7777:8888]:1234",
+ "2001:db8:3333:4444:5555:6666:7777:8888",
+ },
+ }
+
+ for _, tt := range tests {
+ addr, err := net.ResolveTCPAddr("tcp", tt.addr)
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn}
+
+ require.Equal(t, tt.expected, TLSClientIPKey(info))
+ }
+}
+
+func TestGetCertificateMiddleware(t *testing.T) {
+ tests := map[string]struct {
+ useHostnameAsKey bool
+ enforced bool
+ limitPerSecond float64
+ burst int
+ successfulReqCnt int
+ }{
+ "ip_limiter": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "hostname_limiter": {
+ useHostnameAsKey: true,
+ enforced: true,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "not_enforced": {
+ useHostnameAsKey: false,
+ enforced: false,
+ limitPerSecond: 0.1,
+ burst: 5,
+ successfulReqCnt: 5,
+ },
+ "disabled": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0,
+ burst: 5,
+ successfulReqCnt: 10,
+ },
+ "slowly_approach_limit": {
+ useHostnameAsKey: false,
+ enforced: true,
+ limitPerSecond: 0.2,
+ burst: 5,
+ successfulReqCnt: 6, // 5 * 0.2 gives another 1 request
+ },
+ }
+
+ expectedCert := &tls.Certificate{}
+ expectedErr := errors.New("expected error")
+
+ getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return expectedCert, expectedErr
+ }
+
+ for name, tt := range tests {
+ t.Run(name, func(t *testing.T) {
+ hook := testlog.NewGlobal()
+ blocked, cachedEntries, cacheReqs := newTestMetrics(t)
+
+ keyFunc := TLSClientIPKey
+ if tt.useHostnameAsKey {
+ keyFunc = TLSHostnameKey
+ }
+
+ rl := New("limit_name",
+ WithCachedEntriesMetric(cachedEntries),
+ WithCachedRequestsMetric(cacheReqs),
+ WithBlockedCountMetric(blocked),
+ WithNow(stubNow()),
+ WithLimitPerSecond(tt.limitPerSecond),
+ WithBurstSize(tt.burst),
+ WithEnforce(tt.enforced),
+ WithTLSKeyFunc(keyFunc))
+
+ middlewareGetCert := rl.GetCertificateMiddleware(getCertificate)
+
+ addr, err := net.ResolveTCPAddr("tcp", "10.1.2.3:12345")
+ require.NoError(t, err)
+ conn := stubConn{remoteAddr: addr}
+ info := &tls.ClientHelloInfo{Conn: conn, ServerName: "group.gitlab.io"}
+
+ for i := 0; i < tt.successfulReqCnt; i++ {
+ cert, err := middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ // When rate-limiter disabled altogether
+ if tt.limitPerSecond <= 0 {
+ return
+ }
+
+ cert, err := middlewareGetCert(info)
+ if tt.enforced {
+ require.Nil(t, cert)
+ require.Equal(t, err, ErrTLSRateLimited)
+ } else {
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+ }
+
+ require.NotNil(t, hook.LastEntry())
+ require.Equal(t, "TLS connection rate-limited", hook.LastEntry().Message)
+ expectedFields := logrus.Fields{
+ "rate_limiter_name": "limit_name",
+ "source_ip": "10.1.2.3",
+ "req_host": "group.gitlab.io",
+ "rate_limiter_limit_per_second": tt.limitPerSecond,
+ "rate_limiter_burst_size": tt.burst,
+ "enforced": tt.enforced,
+ }
+ require.Equal(t, expectedFields, hook.LastEntry().Data)
+
+ // make another request with different key and expect success
+ if tt.useHostnameAsKey {
+ info.ServerName = "another-group.gitlab.io"
+ } else {
+ addr, err := net.ResolveTCPAddr("tcp", "10.10.20.30:12345")
+ require.NoError(t, err)
+ conn = stubConn{remoteAddr: addr}
+ info.Conn = conn
+ }
+
+ cert, err = middlewareGetCert(info)
+ require.Equal(t, expectedCert, cert)
+ require.Equal(t, expectedErr, err)
+
+ blockedCount := testutil.ToFloat64(blocked.WithLabelValues("limit_name", strconv.FormatBool(tt.enforced)))
+ require.Equal(t, float64(1), blockedCount, "blocked count")
+
+ cachedCount := testutil.ToFloat64(cachedEntries.WithLabelValues("limit_name"))
+ require.Equal(t, float64(2), cachedCount, "cached count") // 1 for first key + 1 for different one
+
+ cacheReqMiss := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "miss"))
+ require.Equal(t, float64(2), cacheReqMiss, "miss count") // 1 for first key + 1 for different one
+ cacheReqHit := testutil.ToFloat64(cacheReqs.WithLabelValues("limit_name", "hit"))
+ require.Equal(t, float64(tt.successfulReqCnt), cacheReqHit, "hit count")
+ })
+ }
+}
+
+func stubNow() func() time.Time {
+ now := time.Now()
+ return func() time.Time {
+ now = now.Add(time.Second)
+
+ return now
+ }
+}