diff options
author | Andrew Newdigate <andrew@gitlab.com> | 2019-01-09 23:07:29 +0300 |
---|---|---|
committer | Andrew Newdigate <andrew@gitlab.com> | 2019-01-14 14:39:32 +0300 |
commit | 7596a62637a23caf2e9e9451e6fc1d6cf12b6792 (patch) | |
tree | edf7b14db6fa510cea0b252d9669bc4bbfd6c039 | |
parent | e8bb2b1482860ed18b6a42ab48bc882c1089df2f (diff) |
Reintroduce a specific dialler for unix sockets
61f6c92779a70d577727e7eefa337409effd69ef removed the Dialer for unix
socket. This was done because a change to the GRPC library caused the
Dialer to stop working, and because the default implementation works as
expected for users not using a proxy.
Unfortunately this led to a regression for users with HTTP or
HTTPS proxy configurations exposed via the `http_proxy` or `https_proxy`
environment variables.
For this reason, we reintroduce the dialer for Unix socket connections.
-rw-r--r-- | client/address_parser.go | 32 | ||||
-rw-r--r-- | client/address_parser_test.go | 74 | ||||
-rw-r--r-- | client/dial.go | 70 |
3 files changed, 133 insertions, 43 deletions
diff --git a/client/address_parser.go b/client/address_parser.go index 55969e909..a052342ae 100644 --- a/client/address_parser.go +++ b/client/address_parser.go @@ -3,22 +3,36 @@ package client import ( "fmt" "net/url" + "strings" ) -func parseAddress(rawAddress string) (canonicalAddress string, err error) { +// extractHostFromRemoteURL will convert Gitaly-style URL addresses of the form +// scheme://host:port to the "host:port" addresses used by `grpc.Dial` +func extractHostFromRemoteURL(rawAddress string) (hostAndPort string, err error) { u, err := url.Parse(rawAddress) if err != nil { return "", err } - // tcp:// addresses are a special case which `grpc.Dial` expects in a - // different format - if u.Scheme == "tcp" || u.Scheme == "tls" { - if u.Path != "" { - return "", fmt.Errorf("%s addresses should not have a path", u.Scheme) - } - return u.Host, nil + if u.Path != "" { + return "", fmt.Errorf("remote addresses should not have a path") } - return u.String(), nil + if u.Host == "" { + return "", fmt.Errorf("remote addresses should have a host") + } + + return u.Host, nil +} + +// extractPathFromSocketURL will convert Gitaly-style URL addresses of the form +// unix:/path/to/socket into file paths: `/path/to/socket` +const unixPrefix = "unix:" + +func extractPathFromSocketURL(rawAddress string) (socketPath string, err error) { + if !strings.HasPrefix(rawAddress, unixPrefix) { + return "", fmt.Errorf("invalid socket address: %s", rawAddress) + } + + return strings.TrimPrefix(rawAddress, unixPrefix), nil } diff --git a/client/address_parser_test.go b/client/address_parser_test.go index 820a902b3..f5dc2f31b 100644 --- a/client/address_parser_test.go +++ b/client/address_parser_test.go @@ -2,19 +2,16 @@ package client import ( "testing" + + "github.com/stretchr/testify/require" ) -func TestParseAddress(t *testing.T) { +func Test_extractHostFromRemoteURL(t *testing.T) { testCases := []struct { raw string canonical string invalid bool }{ - {raw: "unix:/foo/bar.socket", canonical: "unix:///foo/bar.socket"}, - {raw: "unix:///foo/bar.socket", canonical: "unix:///foo/bar.socket"}, - // Mainly for test purposes we explicitly want to support relative paths - {raw: "unix://foo/bar.socket", canonical: "unix://foo/bar.socket"}, - {raw: "unix:foo/bar.socket", canonical: "unix:foo/bar.socket"}, {raw: "tcp://1.2.3.4", canonical: "1.2.3.4"}, {raw: "tcp://1.2.3.4:567", canonical: "1.2.3.4:567"}, {raw: "tcp://foobar", canonical: "foobar"}, @@ -24,28 +21,53 @@ func TestParseAddress(t *testing.T) { {raw: "tcp:///foo/bar.socket", invalid: true}, {raw: "tcp:/foo/bar.socket", invalid: true}, {raw: "tcp://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999", canonical: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999"}, - {raw: "foobar:9999", canonical: "foobar:9999"}, - // As per https://github.com/grpc/grpc/blob/master/doc/naming.md... - {raw: "dns:///127.0.0.1:9999", canonical: "dns:///127.0.0.1:9999"}, - {raw: "dns:///[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999", canonical: "dns:///[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999"}, + {raw: "foobar:9999", invalid: true}, + {raw: "unix:/foo/bar.socket", invalid: true}, + {raw: "unix:///foo/bar.socket", invalid: true}, + {raw: "unix://foo/bar.socket", invalid: true}, + {raw: "unix:foo/bar.socket", invalid: true}, + } + + for _, tc := range testCases { + t.Run(tc.raw, func(t *testing.T) { + canonical, err := extractHostFromRemoteURL(tc.raw) + if tc.invalid { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.canonical, canonical) + }) + } +} + +func Test_extractPathFromSocketURL(t *testing.T) { + testCases := []struct { + raw string + path string + invalid bool + }{ + {raw: "unix:/foo/bar.socket", path: "/foo/bar.socket"}, + {raw: "unix:///foo/bar.socket", path: "///foo/bar.socket"}, // Silly but valid + {raw: "unix:foo/bar.socket", path: "foo/bar.socket"}, + {raw: "unix:../foo/bar.socket", path: "../foo/bar.socket"}, + {raw: "unix:path/with/a/colon:/in/it", path: "path/with/a/colon:/in/it"}, + {raw: "tcp://1.2.3.4", invalid: true}, + {raw: "foo/bar.socket", invalid: true}, } for _, tc := range testCases { - canonical, err := parseAddress(tc.raw) - - if err == nil && tc.invalid { - t.Errorf("%v: expected error, got none", tc) - } else if err != nil && !tc.invalid { - t.Errorf("%v: parse error: %v", tc, err) - continue - } - - if tc.invalid { - continue - } - - if tc.canonical != canonical { - t.Errorf("%v: expected %q, got %q", tc, tc.canonical, canonical) - } + t.Run(tc.raw, func(t *testing.T) { + path, err := extractPathFromSocketURL(tc.raw) + + if tc.invalid { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.path, path) + }) } } diff --git a/client/dial.go b/client/dial.go index d0a51c0c1..fe4a3e683 100644 --- a/client/dial.go +++ b/client/dial.go @@ -1,6 +1,10 @@ package client import ( + "fmt" + "net" + "time" + "google.golang.org/grpc/credentials" "net/url" @@ -11,14 +15,30 @@ import ( // DefaultDialOpts hold the default DialOptions for connection to Gitaly over UNIX-socket var DefaultDialOpts = []grpc.DialOption{} +type connectionType int + +const ( + invalidConnection connectionType = iota + tcpConnection + tlsConnection + unixConnection +) + // Dial gitaly func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, error) { - canonicalAddress, err := parseAddress(rawAddress) - if err != nil { - return nil, err - } + var canonicalAddress string + var err error + + switch getConnectionType(rawAddress) { + case invalidConnection: + return nil, fmt.Errorf("invalid connection string: %s", rawAddress) + + case tlsConnection: + canonicalAddress, err = extractHostFromRemoteURL(rawAddress) // Ensure the form: "host:port" ... + if err != nil { + return nil, err + } - if isTLS(rawAddress) { certPool, err := systemCertPool() if err != nil { return nil, err @@ -26,8 +46,29 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro creds := credentials.NewClientTLSFromCert(certPool, "") connOpts = append(connOpts, grpc.WithTransportCredentials(creds)) - } else { + + case tcpConnection: + canonicalAddress, err = extractHostFromRemoteURL(rawAddress) // Ensure the form: "host:port" ... + if err != nil { + return nil, err + } connOpts = append(connOpts, grpc.WithInsecure()) + + case unixConnection: + canonicalAddress = rawAddress // This will be overriden by the custom dialer... + connOpts = append( + connOpts, + grpc.WithInsecure(), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + path, err := extractPathFromSocketURL(addr) + if err != nil { + return nil, err + } + + return net.DialTimeout("unix", path, timeout) + }), + ) + } conn, err := grpc.Dial(canonicalAddress, connOpts...) @@ -38,7 +79,20 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro return conn, nil } -func isTLS(rawAddress string) bool { +func getConnectionType(rawAddress string) connectionType { u, err := url.Parse(rawAddress) - return err == nil && u.Scheme == "tls" + if err != nil { + return invalidConnection + } + + switch u.Scheme { + case "tls": + return tlsConnection + case "unix": + return unixConnection + case "tcp": + return tcpConnection + default: + return invalidConnection + } } |