Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NOTICE3995
-rw-r--r--go.mod1
-rw-r--r--go.sum8
-rw-r--r--internal/backchannel/backchannel.go60
-rw-r--r--internal/backchannel/backchannel_example_test.go139
-rw-r--r--internal/backchannel/backchannel_test.go223
-rw-r--r--internal/backchannel/client.go125
-rw-r--r--internal/backchannel/insecure.go40
-rw-r--r--internal/backchannel/registry.go51
-rw-r--r--internal/backchannel/server.go146
10 files changed, 4782 insertions, 6 deletions
diff --git a/NOTICE b/NOTICE
index ddd6bcfb9..261bb9fbd 100644
--- a/NOTICE
+++ b/NOTICE
@@ -5203,6 +5203,4001 @@ func TestLRU_Resize(t *testing.T) {
}
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.gitignore - github.com/hashicorp/yamux
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+*.test
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+LICENSE - github.com/hashicorp/yamux
+Mozilla Public License, version 2.0
+
+1. Definitions
+
+1.1. "Contributor"
+
+ means each individual or legal entity that creates, contributes to the
+ creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+
+ means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+
+ means Source Code Form to which the initial Contributor has attached the
+ notice in Exhibit A, the Executable Form of such Source Code Form, and
+ Modifications of such Source Code Form, in each case including portions
+ thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+ means
+
+ a. that the initial Contributor has attached the notice described in
+ Exhibit B to the Covered Software; or
+
+ b. that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the terms of
+ a Secondary License.
+
+1.6. "Executable Form"
+
+ means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+
+ means a work that combines Covered Software with other material, in a
+ separate file or files, that is not Covered Software.
+
+1.8. "License"
+
+ means this document.
+
+1.9. "Licensable"
+
+ means having the right to grant, to the maximum extent possible, whether
+ at the time of the initial grant or subsequently, any and all of the
+ rights conveyed by this License.
+
+1.10. "Modifications"
+
+ means any of the following:
+
+ a. any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered Software; or
+
+ b. any new file in Source Code Form that contains any Covered Software.
+
+1.11. "Patent Claims" of a Contributor
+
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the License,
+ by the making, using, selling, offering for sale, having made, import,
+ or transfer of either its Contributions or its Contributor Version.
+
+1.12. "Secondary License"
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public
+ License, Version 3.0, or any later versions of those licenses.
+
+1.13. "Source Code Form"
+
+ means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that controls, is
+ controlled by, or is under common control with You. For purposes of this
+ definition, "control" means (a) the power, direct or indirect, to cause
+ the direction or management of such entity, whether by contract or
+ otherwise, or (b) ownership of more than fifty percent (50%) of the
+ outstanding shares or beneficial ownership of such entity.
+
+
+2. License Grants and Conditions
+
+2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ a. under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+ b. under Patent Claims of such Contributor to make, use, sell, offer for
+ sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution
+ become effective for each Contribution on the date the Contributor first
+ distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under
+ this License. No additional rights or licenses will be implied from the
+ distribution or licensing of Covered Software under this License.
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
+ Contributor:
+
+ a. for any code that a Contributor has removed from Covered Software; or
+
+ b. for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c. under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+ This License does not grant any rights in the trademarks, service marks,
+ or logos of any Contributor (except as may be necessary to comply with
+ the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this
+ License (see Section 10.2) or under the terms of a Secondary License (if
+ permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+ Each Contributor represents that the Contributor believes its
+ Contributions are its original creation(s) or it has sufficient rights to
+ grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+ This License is not intended to limit any rights You have under
+ applicable copyright doctrines of fair use, fair dealing, or other
+ equivalents.
+
+2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+
+3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under
+ the terms of this License. You must inform recipients that the Source
+ Code Form of the Covered Software is governed by the terms of this
+ License, and how they can obtain a copy of this License. You may not
+ attempt to alter or restrict the recipients' rights in the Source Code
+ Form.
+
+3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a. such Covered Software must also be made available in Source Code Form,
+ as described in Section 3.1, and You must inform recipients of the
+ Executable Form how they can obtain a copy of such Source Code Form by
+ reasonable means in a timely manner, at a charge no more than the cost
+ of distribution to the recipient; and
+
+ b. You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter the
+ recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for
+ the Covered Software. If the Larger Work is a combination of Covered
+ Software with a work governed by one or more Secondary Licenses, and the
+ Covered Software is not Incompatible With Secondary Licenses, this
+ License permits You to additionally distribute such Covered Software
+ under the terms of such Secondary License(s), so that the recipient of
+ the Larger Work may, at their option, further distribute the Covered
+ Software under the terms of either this License or such Secondary
+ License(s).
+
+3.4. Notices
+
+ You may not remove or alter the substance of any license notices
+ (including copyright notices, patent notices, disclaimers of warranty, or
+ limitations of liability) contained within the Source Code Form of the
+ Covered Software, except that You may alter any license notices to the
+ extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on
+ behalf of any Contributor. You must make it absolutely clear that any
+ such warranty, support, indemnity, or liability obligation is offered by
+ You alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+
+ If it is impossible for You to comply with any of the terms of this License
+ with respect to some or all of the Covered Software due to statute,
+ judicial order, or regulation then You must: (a) comply with the terms of
+ this License to the maximum extent possible; and (b) describe the
+ limitations and the code they affect. Such description must be placed in a
+ text file included with all distributions of the Covered Software under
+ this License. Except to the extent prohibited by statute or regulation,
+ such description must be sufficiently detailed for a recipient of ordinary
+ skill to be able to understand it.
+
+5. Termination
+
+5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant,
+ then the rights granted under this License from a particular Contributor
+ are reinstated (a) provisionally, unless and until such Contributor
+ explicitly and finally terminates Your grants, and (b) on an ongoing
+ basis, if such Contributor fails to notify You of the non-compliance by
+ some reasonable means prior to 60 days after You have come back into
+ compliance. Moreover, Your grants from a particular Contributor are
+ reinstated on an ongoing basis if such Contributor notifies You of the
+ non-compliance by some reasonable means, this is the first time You have
+ received notice of non-compliance with this License from such
+ Contributor, and You become compliant prior to 30 days after Your receipt
+ of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions,
+ counter-claims, and cross-claims) alleging that a Contributor Version
+ directly or indirectly infringes any patent, then the rights granted to
+ You by any and all Contributors for the Covered Software under Section
+ 2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+6. Disclaimer of Warranty
+
+ Covered Software is provided under this License on an "as is" basis,
+ without warranty of any kind, either expressed, implied, or statutory,
+ including, without limitation, warranties that the Covered Software is free
+ of defects, merchantable, fit for a particular purpose or non-infringing.
+ The entire risk as to the quality and performance of the Covered Software
+ is with You. Should any Covered Software prove defective in any respect,
+ You (not any Contributor) assume the cost of any necessary servicing,
+ repair, or correction. This disclaimer of warranty constitutes an essential
+ part of this License. No use of any Covered Software is authorized under
+ this License except under this disclaimer.
+
+7. Limitation of Liability
+
+ Under no circumstances and under no legal theory, whether tort (including
+ negligence), contract, or otherwise, shall any Contributor, or anyone who
+ distributes Covered Software as permitted above, be liable to You for any
+ direct, indirect, special, incidental, or consequential damages of any
+ character including, without limitation, damages for lost profits, loss of
+ goodwill, work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses, even if such party shall have been
+ informed of the possibility of such damages. This limitation of liability
+ shall not apply to liability for death or personal injury resulting from
+ such party's negligence to the extent applicable law prohibits such
+ limitation. Some jurisdictions do not allow the exclusion or limitation of
+ incidental or consequential damages, so this exclusion and limitation may
+ not apply to You.
+
+8. Litigation
+
+ Any litigation relating to this License may be brought only in the courts
+ of a jurisdiction where the defendant maintains its principal place of
+ business and such litigation shall be governed by laws of that
+ jurisdiction, without reference to its conflict-of-law provisions. Nothing
+ in this Section shall prevent a party's ability to bring cross-claims or
+ counter-claims.
+
+9. Miscellaneous
+
+ This License represents the complete agreement concerning the subject
+ matter hereof. If any provision of this License is held to be
+ unenforceable, such provision shall be reformed only to the extent
+ necessary to make it enforceable. Any law or regulation which provides that
+ the language of a contract shall be construed against the drafter shall not
+ be used to construe this License against a Contributor.
+
+
+10. Versions of the License
+
+10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version
+ of the License under which You originally received the Covered Software,
+ or under the terms of any subsequent version published by the license
+ steward.
+
+10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a
+ modified version of this License if you rename the license and remove
+ any references to the name of the license steward (except to note that
+ such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses If You choose to distribute Source Code Form that is
+ Incompatible With Secondary Licenses under the terms of this version of
+ the License, the notice described in Exhibit B of this License must be
+ attached.
+
+Exhibit A - Source Code Form License Notice
+
+ This Source Code Form is subject to the
+ terms of the Mozilla Public License, v.
+ 2.0. If a copy of the MPL was not
+ distributed with this file, You can
+ obtain one at
+ http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file,
+then You may include the notice in a location (such as a LICENSE file in a
+relevant directory) where a recipient would be likely to look for such a
+notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+
+ This Source Code Form is "Incompatible
+ With Secondary Licenses", as defined by
+ the Mozilla Public License, v. 2.0.
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+README.md - github.com/hashicorp/yamux
+# Yamux
+
+Yamux (Yet another Multiplexer) is a multiplexing library for Golang.
+It relies on an underlying connection to provide reliability
+and ordering, such as TCP or Unix domain sockets, and provides
+stream-oriented multiplexing. It is inspired by SPDY but is not
+interoperable with it.
+
+Yamux features include:
+
+* Bi-directional streams
+ * Streams can be opened by either client or server
+ * Useful for NAT traversal
+ * Server-side push support
+* Flow control
+ * Avoid starvation
+ * Back-pressure to prevent overwhelming a receiver
+* Keep Alives
+ * Enables persistent connections over a load balancer
+* Efficient
+ * Enables thousands of logical streams with low overhead
+
+## Documentation
+
+For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/yamux).
+
+## Specification
+
+The full specification for Yamux is provided in the `spec.md` file.
+It can be used as a guide to implementors of interoperable libraries.
+
+## Usage
+
+Using Yamux is remarkably simple:
+
+```go
+
+func client() {
+ // Get a TCP connection
+ conn, err := net.Dial(...)
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup client side of yamux
+ session, err := yamux.Client(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Open a new stream
+ stream, err := session.Open()
+ if err != nil {
+ panic(err)
+ }
+
+ // Stream implements net.Conn
+ stream.Write([]byte("ping"))
+}
+
+func server() {
+ // Accept a TCP connection
+ conn, err := listener.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup server side of yamux
+ session, err := yamux.Server(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Accept a stream
+ stream, err := session.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Listen for a message
+ buf := make([]byte, 4)
+ stream.Read(buf)
+}
+
+```
+
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+addr.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "fmt"
+ "net"
+)
+
+// hasAddr is used to get the address from the underlying connection
+type hasAddr interface {
+ LocalAddr() net.Addr
+ RemoteAddr() net.Addr
+}
+
+// yamuxAddr is used when we cannot get the underlying address
+type yamuxAddr struct {
+ Addr string
+}
+
+func (*yamuxAddr) Network() string {
+ return "yamux"
+}
+
+func (y *yamuxAddr) String() string {
+ return fmt.Sprintf("yamux:%s", y.Addr)
+}
+
+// Addr is used to get the address of the listener.
+func (s *Session) Addr() net.Addr {
+ return s.LocalAddr()
+}
+
+// LocalAddr is used to get the local address of the
+// underlying connection.
+func (s *Session) LocalAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &yamuxAddr{"local"}
+ }
+ return addr.LocalAddr()
+}
+
+// RemoteAddr is used to get the address of remote end
+// of the underlying connection
+func (s *Session) RemoteAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &yamuxAddr{"remote"}
+ }
+ return addr.RemoteAddr()
+}
+
+// LocalAddr returns the local address
+func (s *Stream) LocalAddr() net.Addr {
+ return s.session.LocalAddr()
+}
+
+// RemoteAddr returns the remote address
+func (s *Stream) RemoteAddr() net.Addr {
+ return s.session.RemoteAddr()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+bench_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "io"
+ "io/ioutil"
+ "testing"
+)
+
+func BenchmarkPing(b *testing.B) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rtt, err := client.Ping()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ b.Fatalf("bad: %v", rtt)
+ }
+ }
+}
+
+func BenchmarkAccept(b *testing.B) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ doneCh := make(chan struct{})
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ go func() {
+ defer close(doneCh)
+
+ for i := 0; i < b.N; i++ {
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ stream.Close()
+ }
+ }()
+
+ for i := 0; i < b.N; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ stream.Close()
+ }
+ <-doneCh
+}
+
+func BenchmarkSendRecv32(b *testing.B) {
+ const payloadSize = 32
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv64(b *testing.B) {
+ const payloadSize = 64
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv128(b *testing.B) {
+ const payloadSize = 128
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv256(b *testing.B) {
+ const payloadSize = 256
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv512(b *testing.B) {
+ const payloadSize = 512
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv1024(b *testing.B) {
+ const payloadSize = 1024
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv2048(b *testing.B) {
+ const payloadSize = 2048
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv4096(b *testing.B) {
+ const payloadSize = 4096
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecvLarge(b *testing.B) {
+ const sendSize = 512 * 1024 * 1024 //512 MB
+ const recvSize = 4 * 1024 //4 KB
+ benchmarkSendRecv(b, sendSize, recvSize)
+}
+
+func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ sendBuf := make([]byte, sendSize)
+ recvBuf := make([]byte, recvSize)
+ doneCh := make(chan struct{})
+
+ b.SetBytes(int64(sendSize))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ go func() {
+ defer close(doneCh)
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+
+ switch {
+ case sendSize == recvSize:
+ for i := 0; i < b.N; i++ {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+
+ case recvSize > sendSize:
+ b.Fatalf("bad test case; recvSize was: %d and sendSize was: %d, but recvSize must be <= sendSize!", recvSize, sendSize)
+
+ default:
+ chunks := sendSize / recvSize
+ for i := 0; i < b.N; i++ {
+ for j := 0; j < chunks; j++ {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ }
+ }
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ for i := 0; i < b.N; i++ {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+ <-doneCh
+}
+
+func BenchmarkSendRecvParallel32(b *testing.B) {
+ const payloadSize = 32
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel64(b *testing.B) {
+ const payloadSize = 64
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel128(b *testing.B) {
+ const payloadSize = 128
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel256(b *testing.B) {
+ const payloadSize = 256
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel512(b *testing.B) {
+ const payloadSize = 512
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel1024(b *testing.B) {
+ const payloadSize = 1024
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel2048(b *testing.B) {
+ const payloadSize = 2048
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel4096(b *testing.B) {
+ const payloadSize = 4096
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func benchmarkSendRecvParallel(b *testing.B, sendSize int) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ sendBuf := make([]byte, sendSize)
+ discarder := ioutil.Discard.(io.ReaderFrom)
+ b.SetBytes(int64(sendSize))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ doneCh := make(chan struct{})
+
+ go func() {
+ defer close(doneCh)
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+
+ if _, err := discarder.ReadFrom(stream); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf("err: %v", err)
+ }
+
+ for pb.Next() {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf("err: %v", err)
+ }
+ }
+
+ stream.Close()
+ <-doneCh
+ })
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+const.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "encoding/binary"
+ "fmt"
+)
+
+var (
+ // ErrInvalidVersion means we received a frame with an
+ // invalid version
+ ErrInvalidVersion = fmt.Errorf("invalid protocol version")
+
+ // ErrInvalidMsgType means we received a frame with an
+ // invalid message type
+ ErrInvalidMsgType = fmt.Errorf("invalid msg type")
+
+ // ErrSessionShutdown is used if there is a shutdown during
+ // an operation
+ ErrSessionShutdown = fmt.Errorf("session shutdown")
+
+ // ErrStreamsExhausted is returned if we have no more
+ // stream ids to issue
+ ErrStreamsExhausted = fmt.Errorf("streams exhausted")
+
+ // ErrDuplicateStream is used if a duplicate stream is
+ // opened inbound
+ ErrDuplicateStream = fmt.Errorf("duplicate stream initiated")
+
+ // ErrReceiveWindowExceeded indicates the window was exceeded
+ ErrRecvWindowExceeded = fmt.Errorf("recv window exceeded")
+
+ // ErrTimeout is used when we reach an IO deadline
+ ErrTimeout = fmt.Errorf("i/o deadline reached")
+
+ // ErrStreamClosed is returned when using a closed stream
+ ErrStreamClosed = fmt.Errorf("stream closed")
+
+ // ErrUnexpectedFlag is set when we get an unexpected flag
+ ErrUnexpectedFlag = fmt.Errorf("unexpected flag")
+
+ // ErrRemoteGoAway is used when we get a go away from the other side
+ ErrRemoteGoAway = fmt.Errorf("remote end is not accepting connections")
+
+ // ErrConnectionReset is sent if a stream is reset. This can happen
+ // if the backlog is exceeded, or if there was a remote GoAway.
+ ErrConnectionReset = fmt.Errorf("connection reset")
+
+ // ErrConnectionWriteTimeout indicates that we hit the "safety valve"
+ // timeout writing to the underlying stream connection.
+ ErrConnectionWriteTimeout = fmt.Errorf("connection write timeout")
+
+ // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
+ ErrKeepAliveTimeout = fmt.Errorf("keepalive timeout")
+)
+
+const (
+ // protoVersion is the only version we support
+ protoVersion uint8 = 0
+)
+
+const (
+ // Data is used for data frames. They are followed
+ // by length bytes worth of payload.
+ typeData uint8 = iota
+
+ // WindowUpdate is used to change the window of
+ // a given stream. The length indicates the delta
+ // update to the window.
+ typeWindowUpdate
+
+ // Ping is sent as a keep-alive or to measure
+ // the RTT. The StreamID and Length value are echoed
+ // back in the response.
+ typePing
+
+ // GoAway is sent to terminate a session. The StreamID
+ // should be 0 and the length is an error code.
+ typeGoAway
+)
+
+const (
+ // SYN is sent to signal a new stream. May
+ // be sent with a data payload
+ flagSYN uint16 = 1 << iota
+
+ // ACK is sent to acknowledge a new stream. May
+ // be sent with a data payload
+ flagACK
+
+ // FIN is sent to half-close the given stream.
+ // May be sent with a data payload.
+ flagFIN
+
+ // RST is used to hard close a given stream.
+ flagRST
+)
+
+const (
+ // initialStreamWindow is the initial stream window size
+ initialStreamWindow uint32 = 256 * 1024
+)
+
+const (
+ // goAwayNormal is sent on a normal termination
+ goAwayNormal uint32 = iota
+
+ // goAwayProtoErr sent on a protocol error
+ goAwayProtoErr
+
+ // goAwayInternalErr sent on an internal error
+ goAwayInternalErr
+)
+
+const (
+ sizeOfVersion = 1
+ sizeOfType = 1
+ sizeOfFlags = 2
+ sizeOfStreamID = 4
+ sizeOfLength = 4
+ headerSize = sizeOfVersion + sizeOfType + sizeOfFlags +
+ sizeOfStreamID + sizeOfLength
+)
+
+type header []byte
+
+func (h header) Version() uint8 {
+ return h[0]
+}
+
+func (h header) MsgType() uint8 {
+ return h[1]
+}
+
+func (h header) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[2:4])
+}
+
+func (h header) StreamID() uint32 {
+ return binary.BigEndian.Uint32(h[4:8])
+}
+
+func (h header) Length() uint32 {
+ return binary.BigEndian.Uint32(h[8:12])
+}
+
+func (h header) String() string {
+ return fmt.Sprintf("Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d",
+ h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
+}
+
+func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
+ h[0] = protoVersion
+ h[1] = msgType
+ binary.BigEndian.PutUint16(h[2:4], flags)
+ binary.BigEndian.PutUint32(h[4:8], streamID)
+ binary.BigEndian.PutUint32(h[8:12], length)
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+const_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "testing"
+)
+
+func TestConst(t *testing.T) {
+ if protoVersion != 0 {
+ t.Fatalf("bad: %v", protoVersion)
+ }
+
+ if typeData != 0 {
+ t.Fatalf("bad: %v", typeData)
+ }
+ if typeWindowUpdate != 1 {
+ t.Fatalf("bad: %v", typeWindowUpdate)
+ }
+ if typePing != 2 {
+ t.Fatalf("bad: %v", typePing)
+ }
+ if typeGoAway != 3 {
+ t.Fatalf("bad: %v", typeGoAway)
+ }
+
+ if flagSYN != 1 {
+ t.Fatalf("bad: %v", flagSYN)
+ }
+ if flagACK != 2 {
+ t.Fatalf("bad: %v", flagACK)
+ }
+ if flagFIN != 4 {
+ t.Fatalf("bad: %v", flagFIN)
+ }
+ if flagRST != 8 {
+ t.Fatalf("bad: %v", flagRST)
+ }
+
+ if goAwayNormal != 0 {
+ t.Fatalf("bad: %v", goAwayNormal)
+ }
+ if goAwayProtoErr != 1 {
+ t.Fatalf("bad: %v", goAwayProtoErr)
+ }
+ if goAwayInternalErr != 2 {
+ t.Fatalf("bad: %v", goAwayInternalErr)
+ }
+
+ if headerSize != 12 {
+ t.Fatalf("bad header size")
+ }
+}
+
+func TestEncodeDecode(t *testing.T) {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321)
+
+ if hdr.Version() != protoVersion {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.MsgType() != typeWindowUpdate {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.Flags() != flagACK|flagRST {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.StreamID() != 1234 {
+ t.Fatalf("bad: %v", hdr)
+ }
+ if hdr.Length() != 4321 {
+ t.Fatalf("bad: %v", hdr)
+ }
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+go.mod - github.com/hashicorp/yamux
+module github.com/hashicorp/yamux
+
+go 1.15
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+mux.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "time"
+)
+
+// Config is used to tune the Yamux session
+type Config struct {
+ // AcceptBacklog is used to limit how many streams may be
+ // waiting an accept.
+ AcceptBacklog int
+
+ // EnableKeepalive is used to do a period keep alive
+ // messages using a ping.
+ EnableKeepAlive bool
+
+ // KeepAliveInterval is how often to perform the keep alive
+ KeepAliveInterval time.Duration
+
+ // ConnectionWriteTimeout is meant to be a "safety valve" timeout after
+ // we which will suspect a problem with the underlying connection and
+ // close it. This is only applied to writes, where's there's generally
+ // an expectation that things will move along quickly.
+ ConnectionWriteTimeout time.Duration
+
+ // MaxStreamWindowSize is used to control the maximum
+ // window size that we allow for a stream.
+ MaxStreamWindowSize uint32
+
+ // StreamCloseTimeout is the maximum time that a stream will allowed to
+ // be in a half-closed state when `Close` is called before forcibly
+ // closing the connection. Forcibly closed connections will empty the
+ // receive buffer, drop any future packets received for that stream,
+ // and send a RST to the remote side.
+ StreamCloseTimeout time.Duration
+
+ // LogOutput is used to control the log destination. Either Logger or
+ // LogOutput can be set, not both.
+ LogOutput io.Writer
+
+ // Logger is used to pass in the logger to be used. Either Logger or
+ // LogOutput can be set, not both.
+ Logger *log.Logger
+}
+
+// DefaultConfig is used to return a default configuration
+func DefaultConfig() *Config {
+ return &Config{
+ AcceptBacklog: 256,
+ EnableKeepAlive: true,
+ KeepAliveInterval: 30 * time.Second,
+ ConnectionWriteTimeout: 10 * time.Second,
+ MaxStreamWindowSize: initialStreamWindow,
+ StreamCloseTimeout: 5 * time.Minute,
+ LogOutput: os.Stderr,
+ }
+}
+
+// VerifyConfig is used to verify the sanity of configuration
+func VerifyConfig(config *Config) error {
+ if config.AcceptBacklog <= 0 {
+ return fmt.Errorf("backlog must be positive")
+ }
+ if config.KeepAliveInterval == 0 {
+ return fmt.Errorf("keep-alive interval must be positive")
+ }
+ if config.MaxStreamWindowSize < initialStreamWindow {
+ return fmt.Errorf("MaxStreamWindowSize must be larger than %d", initialStreamWindow)
+ }
+ if config.LogOutput != nil && config.Logger != nil {
+ return fmt.Errorf("both Logger and LogOutput may not be set, select one")
+ } else if config.LogOutput == nil && config.Logger == nil {
+ return fmt.Errorf("one of Logger or LogOutput must be set, select one")
+ }
+ return nil
+}
+
+// Server is used to initialize a new server-side connection.
+// There must be at most one server-side connection. If a nil config is
+// provided, the DefaultConfiguration will be used.
+func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, false), nil
+}
+
+// Client is used to initialize a new client-side connection.
+// There must be at most one client-side connection.
+func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, true), nil
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+session.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "math"
+ "net"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// Session is used to wrap a reliable ordered connection and to
+// multiplex it into multiple streams.
+type Session struct {
+ // remoteGoAway indicates the remote side does
+ // not want futher connections. Must be first for alignment.
+ remoteGoAway int32
+
+ // localGoAway indicates that we should stop
+ // accepting futher connections. Must be first for alignment.
+ localGoAway int32
+
+ // nextStreamID is the next stream we should
+ // send. This depends if we are a client/server.
+ nextStreamID uint32
+
+ // config holds our configuration
+ config *Config
+
+ // logger is used for our logs
+ logger *log.Logger
+
+ // conn is the underlying connection
+ conn io.ReadWriteCloser
+
+ // bufRead is a buffered reader
+ bufRead *bufio.Reader
+
+ // pings is used to track inflight pings
+ pings map[uint32]chan struct{}
+ pingID uint32
+ pingLock sync.Mutex
+
+ // streams maps a stream id to a stream, and inflight has an entry
+ // for any outgoing stream that has not yet been established. Both are
+ // protected by streamLock.
+ streams map[uint32]*Stream
+ inflight map[uint32]struct{}
+ streamLock sync.Mutex
+
+ // synCh acts like a semaphore. It is sized to the AcceptBacklog which
+ // is assumed to be symmetric between the client and server. This allows
+ // the client to avoid exceeding the backlog and instead blocks the open.
+ synCh chan struct{}
+
+ // acceptCh is used to pass ready streams to the client
+ acceptCh chan *Stream
+
+ // sendCh is used to mark a stream as ready to send,
+ // or to send a header out directly.
+ sendCh chan sendReady
+
+ // recvDoneCh is closed when recv() exits to avoid a race
+ // between stream registration and stream shutdown
+ recvDoneCh chan struct{}
+
+ // shutdown is used to safely close a session
+ shutdown bool
+ shutdownErr error
+ shutdownCh chan struct{}
+ shutdownLock sync.Mutex
+}
+
+// sendReady is used to either mark a stream as ready
+// or to directly send a header
+type sendReady struct {
+ Hdr []byte
+ Body io.Reader
+ Err chan error
+}
+
+// newSession is used to construct a new session
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+ logger := config.Logger
+ if logger == nil {
+ logger = log.New(config.LogOutput, "", log.LstdFlags)
+ }
+
+ s := &Session{
+ config: config,
+ logger: logger,
+ conn: conn,
+ bufRead: bufio.NewReader(conn),
+ pings: make(map[uint32]chan struct{}),
+ streams: make(map[uint32]*Stream),
+ inflight: make(map[uint32]struct{}),
+ synCh: make(chan struct{}, config.AcceptBacklog),
+ acceptCh: make(chan *Stream, config.AcceptBacklog),
+ sendCh: make(chan sendReady, 64),
+ recvDoneCh: make(chan struct{}),
+ shutdownCh: make(chan struct{}),
+ }
+ if client {
+ s.nextStreamID = 1
+ } else {
+ s.nextStreamID = 2
+ }
+ go s.recv()
+ go s.send()
+ if config.EnableKeepAlive {
+ go s.keepalive()
+ }
+ return s
+}
+
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
+ select {
+ case <-s.shutdownCh:
+ return true
+ default:
+ return false
+ }
+}
+
+// CloseChan returns a read-only channel which is closed as
+// soon as the session is closed.
+func (s *Session) CloseChan() <-chan struct{} {
+ return s.shutdownCh
+}
+
+// NumStreams returns the number of currently open streams
+func (s *Session) NumStreams() int {
+ s.streamLock.Lock()
+ num := len(s.streams)
+ s.streamLock.Unlock()
+ return num
+}
+
+// Open is used to create a new stream as a net.Conn
+func (s *Session) Open() (net.Conn, error) {
+ conn, err := s.OpenStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+// OpenStream is used to create a new stream
+func (s *Session) OpenStream() (*Stream, error) {
+ if s.IsClosed() {
+ return nil, ErrSessionShutdown
+ }
+ if atomic.LoadInt32(&s.remoteGoAway) == 1 {
+ return nil, ErrRemoteGoAway
+ }
+
+ // Block if we have too many inflight SYNs
+ select {
+ case s.synCh <- struct{}{}:
+ case <-s.shutdownCh:
+ return nil, ErrSessionShutdown
+ }
+
+GET_ID:
+ // Get an ID, and check for stream exhaustion
+ id := atomic.LoadUint32(&s.nextStreamID)
+ if id >= math.MaxUint32-1 {
+ return nil, ErrStreamsExhausted
+ }
+ if !atomic.CompareAndSwapUint32(&s.nextStreamID, id, id+2) {
+ goto GET_ID
+ }
+
+ // Register the stream
+ stream := newStream(s, id, streamInit)
+ s.streamLock.Lock()
+ s.streams[id] = stream
+ s.inflight[id] = struct{}{}
+ s.streamLock.Unlock()
+
+ // Send the window update to create
+ if err := stream.sendWindowUpdate(); err != nil {
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
+ }
+ return nil, err
+ }
+ return stream, nil
+}
+
+// Accept is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) Accept() (net.Conn, error) {
+ conn, err := s.AcceptStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, err
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+ select {
+ case stream := <-s.acceptCh:
+ if err := stream.sendWindowUpdate(); err != nil {
+ return nil, err
+ }
+ return stream, nil
+ case <-s.shutdownCh:
+ return nil, s.shutdownErr
+ }
+}
+
+// Close is used to close the session and all streams.
+// Attempts to send a GoAway before closing the connection.
+func (s *Session) Close() error {
+ s.shutdownLock.Lock()
+ defer s.shutdownLock.Unlock()
+
+ if s.shutdown {
+ return nil
+ }
+ s.shutdown = true
+ if s.shutdownErr == nil {
+ s.shutdownErr = ErrSessionShutdown
+ }
+ close(s.shutdownCh)
+ s.conn.Close()
+ <-s.recvDoneCh
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ for _, stream := range s.streams {
+ stream.forceClose()
+ }
+ return nil
+}
+
+// exitErr is used to handle an error that is causing the
+// session to terminate.
+func (s *Session) exitErr(err error) {
+ s.shutdownLock.Lock()
+ if s.shutdownErr == nil {
+ s.shutdownErr = err
+ }
+ s.shutdownLock.Unlock()
+ s.Close()
+}
+
+// GoAway can be used to prevent accepting further
+// connections. It does not close the underlying conn.
+func (s *Session) GoAway() error {
+ return s.waitForSend(s.goAway(goAwayNormal), nil)
+}
+
+// goAway is used to send a goAway message
+func (s *Session) goAway(reason uint32) header {
+ atomic.SwapInt32(&s.localGoAway, 1)
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeGoAway, 0, 0, reason)
+ return hdr
+}
+
+// Ping is used to measure the RTT response time
+func (s *Session) Ping() (time.Duration, error) {
+ // Get a channel for the ping
+ ch := make(chan struct{})
+
+ // Get a new ping id, mark as pending
+ s.pingLock.Lock()
+ id := s.pingID
+ s.pingID++
+ s.pings[id] = ch
+ s.pingLock.Unlock()
+
+ // Send the ping request
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagSYN, 0, id)
+ if err := s.waitForSend(hdr, nil); err != nil {
+ return 0, err
+ }
+
+ // Wait for a response
+ start := time.Now()
+ select {
+ case <-ch:
+ case <-time.After(s.config.ConnectionWriteTimeout):
+ s.pingLock.Lock()
+ delete(s.pings, id) // Ignore it if a response comes later.
+ s.pingLock.Unlock()
+ return 0, ErrTimeout
+ case <-s.shutdownCh:
+ return 0, ErrSessionShutdown
+ }
+
+ // Compute the RTT
+ return time.Now().Sub(start), nil
+}
+
+// keepalive is a long running goroutine that periodically does
+// a ping to keep the connection alive.
+func (s *Session) keepalive() {
+ for {
+ select {
+ case <-time.After(s.config.KeepAliveInterval):
+ _, err := s.Ping()
+ if err != nil {
+ if err != ErrSessionShutdown {
+ s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
+ s.exitErr(ErrKeepAliveTimeout)
+ }
+ return
+ }
+ case <-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// waitForSendErr waits to send a header, checking for a potential shutdown
+func (s *Session) waitForSend(hdr header, body io.Reader) error {
+ errCh := make(chan error, 1)
+ return s.waitForSendErr(hdr, body, errCh)
+}
+
+// waitForSendErr waits to send a header with optional data, checking for a
+// potential shutdown. Since there's the expectation that sends can happen
+// in a timely manner, we enforce the connection write timeout here.
+func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case <-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
+ select {
+ case s.sendCh <- ready:
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+
+ select {
+ case err := <-errCh:
+ return err
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// sendNoWait does a send without waiting. Since there's the expectation that
+// the send happens right here, we enforce the connection write timeout if we
+// can't queue the header to be sent.
+func (s *Session) sendNoWait(hdr header) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case <-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ select {
+ case s.sendCh <- sendReady{Hdr: hdr}:
+ return nil
+ case <-s.shutdownCh:
+ return ErrSessionShutdown
+ case <-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// send is a long running goroutine that sends data
+func (s *Session) send() {
+ for {
+ select {
+ case ready := <-s.sendCh:
+ // Send a header if ready
+ if ready.Hdr != nil {
+ sent := 0
+ for sent < len(ready.Hdr) {
+ n, err := s.conn.Write(ready.Hdr[sent:])
+ if err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to write header: %v", err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ sent += n
+ }
+ }
+
+ // Send data from a body if given
+ if ready.Body != nil {
+ _, err := io.Copy(s.conn, ready.Body)
+ if err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to write body: %v", err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ }
+
+ // No error, successful send
+ asyncSendErr(ready.Err, nil)
+ case <-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// recv is a long running goroutine that accepts new data
+func (s *Session) recv() {
+ if err := s.recvLoop(); err != nil {
+ s.exitErr(err)
+ }
+}
+
+// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
+var (
+ handlers = []func(*Session, header) error{
+ typeData: (*Session).handleStreamMessage,
+ typeWindowUpdate: (*Session).handleStreamMessage,
+ typePing: (*Session).handlePing,
+ typeGoAway: (*Session).handleGoAway,
+ }
+)
+
+// recvLoop continues to receive data until a fatal error is encountered
+func (s *Session) recvLoop() error {
+ defer close(s.recvDoneCh)
+ hdr := header(make([]byte, headerSize))
+ for {
+ // Read the header
+ if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
+ if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
+ s.logger.Printf("[ERR] yamux: Failed to read header: %v", err)
+ }
+ return err
+ }
+
+ // Verify the version
+ if hdr.Version() != protoVersion {
+ s.logger.Printf("[ERR] yamux: Invalid protocol version: %d", hdr.Version())
+ return ErrInvalidVersion
+ }
+
+ mt := hdr.MsgType()
+ if mt < typeData || mt > typeGoAway {
+ return ErrInvalidMsgType
+ }
+
+ if err := handlers[mt](s, hdr); err != nil {
+ return err
+ }
+ }
+}
+
+// handleStreamMessage handles either a data or window update frame
+func (s *Session) handleStreamMessage(hdr header) error {
+ // Check for a new stream creation
+ id := hdr.StreamID()
+ flags := hdr.Flags()
+ if flags&flagSYN == flagSYN {
+ if err := s.incomingStream(id); err != nil {
+ return err
+ }
+ }
+
+ // Get the stream
+ s.streamLock.Lock()
+ stream := s.streams[id]
+ s.streamLock.Unlock()
+
+ // If we do not have a stream, likely we sent a RST
+ if stream == nil {
+ // Drain any data on the wire
+ if hdr.MsgType() == typeData && hdr.Length() > 0 {
+ s.logger.Printf("[WARN] yamux: Discarding data for stream: %d", id)
+ if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
+ s.logger.Printf("[ERR] yamux: Failed to discard data: %v", err)
+ return nil
+ }
+ } else {
+ s.logger.Printf("[WARN] yamux: frame for missing stream: %v", hdr)
+ }
+ return nil
+ }
+
+ // Check if this is a window update
+ if hdr.MsgType() == typeWindowUpdate {
+ if err := stream.incrSendWindow(hdr, flags); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return err
+ }
+ return nil
+ }
+
+ // Read the new data
+ if err := stream.readData(hdr, flags, s.bufRead); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return err
+ }
+ return nil
+}
+
+// handlePing is invokde for a typePing frame
+func (s *Session) handlePing(hdr header) error {
+ flags := hdr.Flags()
+ pingID := hdr.Length()
+
+ // Check if this is a query, respond back in a separate context so we
+ // don't interfere with the receiving thread blocking for the write.
+ if flags&flagSYN == flagSYN {
+ go func() {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, pingID)
+ if err := s.sendNoWait(hdr); err != nil {
+ s.logger.Printf("[WARN] yamux: failed to send ping reply: %v", err)
+ }
+ }()
+ return nil
+ }
+
+ // Handle a response
+ s.pingLock.Lock()
+ ch := s.pings[pingID]
+ if ch != nil {
+ delete(s.pings, pingID)
+ close(ch)
+ }
+ s.pingLock.Unlock()
+ return nil
+}
+
+// handleGoAway is invokde for a typeGoAway frame
+func (s *Session) handleGoAway(hdr header) error {
+ code := hdr.Length()
+ switch code {
+ case goAwayNormal:
+ atomic.SwapInt32(&s.remoteGoAway, 1)
+ case goAwayProtoErr:
+ s.logger.Printf("[ERR] yamux: received protocol error go away")
+ return fmt.Errorf("yamux protocol error")
+ case goAwayInternalErr:
+ s.logger.Printf("[ERR] yamux: received internal error go away")
+ return fmt.Errorf("remote yamux internal error")
+ default:
+ s.logger.Printf("[ERR] yamux: received unexpected go away")
+ return fmt.Errorf("unexpected go away received")
+ }
+ return nil
+}
+
+// incomingStream is used to create a new incoming stream
+func (s *Session) incomingStream(id uint32) error {
+ // Reject immediately if we are doing a go away
+ if atomic.LoadInt32(&s.localGoAway) == 1 {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(hdr)
+ }
+
+ // Allocate a new stream
+ stream := newStream(s, id, streamSYNReceived)
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+
+ // Check if stream already exists
+ if _, ok := s.streams[id]; ok {
+ s.logger.Printf("[ERR] yamux: duplicate stream declared")
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
+ }
+ return ErrDuplicateStream
+ }
+
+ // Register the stream
+ s.streams[id] = stream
+
+ // Check if we've exceeded the backlog
+ select {
+ case s.acceptCh <- stream:
+ return nil
+ default:
+ // Backlog exceeded! RST the stream
+ s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
+ delete(s.streams, id)
+ stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(stream.sendHdr)
+ }
+}
+
+// closeStream is used to close a stream once both sides have
+// issued a close. If there was an in-flight SYN and the stream
+// was not yet established, then this will give the credit back.
+func (s *Session) closeStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
+ }
+ }
+ delete(s.streams, id)
+ s.streamLock.Unlock()
+}
+
+// establishStream is used to mark a stream that was in the
+// SYN Sent state as established.
+func (s *Session) establishStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ delete(s.inflight, id)
+ } else {
+ s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
+ }
+ select {
+ case <-s.synCh:
+ default:
+ s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
+ }
+ s.streamLock.Unlock()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+session_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+type logCapture struct{ bytes.Buffer }
+
+func (l *logCapture) logs() []string {
+ return strings.Split(strings.TrimSpace(l.String()), "\n")
+}
+
+func (l *logCapture) match(expect []string) bool {
+ return reflect.DeepEqual(l.logs(), expect)
+}
+
+func captureLogs(s *Session) *logCapture {
+ buf := new(logCapture)
+ s.logger = log.New(buf, "", 0)
+ return buf
+}
+
+type pipeConn struct {
+ reader *io.PipeReader
+ writer *io.PipeWriter
+ writeBlocker sync.Mutex
+}
+
+func (p *pipeConn) Read(b []byte) (int, error) {
+ return p.reader.Read(b)
+}
+
+func (p *pipeConn) Write(b []byte) (int, error) {
+ p.writeBlocker.Lock()
+ defer p.writeBlocker.Unlock()
+ return p.writer.Write(b)
+}
+
+func (p *pipeConn) Close() error {
+ p.reader.Close()
+ return p.writer.Close()
+}
+
+func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
+ read1, write1 := io.Pipe()
+ read2, write2 := io.Pipe()
+ conn1 := &pipeConn{reader: read1, writer: write2}
+ conn2 := &pipeConn{reader: read2, writer: write1}
+ return conn1, conn2
+}
+
+func testConf() *Config {
+ conf := DefaultConfig()
+ conf.AcceptBacklog = 64
+ conf.KeepAliveInterval = 100 * time.Millisecond
+ conf.ConnectionWriteTimeout = 250 * time.Millisecond
+ return conf
+}
+
+func testConfNoKeepAlive() *Config {
+ conf := testConf()
+ conf.EnableKeepAlive = false
+ return conf
+}
+
+func testClientServer() (*Session, *Session) {
+ return testClientServerConfig(testConf())
+}
+
+func testClientServerConfig(conf *Config) (*Session, *Session) {
+ conn1, conn2 := testConn()
+ client, _ := Client(conn1, conf)
+ server, _ := Server(conn2, conf)
+ return client, server
+}
+
+func TestPing(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ rtt, err := client.Ping()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ t.Fatalf("bad: %v", rtt)
+ }
+
+ rtt, err = server.Ping()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if rtt == 0 {
+ t.Fatalf("bad: %v", rtt)
+ }
+}
+
+func TestPing_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh <- err
+ }()
+
+ select {
+ case err := <-errCh:
+ if err != ErrTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ case <-time.After(client.config.ConnectionWriteTimeout * 2):
+ t.Fatalf("failed to timeout within expected %v", client.config.ConnectionWriteTimeout)
+ }
+
+ // Verify that we recover, even if we gave up
+ clientConn.writeBlocker.Unlock()
+
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh <- err
+ }()
+
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ case <-time.After(client.config.ConnectionWriteTimeout):
+ t.Fatalf("timeout")
+ }
+}
+
+func TestCloseBeforeAck(t *testing.T) {
+ cfg := testConf()
+ cfg.AcceptBacklog = 8
+ client, server := testClientServerConfig(cfg)
+
+ defer client.Close()
+ defer server.Close()
+
+ for i := 0; i < 8; i++ {
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ for i := 0; i < 8; i++ {
+ s, err := server.AcceptStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(time.Second * 5):
+ t.Fatal("timed out trying to open stream")
+ }
+}
+
+func TestAccept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if client.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(4)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf("bad: %v", id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+
+ select {
+ case <-doneCh:
+ case <-time.After(time.Second):
+ panic("timeout")
+ }
+}
+
+func TestClose_closeTimeout(t *testing.T) {
+ conf := testConf()
+ conf.StreamCloseTimeout = 10 * time.Millisecond
+ client, server := testClientServerConfig(conf)
+ defer client.Close()
+ defer server.Close()
+
+ if client.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ // Open a stream on the client but only close it on the server.
+ // We want to see if the stream ever gets cleaned up on the client.
+
+ var clientStream *Stream
+ go func() {
+ defer wg.Done()
+ var err error
+ clientStream, err = client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+
+ select {
+ case <-doneCh:
+ case <-time.After(time.Second):
+ panic("timeout")
+ }
+
+ // We should have zero streams after our timeout period
+ time.Sleep(100 * time.Millisecond)
+
+ if v := server.NumStreams(); v > 0 {
+ t.Fatalf("should have zero streams: %d", v)
+ }
+ if v := client.NumStreams(); v > 0 {
+ t.Fatalf("should have zero streams: %d", v)
+ }
+
+ if _, err := clientStream.Write([]byte("hello")); err == nil {
+ t.Fatal("should error on write")
+ } else if err.Error() != "connection reset" {
+ t.Fatalf("expected connection reset, got %q", err)
+ }
+}
+
+func TestNonNilInterface(t *testing.T) {
+ _, server := testClientServer()
+ server.Close()
+
+ conn, err := server.Accept()
+ if err != nil && conn != nil {
+ t.Error("bad: accept should return a connection of nil value")
+ }
+
+ conn, err = server.Open()
+ if err != nil && conn != nil {
+ t.Error("bad: open should return a connection of nil value")
+ }
+}
+
+func TestSendData_Small(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if server.NumStreams() != 1 {
+ t.Fatalf("bad")
+ }
+
+ buf := make([]byte, 4)
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short read: %d", n)
+ }
+ if string(buf) != "test" {
+ t.Fatalf("bad: %s", buf)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ if client.NumStreams() != 1 {
+ t.Fatalf("bad")
+ }
+
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Write([]byte("test"))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short write %d", n)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(time.Second):
+ panic("timeout")
+ }
+
+ if client.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf("bad")
+ }
+}
+
+func TestSendData_Large(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ const (
+ sendSize = 250 * 1024 * 1024
+ recvSize = 4 * 1024
+ )
+
+ data := make([]byte, sendSize)
+ for idx := range data {
+ data[idx] = byte(idx % 256)
+ }
+
+ wg := &sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ var sz int
+ buf := make([]byte, recvSize)
+ for i := 0; i < sendSize/recvSize; i++ {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != recvSize {
+ t.Fatalf("short read: %d", n)
+ }
+ sz += n
+ for idx := range buf {
+ if buf[idx] != byte(idx%256) {
+ t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
+ }
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ n, err := stream.Write(data)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(data) {
+ t.Fatalf("short write %d", n)
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(5 * time.Second):
+ panic("timeout")
+ }
+}
+
+func TestGoAway(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if err := server.GoAway(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ _, err := client.Open()
+ if err != ErrRemoteGoAway {
+ t.Fatalf("err: %v", err)
+ }
+}
+
+func TestManyStreams(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 512)
+ for {
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n == 0 {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ msg := fmt.Sprintf("%08d", i)
+ for i := 0; i < 1000; i++ {
+ n, err := stream.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(msg) {
+ t.Fatalf("short write %d", n)
+ }
+ }
+ }
+
+ for i := 0; i < 50; i++ {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestManyStreams_PingPong(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &sync.WaitGroup{}
+
+ ping := []byte("ping")
+ pong := []byte("pong")
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for {
+ // Read the 'ping'
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, ping) {
+ t.Fatalf("bad: %s", buf)
+ }
+
+ // Shrink the internal buffer!
+ stream.Shrink()
+
+ // Write out the 'pong'
+ n, err = stream.Write(pong)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for i := 0; i < 1000; i++ {
+ // Send the 'ping'
+ n, err := stream.Write(ping)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("short write %d", n)
+ }
+
+ // Read the 'pong'
+ n, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 4 {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, pong) {
+ t.Fatalf("bad: %s", buf)
+ }
+
+ // Shrink the buffer
+ stream.Shrink()
+ }
+ }
+
+ for i := 0; i < 50; i++ {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestHalfClose(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if _, err = stream.Write([]byte("a")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ stream2.Close() // Half close
+
+ buf := make([]byte, 4)
+ n, err := stream2.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 1 {
+ t.Fatalf("bad: %v", n)
+ }
+
+ // Send more
+ if _, err = stream.Write([]byte("bcd")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ stream.Close()
+
+ // Read after close
+ n, err = stream2.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 3 {
+ t.Fatalf("bad: %v", n)
+ }
+
+ // EOF after close
+ n, err = stream2.Read(buf)
+ if err != io.EOF {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("bad: %v", n)
+ }
+}
+
+func TestReadDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ buf := make([]byte, 4)
+ if _, err := stream.Read(buf); err != ErrTimeout {
+ t.Fatalf("err: %v", err)
+ }
+}
+
+func TestReadDeadline_BlockedRead(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ // Start a read that will block
+ errCh := make(chan error, 1)
+ go func() {
+ buf := make([]byte, 4)
+ _, err := stream.Read(buf)
+ errCh <- err
+ close(errCh)
+ }()
+
+ // Wait to ensure the read has started.
+ time.Sleep(5 * time.Millisecond)
+
+ // Update the read deadline
+ if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ select {
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("expected read timeout")
+ case err := <-errCh:
+ if err != ErrTimeout {
+ t.Fatalf("expected ErrTimeout; got %v", err)
+ }
+ }
+}
+
+func TestWriteDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ buf := make([]byte, 512)
+ for i := 0; i < int(initialStreamWindow); i++ {
+ _, err := stream.Write(buf)
+ if err != nil && err == ErrTimeout {
+ return
+ } else if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ t.Fatalf("Expected timeout")
+}
+
+func TestWriteDeadline_BlockedWrite(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ // Start a goroutine making writes that will block
+ errCh := make(chan error, 1)
+ go func() {
+ buf := make([]byte, 512)
+ for i := 0; i < int(initialStreamWindow); i++ {
+ _, err := stream.Write(buf)
+ if err == nil {
+ continue
+ }
+
+ errCh <- err
+ close(errCh)
+ return
+ }
+
+ close(errCh)
+ }()
+
+ // Wait to ensure the write has started.
+ time.Sleep(5 * time.Millisecond)
+
+ // Update the write deadline
+ if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ select {
+ case <-time.After(1 * time.Second):
+ t.Fatal("expected write timeout")
+ case err := <-errCh:
+ if err != ErrTimeout {
+ t.Fatalf("expected ErrTimeout; got %v", err)
+ }
+ }
+}
+
+func TestBacklogExceeded(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ // Fill the backlog
+ max := client.config.AcceptBacklog
+ for i := 0; i < max; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte("foo")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ // Attempt to open a new stream
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := client.Open()
+ errCh <- err
+ }()
+
+ // Shutdown the server
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ server.Close()
+ }()
+
+ select {
+ case err := <-errCh:
+ if err == nil {
+ t.Fatalf("open should fail")
+ }
+ case <-time.After(time.Second):
+ t.Fatalf("timeout")
+ }
+}
+
+func TestKeepAlive(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ time.Sleep(200 * time.Millisecond)
+
+ // Ping value should increase
+ client.pingLock.Lock()
+ defer client.pingLock.Unlock()
+ if client.pingID == 0 {
+ t.Fatalf("should ping")
+ }
+
+ server.pingLock.Lock()
+ defer server.pingLock.Unlock()
+ if server.pingID == 0 {
+ t.Fatalf("should ping")
+ }
+}
+
+func TestKeepAlive_Timeout(t *testing.T) {
+ conn1, conn2 := testConn()
+
+ clientConf := testConf()
+ clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
+ clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
+ client, _ := Client(conn1, clientConf)
+ defer client.Close()
+
+ server, _ := Server(conn2, testConf())
+ defer server.Close()
+
+ _ = captureLogs(client) // Client logs aren't part of the test
+ serverLogs := captureLogs(server)
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Accept() // Wait until server closes
+ errCh <- err
+ }()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ select {
+ case err := <-errCh:
+ if err != ErrKeepAliveTimeout {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("timeout waiting for timeout")
+ }
+
+ if !server.IsClosed() {
+ t.Fatalf("server should have closed")
+ }
+
+ if !serverLogs.match([]string{"[ERR] yamux: keepalive failed: i/o deadline reached"}) {
+ t.Fatalf("server log incorect: %v", serverLogs.logs())
+ }
+}
+
+func TestLargeWindow(t *testing.T) {
+ conf := DefaultConfig()
+ conf.MaxStreamWindowSize *= 2
+
+ client, server := testClientServerConfig(conf)
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream2.Close()
+
+ stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
+ buf := make([]byte, conf.MaxStreamWindowSize)
+ n, err := stream.Write(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if n != len(buf) {
+ t.Fatalf("short write: %d", n)
+ }
+}
+
+type UnlimitedReader struct{}
+
+func (u *UnlimitedReader) Read(p []byte) (int, error) {
+ runtime.Gosched()
+ return len(p), nil
+}
+
+func TestSendData_VeryLarge(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ var n int64 = 1 * 1024 * 1024 * 1024
+ var workers int = 16
+
+ wg := &sync.WaitGroup{}
+ wg.Add(workers * 2)
+
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ _, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
+ t.Fatalf("bad header")
+ }
+
+ recv, err := io.Copy(ioutil.Discard, stream)
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if recv != n {
+ t.Fatalf("bad: %v", recv)
+ }
+ }()
+ }
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ _, err = stream.Write([]byte{0, 1, 2, 3})
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+
+ unlimited := &UnlimitedReader{}
+ sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if sent != n {
+ t.Fatalf("bad: %v", sent)
+ }
+ }()
+ }
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case <-doneCh:
+ case <-time.After(20 * time.Second):
+ panic("timeout")
+ }
+}
+
+func TestBacklogExceeded_Accept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ max := 5 * client.config.AcceptBacklog
+ go func() {
+ for i := 0; i < max; i++ {
+ stream, err := server.Accept()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }
+ }()
+
+ // Fill the backlog
+ for i := 0; i < max; i++ {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte("foo")); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }
+}
+
+func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize) - 1
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ n, err := stream.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if int64(n) != flood {
+ t.Fatalf("short write: %d", n)
+ }
+ }()
+
+ // The client will open a stream, block outbound writes, and then
+ // listen to the flood from the server, which should time out since
+ // it won't be able to send the window update.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ _, err = stream.Read(make([]byte, flood))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PartialReadWindowUpdate(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize)
+ var wr *Stream
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ var err error
+ wr, err = server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer wr.Close()
+
+ if wr.sendWindow != client.config.MaxStreamWindowSize {
+ t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
+ }
+
+ n, err := wr.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ if int64(n) != flood {
+ t.Fatalf("short write: %d", n)
+ }
+ if wr.sendWindow != 0 {
+ t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
+ }
+ }()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ wg.Wait()
+
+ _, err = stream.Read(make([]byte, flood/2+1))
+
+ if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
+ t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
+ }
+}
+
+func TestSession_sendNoWait_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we'll
+ // probe sendNoWait once it gets into that state.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, 0)
+ for {
+ err = client.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf("err: %v", err)
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PingOfDeath(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ var doPingOfDeath sync.Mutex
+ doPingOfDeath.Lock()
+
+ // This is used later to block outbound writes.
+ conn := server.conn.(*pipeConn)
+
+ // The server will accept a stream, block outbound writes, and then
+ // flood its send channel so that no more headers can be queued.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn.writeBlocker.Lock()
+ for {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, 0, 0, 0)
+ err = server.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf("err: %v", err)
+ }
+ }
+
+ doPingOfDeath.Unlock()
+ }()
+
+ // The client will open a stream and then send the server a ping once it
+ // can no longer write. This makes sure the server doesn't deadlock reads
+ // while trying to reply to the ping with no ability to write.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ // This ping will never unblock because the ping id will never
+ // show up in a response.
+ doPingOfDeath.Lock()
+ go func() { client.Ping() }()
+
+ // Wait for a while to make sure the previous ping times out,
+ // then turn writes back on and make sure a ping works again.
+ time.Sleep(2 * server.config.ConnectionWriteTimeout)
+ conn.writeBlocker.Unlock()
+ if _, err = client.Ping(); err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_ConnectionWriteTimeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we'll
+ // tee up a write and make sure it eventually times out.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf("err: %v", err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ // Since the write goroutine is blocked then this will return a
+ // timeout since it can't get feedback about whether the write
+ // worked.
+ n, err := stream.Write([]byte("hello"))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf("err: %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("lied about writes: %d", n)
+ }
+ }()
+
+ wg.Wait()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+spec.md - github.com/hashicorp/yamux
+# Specification
+
+We use this document to detail the internal specification of Yamux.
+This is used both as a guide for implementing Yamux, but also for
+alternative interoperable libraries to be built.
+
+# Framing
+
+Yamux uses a streaming connection underneath, but imposes a message
+framing so that it can be shared between many logical streams. Each
+frame contains a header like:
+
+* Version (8 bits)
+* Type (8 bits)
+* Flags (16 bits)
+* StreamID (32 bits)
+* Length (32 bits)
+
+This means that each header has a 12 byte overhead.
+All fields are encoded in network order (big endian).
+Each field is described below:
+
+## Version Field
+
+The version field is used for future backward compatibility. At the
+current time, the field is always set to 0, to indicate the initial
+version.
+
+## Type Field
+
+The type field is used to switch the frame message type. The following
+message types are supported:
+
+* 0x0 Data - Used to transmit data. May transmit zero length payloads
+ depending on the flags.
+
+* 0x1 Window Update - Used to updated the senders receive window size.
+ This is used to implement per-session flow control.
+
+* 0x2 Ping - Used to measure RTT. It can also be used to heart-beat
+ and do keep-alives over TCP.
+
+* 0x3 Go Away - Used to close a session.
+
+## Flag Field
+
+The flags field is used to provide additional information related
+to the message type. The following flags are supported:
+
+* 0x1 SYN - Signals the start of a new stream. May be sent with a data or
+ window update message. Also sent with a ping to indicate outbound.
+
+* 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data
+ or window update message. Also sent with a ping to indicate response.
+
+* 0x4 FIN - Performs a half-close of a stream. May be sent with a data
+ message or window update.
+
+* 0x8 RST - Reset a stream immediately. May be sent with a data or
+ window update message.
+
+## StreamID Field
+
+The StreamID field is used to identify the logical stream the frame
+is addressing. The client side should use odd ID's, and the server even.
+This prevents any collisions. Additionally, the 0 ID is reserved to represent
+the session.
+
+Both Ping and Go Away messages should always use the 0 StreamID.
+
+## Length Field
+
+The meaning of the length field depends on the message type:
+
+* Data - provides the length of bytes following the header
+* Window update - provides a delta update to the window size
+* Ping - Contains an opaque value, echoed back
+* Go Away - Contains an error code
+
+# Message Flow
+
+There is no explicit connection setup, as Yamux relies on an underlying
+transport to be provided. However, there is a distinction between client
+and server side of the connection.
+
+## Opening a stream
+
+To open a stream, an initial data or window update frame is sent
+with a new StreamID. The SYN flag should be set to signal a new stream.
+
+The receiver must then reply with either a data or window update frame
+with the StreamID along with the ACK flag to accept the stream or with
+the RST flag to reject the stream.
+
+Because we are relying on the reliable stream underneath, a connection
+can begin sending data once the SYN flag is sent. The corresponding
+ACK does not need to be received. This is particularly well suited
+for an RPC system where a client wants to open a stream and immediately
+fire a request without waiting for the RTT of the ACK.
+
+This does introduce the possibility of a connection being rejected
+after data has been sent already. This is a slight semantic difference
+from TCP, where the conection cannot be refused after it is opened.
+Clients should be prepared to handle this by checking for an error
+that indicates a RST was received.
+
+## Closing a stream
+
+To close a stream, either side sends a data or window update frame
+along with the FIN flag. This does a half-close indicating the sender
+will send no further data.
+
+Once both sides have closed the connection, the stream is closed.
+
+Alternatively, if an error occurs, the RST flag can be used to
+hard close a stream immediately.
+
+## Flow Control
+
+When Yamux is initially starts each stream with a 256KB window size.
+There is no window size for the session.
+
+To prevent the streams from stalling, window update frames should be
+sent regularly. Yamux can be configured to provide a larger limit for
+windows sizes. Both sides assume the initial 256KB window, but can
+immediately send a window update as part of the SYN/ACK indicating a
+larger window.
+
+Both sides should track the number of bytes sent in Data frames
+only, as only they are tracked as part of the window size.
+
+## Session termination
+
+When a session is being terminated, the Go Away message should
+be sent. The Length should be set to one of the following to
+provide an error code:
+
+* 0x0 Normal termination
+* 0x1 Protocol error
+* 0x2 Internal error
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+stream.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "bytes"
+ "io"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+type streamState int
+
+const (
+ streamInit streamState = iota
+ streamSYNSent
+ streamSYNReceived
+ streamEstablished
+ streamLocalClose
+ streamRemoteClose
+ streamClosed
+ streamReset
+)
+
+// Stream is used to represent a logical stream
+// within a session.
+type Stream struct {
+ recvWindow uint32
+ sendWindow uint32
+
+ id uint32
+ session *Session
+
+ state streamState
+ stateLock sync.Mutex
+
+ recvBuf *bytes.Buffer
+ recvLock sync.Mutex
+
+ controlHdr header
+ controlErr chan error
+ controlHdrLock sync.Mutex
+
+ sendHdr header
+ sendErr chan error
+ sendLock sync.Mutex
+
+ recvNotifyCh chan struct{}
+ sendNotifyCh chan struct{}
+
+ readDeadline atomic.Value // time.Time
+ writeDeadline atomic.Value // time.Time
+
+ // closeTimer is set with stateLock held to honor the StreamCloseTimeout
+ // setting on Session.
+ closeTimer *time.Timer
+}
+
+// newStream is used to construct a new stream within
+// a given session for an ID
+func newStream(session *Session, id uint32, state streamState) *Stream {
+ s := &Stream{
+ id: id,
+ session: session,
+ state: state,
+ controlHdr: header(make([]byte, headerSize)),
+ controlErr: make(chan error, 1),
+ sendHdr: header(make([]byte, headerSize)),
+ sendErr: make(chan error, 1),
+ recvWindow: initialStreamWindow,
+ sendWindow: initialStreamWindow,
+ recvNotifyCh: make(chan struct{}, 1),
+ sendNotifyCh: make(chan struct{}, 1),
+ }
+ s.readDeadline.Store(time.Time{})
+ s.writeDeadline.Store(time.Time{})
+ return s
+}
+
+// Session returns the associated stream session
+func (s *Stream) Session() *Session {
+ return s.session
+}
+
+// StreamID returns the ID of this stream
+func (s *Stream) StreamID() uint32 {
+ return s.id
+}
+
+// Read is used to read from the stream
+func (s *Stream) Read(b []byte) (n int, err error) {
+ defer asyncNotify(s.recvNotifyCh)
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamRemoteClose:
+ fallthrough
+ case streamClosed:
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ s.stateLock.Unlock()
+ return 0, io.EOF
+ }
+ s.recvLock.Unlock()
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ goto WAIT
+ }
+
+ // Read any bytes
+ n, _ = s.recvBuf.Read(b)
+ s.recvLock.Unlock()
+
+ // Send a window update potentially
+ err = s.sendWindowUpdate()
+ return n, err
+
+WAIT:
+ var timeout <-chan time.Time
+ var timer *time.Timer
+ readDeadline := s.readDeadline.Load().(time.Time)
+ if !readDeadline.IsZero() {
+ delay := readDeadline.Sub(time.Now())
+ timer = time.NewTimer(delay)
+ timeout = timer.C
+ }
+ select {
+ case <-s.recvNotifyCh:
+ if timer != nil {
+ timer.Stop()
+ }
+ goto START
+ case <-timeout:
+ return 0, ErrTimeout
+ }
+}
+
+// Write is used to write to the stream
+func (s *Stream) Write(b []byte) (n int, err error) {
+ s.sendLock.Lock()
+ defer s.sendLock.Unlock()
+ total := 0
+ for total < len(b) {
+ n, err := s.write(b[total:])
+ total += n
+ if err != nil {
+ return total, err
+ }
+ }
+ return total, nil
+}
+
+// write is used to write to the stream, may return on
+// a short write.
+func (s *Stream) write(b []byte) (n int, err error) {
+ var flags uint16
+ var max uint32
+ var body io.Reader
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamClosed:
+ s.stateLock.Unlock()
+ return 0, ErrStreamClosed
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ window := atomic.LoadUint32(&s.sendWindow)
+ if window == 0 {
+ goto WAIT
+ }
+
+ // Determine the flags if any
+ flags = s.sendFlags()
+
+ // Send up to our send window
+ max = min(window, uint32(len(b)))
+ body = bytes.NewReader(b[:max])
+
+ // Send the header
+ s.sendHdr.encode(typeData, flags, s.id, max)
+ if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
+ return 0, err
+ }
+
+ // Reduce our send window
+ atomic.AddUint32(&s.sendWindow, ^uint32(max-1))
+
+ // Unlock
+ return int(max), err
+
+WAIT:
+ var timeout <-chan time.Time
+ writeDeadline := s.writeDeadline.Load().(time.Time)
+ if !writeDeadline.IsZero() {
+ delay := writeDeadline.Sub(time.Now())
+ timeout = time.After(delay)
+ }
+ select {
+ case <-s.sendNotifyCh:
+ goto START
+ case <-timeout:
+ return 0, ErrTimeout
+ }
+ return 0, nil
+}
+
+// sendFlags determines any flags that are appropriate
+// based on the current stream state
+func (s *Stream) sendFlags() uint16 {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+ var flags uint16
+ switch s.state {
+ case streamInit:
+ flags |= flagSYN
+ s.state = streamSYNSent
+ case streamSYNReceived:
+ flags |= flagACK
+ s.state = streamEstablished
+ }
+ return flags
+}
+
+// sendWindowUpdate potentially sends a window update enabling
+// further writes to take place. Must be invoked with the lock.
+func (s *Stream) sendWindowUpdate() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ // Determine the delta update
+ max := s.session.config.MaxStreamWindowSize
+ var bufLen uint32
+ s.recvLock.Lock()
+ if s.recvBuf != nil {
+ bufLen = uint32(s.recvBuf.Len())
+ }
+ delta := (max - bufLen) - s.recvWindow
+
+ // Determine the flags if any
+ flags := s.sendFlags()
+
+ // Check if we can omit the update
+ if delta < (max/2) && flags == 0 {
+ s.recvLock.Unlock()
+ return nil
+ }
+
+ // Update our window
+ s.recvWindow += delta
+ s.recvLock.Unlock()
+
+ // Send the header
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// sendClose is used to send a FIN
+func (s *Stream) sendClose() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ flags := s.sendFlags()
+ flags |= flagFIN
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Close is used to close the stream
+func (s *Stream) Close() error {
+ closeStream := false
+ s.stateLock.Lock()
+ switch s.state {
+ // Opened means we need to signal a close
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamLocalClose
+ goto SEND_CLOSE
+
+ case streamLocalClose:
+ case streamRemoteClose:
+ s.state = streamClosed
+ closeStream = true
+ goto SEND_CLOSE
+
+ case streamClosed:
+ case streamReset:
+ default:
+ panic("unhandled state")
+ }
+ s.stateLock.Unlock()
+ return nil
+SEND_CLOSE:
+ // This shouldn't happen (the more realistic scenario to cancel the
+ // timer is via processFlags) but just in case this ever happens, we
+ // cancel the timer to prevent dangling timers.
+ if s.closeTimer != nil {
+ s.closeTimer.Stop()
+ s.closeTimer = nil
+ }
+
+ // If we have a StreamCloseTimeout set we start the timeout timer.
+ // We do this only if we're not already closing the stream since that
+ // means this was a graceful close.
+ //
+ // This prevents memory leaks if one side (this side) closes and the
+ // remote side poorly behaves and never responds with a FIN to complete
+ // the close. After the specified timeout, we clean our resources up no
+ // matter what.
+ if !closeStream && s.session.config.StreamCloseTimeout > 0 {
+ s.closeTimer = time.AfterFunc(
+ s.session.config.StreamCloseTimeout, s.closeTimeout)
+ }
+
+ s.stateLock.Unlock()
+ s.sendClose()
+ s.notifyWaiting()
+ if closeStream {
+ s.session.closeStream(s.id)
+ }
+ return nil
+}
+
+// closeTimeout is called after StreamCloseTimeout during a close to
+// close this stream.
+func (s *Stream) closeTimeout() {
+ // Close our side forcibly
+ s.forceClose()
+
+ // Free the stream from the session map
+ s.session.closeStream(s.id)
+
+ // Send a RST so the remote side closes too.
+ s.sendLock.Lock()
+ defer s.sendLock.Unlock()
+ s.sendHdr.encode(typeWindowUpdate, flagRST, s.id, 0)
+ s.session.sendNoWait(s.sendHdr)
+}
+
+// forceClose is used for when the session is exiting
+func (s *Stream) forceClose() {
+ s.stateLock.Lock()
+ s.state = streamClosed
+ s.stateLock.Unlock()
+ s.notifyWaiting()
+}
+
+// processFlags is used to update the state of the stream
+// based on set flags, if any. Lock must be held
+func (s *Stream) processFlags(flags uint16) error {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+
+ // Close the stream without holding the state lock
+ closeStream := false
+ defer func() {
+ if closeStream {
+ if s.closeTimer != nil {
+ // Stop our close timeout timer since we gracefully closed
+ s.closeTimer.Stop()
+ }
+
+ s.session.closeStream(s.id)
+ }
+ }()
+
+ if flags&flagACK == flagACK {
+ if s.state == streamSYNSent {
+ s.state = streamEstablished
+ }
+ s.session.establishStream(s.id)
+ }
+ if flags&flagFIN == flagFIN {
+ switch s.state {
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamRemoteClose
+ s.notifyWaiting()
+ case streamLocalClose:
+ s.state = streamClosed
+ closeStream = true
+ s.notifyWaiting()
+ default:
+ s.session.logger.Printf("[ERR] yamux: unexpected FIN flag in state %d", s.state)
+ return ErrUnexpectedFlag
+ }
+ }
+ if flags&flagRST == flagRST {
+ s.state = streamReset
+ closeStream = true
+ s.notifyWaiting()
+ }
+ return nil
+}
+
+// notifyWaiting notifies all the waiting channels
+func (s *Stream) notifyWaiting() {
+ asyncNotify(s.recvNotifyCh)
+ asyncNotify(s.sendNotifyCh)
+}
+
+// incrSendWindow updates the size of our send window
+func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Increase window, unblock a sender
+ atomic.AddUint32(&s.sendWindow, hdr.Length())
+ asyncNotify(s.sendNotifyCh)
+ return nil
+}
+
+// readData is used to handle a data frame
+func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Check that our recv window is not exceeded
+ length := hdr.Length()
+ if length == 0 {
+ return nil
+ }
+
+ // Wrap in a limited reader
+ conn = &io.LimitedReader{R: conn, N: int64(length)}
+
+ // Copy into buffer
+ s.recvLock.Lock()
+
+ if length > s.recvWindow {
+ s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
+ return ErrRecvWindowExceeded
+ }
+
+ if s.recvBuf == nil {
+ // Allocate the receive buffer just-in-time to fit the full data frame.
+ // This way we can read in the whole packet without further allocations.
+ s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
+ }
+ if _, err := io.Copy(s.recvBuf, conn); err != nil {
+ s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
+ s.recvLock.Unlock()
+ return err
+ }
+
+ // Decrement the receive window
+ s.recvWindow -= length
+ s.recvLock.Unlock()
+
+ // Unblock any readers
+ asyncNotify(s.recvNotifyCh)
+ return nil
+}
+
+// SetDeadline sets the read and write deadlines
+func (s *Stream) SetDeadline(t time.Time) error {
+ if err := s.SetReadDeadline(t); err != nil {
+ return err
+ }
+ if err := s.SetWriteDeadline(t); err != nil {
+ return err
+ }
+ return nil
+}
+
+// SetReadDeadline sets the deadline for blocked and future Read calls.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+ s.readDeadline.Store(t)
+ asyncNotify(s.recvNotifyCh)
+ return nil
+}
+
+// SetWriteDeadline sets the deadline for blocked and future Write calls
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+ s.writeDeadline.Store(t)
+ asyncNotify(s.sendNotifyCh)
+ return nil
+}
+
+// Shrink is used to compact the amount of buffers utilized
+// This is useful when using Yamux in a connection pool to reduce
+// the idle memory utilization.
+func (s *Stream) Shrink() {
+ s.recvLock.Lock()
+ if s.recvBuf != nil && s.recvBuf.Len() == 0 {
+ s.recvBuf = nil
+ }
+ s.recvLock.Unlock()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+util.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "sync"
+ "time"
+)
+
+var (
+ timerPool = &sync.Pool{
+ New: func() interface{} {
+ timer := time.NewTimer(time.Hour * 1e6)
+ timer.Stop()
+ return timer
+ },
+ }
+)
+
+// asyncSendErr is used to try an async send of an error
+func asyncSendErr(ch chan error, err error) {
+ if ch == nil {
+ return
+ }
+ select {
+ case ch <- err:
+ default:
+ }
+}
+
+// asyncNotify is used to signal a waiting goroutine
+func asyncNotify(ch chan struct{}) {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
+// min computes the minimum of two values
+func min(a, b uint32) uint32 {
+ if a < b {
+ return a
+ }
+ return b
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+util_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ "testing"
+)
+
+func TestAsyncSendErr(t *testing.T) {
+ ch := make(chan error)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case <-ch:
+ t.Fatalf("should not get")
+ default:
+ }
+
+ ch = make(chan error, 1)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("should get")
+ }
+}
+
+func TestAsyncNotify(t *testing.T) {
+ ch := make(chan struct{})
+ asyncNotify(ch)
+ select {
+ case <-ch:
+ t.Fatalf("should not get")
+ default:
+ }
+
+ ch = make(chan struct{}, 1)
+ asyncNotify(ch)
+ select {
+ case <-ch:
+ default:
+ t.Fatalf("should get")
+ }
+}
+
+func TestMin(t *testing.T) {
+ if min(1, 2) != 1 {
+ t.Fatalf("bad")
+ }
+ if min(2, 1) != 1 {
+ t.Fatalf("bad")
+ }
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LICENSE - github.com/jcmturner/gofork
Copyright (c) 2009 The Go Authors. All rights reserved.
diff --git a/go.mod b/go.mod
index f9c2d9ae9..9b67ddc36 100644
--- a/go.mod
+++ b/go.mod
@@ -23,6 +23,7 @@ require (
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/hashicorp/golang-lru v0.5.4
+ github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864
github.com/kelseyhightower/envconfig v1.3.0
github.com/lib/pq v1.2.0
github.com/libgit2/git2go/v31 v31.4.12
diff --git a/go.sum b/go.sum
index 34b473729..ee105d591 100644
--- a/go.sum
+++ b/go.sum
@@ -196,6 +196,8 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
+github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864 h1:Y4V+SFe7d3iH+9pJCoeWIOS5/xBJIFsltS7E+KJSsJY=
+github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
@@ -304,14 +306,10 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/otiai10/copy v1.0.1 h1:gtBjD8aq4nychvRZ2CyJvFWAw0aja+VHazDdruZKGZA=
github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
-github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
-github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE=
github.com/otiai10/curr v1.0.0 h1:TJIWdbX0B+kpNagQrjgq8bCMrbhiuX73M2XwgtDMoOI=
github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs=
github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
-github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
-github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
github.com/otiai10/mint v1.3.0 h1:Ady6MKVezQwHBkGzLFbrsywyp09Ah7rkmfjV3Bcr5uc=
github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo=
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
@@ -360,7 +358,6 @@ github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBa
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
-github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -423,7 +420,6 @@ gitlab.com/gitlab-org/gitaly v1.68.0/go.mod h1:/pCsB918Zu5wFchZ9hLYin9WkJ2yQqdVN
gitlab.com/gitlab-org/gitlab-shell v1.9.8-0.20201117050822-3f9890ef73dc h1:ENBJqb2gK3cZFHe8dbEM/TPXCwkAxDuxGK255ux4XPg=
gitlab.com/gitlab-org/gitlab-shell v1.9.8-0.20201117050822-3f9890ef73dc/go.mod h1:5QSTbpAHY2v0iIH5uHh2KA9w7sPUqPmnLjDApI/sv1U=
gitlab.com/gitlab-org/labkit v0.0.0-20190221122536-0c3fc7cdd57c/go.mod h1:rYhLgfrbEcyfinG+R3EvKu6bZSsmwQqcXzLfHWSfUKM=
-gitlab.com/gitlab-org/labkit v0.0.0-20190221122536-0c3fc7cdd57c/go.mod h1:rYhLgfrbEcyfinG+R3EvKu6bZSsmwQqcXzLfHWSfUKM=
gitlab.com/gitlab-org/labkit v0.0.0-20200908084045-45895e129029/go.mod h1:SNfxkfUwVNECgtmluVayv0GWFgEjjBs5AzgsowPQuo0=
gitlab.com/gitlab-org/labkit v1.0.0 h1:t2Wr8ygtvHfXAMlCkoEdk5pdb5Gy1IYdr41H7t4kAYw=
gitlab.com/gitlab-org/labkit v1.0.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U=
diff --git a/internal/backchannel/backchannel.go b/internal/backchannel/backchannel.go
new file mode 100644
index 000000000..b22fc71e0
--- /dev/null
+++ b/internal/backchannel/backchannel.go
@@ -0,0 +1,60 @@
+// Package backchannel implements connection multiplexing that allows for invoking
+// gRPC methods from the server to the client.
+//
+// gRPC allows only for invoking RPCs from client to the server. Invoking
+// RPCs from the server to the client can be useful in some cases such as
+// tunneling through firewalls. While implementing such a use case would be
+// possible with plain bidirectional streams, the approach has various limitations
+// that force additional work on the user. All messages in a single stream are ordered
+// and processed sequentially. If concurrency is desired, this would require the user
+// to implement their own concurrency handling. Request routing and cancellations would also
+// have to be implemented separately on top of the bidirectional stream.
+//
+// To do away with these problems, this package provides a multiplexed transport for running two
+// independent gRPC sessions on a single connection. This allows for dialing back to the client from
+// the server to establish another gRPC session where the server and client roles are switched.
+//
+// The server side supports clients that are unaware of the multiplexing. The server peeks the incoming
+// network stream to see if it starts with the magic bytes that indicate a multiplexing aware client.
+// If the magic bytes are present, the server initiates the multiplexing session and dials back to the client
+// over the already established network connection. If the magic bytes are not present, the server restores the
+// the bytes back into the original network stream and handles it without a multiplexing session.
+//
+// Usage:
+// 1. Implement a ServerFactory, which is simply a function that returns a Server that can serve on the backchannel
+// connection. Plug in the ClientHandshake returned by the ServerFactory.ClientHandshaker via grpc.WithTransportCredentials.
+// This ensures all connections established by gRPC work with a multiplexing session and have a backchannel Server serving.
+// 2. Configure the ServerHandshake on the server side by passing it into the gRPC server via the grpc.Creds option.
+// The ServerHandshake method is called on each newly established connection. It peeks the network stream to see if a
+// multiplexing session should be initiated. If so, it also dials back to the client's backchannel server. Server
+// makes the backchannel connection's available later via the Registry's Backchannel method. The ID of the
+// peer associated with the current RPC handler can be fetched via GetPeerID. The returned ID can be used
+// to access the correct backchannel connection from the Registry.
+package backchannel
+
+import (
+ "io"
+ "net"
+
+ "github.com/hashicorp/yamux"
+)
+
+// magicBytes are sent by the client to server to identify as a multiplexing aware client.
+var magicBytes = []byte("backchannel")
+
+// muxConfig returns a new config to use with the multiplexing session.
+func muxConfig(logger io.Writer) *yamux.Config {
+ cfg := yamux.DefaultConfig()
+ cfg.LogOutput = logger
+ return cfg
+}
+
+// connCloser wraps a net.Conn and calls the provided close function instead when Close
+// is called.
+type connCloser struct {
+ net.Conn
+ close func() error
+}
+
+// Close calls the provided close function.
+func (cc connCloser) Close() error { return cc.close() }
diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go
new file mode 100644
index 000000000..20387b2de
--- /dev/null
+++ b/internal/backchannel/backchannel_example_test.go
@@ -0,0 +1,139 @@
+package backchannel_test
+
+import (
+ "context"
+ "fmt"
+ "net"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+)
+
+func Example() {
+ // Open the server's listener.
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ fmt.Printf("failed to start listener: %v", err)
+ return
+ }
+
+ // Registry is for storing the open backchannels. It should be passed into the ServerHandshaker
+ // which creates the backchannel connections and adds them to the registry. The RPC handlers
+ // can use the registry to access available backchannels by their peer ID.
+ registry := backchannel.NewRegistry()
+
+ logger := logrus.NewEntry(logrus.New())
+
+ // ServerHandshaker initiates the multiplexing session on the server side. Once that is done,
+ // it creates the backchannel connection and stores it into the registry. For each connection,
+ // the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a
+ // backchannel connection.
+ handshaker := backchannel.NewServerHandshaker(logger, backchannel.Insecure(), registry)
+
+ // Create the server
+ srv := grpc.NewServer(
+ grpc.Creds(handshaker),
+ grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
+ fmt.Println("Gitaly received a transactional mutator")
+
+ backchannelID, err := backchannel.GetPeerID(stream.Context())
+ if err == backchannel.ErrNonMultiplexedConnection {
+ // This call is from a client that is not multiplexing aware. Client is not
+ // Praefect, so no need to perform voting. The client could be for example
+ // GitLab calling Gitaly directly.
+ fmt.Println("Gitaly responding to a non-multiplexed client")
+ return stream.SendMsg(&gitalypb.CreateBranchResponse{})
+ } else if err != nil {
+ return fmt.Errorf("get peer id: %w", err)
+ }
+
+ backchannelConn, err := registry.Backchannel(backchannelID)
+ if err != nil {
+ return fmt.Errorf("get backchannel: %w", err)
+ }
+
+ fmt.Println("Gitaly sending vote to Praefect via backchannel")
+ if err := backchannelConn.Invoke(
+ stream.Context(), "/Praefect/VoteTransaction",
+ &gitalypb.VoteTransactionRequest{}, &gitalypb.VoteTransactionResponse{},
+ ); err != nil {
+ return fmt.Errorf("invoke backchannel: %w", err)
+ }
+ fmt.Println("Gitaly received vote response via backchannel")
+
+ fmt.Println("Gitaly responding to the transactional mutator")
+ return stream.SendMsg(&gitalypb.CreateBranchResponse{})
+ }),
+ )
+ defer srv.Stop()
+
+ // Start the server
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ fmt.Printf("failed to serve: %v", err)
+ }
+ }()
+
+ fmt.Printf("Invoke with a multiplexed client:\n\n")
+ if err := invokeWithMuxedClient(logger, ln.Addr().String()); err != nil {
+ fmt.Printf("failed to invoke with muxed client: %v", err)
+ return
+ }
+
+ fmt.Printf("\nInvoke with a non-multiplexed client:\n\n")
+ if err := invokeWithNormalClient(ln.Addr().String()); err != nil {
+ fmt.Printf("failed to invoke with non-muxed client: %v", err)
+ return
+ }
+ // Output:
+ // Invoke with a multiplexed client:
+ //
+ // Gitaly received a transactional mutator
+ // Gitaly sending vote to Praefect via backchannel
+ // Praefect received vote via backchannel
+ // Praefect responding via backchannel
+ // Gitaly received vote response via backchannel
+ // Gitaly responding to the transactional mutator
+ //
+ // Invoke with a non-multiplexed client:
+ //
+ // Gitaly received a transactional mutator
+ // Gitaly responding to a non-multiplexed client
+}
+
+func invokeWithMuxedClient(logger *logrus.Entry, address string) error {
+ // serverFactory gets called on each established connection. The Server it returns
+ // is started on Praefect's end of the connection, which Gitaly can call.
+ serverFactory := backchannel.ServerFactory(func() backchannel.Server {
+ return grpc.NewServer(grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
+ fmt.Println("Praefect received vote via backchannel")
+ fmt.Println("Praefect responding via backchannel")
+ return stream.SendMsg(&gitalypb.VoteTransactionResponse{})
+ }))
+ })
+
+ return invokeWithOpts(address, grpc.WithTransportCredentials(serverFactory.ClientHandshaker(logger, backchannel.Insecure())))
+}
+
+func invokeWithNormalClient(address string) error {
+ return invokeWithOpts(address, grpc.WithInsecure())
+}
+
+func invokeWithOpts(address string, opts ...grpc.DialOption) error {
+ clientConn, err := grpc.Dial(address, opts...)
+ if err != nil {
+ return fmt.Errorf("dial server: %w", err)
+ }
+
+ if err := clientConn.Invoke(context.Background(), "/Gitaly/Mutator", &gitalypb.CreateBranchRequest{}, &gitalypb.CreateBranchResponse{}); err != nil {
+ return fmt.Errorf("call server: %w", err)
+ }
+
+ if err := clientConn.Close(); err != nil {
+ return fmt.Errorf("close clientConn: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go
new file mode 100644
index 000000000..6b93153cb
--- /dev/null
+++ b/internal/backchannel/backchannel_test.go
@@ -0,0 +1,223 @@
+package backchannel
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+type mockTransactionServer struct {
+ voteTransactionFunc func(context.Context, *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error)
+ *gitalypb.UnimplementedRefTransactionServer
+}
+
+func (m mockTransactionServer) VoteTransaction(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ return m.voteTransactionFunc(ctx, req)
+}
+
+func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
+ registry := NewRegistry()
+ handshaker := NewServerHandshaker(testhelper.DiscardTestEntry(t), Insecure(), registry)
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+
+ errNonMultiplexed := status.Error(codes.FailedPrecondition, ErrNonMultiplexedConnection.Error())
+ srv := grpc.NewServer(grpc.Creds(handshaker))
+
+ gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{
+ voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ peerID, err := GetPeerID(ctx)
+ if err == ErrNonMultiplexedConnection {
+ return nil, errNonMultiplexed
+ }
+ assert.NoError(t, err)
+
+ cc, err := registry.Backchannel(peerID)
+ if !assert.NoError(t, err) {
+ return nil, err
+ }
+
+ return gitalypb.NewRefTransactionClient(cc).VoteTransaction(ctx, req)
+ },
+ })
+
+ defer srv.Stop()
+ go srv.Serve(ln)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ start := make(chan struct{})
+
+ // Create 25 multiplexed clients and non-multiplexed clients that launch requests
+ // concurrently.
+ var wg sync.WaitGroup
+ for i := uint64(0); i < 25; i++ {
+ i := i
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ <-start
+ client, err := grpc.Dial(ln.Addr().String(), grpc.WithInsecure())
+ if !assert.NoError(t, err) {
+ return
+ }
+
+ resp, err := gitalypb.NewRefTransactionClient(client).VoteTransaction(ctx, &gitalypb.VoteTransactionRequest{})
+ assert.Equal(t, err, errNonMultiplexed)
+ assert.Nil(t, resp)
+
+ assert.NoError(t, client.Close())
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ expectedErr := status.Error(codes.Internal, fmt.Sprintf("multiplexed %d", i))
+ serverFactory := ServerFactory(func() Server {
+ srv := grpc.NewServer()
+ gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{
+ voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ assert.Equal(t, &gitalypb.VoteTransactionRequest{TransactionId: i}, req)
+ return nil, expectedErr
+ },
+ })
+
+ return srv
+ })
+
+ <-start
+ client, err := grpc.Dial(ln.Addr().String(),
+ grpc.WithTransportCredentials(serverFactory.ClientHandshaker(testhelper.DiscardTestEntry(t), Insecure())),
+ )
+ if !assert.NoError(t, err) {
+ return
+ }
+
+ // Run two invocations concurrently on each multiplexed client to sanity check
+ // the routing works with multiple requests from a connection.
+ var invocations sync.WaitGroup
+ for invocation := 0; invocation < 2; invocation++ {
+ invocations.Add(1)
+ go func() {
+ defer invocations.Done()
+ resp, err := gitalypb.NewRefTransactionClient(client).VoteTransaction(ctx, &gitalypb.VoteTransactionRequest{TransactionId: i})
+ assert.Equal(t, err, expectedErr)
+ assert.Nil(t, resp)
+ }()
+ }
+
+ invocations.Wait()
+ assert.NoError(t, client.Close())
+ }()
+ }
+
+ // Establish the connection and fire the requests.
+ close(start)
+
+ // Wait for the clients to finish their calls and close their connections.
+ wg.Wait()
+}
+
+type mockSSHService struct {
+ sshUploadPackFunc func(gitalypb.SSHService_SSHUploadPackServer) error
+ *gitalypb.UnimplementedSSHServiceServer
+}
+
+func (m mockSSHService) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ return m.sshUploadPackFunc(stream)
+}
+
+func Benchmark(b *testing.B) {
+ for _, tc := range []struct {
+ desc string
+ multiplexed bool
+ }{
+ {desc: "multiplexed", multiplexed: true},
+ {desc: "normal"},
+ } {
+ b.Run(tc.desc, func(b *testing.B) {
+ for _, messageSize := range []int64{
+ 1024,
+ 1024 * 1024,
+ 3 * 1024 * 1024,
+ } {
+ b.Run(fmt.Sprintf("message size %dkb", messageSize/1024), func(b *testing.B) {
+ var serverOpts []grpc.ServerOption
+ if tc.multiplexed {
+ serverOpts = []grpc.ServerOption{
+ grpc.Creds(NewServerHandshaker(testhelper.DiscardTestEntry(b), Insecure(), NewRegistry())),
+ }
+ }
+
+ srv := grpc.NewServer(serverOpts...)
+ gitalypb.RegisterSSHServiceServer(srv, mockSSHService{
+ sshUploadPackFunc: func(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ for {
+ _, err := stream.Recv()
+ if err != nil {
+ assert.Equal(b, io.EOF, err)
+ return nil
+ }
+ }
+ },
+ })
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ require.NoError(b, err)
+
+ defer srv.Stop()
+ go srv.Serve(ln)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ opts := []grpc.DialOption{grpc.WithBlock(), grpc.WithInsecure()}
+ if tc.multiplexed {
+ nopServer := ServerFactory(func() Server { return grpc.NewServer() })
+ opts = []grpc.DialOption{
+ grpc.WithBlock(),
+ grpc.WithTransportCredentials(nopServer.ClientHandshaker(
+ testhelper.DiscardTestEntry(b), Insecure(),
+ )),
+ }
+ }
+
+ cc, err := grpc.DialContext(ctx, ln.Addr().String(), opts...)
+ require.NoError(b, err)
+
+ defer cc.Close()
+
+ client, err := gitalypb.NewSSHServiceClient(cc).SSHUploadPack(ctx)
+ require.NoError(b, err)
+
+ request := &gitalypb.SSHUploadPackRequest{Stdin: make([]byte, messageSize)}
+ b.SetBytes(messageSize)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ require.NoError(b, client.Send(request))
+ }
+
+ require.NoError(b, client.CloseSend())
+ _, err = client.Recv()
+ require.Equal(b, io.EOF, err)
+ })
+ }
+ })
+ }
+}
diff --git a/internal/backchannel/client.go b/internal/backchannel/client.go
new file mode 100644
index 000000000..a844d6d7f
--- /dev/null
+++ b/internal/backchannel/client.go
@@ -0,0 +1,125 @@
+package backchannel
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/hashicorp/yamux"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/credentials"
+)
+
+// Server is the interface of a backchannel server.
+type Server interface {
+ // Serve starts serving on the listener.
+ Serve(net.Listener) error
+ // Stops the server and closes all connections.
+ Stop()
+}
+
+// ServerFactory returns the server that should serve on the backchannel.
+// Each invocation should return a new server as the servers get stopped when
+// a backchannel closes.
+type ServerFactory func() Server
+
+// ClientHandshaker returns TransportCredentials that perform the client side multiplexing handshake and
+// start the backchannel Server on the established connections. The provided logger is used to log multiplexing
+// errors and the transport credentials are used to intiliaze the connection prior to the multiplexing.
+func (sf ServerFactory) ClientHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials) credentials.TransportCredentials {
+ return clientHandshaker{TransportCredentials: tc, serverFactory: sf, logger: logger}
+}
+
+type clientHandshaker struct {
+ credentials.TransportCredentials
+ serverFactory ServerFactory
+ logger *logrus.Entry
+}
+
+func (ch clientHandshaker) ClientHandshake(ctx context.Context, serverName string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := ch.TransportCredentials.ClientHandshake(ctx, serverName, conn)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ clientStream, err := ch.serve(ctx, conn)
+ if err != nil {
+ return nil, nil, fmt.Errorf("serve: %w", err)
+ }
+
+ return clientStream, authInfo, nil
+}
+
+func (ch clientHandshaker) serve(ctx context.Context, conn net.Conn) (net.Conn, error) {
+ deadline := time.Time{}
+ if dl, ok := ctx.Deadline(); ok {
+ deadline = dl
+ }
+
+ // gRPC expects the ClientHandshaker implementation to respect the deadline set in the context.
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("set connection deadline: %w", err)
+ }
+
+ defer func() {
+ // The deadline has to be cleared after the muxing session is established as we are not
+ // returning the Conn itself but the stream, thus gRPC can't clear the deadline we set
+ // on the Conn.
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ ch.logger.WithError(err).Error("remove connection deadline")
+ }
+ }()
+
+ // Write the magic bytes on the connection so the server knows we're about to initiate
+ // a multiplexing session.
+ if _, err := conn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("write backchannel magic bytes: %w", err)
+ }
+
+ logger := ch.logger.WriterLevel(logrus.ErrorLevel)
+
+ // Initiate the multiplexing session.
+ muxSession, err := yamux.Client(conn, muxConfig(logger))
+ if err != nil {
+ logger.Close()
+ return nil, fmt.Errorf("open multiplexing session: %w", err)
+ }
+
+ go func() {
+ <-muxSession.CloseChan()
+ logger.Close()
+ }()
+
+ // Initiate the stream to the server. This is used by the client's gRPC session.
+ clientToServer, err := muxSession.Open()
+ if err != nil {
+ return nil, fmt.Errorf("open client stream: %w", err)
+ }
+
+ // Run the backchannel server.
+ server := ch.serverFactory()
+ serveErr := make(chan error, 1)
+ go func() { serveErr <- server.Serve(muxSession) }()
+
+ return connCloser{
+ Conn: clientToServer,
+ close: func() error {
+ // Stop closes the listener, which is the muxing session. Closing the
+ // muxing session closes the underlying network connection.
+ //
+ // There's no sense in doing a graceful shutdown. The connection is being closed,
+ // it would no longer receive a response from the server.
+ server.Stop()
+ // Serve returns a non-nil error if it returned before Stop was called. If the error
+ // is non-nil, it indicates a serving failure prior to calling Stop.
+ return <-serveErr
+ }}, nil
+}
+
+func (ch clientHandshaker) Clone() credentials.TransportCredentials {
+ return clientHandshaker{
+ TransportCredentials: ch.TransportCredentials.Clone(),
+ serverFactory: ch.serverFactory,
+ }
+}
diff --git a/internal/backchannel/insecure.go b/internal/backchannel/insecure.go
new file mode 100644
index 000000000..678a90527
--- /dev/null
+++ b/internal/backchannel/insecure.go
@@ -0,0 +1,40 @@
+package backchannel
+
+import (
+ "context"
+ "net"
+
+ "google.golang.org/grpc/credentials"
+)
+
+type insecureAuthInfo struct{ credentials.CommonAuthInfo }
+
+func (insecureAuthInfo) AuthType() string { return "insecure" }
+
+type insecure struct{}
+
+func (insecure) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ return conn, insecureAuthInfo{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
+}
+
+func (insecure) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ return conn, insecureAuthInfo{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
+}
+
+func (insecure) Info() credentials.ProtocolInfo {
+ return credentials.ProtocolInfo{SecurityProtocol: "insecure"}
+}
+
+func (insecure) Clone() credentials.TransportCredentials { return Insecure() }
+
+func (insecure) OverrideServerName(string) error { return nil }
+
+// Insecure can be used in place of transport credentials when no transport security is configured.
+// Its handshakes simply return the passed in connection.
+//
+// Similar credentials are already implemented in gRPC:
+// https://github.com/grpc/grpc-go/blob/702608ffae4d03a6821b96d3e2311973d34b96dc/credentials/insecure/insecure.go
+// We've reimplemented these here as upgrading our gRPC version was very involved. Once
+// we've upgrade to a version that contains the insecure credentials, this implementation can be removed and
+// substituted by the official implementation.
+func Insecure() credentials.TransportCredentials { return insecure{} }
diff --git a/internal/backchannel/registry.go b/internal/backchannel/registry.go
new file mode 100644
index 000000000..407d770f9
--- /dev/null
+++ b/internal/backchannel/registry.go
@@ -0,0 +1,51 @@
+package backchannel
+
+import (
+ "fmt"
+ "sync"
+
+ "google.golang.org/grpc"
+)
+
+// ID is a monotonically increasing number that uniquely identifies a peer connection.
+type ID uint64
+
+// Registry is a thread safe registry for backchannels. It enables accessing the backchannels via a
+// unique ID.
+type Registry struct {
+ m sync.RWMutex
+ currentID ID
+ backchannels map[ID]*grpc.ClientConn
+}
+
+// NewRegistry returns a new Registry.
+func NewRegistry() *Registry { return &Registry{backchannels: map[ID]*grpc.ClientConn{}} }
+
+// Backchannel returns a backchannel for the ID. Returns an error if no backchannel is registered
+// for the ID.
+func (r *Registry) Backchannel(id ID) (*grpc.ClientConn, error) {
+ r.m.RLock()
+ defer r.m.RUnlock()
+ backchannel, ok := r.backchannels[id]
+ if !ok {
+ return nil, fmt.Errorf("no backchannel for peer %d", id)
+ }
+
+ return backchannel, nil
+}
+
+// RegisterBackchannel registers a new backchannel and returns its unique ID.
+func (r *Registry) RegisterBackchannel(conn *grpc.ClientConn) ID {
+ r.m.Lock()
+ defer r.m.Unlock()
+ r.currentID++
+ r.backchannels[r.currentID] = conn
+ return r.currentID
+}
+
+// RemoveBackchannel removes a backchannel from the registry.
+func (r *Registry) RemoveBackchannel(id ID) {
+ r.m.Lock()
+ defer r.m.Unlock()
+ delete(r.backchannels, id)
+}
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go
new file mode 100644
index 000000000..44a769ca9
--- /dev/null
+++ b/internal/backchannel/server.go
@@ -0,0 +1,146 @@
+package backchannel
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+
+ "github.com/hashicorp/yamux"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/peer"
+)
+
+// ErrNonMultiplexedConnection is returned when attempting to get the peer id of a non-multiplexed
+// connection.
+var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection")
+
+// authInfoWrapper is used to pass the peer id through the context to the RPC handlers.
+type authInfoWrapper struct {
+ id ID
+ credentials.AuthInfo
+}
+
+func (w authInfoWrapper) peerID() ID { return w.id }
+
+// GetPeerID gets the ID of the current peer connection.
+func GetPeerID(ctx context.Context) (ID, error) {
+ peerInfo, ok := peer.FromContext(ctx)
+ if !ok {
+ return 0, errors.New("no peer info in context")
+ }
+
+ wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID })
+ if !ok {
+ return 0, ErrNonMultiplexedConnection
+ }
+
+ return wrapper.peerID(), nil
+}
+
+// ServerHandshaker implements the server side handshake of the multiplexed connection.
+type ServerHandshaker struct {
+ registry *Registry
+ logger *logrus.Entry
+ credentials.TransportCredentials
+}
+
+// NewServerHandshaker returns a new server side implementation of the backchannel.
+func NewServerHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials, reg *Registry) credentials.TransportCredentials {
+ return ServerHandshaker{
+ TransportCredentials: tc,
+ registry: reg,
+ logger: logger,
+ }
+}
+
+// restoredConn allows for restoring the connection's stream after peeking it. If the connection
+// was not multiplexed, the peeked bytes are restored back into the stream.
+type restoredConn struct {
+ net.Conn
+ reader io.Reader
+}
+
+func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }
+
+// ServerHandshake peeks the connection to determine whether the client supports establishing a
+// backchannel by multiplexing the network connection. If so, it establishes a gRPC ClientConn back
+// to the client and stores it's ID in the AuthInfo where it can be later accessed by the RPC handlers.
+// gRPC sets an IO timeout on the connection before calling ServerHandshake, so we don't have to handle
+// timeouts separately.
+func (s ServerHandshaker) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := s.TransportCredentials.ServerHandshake(conn)
+ if err != nil {
+ return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
+ }
+
+ peeked, err := ioutil.ReadAll(io.LimitReader(conn, int64(len(magicBytes))))
+ if err != nil {
+ return nil, nil, fmt.Errorf("peek network stream: %w", err)
+ }
+
+ if !bytes.Equal(peeked, magicBytes) {
+ // If the client connection is not multiplexed, restore the peeked bytes back into the stream.
+ // We also set a 0 peer ID in the authInfo to indicate that the server handshake was attempted
+ // but this was not a multiplexed connection.
+ return &restoredConn{
+ Conn: conn,
+ reader: io.MultiReader(bytes.NewReader(peeked), conn),
+ }, authInfo, nil
+ }
+
+ // It is not necessary to clean up any of the multiplexing-related sessions on errors as the
+ // gRPC server closes the conn if there is an error, which closes the multiplexing
+ // session as well.
+
+ logger := s.logger.WriterLevel(logrus.ErrorLevel)
+
+ // Open the server side of the multiplexing session.
+ muxSession, err := yamux.Server(conn, muxConfig(logger))
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("create multiplexing session: %w", err)
+ }
+
+ // Accept the client's stream. This is the client's gRPC session to the server.
+ clientToServerStream, err := muxSession.Accept()
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("accept client's stream: %w", err)
+ }
+
+ // The address does not actually matter but we set it so clientConn.Target returns a meaningful value.
+ // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured.
+ backchannelConn, err := grpc.Dial(
+ "multiplexed/"+conn.RemoteAddr().String(),
+ grpc.WithInsecure(),
+ grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }),
+ )
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("dial backchannel: %w", err)
+ }
+
+ id := s.registry.RegisterBackchannel(backchannelConn)
+ // The returned connection must close the underlying network connection, we redirect the close
+ // to the muxSession which also closes the underlying connection.
+ return connCloser{
+ Conn: clientToServerStream,
+ close: func() error {
+ s.registry.RemoveBackchannel(id)
+ backchannelConn.Close()
+ muxSession.Close()
+ logger.Close()
+ return nil
+ },
+ },
+ authInfoWrapper{
+ id: id,
+ AuthInfo: authInfo,
+ }, nil
+}