Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitlab-foss.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'workhorse/internal/channel')
-rw-r--r--workhorse/internal/channel/auth_checker.go69
-rw-r--r--workhorse/internal/channel/auth_checker_test.go53
-rw-r--r--workhorse/internal/channel/channel.go132
-rw-r--r--workhorse/internal/channel/proxy.go56
-rw-r--r--workhorse/internal/channel/wrappers.go134
-rw-r--r--workhorse/internal/channel/wrappers_test.go155
6 files changed, 599 insertions, 0 deletions
diff --git a/workhorse/internal/channel/auth_checker.go b/workhorse/internal/channel/auth_checker.go
new file mode 100644
index 00000000000..f44850e0861
--- /dev/null
+++ b/workhorse/internal/channel/auth_checker.go
@@ -0,0 +1,69 @@
+package channel
+
+import (
+ "errors"
+ "net/http"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+type AuthCheckerFunc func() *api.ChannelSettings
+
+// Regularly checks that authorization is still valid for a channel, outputting
+// to the stopper when it isn't
+type AuthChecker struct {
+ Checker AuthCheckerFunc
+ Template *api.ChannelSettings
+ StopCh chan error
+ Done chan struct{}
+ Count int64
+}
+
+var ErrAuthChanged = errors.New("connection closed: authentication changed or endpoint unavailable")
+
+func NewAuthChecker(f AuthCheckerFunc, template *api.ChannelSettings, stopCh chan error) *AuthChecker {
+ return &AuthChecker{
+ Checker: f,
+ Template: template,
+ StopCh: stopCh,
+ Done: make(chan struct{}),
+ }
+}
+func (c *AuthChecker) Loop(interval time.Duration) {
+ for {
+ select {
+ case <-time.After(interval):
+ settings := c.Checker()
+ if !c.Template.IsEqual(settings) {
+ c.StopCh <- ErrAuthChanged
+ return
+ }
+ c.Count = c.Count + 1
+ case <-c.Done:
+ return
+ }
+ }
+}
+
+func (c *AuthChecker) Close() error {
+ close(c.Done)
+ return nil
+}
+
+// Generates a CheckerFunc from an *api.API + request needing authorization
+func authCheckFunc(myAPI *api.API, r *http.Request, suffix string) AuthCheckerFunc {
+ return func() *api.ChannelSettings {
+ httpResponse, authResponse, err := myAPI.PreAuthorize(suffix, r)
+ if err != nil {
+ return nil
+ }
+ defer httpResponse.Body.Close()
+
+ if httpResponse.StatusCode != http.StatusOK || authResponse == nil {
+ return nil
+ }
+
+ return authResponse.Channel
+ }
+}
diff --git a/workhorse/internal/channel/auth_checker_test.go b/workhorse/internal/channel/auth_checker_test.go
new file mode 100644
index 00000000000..18beb45cf3a
--- /dev/null
+++ b/workhorse/internal/channel/auth_checker_test.go
@@ -0,0 +1,53 @@
+package channel
+
+import (
+ "testing"
+ "time"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+)
+
+func checkerSeries(values ...*api.ChannelSettings) AuthCheckerFunc {
+ return func() *api.ChannelSettings {
+ if len(values) == 0 {
+ return nil
+ }
+ out := values[0]
+ values = values[1:]
+ return out
+ }
+}
+
+func TestAuthCheckerStopsWhenAuthFails(t *testing.T) {
+ template := &api.ChannelSettings{Url: "ws://example.com"}
+ stopCh := make(chan error)
+ series := checkerSeries(template, template, template)
+ ac := NewAuthChecker(series, template, stopCh)
+
+ go ac.Loop(1 * time.Millisecond)
+ if err := <-stopCh; err != ErrAuthChanged {
+ t.Fatalf("Expected ErrAuthChanged, got %v", err)
+ }
+
+ if ac.Count != 3 {
+ t.Fatalf("Expected 3 successful checks, got %v", ac.Count)
+ }
+}
+
+func TestAuthCheckerStopsWhenAuthChanges(t *testing.T) {
+ template := &api.ChannelSettings{Url: "ws://example.com"}
+ changed := template.Clone()
+ changed.Url = "wss://example.com"
+ stopCh := make(chan error)
+ series := checkerSeries(template, changed, template)
+ ac := NewAuthChecker(series, template, stopCh)
+
+ go ac.Loop(1 * time.Millisecond)
+ if err := <-stopCh; err != ErrAuthChanged {
+ t.Fatalf("Expected ErrAuthChanged, got %v", err)
+ }
+
+ if ac.Count != 1 {
+ t.Fatalf("Expected 1 successful check, got %v", ac.Count)
+ }
+}
diff --git a/workhorse/internal/channel/channel.go b/workhorse/internal/channel/channel.go
new file mode 100644
index 00000000000..381ce95df82
--- /dev/null
+++ b/workhorse/internal/channel/channel.go
@@ -0,0 +1,132 @@
+package channel
+
+import (
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/gorilla/websocket"
+
+ "gitlab.com/gitlab-org/labkit/log"
+
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/api"
+ "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper"
+)
+
+var (
+ // See doc/channel.md for documentation of this subprotocol
+ subprotocols = []string{"terminal.gitlab.com", "base64.terminal.gitlab.com"}
+ upgrader = &websocket.Upgrader{Subprotocols: subprotocols}
+ ReauthenticationInterval = 5 * time.Minute
+ BrowserPingInterval = 30 * time.Second
+)
+
+func Handler(myAPI *api.API) http.Handler {
+ return myAPI.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) {
+ if err := a.Channel.Validate(); err != nil {
+ helper.Fail500(w, r, err)
+ return
+ }
+
+ proxy := NewProxy(2) // two stoppers: auth checker, max time
+ checker := NewAuthChecker(
+ authCheckFunc(myAPI, r, "authorize"),
+ a.Channel,
+ proxy.StopCh,
+ )
+ defer checker.Close()
+ go checker.Loop(ReauthenticationInterval)
+ go closeAfterMaxTime(proxy, a.Channel.MaxSessionTime)
+
+ ProxyChannel(w, r, a.Channel, proxy)
+ }, "authorize")
+}
+
+func ProxyChannel(w http.ResponseWriter, r *http.Request, settings *api.ChannelSettings, proxy *Proxy) {
+ server, err := connectToServer(settings, r)
+ if err != nil {
+ helper.Fail500(w, r, err)
+ log.ContextLogger(r.Context()).WithError(err).Print("Channel: connecting to server failed")
+ return
+ }
+ defer server.UnderlyingConn().Close()
+ serverAddr := server.UnderlyingConn().RemoteAddr().String()
+
+ client, err := upgradeClient(w, r)
+ if err != nil {
+ log.ContextLogger(r.Context()).WithError(err).Print("Channel: upgrading client to websocket failed")
+ return
+ }
+
+ // Regularly send ping messages to the browser to keep the websocket from
+ // being timed out by intervening proxies.
+ go pingLoop(client)
+
+ defer client.UnderlyingConn().Close()
+ clientAddr := getClientAddr(r) // We can't know the port with confidence
+
+ logEntry := log.WithContextFields(r.Context(), log.Fields{
+ "clientAddr": clientAddr,
+ "serverAddr": serverAddr,
+ })
+
+ logEntry.Print("Channel: started proxying")
+
+ defer logEntry.Print("Channel: finished proxying")
+
+ if err := proxy.Serve(server, client, serverAddr, clientAddr); err != nil {
+ logEntry.WithError(err).Print("Channel: error proxying")
+ }
+}
+
+// In the future, we might want to look at X-Client-Ip or X-Forwarded-For
+func getClientAddr(r *http.Request) string {
+ return r.RemoteAddr
+}
+
+func upgradeClient(w http.ResponseWriter, r *http.Request) (Connection, error) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ return Wrap(conn, conn.Subprotocol()), nil
+}
+
+func pingLoop(conn Connection) {
+ for {
+ time.Sleep(BrowserPingInterval)
+ deadline := time.Now().Add(5 * time.Second)
+ if err := conn.WriteControl(websocket.PingMessage, nil, deadline); err != nil {
+ // Either the connection was already closed so no further pings are
+ // needed, or this connection is now dead and no further pings can
+ // be sent.
+ break
+ }
+ }
+}
+
+func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) {
+ settings = settings.Clone()
+
+ helper.SetForwardedFor(&settings.Header, r)
+
+ conn, _, err := settings.Dial()
+ if err != nil {
+ return nil, err
+ }
+
+ return Wrap(conn, conn.Subprotocol()), nil
+}
+
+func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) {
+ if maxSessionTime == 0 {
+ return
+ }
+
+ <-time.After(time.Duration(maxSessionTime) * time.Second)
+ proxy.StopCh <- fmt.Errorf(
+ "connection closed: session time greater than maximum time allowed - %v seconds",
+ maxSessionTime,
+ )
+}
diff --git a/workhorse/internal/channel/proxy.go b/workhorse/internal/channel/proxy.go
new file mode 100644
index 00000000000..71f58092276
--- /dev/null
+++ b/workhorse/internal/channel/proxy.go
@@ -0,0 +1,56 @@
+package channel
+
+import (
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+// ANSI "end of channel" code
+var eot = []byte{0x04}
+
+// An abstraction of gorilla's *websocket.Conn
+type Connection interface {
+ UnderlyingConn() net.Conn
+ ReadMessage() (int, []byte, error)
+ WriteMessage(int, []byte) error
+ WriteControl(int, []byte, time.Time) error
+}
+
+type Proxy struct {
+ StopCh chan error
+}
+
+// stoppers is the number of goroutines that may attempt to call Stop()
+func NewProxy(stoppers int) *Proxy {
+ return &Proxy{
+ StopCh: make(chan error, stoppers+2), // each proxy() call is a stopper
+ }
+}
+
+func (p *Proxy) Serve(upstream, downstream Connection, upstreamAddr, downstreamAddr string) error {
+ // This signals the upstream channel to kill the exec'd process
+ defer upstream.WriteMessage(websocket.BinaryMessage, eot)
+
+ go p.proxy(upstream, downstream, upstreamAddr, downstreamAddr)
+ go p.proxy(downstream, upstream, downstreamAddr, upstreamAddr)
+
+ return <-p.StopCh
+}
+
+func (p *Proxy) proxy(to, from Connection, toAddr, fromAddr string) {
+ for {
+ messageType, data, err := from.ReadMessage()
+ if err != nil {
+ p.StopCh <- fmt.Errorf("reading from %s: %s", fromAddr, err)
+ break
+ }
+
+ if err := to.WriteMessage(messageType, data); err != nil {
+ p.StopCh <- fmt.Errorf("writing to %s: %s", toAddr, err)
+ break
+ }
+ }
+}
diff --git a/workhorse/internal/channel/wrappers.go b/workhorse/internal/channel/wrappers.go
new file mode 100644
index 00000000000..6fd955bedc7
--- /dev/null
+++ b/workhorse/internal/channel/wrappers.go
@@ -0,0 +1,134 @@
+package channel
+
+import (
+ "encoding/base64"
+ "net"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+func Wrap(conn Connection, subprotocol string) Connection {
+ switch subprotocol {
+ case "channel.k8s.io":
+ return &kubeWrapper{base64: false, conn: conn}
+ case "base64.channel.k8s.io":
+ return &kubeWrapper{base64: true, conn: conn}
+ case "terminal.gitlab.com":
+ return &gitlabWrapper{base64: false, conn: conn}
+ case "base64.terminal.gitlab.com":
+ return &gitlabWrapper{base64: true, conn: conn}
+ }
+
+ return conn
+}
+
+type kubeWrapper struct {
+ base64 bool
+ conn Connection
+}
+
+type gitlabWrapper struct {
+ base64 bool
+ conn Connection
+}
+
+func (w *gitlabWrapper) ReadMessage() (int, []byte, error) {
+ mt, data, err := w.conn.ReadMessage()
+ if err != nil {
+ return mt, data, err
+ }
+
+ if isData(mt) {
+ mt = websocket.BinaryMessage
+ if w.base64 {
+ data, err = decodeBase64(data)
+ }
+ }
+
+ return mt, data, err
+}
+
+func (w *gitlabWrapper) WriteMessage(mt int, data []byte) error {
+ if isData(mt) {
+ if w.base64 {
+ mt = websocket.TextMessage
+ data = encodeBase64(data)
+ } else {
+ mt = websocket.BinaryMessage
+ }
+ }
+
+ return w.conn.WriteMessage(mt, data)
+}
+
+func (w *gitlabWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
+ return w.conn.WriteControl(mt, data, deadline)
+}
+
+func (w *gitlabWrapper) UnderlyingConn() net.Conn {
+ return w.conn.UnderlyingConn()
+}
+
+// Coalesces all wsstreams into a single stream. In practice, we should only
+// receive data on stream 1.
+func (w *kubeWrapper) ReadMessage() (int, []byte, error) {
+ mt, data, err := w.conn.ReadMessage()
+ if err != nil {
+ return mt, data, err
+ }
+
+ if isData(mt) {
+ mt = websocket.BinaryMessage
+
+ // Remove the WSStream channel number, decode to raw
+ if len(data) > 0 {
+ data = data[1:]
+ if w.base64 {
+ data, err = decodeBase64(data)
+ }
+ }
+ }
+
+ return mt, data, err
+}
+
+// Always sends to wsstream 0
+func (w *kubeWrapper) WriteMessage(mt int, data []byte) error {
+ if isData(mt) {
+ if w.base64 {
+ mt = websocket.TextMessage
+ data = append([]byte{'0'}, encodeBase64(data)...)
+ } else {
+ mt = websocket.BinaryMessage
+ data = append([]byte{0}, data...)
+ }
+ }
+
+ return w.conn.WriteMessage(mt, data)
+}
+
+func (w *kubeWrapper) WriteControl(mt int, data []byte, deadline time.Time) error {
+ return w.conn.WriteControl(mt, data, deadline)
+}
+
+func (w *kubeWrapper) UnderlyingConn() net.Conn {
+ return w.conn.UnderlyingConn()
+}
+
+func isData(mt int) bool {
+ return mt == websocket.BinaryMessage || mt == websocket.TextMessage
+}
+
+func encodeBase64(data []byte) []byte {
+ buf := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
+ base64.StdEncoding.Encode(buf, data)
+
+ return buf
+}
+
+func decodeBase64(data []byte) ([]byte, error) {
+ buf := make([]byte, base64.StdEncoding.DecodedLen(len(data)))
+ n, err := base64.StdEncoding.Decode(buf, data)
+ return buf[:n], err
+}
diff --git a/workhorse/internal/channel/wrappers_test.go b/workhorse/internal/channel/wrappers_test.go
new file mode 100644
index 00000000000..1e0226f85d8
--- /dev/null
+++ b/workhorse/internal/channel/wrappers_test.go
@@ -0,0 +1,155 @@
+package channel
+
+import (
+ "bytes"
+ "errors"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/gorilla/websocket"
+)
+
+type testcase struct {
+ input *fakeConn
+ expected *fakeConn
+}
+
+type fakeConn struct {
+ // WebSocket message type
+ mt int
+ data []byte
+ err error
+}
+
+func (f *fakeConn) ReadMessage() (int, []byte, error) {
+ return f.mt, f.data, f.err
+}
+
+func (f *fakeConn) WriteMessage(mt int, data []byte) error {
+ f.mt = mt
+ f.data = data
+ return f.err
+}
+
+func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
+ f.mt = mt
+ f.data = data
+ return f.err
+}
+
+func (f *fakeConn) UnderlyingConn() net.Conn {
+ return nil
+}
+
+func fake(mt int, data []byte, err error) *fakeConn {
+ return &fakeConn{mt: mt, data: []byte(data), err: err}
+}
+
+var (
+ msg = []byte("foo bar")
+ msgBase64 = []byte("Zm9vIGJhcg==")
+ kubeMsg = append([]byte{0}, msg...)
+ kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
+
+ errFake = errors.New("fake error")
+
+ text = websocket.TextMessage
+ binary = websocket.BinaryMessage
+ other = 999
+
+ fakeOther = fake(other, []byte("foo"), nil)
+)
+
+func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
+ if expected.mt != actual.mt {
+ t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
+ t.Fatalf(msg, args...)
+ }
+
+ if !bytes.Equal(expected.data, actual.data) {
+ t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
+ t.Fatalf(msg, args...)
+ }
+
+ if expected.err != actual.err {
+ t.Logf("error expected to be %v but was %v", expected.err, actual.err)
+ t.Fatalf(msg, args...)
+ }
+}
+
+func TestReadMessage(t *testing.T) {
+ testCases := map[string][]testcase{
+ "channel.k8s.io": {
+ {fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)},
+ {fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
+ {fake(text, kubeMsg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.channel.k8s.io": {
+ {fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)},
+ {fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
+ {fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(binary, msg, errFake)},
+ {fake(binary, msg, nil), fake(binary, msg, nil)},
+ {fake(text, msg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.terminal.gitlab.com": {
+ {fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)},
+ {fake(text, msgBase64, nil), fake(binary, msg, nil)},
+ {fake(binary, msgBase64, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ }
+
+ for subprotocol, cases := range testCases {
+ for i, tc := range cases {
+ conn := Wrap(tc.input, subprotocol)
+ mt, data, err := conn.ReadMessage()
+ actual := fake(mt, data, err)
+ requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
+ }
+ }
+}
+
+func TestWriteMessage(t *testing.T) {
+ testCases := map[string][]testcase{
+ "channel.k8s.io": {
+ {fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)},
+ {fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
+ {fake(text, msg, nil), fake(binary, kubeMsg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.channel.k8s.io": {
+ {fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)},
+ {fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
+ {fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
+ {fakeOther, fakeOther},
+ },
+ "terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(binary, msg, errFake)},
+ {fake(binary, msg, nil), fake(binary, msg, nil)},
+ {fake(text, msg, nil), fake(binary, msg, nil)},
+ {fakeOther, fakeOther},
+ },
+ "base64.terminal.gitlab.com": {
+ {fake(binary, msg, errFake), fake(text, msgBase64, errFake)},
+ {fake(binary, msg, nil), fake(text, msgBase64, nil)},
+ {fake(text, msg, nil), fake(text, msgBase64, nil)},
+ {fakeOther, fakeOther},
+ },
+ }
+
+ for subprotocol, cases := range testCases {
+ for i, tc := range cases {
+ actual := fake(0, nil, tc.input.err)
+ conn := Wrap(actual, subprotocol)
+ actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
+ requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
+ }
+ }
+}