diff options
Diffstat (limited to 'workhorse/internal/upstream/roundtripper')
3 files changed, 127 insertions, 0 deletions
diff --git a/workhorse/internal/upstream/roundtripper/roundtripper.go b/workhorse/internal/upstream/roundtripper/roundtripper.go new file mode 100644 index 00000000000..84f1983b471 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/roundtripper.go @@ -0,0 +1,61 @@ +package roundtripper + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "gitlab.com/gitlab-org/labkit/correlation" + "gitlab.com/gitlab-org/labkit/tracing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" +) + +func mustParseAddress(address, scheme string) string { + if scheme == "https" { + panic("TLS is not supported for backend connections") + } + + for _, suffix := range []string{"", ":" + scheme} { + address += suffix + if host, port, err := net.SplitHostPort(address); err == nil && host != "" && port != "" { + return host + ":" + port + } + } + + panic(fmt.Errorf("could not parse host:port from address %q and scheme %q", address, scheme)) +} + +// NewBackendRoundTripper returns a new RoundTripper instance using the provided values +func NewBackendRoundTripper(backend *url.URL, socket string, proxyHeadersTimeout time.Duration, developmentMode bool) http.RoundTripper { + // Copied from the definition of http.DefaultTransport. We can't literally copy http.DefaultTransport because of its hidden internal state. + transport, dialer := newBackendTransport() + transport.ResponseHeaderTimeout = proxyHeadersTimeout + + if backend != nil && socket == "" { + address := mustParseAddress(backend.Host, backend.Scheme) + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "tcp", address) + } + } else if socket != "" { + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", socket) + } + } else { + panic("backend is nil and socket is empty") + } + + return tracing.NewRoundTripper( + correlation.NewInstrumentedRoundTripper( + badgateway.NewRoundTripper(developmentMode, transport), + ), + ) +} + +// NewTestBackendRoundTripper sets up a RoundTripper for testing purposes +func NewTestBackendRoundTripper(backend *url.URL) http.RoundTripper { + return NewBackendRoundTripper(backend, "", 0, true) +} diff --git a/workhorse/internal/upstream/roundtripper/roundtripper_test.go b/workhorse/internal/upstream/roundtripper/roundtripper_test.go new file mode 100644 index 00000000000..79ffa244918 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/roundtripper_test.go @@ -0,0 +1,39 @@ +package roundtripper + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMustParseAddress(t *testing.T) { + successExamples := []struct{ address, scheme, expected string }{ + {"1.2.3.4:56", "http", "1.2.3.4:56"}, + {"[::1]:23", "http", "::1:23"}, + {"4.5.6.7", "http", "4.5.6.7:http"}, + } + for i, example := range successExamples { + t.Run(strconv.Itoa(i), func(t *testing.T) { + require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme)) + }) + } +} + +func TestMustParseAddressPanic(t *testing.T) { + panicExamples := []struct{ address, scheme string }{ + {"1.2.3.4", ""}, + {"1.2.3.4", "https"}, + } + + for i, panicExample := range panicExamples { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + mustParseAddress(panicExample.address, panicExample.scheme) + }) + } +} diff --git a/workhorse/internal/upstream/roundtripper/transport.go b/workhorse/internal/upstream/roundtripper/transport.go new file mode 100644 index 00000000000..84d9623b129 --- /dev/null +++ b/workhorse/internal/upstream/roundtripper/transport.go @@ -0,0 +1,27 @@ +package roundtripper + +import ( + "net" + "net/http" + "time" +) + +// newBackendTransport setups the default HTTP transport which Workhorse uses +// to communicate with the upstream +func newBackendTransport() (*http.Transport, *net.Dialer) { + dialler := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialler.DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + return transport, dialler +} |