Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/channel_test.go')
-rw-r--r--workhorse/channel_test.go245
1 files changed, 245 insertions, 0 deletions
diff --git a/workhorse/channel_test.go b/workhorse/channel_test.go
new file mode 100644
index 00000000000..cd8957ed829
--- /dev/null
+++ b/workhorse/channel_test.go
@@ -0,0 +1,245 @@
+package main
+
+import (
+ "bytes"
+ "encoding/pem"
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "path"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+var (
+ envTerminalPath = fmt.Sprintf("%s/-/environments/1/terminal.ws", testProject)
+ jobTerminalPath = fmt.Sprintf("%s/-/jobs/1/terminal.ws", testProject)
+ servicesProxyWSPath = fmt.Sprintf("%s/-/jobs/1/proxy.ws", testProject)
+)
+
+type connWithReq struct {
+ conn *websocket.Conn
+ req *http.Request
+}
+
+func TestChannelHappyPath(t *testing.T) {
+ tests := []struct {
+ name string
+ channelPath string
+ }{
+ {"environments", envTerminalPath},
+ {"jobs", jobTerminalPath},
+ {"services", servicesProxyWSPath},
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
+ defer close()
+
+ client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
+ require.NoError(t, err)
+
+ server := (<-serverConns).conn
+ defer server.Close()
+
+ message := "test message"
+
+ // channel.k8s.io: server writes to channel 1, STDOUT
+ require.NoError(t, say(server, "\x01"+message))
+ requireReadMessage(t, client, websocket.BinaryMessage, message)
+
+ require.NoError(t, say(client, message))
+
+ // channel.k8s.io: client writes get put on channel 0, STDIN
+ requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
+
+ // Closing the client should send an EOT signal to the server's STDIN
+ client.Close()
+ requireReadMessage(t, server, websocket.BinaryMessage, "\x00\x04")
+ })
+ }
+}
+
+func TestChannelBadTLS(t *testing.T) {
+ _, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
+ defer close()
+
+ _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
+ require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err)
+}
+
+func TestChannelSessionTimeout(t *testing.T) {
+ serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
+ defer close()
+
+ client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
+ require.NoError(t, err)
+
+ sc := <-serverConns
+ defer sc.conn.Close()
+
+ client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
+ _, _, err = client.ReadMessage()
+
+ require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err)
+}
+
+func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
+ hdr := make(http.Header)
+ hdr.Set("Random-Header", "Value")
+ serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
+ defer close()
+
+ client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
+ require.NoError(t, err)
+ defer client.Close()
+
+ sc := <-serverConns
+ defer sc.conn.Close()
+ require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote")
+}
+
+func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
+ serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
+ defer close()
+
+ hdr := make(http.Header)
+ hdr.Set("X-Forwarded-For", "127.0.0.2")
+ client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
+ require.NoError(t, err)
+ defer client.Close()
+
+ clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
+ require.NoError(t, err)
+
+ sc := <-serverConns
+ defer sc.conn.Close()
+
+ require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
+}
+
+func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
+ serverConns, remote := startWebsocketServer(subprotocols...)
+ authResponse := channelOkBody(remote, nil, subprotocols...)
+ if modifier != nil {
+ modifier(authResponse)
+ }
+ upstream := testAuthServer(t, nil, nil, 200, authResponse)
+ workhorse := startWorkhorseServer(upstream.URL)
+
+ return serverConns, websocketURL(workhorse.URL, channelPath), func() {
+ workhorse.Close()
+ upstream.Close()
+ remote.Close()
+ }
+}
+
+func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) {
+ upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
+
+ connCh := make(chan connWithReq, 1)
+ server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh))
+
+ return connCh, server
+}
+
+func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ logEntry := log.WithFields(log.Fields{
+ "method": r.Method,
+ "url": r.URL,
+ "headers": r.Header,
+ })
+
+ logEntry.Info("WEBSOCKET")
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ logEntry.WithError(err).Error("WEBSOCKET Upgrade failed")
+ return
+ }
+ connCh <- connWithReq{conn, r}
+ // The connection has been hijacked so it's OK to end here
+ })
+}
+
+func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
+ out := &api.Response{
+ Channel: &api.ChannelSettings{
+ Url: websocketURL(remote.URL),
+ Header: header,
+ Subprotocols: subprotocols,
+ MaxSessionTime: 0,
+ },
+ }
+
+ if len(remote.TLS.Certificates) > 0 {
+ data := bytes.NewBuffer(nil)
+ pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]})
+ out.Channel.CAPem = data.String()
+ }
+
+ return out
+}
+
+func badCA(authResponse *api.Response) {
+ authResponse.Channel.CAPem = "Bad CA"
+}
+
+func timeout(authResponse *api.Response) {
+ authResponse.Channel.MaxSessionTime = 1
+}
+
+func setHeader(hdr http.Header) func(*api.Response) {
+ return func(authResponse *api.Response) {
+ authResponse.Channel.Header = hdr
+ }
+}
+
+func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) {
+ dialer := &websocket.Dialer{
+ Subprotocols: subprotocols,
+ }
+
+ return dialer.Dial(url, header)
+}
+
+func websocketURL(httpURL string, suffix ...string) string {
+ url, err := url.Parse(httpURL)
+ if err != nil {
+ panic(err)
+ }
+
+ switch url.Scheme {
+ case "http":
+ url.Scheme = "ws"
+ case "https":
+ url.Scheme = "wss"
+ default:
+ panic("Unknown scheme: " + url.Scheme)
+ }
+
+ url.Path = path.Join(url.Path, strings.Join(suffix, "/"))
+
+ return url.String()
+}
+
+func say(conn *websocket.Conn, message string) error {
+ return conn.WriteMessage(websocket.TextMessage, []byte(message))
+}
+
+func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
+ messageType, data, err := conn.ReadMessage()
+ require.NoError(t, err)
+
+ require.Equal(t, expectedMessageType, messageType, "message type")
+ require.Equal(t, expectedData, string(data), "message data")
+}