Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Newdigate <andrew@gitlab.com>2019-01-09 23:07:29 +0300
committerAndrew Newdigate <andrew@gitlab.com>2019-01-14 14:39:32 +0300
commit7596a62637a23caf2e9e9451e6fc1d6cf12b6792 (patch)
treeedf7b14db6fa510cea0b252d9669bc4bbfd6c039
parente8bb2b1482860ed18b6a42ab48bc882c1089df2f (diff)
Reintroduce a specific dialler for unix sockets
61f6c92779a70d577727e7eefa337409effd69ef removed the Dialer for unix socket. This was done because a change to the GRPC library caused the Dialer to stop working, and because the default implementation works as expected for users not using a proxy. Unfortunately this led to a regression for users with HTTP or HTTPS proxy configurations exposed via the `http_proxy` or `https_proxy` environment variables. For this reason, we reintroduce the dialer for Unix socket connections.
-rw-r--r--client/address_parser.go32
-rw-r--r--client/address_parser_test.go74
-rw-r--r--client/dial.go70
3 files changed, 133 insertions, 43 deletions
diff --git a/client/address_parser.go b/client/address_parser.go
index 55969e909..a052342ae 100644
--- a/client/address_parser.go
+++ b/client/address_parser.go
@@ -3,22 +3,36 @@ package client
import (
"fmt"
"net/url"
+ "strings"
)
-func parseAddress(rawAddress string) (canonicalAddress string, err error) {
+// extractHostFromRemoteURL will convert Gitaly-style URL addresses of the form
+// scheme://host:port to the "host:port" addresses used by `grpc.Dial`
+func extractHostFromRemoteURL(rawAddress string) (hostAndPort string, err error) {
u, err := url.Parse(rawAddress)
if err != nil {
return "", err
}
- // tcp:// addresses are a special case which `grpc.Dial` expects in a
- // different format
- if u.Scheme == "tcp" || u.Scheme == "tls" {
- if u.Path != "" {
- return "", fmt.Errorf("%s addresses should not have a path", u.Scheme)
- }
- return u.Host, nil
+ if u.Path != "" {
+ return "", fmt.Errorf("remote addresses should not have a path")
}
- return u.String(), nil
+ if u.Host == "" {
+ return "", fmt.Errorf("remote addresses should have a host")
+ }
+
+ return u.Host, nil
+}
+
+// extractPathFromSocketURL will convert Gitaly-style URL addresses of the form
+// unix:/path/to/socket into file paths: `/path/to/socket`
+const unixPrefix = "unix:"
+
+func extractPathFromSocketURL(rawAddress string) (socketPath string, err error) {
+ if !strings.HasPrefix(rawAddress, unixPrefix) {
+ return "", fmt.Errorf("invalid socket address: %s", rawAddress)
+ }
+
+ return strings.TrimPrefix(rawAddress, unixPrefix), nil
}
diff --git a/client/address_parser_test.go b/client/address_parser_test.go
index 820a902b3..f5dc2f31b 100644
--- a/client/address_parser_test.go
+++ b/client/address_parser_test.go
@@ -2,19 +2,16 @@ package client
import (
"testing"
+
+ "github.com/stretchr/testify/require"
)
-func TestParseAddress(t *testing.T) {
+func Test_extractHostFromRemoteURL(t *testing.T) {
testCases := []struct {
raw string
canonical string
invalid bool
}{
- {raw: "unix:/foo/bar.socket", canonical: "unix:///foo/bar.socket"},
- {raw: "unix:///foo/bar.socket", canonical: "unix:///foo/bar.socket"},
- // Mainly for test purposes we explicitly want to support relative paths
- {raw: "unix://foo/bar.socket", canonical: "unix://foo/bar.socket"},
- {raw: "unix:foo/bar.socket", canonical: "unix:foo/bar.socket"},
{raw: "tcp://1.2.3.4", canonical: "1.2.3.4"},
{raw: "tcp://1.2.3.4:567", canonical: "1.2.3.4:567"},
{raw: "tcp://foobar", canonical: "foobar"},
@@ -24,28 +21,53 @@ func TestParseAddress(t *testing.T) {
{raw: "tcp:///foo/bar.socket", invalid: true},
{raw: "tcp:/foo/bar.socket", invalid: true},
{raw: "tcp://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999", canonical: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999"},
- {raw: "foobar:9999", canonical: "foobar:9999"},
- // As per https://github.com/grpc/grpc/blob/master/doc/naming.md...
- {raw: "dns:///127.0.0.1:9999", canonical: "dns:///127.0.0.1:9999"},
- {raw: "dns:///[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999", canonical: "dns:///[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9999"},
+ {raw: "foobar:9999", invalid: true},
+ {raw: "unix:/foo/bar.socket", invalid: true},
+ {raw: "unix:///foo/bar.socket", invalid: true},
+ {raw: "unix://foo/bar.socket", invalid: true},
+ {raw: "unix:foo/bar.socket", invalid: true},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.raw, func(t *testing.T) {
+ canonical, err := extractHostFromRemoteURL(tc.raw)
+ if tc.invalid {
+ require.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, tc.canonical, canonical)
+ })
+ }
+}
+
+func Test_extractPathFromSocketURL(t *testing.T) {
+ testCases := []struct {
+ raw string
+ path string
+ invalid bool
+ }{
+ {raw: "unix:/foo/bar.socket", path: "/foo/bar.socket"},
+ {raw: "unix:///foo/bar.socket", path: "///foo/bar.socket"}, // Silly but valid
+ {raw: "unix:foo/bar.socket", path: "foo/bar.socket"},
+ {raw: "unix:../foo/bar.socket", path: "../foo/bar.socket"},
+ {raw: "unix:path/with/a/colon:/in/it", path: "path/with/a/colon:/in/it"},
+ {raw: "tcp://1.2.3.4", invalid: true},
+ {raw: "foo/bar.socket", invalid: true},
}
for _, tc := range testCases {
- canonical, err := parseAddress(tc.raw)
-
- if err == nil && tc.invalid {
- t.Errorf("%v: expected error, got none", tc)
- } else if err != nil && !tc.invalid {
- t.Errorf("%v: parse error: %v", tc, err)
- continue
- }
-
- if tc.invalid {
- continue
- }
-
- if tc.canonical != canonical {
- t.Errorf("%v: expected %q, got %q", tc, tc.canonical, canonical)
- }
+ t.Run(tc.raw, func(t *testing.T) {
+ path, err := extractPathFromSocketURL(tc.raw)
+
+ if tc.invalid {
+ require.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, tc.path, path)
+ })
}
}
diff --git a/client/dial.go b/client/dial.go
index d0a51c0c1..fe4a3e683 100644
--- a/client/dial.go
+++ b/client/dial.go
@@ -1,6 +1,10 @@
package client
import (
+ "fmt"
+ "net"
+ "time"
+
"google.golang.org/grpc/credentials"
"net/url"
@@ -11,14 +15,30 @@ import (
// DefaultDialOpts hold the default DialOptions for connection to Gitaly over UNIX-socket
var DefaultDialOpts = []grpc.DialOption{}
+type connectionType int
+
+const (
+ invalidConnection connectionType = iota
+ tcpConnection
+ tlsConnection
+ unixConnection
+)
+
// Dial gitaly
func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, error) {
- canonicalAddress, err := parseAddress(rawAddress)
- if err != nil {
- return nil, err
- }
+ var canonicalAddress string
+ var err error
+
+ switch getConnectionType(rawAddress) {
+ case invalidConnection:
+ return nil, fmt.Errorf("invalid connection string: %s", rawAddress)
+
+ case tlsConnection:
+ canonicalAddress, err = extractHostFromRemoteURL(rawAddress) // Ensure the form: "host:port" ...
+ if err != nil {
+ return nil, err
+ }
- if isTLS(rawAddress) {
certPool, err := systemCertPool()
if err != nil {
return nil, err
@@ -26,8 +46,29 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro
creds := credentials.NewClientTLSFromCert(certPool, "")
connOpts = append(connOpts, grpc.WithTransportCredentials(creds))
- } else {
+
+ case tcpConnection:
+ canonicalAddress, err = extractHostFromRemoteURL(rawAddress) // Ensure the form: "host:port" ...
+ if err != nil {
+ return nil, err
+ }
connOpts = append(connOpts, grpc.WithInsecure())
+
+ case unixConnection:
+ canonicalAddress = rawAddress // This will be overriden by the custom dialer...
+ connOpts = append(
+ connOpts,
+ grpc.WithInsecure(),
+ grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
+ path, err := extractPathFromSocketURL(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ return net.DialTimeout("unix", path, timeout)
+ }),
+ )
+
}
conn, err := grpc.Dial(canonicalAddress, connOpts...)
@@ -38,7 +79,20 @@ func Dial(rawAddress string, connOpts []grpc.DialOption) (*grpc.ClientConn, erro
return conn, nil
}
-func isTLS(rawAddress string) bool {
+func getConnectionType(rawAddress string) connectionType {
u, err := url.Parse(rawAddress)
- return err == nil && u.Scheme == "tls"
+ if err != nil {
+ return invalidConnection
+ }
+
+ switch u.Scheme {
+ case "tls":
+ return tlsConnection
+ case "unix":
+ return unixConnection
+ case "tcp":
+ return tcpConnection
+ default:
+ return invalidConnection
+ }
}