Welcome to mirror list, hosted at ThFree Co, Russian Federation.

mux.go « listenmux « grpc « internal - gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f5a1ed67a97fe68fd9c6e2e5dd4e772fc290bf01 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package listenmux

import (
	"bytes"
	"fmt"
	"io"
	"net"

	"google.golang.org/grpc/credentials"
)

const magicLen = 11

// Handshaker represents a multiplexed connection type.
type Handshaker interface {
	// Handshake is called with a valid Conn and AuthInfo when grpc-go
	// accepts a connection that presents the Magic() magic bytes. From the
	// point of view of grpc-go, it is part of the
	// credentials.TransportCredentials.ServerHandshake callback. The return
	// values of Handshake become the return values of
	// credentials.TransportCredentials.ServerHandshake.
	Handshake(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error)

	// Magic returns the magic bytes that clients must send on the wire to
	// reach this handshaker. The string must be no more than 11 bytes long.
	// If it is longer, this handshaker will never be called.
	Magic() string
}

var _ credentials.TransportCredentials = &Mux{}

// Mux is a listener multiplexer that plugs into grpc-go as a "TransportCredentials" callback.
type Mux struct {
	credentials.TransportCredentials
	handshakers map[string]Handshaker
}

// Register registers a handshaker. It is not thread-safe.
func (m *Mux) Register(h Handshaker) {
	if len(h.Magic()) != magicLen {
		panic("wrong magic bytes length")
	}

	m.handshakers[h.Magic()] = h
}

// New returns a *Mux that wraps existing transport credentials. This
// does nothing interesting unless you also call Register to add
// handshakers to the Mux.
func New(tc credentials.TransportCredentials) *Mux {
	return &Mux{
		TransportCredentials: tc,
		handshakers:          make(map[string]Handshaker),
	}
}

// restoredConn allows for restoring the connection's stream after peeking it. If the connection
// was not multiplexed, the peeked bytes are restored back into the stream.
type restoredConn struct {
	net.Conn
	reader io.Reader
}

func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }

// ServerHandshake peeks the connection to determine whether the client
// wants to make a multiplexed connection. It is part of the
// TransportCredentials interface.
func (m *Mux) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
	conn, authInfo, err := m.TransportCredentials.ServerHandshake(conn)
	if err != nil {
		return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
	}

	peeked, err := io.ReadAll(io.LimitReader(conn, magicLen))
	if err != nil {
		return nil, nil, fmt.Errorf("peek network stream: %w", err)
	}

	if h, ok := m.handshakers[string(peeked)]; ok {
		return h.Handshake(conn, authInfo)
	}

	return &restoredConn{
		Conn:   conn,
		reader: io.MultiReader(bytes.NewReader(peeked), conn),
	}, authInfo, nil
}