package main import ( "bytes" "encoding/pem" "fmt" "net" "net/http" "net/http/httptest" "net/url" "path" "strings" "testing" "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" ) var ( envTerminalPath = fmt.Sprintf("%s/-/environments/1/terminal.ws", testProject) jobTerminalPath = fmt.Sprintf("%s/-/jobs/1/terminal.ws", testProject) servicesProxyWSPath = fmt.Sprintf("%s/-/jobs/1/proxy.ws", testProject) ) type connWithReq struct { conn *websocket.Conn req *http.Request } func TestChannelHappyPath(t *testing.T) { tests := []struct { name string channelPath string }{ {"environments", envTerminalPath}, {"jobs", jobTerminalPath}, {"services", servicesProxyWSPath}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") require.NoError(t, err) server := (<-serverConns).conn defer server.Close() message := "test message" // channel.k8s.io: server writes to channel 1, STDOUT require.NoError(t, say(server, "\x01"+message)) requireReadMessage(t, client, websocket.BinaryMessage, message) require.NoError(t, say(client, message)) // channel.k8s.io: client writes get put on channel 0, STDIN requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message) // Closing the client should send an EOT signal to the server's STDIN client.Close() requireReadMessage(t, server, websocket.BinaryMessage, "\x00\x04") }) } } func TestChannelBadTLS(t *testing.T) { _, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io") defer close() _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err) } func TestChannelSessionTimeout(t *testing.T) { serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") require.NoError(t, err) sc := <-serverConns defer sc.conn.Close() client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second)) _, _, err = client.ReadMessage() require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err) } func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { hdr := make(http.Header) hdr.Set("Random-Header", "Value") serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") require.NoError(t, err) defer client.Close() sc := <-serverConns defer sc.conn.Close() require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote") } func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io") defer close() hdr := make(http.Header) hdr.Set("X-Forwarded-For", "127.0.0.2") client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com") require.NoError(t, err) defer client.Close() clientIP, _, err := net.SplitHostPort(client.LocalAddr().String()) require.NoError(t, err) sc := <-serverConns defer sc.conn.Close() require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote") } func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { serverConns, remote := startWebsocketServer(subprotocols...) authResponse := channelOkBody(remote, nil, subprotocols...) if modifier != nil { modifier(authResponse) } upstream := testAuthServer(t, nil, nil, 200, authResponse) workhorse := startWorkhorseServer(upstream.URL) return serverConns, websocketURL(workhorse.URL, channelPath), func() { workhorse.Close() upstream.Close() remote.Close() } } func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) { upgrader := &websocket.Upgrader{Subprotocols: subprotocols} connCh := make(chan connWithReq, 1) server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh)) return connCh, server } func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logEntry := log.WithFields(log.Fields{ "method": r.Method, "url": r.URL, "headers": r.Header, }) logEntry.Info("WEBSOCKET") conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logEntry.WithError(err).Error("WEBSOCKET Upgrade failed") return } connCh <- connWithReq{conn, r} // The connection has been hijacked so it's OK to end here }) } func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response { out := &api.Response{ Channel: &api.ChannelSettings{ Url: websocketURL(remote.URL), Header: header, Subprotocols: subprotocols, MaxSessionTime: 0, }, } if len(remote.TLS.Certificates) > 0 { data := bytes.NewBuffer(nil) pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]}) out.Channel.CAPem = data.String() } return out } func badCA(authResponse *api.Response) { authResponse.Channel.CAPem = "Bad CA" } func timeout(authResponse *api.Response) { authResponse.Channel.MaxSessionTime = 1 } func setHeader(hdr http.Header) func(*api.Response) { return func(authResponse *api.Response) { authResponse.Channel.Header = hdr } } func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) { dialer := &websocket.Dialer{ Subprotocols: subprotocols, } return dialer.Dial(url, header) } func websocketURL(httpURL string, suffix ...string) string { url, err := url.Parse(httpURL) if err != nil { panic(err) } switch url.Scheme { case "http": url.Scheme = "ws" case "https": url.Scheme = "wss" default: panic("Unknown scheme: " + url.Scheme) } url.Path = path.Join(url.Path, strings.Join(suffix, "/")) return url.String() } func say(conn *websocket.Conn, message string) error { return conn.WriteMessage(websocket.TextMessage, []byte(message)) } func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) { messageType, data, err := conn.ReadMessage() require.NoError(t, err) require.Equal(t, expectedMessageType, messageType, "message type") require.Equal(t, expectedData, string(data), "message data") }