1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
package listenmux
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"google.golang.org/grpc/credentials"
)
const magicLen = 11
// Handshaker represents a multiplexed connection type.
type Handshaker interface {
// Handshake is called with a valid Conn and AuthInfo when grpc-go
// accepts a connection that presents the Magic() magic bytes. From the
// point of view of grpc-go, it is part of the
// credentials.TransportCredentials.ServerHandshake callback. The return
// values of Handshake become the return values of
// credentials.TransportCredentials.ServerHandshake.
Handshake(net.Conn, credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error)
// Magic returns the magic bytes that clients must send on the wire to
// reach this handshaker. The string must be no more than 11 bytes long.
// If it is longer, this handshaker will never be called.
Magic() string
}
var _ credentials.TransportCredentials = &Mux{}
// Mux is a listener multiplexer that plugs into grpc-go as a "TransportCredentials" callback.
type Mux struct {
credentials.TransportCredentials
handshakers map[string]Handshaker
}
// Register registers a handshaker. It is not thread-safe.
func (m *Mux) Register(h Handshaker) {
if len(h.Magic()) != magicLen {
panic("wrong magic bytes length")
}
m.handshakers[h.Magic()] = h
}
// New returns a *Mux that wraps existing transport credentials. This
// does nothing interesting unless you also call Register to add
// handshakers to the Mux.
func New(tc credentials.TransportCredentials) *Mux {
return &Mux{
TransportCredentials: tc,
handshakers: make(map[string]Handshaker),
}
}
// restoredConn allows for restoring the connection's stream after peeking it. If the connection
// was not multiplexed, the peeked bytes are restored back into the stream.
type restoredConn struct {
net.Conn
reader io.Reader
}
func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }
// ServerHandshake peeks the connection to determine whether the client
// wants to make a multiplexed connection. It is part of the
// TransportCredentials interface.
func (m *Mux) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, authInfo, err := m.TransportCredentials.ServerHandshake(conn)
if err != nil {
return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
}
peeked, err := ioutil.ReadAll(io.LimitReader(conn, magicLen))
if err != nil {
return nil, nil, fmt.Errorf("peek network stream: %w", err)
}
if h, ok := m.handshakers[string(peeked)]; ok {
return h.Handshake(conn, authInfo)
}
return &restoredConn{
Conn: conn,
reader: io.MultiReader(bytes.NewReader(peeked), conn),
}, authInfo, nil
}
|