Welcome to mirror list, hosted at ThFree Co, Russian Federation.

gitlab.com/gitlab-org/gitaly.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSami Hiltunen <shiltunen@gitlab.com>2021-03-17 14:11:31 +0300
committerSami Hiltunen <shiltunen@gitlab.com>2021-03-29 14:32:55 +0300
commit0aa12d0e309cd7eccb95b2a228078c4c4fc9b7a8 (patch)
tree8fa69088f999ef787a29e82e1b9fd9bae253b723
parentdec655a6dbf2a93b09ba738f94f74132e5bdcc7f (diff)
add backchannel package for bidirectional gRPC invocations
When Gitaly needs to cast a vote for a reference transaction, it currently relies on Praefect to provide address information and tokens needed to dial to the Praefect for voting. This can be problematic as Praefect's listening socket may not be reachable by the Gitaly. Configuring TLS provides additional challenges, as Gitaly would need to have the certificates to identify the Praefect. Praefect already establishes a connection to Gitaly, so we can avoid these problems by piggybacking Gitaly's votes through that same connection. This piggybacking could be implemented as persistent bidi stream that Praefect calls Gitaly with. However, this leaves a lot of the complexities managed by gRPC up to us, such as request routing, concurrency handling, cancellations and request packing. In order to leverage gRPC for the above, we can instead multiplex the network connection established by Praefect in order to run multiple gRPC/HTTP2 sessions in it. This allows for dialing from Gitaly to a gRPC server running on Praefect's end of the established connection. This allows us to rely on simply using gRPC in the usual fashion to communicate with the transaction service running in Praefect. The implementation has components in the client and the server. The client implements a function, a ServerFactory, that returns a server that should serve on the client's end of the connection. It implements a ClientHanshake which multiplexes the connection and starts the backchannel server on it. The gRPC ClientConn itself looks like normal ClientConn to the rest of the code. The server implements most of the functionality in gRPC's ServerHandshake. When a connection is established, the server peeks the stream to see if the client has indicated it supports multiplexing. Client indicates this by sending magic bytes to the server after establishing the connection. If the client is not multiplexing aware, the server handles the connection as it does usually. If the client is multiplexing aware, the server establishes a multiplexing session, dials the client's backchannel server and starts serving the connection as usual. The ServerHandshake injects an identifier which can be accessed through the RPC handler's context. The ID can be used to retrieve the peer's backchannel connection when it is needed, namely when the votes need to be cast. Connections are uniquely identified by an incrementing counter, so no ID can ever refer to the wrong peer.
-rw-r--r--NOTICE3995
-rw-r--r--go.mod1
-rw-r--r--go.sum8
-rw-r--r--internal/backchannel/backchannel.go60
-rw-r--r--internal/backchannel/backchannel_example_test.go139
-rw-r--r--internal/backchannel/backchannel_test.go223
-rw-r--r--internal/backchannel/client.go125
-rw-r--r--internal/backchannel/insecure.go40
-rw-r--r--internal/backchannel/registry.go51
-rw-r--r--internal/backchannel/server.go146
10 files changed, 4782 insertions, 6 deletions
diff --git a/NOTICE b/NOTICE
index ddd6bcfb9..261bb9fbd 100644
--- a/NOTICE
+++ b/NOTICE
@@ -5203,6 +5203,4001 @@ func TestLRU_Resize(t *testing.T) {
}
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.gitignore - github.com/hashicorp/yamux
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+*.test
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+LICENSE - github.com/hashicorp/yamux
+Mozilla Public License, version 2.0
+
+1. Definitions
+
+1.1. &#34;Contributor&#34;
+
+ means each individual or legal entity that creates, contributes to the
+ creation of, or owns Covered Software.
+
+1.2. &#34;Contributor Version&#34;
+
+ means the combination of the Contributions of others (if any) used by a
+ Contributor and that particular Contributor&#39;s Contribution.
+
+1.3. &#34;Contribution&#34;
+
+ means Covered Software of a particular Contributor.
+
+1.4. &#34;Covered Software&#34;
+
+ means Source Code Form to which the initial Contributor has attached the
+ notice in Exhibit A, the Executable Form of such Source Code Form, and
+ Modifications of such Source Code Form, in each case including portions
+ thereof.
+
+1.5. &#34;Incompatible With Secondary Licenses&#34;
+ means
+
+ a. that the initial Contributor has attached the notice described in
+ Exhibit B to the Covered Software; or
+
+ b. that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the terms of
+ a Secondary License.
+
+1.6. &#34;Executable Form&#34;
+
+ means any form of the work other than Source Code Form.
+
+1.7. &#34;Larger Work&#34;
+
+ means a work that combines Covered Software with other material, in a
+ separate file or files, that is not Covered Software.
+
+1.8. &#34;License&#34;
+
+ means this document.
+
+1.9. &#34;Licensable&#34;
+
+ means having the right to grant, to the maximum extent possible, whether
+ at the time of the initial grant or subsequently, any and all of the
+ rights conveyed by this License.
+
+1.10. &#34;Modifications&#34;
+
+ means any of the following:
+
+ a. any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered Software; or
+
+ b. any new file in Source Code Form that contains any Covered Software.
+
+1.11. &#34;Patent Claims&#34; of a Contributor
+
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the License,
+ by the making, using, selling, offering for sale, having made, import,
+ or transfer of either its Contributions or its Contributor Version.
+
+1.12. &#34;Secondary License&#34;
+
+ means either the GNU General Public License, Version 2.0, the GNU Lesser
+ General Public License, Version 2.1, the GNU Affero General Public
+ License, Version 3.0, or any later versions of those licenses.
+
+1.13. &#34;Source Code Form&#34;
+
+ means the form of the work preferred for making modifications.
+
+1.14. &#34;You&#34; (or &#34;Your&#34;)
+
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, &#34;You&#34; includes any entity that controls, is
+ controlled by, or is under common control with You. For purposes of this
+ definition, &#34;control&#34; means (a) the power, direct or indirect, to cause
+ the direction or management of such entity, whether by contract or
+ otherwise, or (b) ownership of more than fifty percent (50%) of the
+ outstanding shares or beneficial ownership of such entity.
+
+
+2. License Grants and Conditions
+
+2.1. Grants
+
+ Each Contributor hereby grants You a world-wide, royalty-free,
+ non-exclusive license:
+
+ a. under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+ b. under Patent Claims of such Contributor to make, use, sell, offer for
+ sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+ The licenses granted in Section 2.1 with respect to any Contribution
+ become effective for each Contribution on the date the Contributor first
+ distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+ The licenses granted in this Section 2 are the only rights granted under
+ this License. No additional rights or licenses will be implied from the
+ distribution or licensing of Covered Software under this License.
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
+ Contributor:
+
+ a. for any code that a Contributor has removed from Covered Software; or
+
+ b. for infringements caused by: (i) Your and any other third party&#39;s
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+ c. under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+ This License does not grant any rights in the trademarks, service marks,
+ or logos of any Contributor (except as may be necessary to comply with
+ the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+ No Contributor makes additional grants as a result of Your choice to
+ distribute the Covered Software under a subsequent version of this
+ License (see Section 10.2) or under the terms of a Secondary License (if
+ permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+ Each Contributor represents that the Contributor believes its
+ Contributions are its original creation(s) or it has sufficient rights to
+ grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+ This License is not intended to limit any rights You have under
+ applicable copyright doctrines of fair use, fair dealing, or other
+ equivalents.
+
+2.7. Conditions
+
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in
+ Section 2.1.
+
+
+3. Responsibilities
+
+3.1. Distribution of Source Form
+
+ All distribution of Covered Software in Source Code Form, including any
+ Modifications that You create or to which You contribute, must be under
+ the terms of this License. You must inform recipients that the Source
+ Code Form of the Covered Software is governed by the terms of this
+ License, and how they can obtain a copy of this License. You may not
+ attempt to alter or restrict the recipients&#39; rights in the Source Code
+ Form.
+
+3.2. Distribution of Executable Form
+
+ If You distribute Covered Software in Executable Form then:
+
+ a. such Covered Software must also be made available in Source Code Form,
+ as described in Section 3.1, and You must inform recipients of the
+ Executable Form how they can obtain a copy of such Source Code Form by
+ reasonable means in a timely manner, at a charge no more than the cost
+ of distribution to the recipient; and
+
+ b. You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter the
+ recipients&#39; rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+ You may create and distribute a Larger Work under terms of Your choice,
+ provided that You also comply with the requirements of this License for
+ the Covered Software. If the Larger Work is a combination of Covered
+ Software with a work governed by one or more Secondary Licenses, and the
+ Covered Software is not Incompatible With Secondary Licenses, this
+ License permits You to additionally distribute such Covered Software
+ under the terms of such Secondary License(s), so that the recipient of
+ the Larger Work may, at their option, further distribute the Covered
+ Software under the terms of either this License or such Secondary
+ License(s).
+
+3.4. Notices
+
+ You may not remove or alter the substance of any license notices
+ (including copyright notices, patent notices, disclaimers of warranty, or
+ limitations of liability) contained within the Source Code Form of the
+ Covered Software, except that You may alter any license notices to the
+ extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+ You may choose to offer, and to charge a fee for, warranty, support,
+ indemnity or liability obligations to one or more recipients of Covered
+ Software. However, You may do so only on Your own behalf, and not on
+ behalf of any Contributor. You must make it absolutely clear that any
+ such warranty, support, indemnity, or liability obligation is offered by
+ You alone, and You hereby agree to indemnify every Contributor for any
+ liability incurred by such Contributor as a result of warranty, support,
+ indemnity or liability terms You offer. You may include additional
+ disclaimers of warranty and limitations of liability specific to any
+ jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+
+ If it is impossible for You to comply with any of the terms of this License
+ with respect to some or all of the Covered Software due to statute,
+ judicial order, or regulation then You must: (a) comply with the terms of
+ this License to the maximum extent possible; and (b) describe the
+ limitations and the code they affect. Such description must be placed in a
+ text file included with all distributions of the Covered Software under
+ this License. Except to the extent prohibited by statute or regulation,
+ such description must be sufficiently detailed for a recipient of ordinary
+ skill to be able to understand it.
+
+5. Termination
+
+5.1. The rights granted under this License will terminate automatically if You
+ fail to comply with any of its terms. However, if You become compliant,
+ then the rights granted under this License from a particular Contributor
+ are reinstated (a) provisionally, unless and until such Contributor
+ explicitly and finally terminates Your grants, and (b) on an ongoing
+ basis, if such Contributor fails to notify You of the non-compliance by
+ some reasonable means prior to 60 days after You have come back into
+ compliance. Moreover, Your grants from a particular Contributor are
+ reinstated on an ongoing basis if such Contributor notifies You of the
+ non-compliance by some reasonable means, this is the first time You have
+ received notice of non-compliance with this License from such
+ Contributor, and You become compliant prior to 30 days after Your receipt
+ of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+ infringement claim (excluding declaratory judgment actions,
+ counter-claims, and cross-claims) alleging that a Contributor Version
+ directly or indirectly infringes any patent, then the rights granted to
+ You by any and all Contributors for the Covered Software under Section
+ 2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user
+ license agreements (excluding distributors and resellers) which have been
+ validly granted by You or Your distributors under this License prior to
+ termination shall survive termination.
+
+6. Disclaimer of Warranty
+
+ Covered Software is provided under this License on an &#34;as is&#34; basis,
+ without warranty of any kind, either expressed, implied, or statutory,
+ including, without limitation, warranties that the Covered Software is free
+ of defects, merchantable, fit for a particular purpose or non-infringing.
+ The entire risk as to the quality and performance of the Covered Software
+ is with You. Should any Covered Software prove defective in any respect,
+ You (not any Contributor) assume the cost of any necessary servicing,
+ repair, or correction. This disclaimer of warranty constitutes an essential
+ part of this License. No use of any Covered Software is authorized under
+ this License except under this disclaimer.
+
+7. Limitation of Liability
+
+ Under no circumstances and under no legal theory, whether tort (including
+ negligence), contract, or otherwise, shall any Contributor, or anyone who
+ distributes Covered Software as permitted above, be liable to You for any
+ direct, indirect, special, incidental, or consequential damages of any
+ character including, without limitation, damages for lost profits, loss of
+ goodwill, work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses, even if such party shall have been
+ informed of the possibility of such damages. This limitation of liability
+ shall not apply to liability for death or personal injury resulting from
+ such party&#39;s negligence to the extent applicable law prohibits such
+ limitation. Some jurisdictions do not allow the exclusion or limitation of
+ incidental or consequential damages, so this exclusion and limitation may
+ not apply to You.
+
+8. Litigation
+
+ Any litigation relating to this License may be brought only in the courts
+ of a jurisdiction where the defendant maintains its principal place of
+ business and such litigation shall be governed by laws of that
+ jurisdiction, without reference to its conflict-of-law provisions. Nothing
+ in this Section shall prevent a party&#39;s ability to bring cross-claims or
+ counter-claims.
+
+9. Miscellaneous
+
+ This License represents the complete agreement concerning the subject
+ matter hereof. If any provision of this License is held to be
+ unenforceable, such provision shall be reformed only to the extent
+ necessary to make it enforceable. Any law or regulation which provides that
+ the language of a contract shall be construed against the drafter shall not
+ be used to construe this License against a Contributor.
+
+
+10. Versions of the License
+
+10.1. New Versions
+
+ Mozilla Foundation is the license steward. Except as provided in Section
+ 10.3, no one other than the license steward has the right to modify or
+ publish new versions of this License. Each version will be given a
+ distinguishing version number.
+
+10.2. Effect of New Versions
+
+ You may distribute the Covered Software under the terms of the version
+ of the License under which You originally received the Covered Software,
+ or under the terms of any subsequent version published by the license
+ steward.
+
+10.3. Modified Versions
+
+ If you create software not governed by this License, and you want to
+ create a new license for such software, you may create and use a
+ modified version of this License if you rename the license and remove
+ any references to the name of the license steward (except to note that
+ such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+ Licenses If You choose to distribute Source Code Form that is
+ Incompatible With Secondary Licenses under the terms of this version of
+ the License, the notice described in Exhibit B of this License must be
+ attached.
+
+Exhibit A - Source Code Form License Notice
+
+ This Source Code Form is subject to the
+ terms of the Mozilla Public License, v.
+ 2.0. If a copy of the MPL was not
+ distributed with this file, You can
+ obtain one at
+ http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular file,
+then You may include the notice in a location (such as a LICENSE file in a
+relevant directory) where a recipient would be likely to look for such a
+notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - &#34;Incompatible With Secondary Licenses&#34; Notice
+
+ This Source Code Form is &#34;Incompatible
+ With Secondary Licenses&#34;, as defined by
+ the Mozilla Public License, v. 2.0.
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+README.md - github.com/hashicorp/yamux
+# Yamux
+
+Yamux (Yet another Multiplexer) is a multiplexing library for Golang.
+It relies on an underlying connection to provide reliability
+and ordering, such as TCP or Unix domain sockets, and provides
+stream-oriented multiplexing. It is inspired by SPDY but is not
+interoperable with it.
+
+Yamux features include:
+
+* Bi-directional streams
+ * Streams can be opened by either client or server
+ * Useful for NAT traversal
+ * Server-side push support
+* Flow control
+ * Avoid starvation
+ * Back-pressure to prevent overwhelming a receiver
+* Keep Alives
+ * Enables persistent connections over a load balancer
+* Efficient
+ * Enables thousands of logical streams with low overhead
+
+## Documentation
+
+For complete documentation, see the associated [Godoc](http://godoc.org/github.com/hashicorp/yamux).
+
+## Specification
+
+The full specification for Yamux is provided in the `spec.md` file.
+It can be used as a guide to implementors of interoperable libraries.
+
+## Usage
+
+Using Yamux is remarkably simple:
+
+```go
+
+func client() {
+ // Get a TCP connection
+ conn, err := net.Dial(...)
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup client side of yamux
+ session, err := yamux.Client(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Open a new stream
+ stream, err := session.Open()
+ if err != nil {
+ panic(err)
+ }
+
+ // Stream implements net.Conn
+ stream.Write([]byte(&#34;ping&#34;))
+}
+
+func server() {
+ // Accept a TCP connection
+ conn, err := listener.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Setup server side of yamux
+ session, err := yamux.Server(conn, nil)
+ if err != nil {
+ panic(err)
+ }
+
+ // Accept a stream
+ stream, err := session.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ // Listen for a message
+ buf := make([]byte, 4)
+ stream.Read(buf)
+}
+
+```
+
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+addr.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;fmt&#34;
+ &#34;net&#34;
+)
+
+// hasAddr is used to get the address from the underlying connection
+type hasAddr interface {
+ LocalAddr() net.Addr
+ RemoteAddr() net.Addr
+}
+
+// yamuxAddr is used when we cannot get the underlying address
+type yamuxAddr struct {
+ Addr string
+}
+
+func (*yamuxAddr) Network() string {
+ return &#34;yamux&#34;
+}
+
+func (y *yamuxAddr) String() string {
+ return fmt.Sprintf(&#34;yamux:%s&#34;, y.Addr)
+}
+
+// Addr is used to get the address of the listener.
+func (s *Session) Addr() net.Addr {
+ return s.LocalAddr()
+}
+
+// LocalAddr is used to get the local address of the
+// underlying connection.
+func (s *Session) LocalAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &amp;yamuxAddr{&#34;local&#34;}
+ }
+ return addr.LocalAddr()
+}
+
+// RemoteAddr is used to get the address of remote end
+// of the underlying connection
+func (s *Session) RemoteAddr() net.Addr {
+ addr, ok := s.conn.(hasAddr)
+ if !ok {
+ return &amp;yamuxAddr{&#34;remote&#34;}
+ }
+ return addr.RemoteAddr()
+}
+
+// LocalAddr returns the local address
+func (s *Stream) LocalAddr() net.Addr {
+ return s.session.LocalAddr()
+}
+
+// RemoteAddr returns the remote address
+func (s *Stream) RemoteAddr() net.Addr {
+ return s.session.RemoteAddr()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+bench_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;io&#34;
+ &#34;io/ioutil&#34;
+ &#34;testing&#34;
+)
+
+func BenchmarkPing(b *testing.B) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ rtt, err := client.Ping()
+ if err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if rtt == 0 {
+ b.Fatalf(&#34;bad: %v&#34;, rtt)
+ }
+ }
+}
+
+func BenchmarkAccept(b *testing.B) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ doneCh := make(chan struct{})
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ go func() {
+ defer close(doneCh)
+
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ stream.Close()
+ }
+ }()
+
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ stream.Close()
+ }
+ &lt;-doneCh
+}
+
+func BenchmarkSendRecv32(b *testing.B) {
+ const payloadSize = 32
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv64(b *testing.B) {
+ const payloadSize = 64
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv128(b *testing.B) {
+ const payloadSize = 128
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv256(b *testing.B) {
+ const payloadSize = 256
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv512(b *testing.B) {
+ const payloadSize = 512
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv1024(b *testing.B) {
+ const payloadSize = 1024
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv2048(b *testing.B) {
+ const payloadSize = 2048
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecv4096(b *testing.B) {
+ const payloadSize = 4096
+ benchmarkSendRecv(b, payloadSize, payloadSize)
+}
+
+func BenchmarkSendRecvLarge(b *testing.B) {
+ const sendSize = 512 * 1024 * 1024 //512 MB
+ const recvSize = 4 * 1024 //4 KB
+ benchmarkSendRecv(b, sendSize, recvSize)
+}
+
+func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ sendBuf := make([]byte, sendSize)
+ recvBuf := make([]byte, recvSize)
+ doneCh := make(chan struct{})
+
+ b.SetBytes(int64(sendSize))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ go func() {
+ defer close(doneCh)
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+
+ switch {
+ case sendSize == recvSize:
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+
+ case recvSize &gt; sendSize:
+ b.Fatalf(&#34;bad test case; recvSize was: %d and sendSize was: %d, but recvSize must be &lt;= sendSize!&#34;, recvSize, sendSize)
+
+ default:
+ chunks := sendSize / recvSize
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ for j := 0; j &lt; chunks; j&#43;&#43; {
+ if _, err := stream.Read(recvBuf); err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ }
+ }
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ for i := 0; i &lt; b.N; i&#43;&#43; {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ &lt;-doneCh
+}
+
+func BenchmarkSendRecvParallel32(b *testing.B) {
+ const payloadSize = 32
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel64(b *testing.B) {
+ const payloadSize = 64
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel128(b *testing.B) {
+ const payloadSize = 128
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel256(b *testing.B) {
+ const payloadSize = 256
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel512(b *testing.B) {
+ const payloadSize = 512
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel1024(b *testing.B) {
+ const payloadSize = 1024
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel2048(b *testing.B) {
+ const payloadSize = 2048
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func BenchmarkSendRecvParallel4096(b *testing.B) {
+ const payloadSize = 4096
+ benchmarkSendRecvParallel(b, payloadSize)
+}
+
+func benchmarkSendRecvParallel(b *testing.B, sendSize int) {
+ client, server := testClientServer()
+ defer func() {
+ client.Close()
+ server.Close()
+ }()
+
+ sendBuf := make([]byte, sendSize)
+ discarder := ioutil.Discard.(io.ReaderFrom)
+ b.SetBytes(int64(sendSize))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ doneCh := make(chan struct{})
+
+ go func() {
+ defer close(doneCh)
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ return
+ }
+ defer stream.Close()
+
+ if _, err := discarder.ReadFrom(stream); err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ stream, err := client.Open()
+ if err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ for pb.Next() {
+ if _, err := stream.Write(sendBuf); err != nil {
+ b.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+
+ stream.Close()
+ &lt;-doneCh
+ })
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+const.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;encoding/binary&#34;
+ &#34;fmt&#34;
+)
+
+var (
+ // ErrInvalidVersion means we received a frame with an
+ // invalid version
+ ErrInvalidVersion = fmt.Errorf(&#34;invalid protocol version&#34;)
+
+ // ErrInvalidMsgType means we received a frame with an
+ // invalid message type
+ ErrInvalidMsgType = fmt.Errorf(&#34;invalid msg type&#34;)
+
+ // ErrSessionShutdown is used if there is a shutdown during
+ // an operation
+ ErrSessionShutdown = fmt.Errorf(&#34;session shutdown&#34;)
+
+ // ErrStreamsExhausted is returned if we have no more
+ // stream ids to issue
+ ErrStreamsExhausted = fmt.Errorf(&#34;streams exhausted&#34;)
+
+ // ErrDuplicateStream is used if a duplicate stream is
+ // opened inbound
+ ErrDuplicateStream = fmt.Errorf(&#34;duplicate stream initiated&#34;)
+
+ // ErrReceiveWindowExceeded indicates the window was exceeded
+ ErrRecvWindowExceeded = fmt.Errorf(&#34;recv window exceeded&#34;)
+
+ // ErrTimeout is used when we reach an IO deadline
+ ErrTimeout = fmt.Errorf(&#34;i/o deadline reached&#34;)
+
+ // ErrStreamClosed is returned when using a closed stream
+ ErrStreamClosed = fmt.Errorf(&#34;stream closed&#34;)
+
+ // ErrUnexpectedFlag is set when we get an unexpected flag
+ ErrUnexpectedFlag = fmt.Errorf(&#34;unexpected flag&#34;)
+
+ // ErrRemoteGoAway is used when we get a go away from the other side
+ ErrRemoteGoAway = fmt.Errorf(&#34;remote end is not accepting connections&#34;)
+
+ // ErrConnectionReset is sent if a stream is reset. This can happen
+ // if the backlog is exceeded, or if there was a remote GoAway.
+ ErrConnectionReset = fmt.Errorf(&#34;connection reset&#34;)
+
+ // ErrConnectionWriteTimeout indicates that we hit the &#34;safety valve&#34;
+ // timeout writing to the underlying stream connection.
+ ErrConnectionWriteTimeout = fmt.Errorf(&#34;connection write timeout&#34;)
+
+ // ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
+ ErrKeepAliveTimeout = fmt.Errorf(&#34;keepalive timeout&#34;)
+)
+
+const (
+ // protoVersion is the only version we support
+ protoVersion uint8 = 0
+)
+
+const (
+ // Data is used for data frames. They are followed
+ // by length bytes worth of payload.
+ typeData uint8 = iota
+
+ // WindowUpdate is used to change the window of
+ // a given stream. The length indicates the delta
+ // update to the window.
+ typeWindowUpdate
+
+ // Ping is sent as a keep-alive or to measure
+ // the RTT. The StreamID and Length value are echoed
+ // back in the response.
+ typePing
+
+ // GoAway is sent to terminate a session. The StreamID
+ // should be 0 and the length is an error code.
+ typeGoAway
+)
+
+const (
+ // SYN is sent to signal a new stream. May
+ // be sent with a data payload
+ flagSYN uint16 = 1 &lt;&lt; iota
+
+ // ACK is sent to acknowledge a new stream. May
+ // be sent with a data payload
+ flagACK
+
+ // FIN is sent to half-close the given stream.
+ // May be sent with a data payload.
+ flagFIN
+
+ // RST is used to hard close a given stream.
+ flagRST
+)
+
+const (
+ // initialStreamWindow is the initial stream window size
+ initialStreamWindow uint32 = 256 * 1024
+)
+
+const (
+ // goAwayNormal is sent on a normal termination
+ goAwayNormal uint32 = iota
+
+ // goAwayProtoErr sent on a protocol error
+ goAwayProtoErr
+
+ // goAwayInternalErr sent on an internal error
+ goAwayInternalErr
+)
+
+const (
+ sizeOfVersion = 1
+ sizeOfType = 1
+ sizeOfFlags = 2
+ sizeOfStreamID = 4
+ sizeOfLength = 4
+ headerSize = sizeOfVersion &#43; sizeOfType &#43; sizeOfFlags &#43;
+ sizeOfStreamID &#43; sizeOfLength
+)
+
+type header []byte
+
+func (h header) Version() uint8 {
+ return h[0]
+}
+
+func (h header) MsgType() uint8 {
+ return h[1]
+}
+
+func (h header) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[2:4])
+}
+
+func (h header) StreamID() uint32 {
+ return binary.BigEndian.Uint32(h[4:8])
+}
+
+func (h header) Length() uint32 {
+ return binary.BigEndian.Uint32(h[8:12])
+}
+
+func (h header) String() string {
+ return fmt.Sprintf(&#34;Vsn:%d Type:%d Flags:%d StreamID:%d Length:%d&#34;,
+ h.Version(), h.MsgType(), h.Flags(), h.StreamID(), h.Length())
+}
+
+func (h header) encode(msgType uint8, flags uint16, streamID uint32, length uint32) {
+ h[0] = protoVersion
+ h[1] = msgType
+ binary.BigEndian.PutUint16(h[2:4], flags)
+ binary.BigEndian.PutUint32(h[4:8], streamID)
+ binary.BigEndian.PutUint32(h[8:12], length)
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+const_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;testing&#34;
+)
+
+func TestConst(t *testing.T) {
+ if protoVersion != 0 {
+ t.Fatalf(&#34;bad: %v&#34;, protoVersion)
+ }
+
+ if typeData != 0 {
+ t.Fatalf(&#34;bad: %v&#34;, typeData)
+ }
+ if typeWindowUpdate != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, typeWindowUpdate)
+ }
+ if typePing != 2 {
+ t.Fatalf(&#34;bad: %v&#34;, typePing)
+ }
+ if typeGoAway != 3 {
+ t.Fatalf(&#34;bad: %v&#34;, typeGoAway)
+ }
+
+ if flagSYN != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, flagSYN)
+ }
+ if flagACK != 2 {
+ t.Fatalf(&#34;bad: %v&#34;, flagACK)
+ }
+ if flagFIN != 4 {
+ t.Fatalf(&#34;bad: %v&#34;, flagFIN)
+ }
+ if flagRST != 8 {
+ t.Fatalf(&#34;bad: %v&#34;, flagRST)
+ }
+
+ if goAwayNormal != 0 {
+ t.Fatalf(&#34;bad: %v&#34;, goAwayNormal)
+ }
+ if goAwayProtoErr != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, goAwayProtoErr)
+ }
+ if goAwayInternalErr != 2 {
+ t.Fatalf(&#34;bad: %v&#34;, goAwayInternalErr)
+ }
+
+ if headerSize != 12 {
+ t.Fatalf(&#34;bad header size&#34;)
+ }
+}
+
+func TestEncodeDecode(t *testing.T) {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagACK|flagRST, 1234, 4321)
+
+ if hdr.Version() != protoVersion {
+ t.Fatalf(&#34;bad: %v&#34;, hdr)
+ }
+ if hdr.MsgType() != typeWindowUpdate {
+ t.Fatalf(&#34;bad: %v&#34;, hdr)
+ }
+ if hdr.Flags() != flagACK|flagRST {
+ t.Fatalf(&#34;bad: %v&#34;, hdr)
+ }
+ if hdr.StreamID() != 1234 {
+ t.Fatalf(&#34;bad: %v&#34;, hdr)
+ }
+ if hdr.Length() != 4321 {
+ t.Fatalf(&#34;bad: %v&#34;, hdr)
+ }
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+go.mod - github.com/hashicorp/yamux
+module github.com/hashicorp/yamux
+
+go 1.15
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+mux.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;fmt&#34;
+ &#34;io&#34;
+ &#34;log&#34;
+ &#34;os&#34;
+ &#34;time&#34;
+)
+
+// Config is used to tune the Yamux session
+type Config struct {
+ // AcceptBacklog is used to limit how many streams may be
+ // waiting an accept.
+ AcceptBacklog int
+
+ // EnableKeepalive is used to do a period keep alive
+ // messages using a ping.
+ EnableKeepAlive bool
+
+ // KeepAliveInterval is how often to perform the keep alive
+ KeepAliveInterval time.Duration
+
+ // ConnectionWriteTimeout is meant to be a &#34;safety valve&#34; timeout after
+ // we which will suspect a problem with the underlying connection and
+ // close it. This is only applied to writes, where&#39;s there&#39;s generally
+ // an expectation that things will move along quickly.
+ ConnectionWriteTimeout time.Duration
+
+ // MaxStreamWindowSize is used to control the maximum
+ // window size that we allow for a stream.
+ MaxStreamWindowSize uint32
+
+ // StreamCloseTimeout is the maximum time that a stream will allowed to
+ // be in a half-closed state when `Close` is called before forcibly
+ // closing the connection. Forcibly closed connections will empty the
+ // receive buffer, drop any future packets received for that stream,
+ // and send a RST to the remote side.
+ StreamCloseTimeout time.Duration
+
+ // LogOutput is used to control the log destination. Either Logger or
+ // LogOutput can be set, not both.
+ LogOutput io.Writer
+
+ // Logger is used to pass in the logger to be used. Either Logger or
+ // LogOutput can be set, not both.
+ Logger *log.Logger
+}
+
+// DefaultConfig is used to return a default configuration
+func DefaultConfig() *Config {
+ return &amp;Config{
+ AcceptBacklog: 256,
+ EnableKeepAlive: true,
+ KeepAliveInterval: 30 * time.Second,
+ ConnectionWriteTimeout: 10 * time.Second,
+ MaxStreamWindowSize: initialStreamWindow,
+ StreamCloseTimeout: 5 * time.Minute,
+ LogOutput: os.Stderr,
+ }
+}
+
+// VerifyConfig is used to verify the sanity of configuration
+func VerifyConfig(config *Config) error {
+ if config.AcceptBacklog &lt;= 0 {
+ return fmt.Errorf(&#34;backlog must be positive&#34;)
+ }
+ if config.KeepAliveInterval == 0 {
+ return fmt.Errorf(&#34;keep-alive interval must be positive&#34;)
+ }
+ if config.MaxStreamWindowSize &lt; initialStreamWindow {
+ return fmt.Errorf(&#34;MaxStreamWindowSize must be larger than %d&#34;, initialStreamWindow)
+ }
+ if config.LogOutput != nil &amp;&amp; config.Logger != nil {
+ return fmt.Errorf(&#34;both Logger and LogOutput may not be set, select one&#34;)
+ } else if config.LogOutput == nil &amp;&amp; config.Logger == nil {
+ return fmt.Errorf(&#34;one of Logger or LogOutput must be set, select one&#34;)
+ }
+ return nil
+}
+
+// Server is used to initialize a new server-side connection.
+// There must be at most one server-side connection. If a nil config is
+// provided, the DefaultConfiguration will be used.
+func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, false), nil
+}
+
+// Client is used to initialize a new client-side connection.
+// There must be at most one client-side connection.
+func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, true), nil
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+session.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;bufio&#34;
+ &#34;fmt&#34;
+ &#34;io&#34;
+ &#34;io/ioutil&#34;
+ &#34;log&#34;
+ &#34;math&#34;
+ &#34;net&#34;
+ &#34;strings&#34;
+ &#34;sync&#34;
+ &#34;sync/atomic&#34;
+ &#34;time&#34;
+)
+
+// Session is used to wrap a reliable ordered connection and to
+// multiplex it into multiple streams.
+type Session struct {
+ // remoteGoAway indicates the remote side does
+ // not want futher connections. Must be first for alignment.
+ remoteGoAway int32
+
+ // localGoAway indicates that we should stop
+ // accepting futher connections. Must be first for alignment.
+ localGoAway int32
+
+ // nextStreamID is the next stream we should
+ // send. This depends if we are a client/server.
+ nextStreamID uint32
+
+ // config holds our configuration
+ config *Config
+
+ // logger is used for our logs
+ logger *log.Logger
+
+ // conn is the underlying connection
+ conn io.ReadWriteCloser
+
+ // bufRead is a buffered reader
+ bufRead *bufio.Reader
+
+ // pings is used to track inflight pings
+ pings map[uint32]chan struct{}
+ pingID uint32
+ pingLock sync.Mutex
+
+ // streams maps a stream id to a stream, and inflight has an entry
+ // for any outgoing stream that has not yet been established. Both are
+ // protected by streamLock.
+ streams map[uint32]*Stream
+ inflight map[uint32]struct{}
+ streamLock sync.Mutex
+
+ // synCh acts like a semaphore. It is sized to the AcceptBacklog which
+ // is assumed to be symmetric between the client and server. This allows
+ // the client to avoid exceeding the backlog and instead blocks the open.
+ synCh chan struct{}
+
+ // acceptCh is used to pass ready streams to the client
+ acceptCh chan *Stream
+
+ // sendCh is used to mark a stream as ready to send,
+ // or to send a header out directly.
+ sendCh chan sendReady
+
+ // recvDoneCh is closed when recv() exits to avoid a race
+ // between stream registration and stream shutdown
+ recvDoneCh chan struct{}
+
+ // shutdown is used to safely close a session
+ shutdown bool
+ shutdownErr error
+ shutdownCh chan struct{}
+ shutdownLock sync.Mutex
+}
+
+// sendReady is used to either mark a stream as ready
+// or to directly send a header
+type sendReady struct {
+ Hdr []byte
+ Body io.Reader
+ Err chan error
+}
+
+// newSession is used to construct a new session
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+ logger := config.Logger
+ if logger == nil {
+ logger = log.New(config.LogOutput, &#34;&#34;, log.LstdFlags)
+ }
+
+ s := &amp;Session{
+ config: config,
+ logger: logger,
+ conn: conn,
+ bufRead: bufio.NewReader(conn),
+ pings: make(map[uint32]chan struct{}),
+ streams: make(map[uint32]*Stream),
+ inflight: make(map[uint32]struct{}),
+ synCh: make(chan struct{}, config.AcceptBacklog),
+ acceptCh: make(chan *Stream, config.AcceptBacklog),
+ sendCh: make(chan sendReady, 64),
+ recvDoneCh: make(chan struct{}),
+ shutdownCh: make(chan struct{}),
+ }
+ if client {
+ s.nextStreamID = 1
+ } else {
+ s.nextStreamID = 2
+ }
+ go s.recv()
+ go s.send()
+ if config.EnableKeepAlive {
+ go s.keepalive()
+ }
+ return s
+}
+
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
+ select {
+ case &lt;-s.shutdownCh:
+ return true
+ default:
+ return false
+ }
+}
+
+// CloseChan returns a read-only channel which is closed as
+// soon as the session is closed.
+func (s *Session) CloseChan() &lt;-chan struct{} {
+ return s.shutdownCh
+}
+
+// NumStreams returns the number of currently open streams
+func (s *Session) NumStreams() int {
+ s.streamLock.Lock()
+ num := len(s.streams)
+ s.streamLock.Unlock()
+ return num
+}
+
+// Open is used to create a new stream as a net.Conn
+func (s *Session) Open() (net.Conn, error) {
+ conn, err := s.OpenStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+// OpenStream is used to create a new stream
+func (s *Session) OpenStream() (*Stream, error) {
+ if s.IsClosed() {
+ return nil, ErrSessionShutdown
+ }
+ if atomic.LoadInt32(&amp;s.remoteGoAway) == 1 {
+ return nil, ErrRemoteGoAway
+ }
+
+ // Block if we have too many inflight SYNs
+ select {
+ case s.synCh &lt;- struct{}{}:
+ case &lt;-s.shutdownCh:
+ return nil, ErrSessionShutdown
+ }
+
+GET_ID:
+ // Get an ID, and check for stream exhaustion
+ id := atomic.LoadUint32(&amp;s.nextStreamID)
+ if id &gt;= math.MaxUint32-1 {
+ return nil, ErrStreamsExhausted
+ }
+ if !atomic.CompareAndSwapUint32(&amp;s.nextStreamID, id, id&#43;2) {
+ goto GET_ID
+ }
+
+ // Register the stream
+ stream := newStream(s, id, streamInit)
+ s.streamLock.Lock()
+ s.streams[id] = stream
+ s.inflight[id] = struct{}{}
+ s.streamLock.Unlock()
+
+ // Send the window update to create
+ if err := stream.sendWindowUpdate(); err != nil {
+ select {
+ case &lt;-s.synCh:
+ default:
+ s.logger.Printf(&#34;[ERR] yamux: aborted stream open without inflight syn semaphore&#34;)
+ }
+ return nil, err
+ }
+ return stream, nil
+}
+
+// Accept is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) Accept() (net.Conn, error) {
+ conn, err := s.AcceptStream()
+ if err != nil {
+ return nil, err
+ }
+ return conn, err
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+ select {
+ case stream := &lt;-s.acceptCh:
+ if err := stream.sendWindowUpdate(); err != nil {
+ return nil, err
+ }
+ return stream, nil
+ case &lt;-s.shutdownCh:
+ return nil, s.shutdownErr
+ }
+}
+
+// Close is used to close the session and all streams.
+// Attempts to send a GoAway before closing the connection.
+func (s *Session) Close() error {
+ s.shutdownLock.Lock()
+ defer s.shutdownLock.Unlock()
+
+ if s.shutdown {
+ return nil
+ }
+ s.shutdown = true
+ if s.shutdownErr == nil {
+ s.shutdownErr = ErrSessionShutdown
+ }
+ close(s.shutdownCh)
+ s.conn.Close()
+ &lt;-s.recvDoneCh
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ for _, stream := range s.streams {
+ stream.forceClose()
+ }
+ return nil
+}
+
+// exitErr is used to handle an error that is causing the
+// session to terminate.
+func (s *Session) exitErr(err error) {
+ s.shutdownLock.Lock()
+ if s.shutdownErr == nil {
+ s.shutdownErr = err
+ }
+ s.shutdownLock.Unlock()
+ s.Close()
+}
+
+// GoAway can be used to prevent accepting further
+// connections. It does not close the underlying conn.
+func (s *Session) GoAway() error {
+ return s.waitForSend(s.goAway(goAwayNormal), nil)
+}
+
+// goAway is used to send a goAway message
+func (s *Session) goAway(reason uint32) header {
+ atomic.SwapInt32(&amp;s.localGoAway, 1)
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeGoAway, 0, 0, reason)
+ return hdr
+}
+
+// Ping is used to measure the RTT response time
+func (s *Session) Ping() (time.Duration, error) {
+ // Get a channel for the ping
+ ch := make(chan struct{})
+
+ // Get a new ping id, mark as pending
+ s.pingLock.Lock()
+ id := s.pingID
+ s.pingID&#43;&#43;
+ s.pings[id] = ch
+ s.pingLock.Unlock()
+
+ // Send the ping request
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagSYN, 0, id)
+ if err := s.waitForSend(hdr, nil); err != nil {
+ return 0, err
+ }
+
+ // Wait for a response
+ start := time.Now()
+ select {
+ case &lt;-ch:
+ case &lt;-time.After(s.config.ConnectionWriteTimeout):
+ s.pingLock.Lock()
+ delete(s.pings, id) // Ignore it if a response comes later.
+ s.pingLock.Unlock()
+ return 0, ErrTimeout
+ case &lt;-s.shutdownCh:
+ return 0, ErrSessionShutdown
+ }
+
+ // Compute the RTT
+ return time.Now().Sub(start), nil
+}
+
+// keepalive is a long running goroutine that periodically does
+// a ping to keep the connection alive.
+func (s *Session) keepalive() {
+ for {
+ select {
+ case &lt;-time.After(s.config.KeepAliveInterval):
+ _, err := s.Ping()
+ if err != nil {
+ if err != ErrSessionShutdown {
+ s.logger.Printf(&#34;[ERR] yamux: keepalive failed: %v&#34;, err)
+ s.exitErr(ErrKeepAliveTimeout)
+ }
+ return
+ }
+ case &lt;-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// waitForSendErr waits to send a header, checking for a potential shutdown
+func (s *Session) waitForSend(hdr header, body io.Reader) error {
+ errCh := make(chan error, 1)
+ return s.waitForSendErr(hdr, body, errCh)
+}
+
+// waitForSendErr waits to send a header with optional data, checking for a
+// potential shutdown. Since there&#39;s the expectation that sends can happen
+// in a timely manner, we enforce the connection write timeout here.
+func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case &lt;-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
+ select {
+ case s.sendCh &lt;- ready:
+ case &lt;-s.shutdownCh:
+ return ErrSessionShutdown
+ case &lt;-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+
+ select {
+ case err := &lt;-errCh:
+ return err
+ case &lt;-s.shutdownCh:
+ return ErrSessionShutdown
+ case &lt;-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// sendNoWait does a send without waiting. Since there&#39;s the expectation that
+// the send happens right here, we enforce the connection write timeout if we
+// can&#39;t queue the header to be sent.
+func (s *Session) sendNoWait(hdr header) error {
+ t := timerPool.Get()
+ timer := t.(*time.Timer)
+ timer.Reset(s.config.ConnectionWriteTimeout)
+ defer func() {
+ timer.Stop()
+ select {
+ case &lt;-timer.C:
+ default:
+ }
+ timerPool.Put(t)
+ }()
+
+ select {
+ case s.sendCh &lt;- sendReady{Hdr: hdr}:
+ return nil
+ case &lt;-s.shutdownCh:
+ return ErrSessionShutdown
+ case &lt;-timer.C:
+ return ErrConnectionWriteTimeout
+ }
+}
+
+// send is a long running goroutine that sends data
+func (s *Session) send() {
+ for {
+ select {
+ case ready := &lt;-s.sendCh:
+ // Send a header if ready
+ if ready.Hdr != nil {
+ sent := 0
+ for sent &lt; len(ready.Hdr) {
+ n, err := s.conn.Write(ready.Hdr[sent:])
+ if err != nil {
+ s.logger.Printf(&#34;[ERR] yamux: Failed to write header: %v&#34;, err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ sent &#43;= n
+ }
+ }
+
+ // Send data from a body if given
+ if ready.Body != nil {
+ _, err := io.Copy(s.conn, ready.Body)
+ if err != nil {
+ s.logger.Printf(&#34;[ERR] yamux: Failed to write body: %v&#34;, err)
+ asyncSendErr(ready.Err, err)
+ s.exitErr(err)
+ return
+ }
+ }
+
+ // No error, successful send
+ asyncSendErr(ready.Err, nil)
+ case &lt;-s.shutdownCh:
+ return
+ }
+ }
+}
+
+// recv is a long running goroutine that accepts new data
+func (s *Session) recv() {
+ if err := s.recvLoop(); err != nil {
+ s.exitErr(err)
+ }
+}
+
+// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
+var (
+ handlers = []func(*Session, header) error{
+ typeData: (*Session).handleStreamMessage,
+ typeWindowUpdate: (*Session).handleStreamMessage,
+ typePing: (*Session).handlePing,
+ typeGoAway: (*Session).handleGoAway,
+ }
+)
+
+// recvLoop continues to receive data until a fatal error is encountered
+func (s *Session) recvLoop() error {
+ defer close(s.recvDoneCh)
+ hdr := header(make([]byte, headerSize))
+ for {
+ // Read the header
+ if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
+ if err != io.EOF &amp;&amp; !strings.Contains(err.Error(), &#34;closed&#34;) &amp;&amp; !strings.Contains(err.Error(), &#34;reset by peer&#34;) {
+ s.logger.Printf(&#34;[ERR] yamux: Failed to read header: %v&#34;, err)
+ }
+ return err
+ }
+
+ // Verify the version
+ if hdr.Version() != protoVersion {
+ s.logger.Printf(&#34;[ERR] yamux: Invalid protocol version: %d&#34;, hdr.Version())
+ return ErrInvalidVersion
+ }
+
+ mt := hdr.MsgType()
+ if mt &lt; typeData || mt &gt; typeGoAway {
+ return ErrInvalidMsgType
+ }
+
+ if err := handlers[mt](s, hdr); err != nil {
+ return err
+ }
+ }
+}
+
+// handleStreamMessage handles either a data or window update frame
+func (s *Session) handleStreamMessage(hdr header) error {
+ // Check for a new stream creation
+ id := hdr.StreamID()
+ flags := hdr.Flags()
+ if flags&amp;flagSYN == flagSYN {
+ if err := s.incomingStream(id); err != nil {
+ return err
+ }
+ }
+
+ // Get the stream
+ s.streamLock.Lock()
+ stream := s.streams[id]
+ s.streamLock.Unlock()
+
+ // If we do not have a stream, likely we sent a RST
+ if stream == nil {
+ // Drain any data on the wire
+ if hdr.MsgType() == typeData &amp;&amp; hdr.Length() &gt; 0 {
+ s.logger.Printf(&#34;[WARN] yamux: Discarding data for stream: %d&#34;, id)
+ if _, err := io.CopyN(ioutil.Discard, s.bufRead, int64(hdr.Length())); err != nil {
+ s.logger.Printf(&#34;[ERR] yamux: Failed to discard data: %v&#34;, err)
+ return nil
+ }
+ } else {
+ s.logger.Printf(&#34;[WARN] yamux: frame for missing stream: %v&#34;, hdr)
+ }
+ return nil
+ }
+
+ // Check if this is a window update
+ if hdr.MsgType() == typeWindowUpdate {
+ if err := stream.incrSendWindow(hdr, flags); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf(&#34;[WARN] yamux: failed to send go away: %v&#34;, sendErr)
+ }
+ return err
+ }
+ return nil
+ }
+
+ // Read the new data
+ if err := stream.readData(hdr, flags, s.bufRead); err != nil {
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf(&#34;[WARN] yamux: failed to send go away: %v&#34;, sendErr)
+ }
+ return err
+ }
+ return nil
+}
+
+// handlePing is invokde for a typePing frame
+func (s *Session) handlePing(hdr header) error {
+ flags := hdr.Flags()
+ pingID := hdr.Length()
+
+ // Check if this is a query, respond back in a separate context so we
+ // don&#39;t interfere with the receiving thread blocking for the write.
+ if flags&amp;flagSYN == flagSYN {
+ go func() {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, pingID)
+ if err := s.sendNoWait(hdr); err != nil {
+ s.logger.Printf(&#34;[WARN] yamux: failed to send ping reply: %v&#34;, err)
+ }
+ }()
+ return nil
+ }
+
+ // Handle a response
+ s.pingLock.Lock()
+ ch := s.pings[pingID]
+ if ch != nil {
+ delete(s.pings, pingID)
+ close(ch)
+ }
+ s.pingLock.Unlock()
+ return nil
+}
+
+// handleGoAway is invokde for a typeGoAway frame
+func (s *Session) handleGoAway(hdr header) error {
+ code := hdr.Length()
+ switch code {
+ case goAwayNormal:
+ atomic.SwapInt32(&amp;s.remoteGoAway, 1)
+ case goAwayProtoErr:
+ s.logger.Printf(&#34;[ERR] yamux: received protocol error go away&#34;)
+ return fmt.Errorf(&#34;yamux protocol error&#34;)
+ case goAwayInternalErr:
+ s.logger.Printf(&#34;[ERR] yamux: received internal error go away&#34;)
+ return fmt.Errorf(&#34;remote yamux internal error&#34;)
+ default:
+ s.logger.Printf(&#34;[ERR] yamux: received unexpected go away&#34;)
+ return fmt.Errorf(&#34;unexpected go away received&#34;)
+ }
+ return nil
+}
+
+// incomingStream is used to create a new incoming stream
+func (s *Session) incomingStream(id uint32) error {
+ // Reject immediately if we are doing a go away
+ if atomic.LoadInt32(&amp;s.localGoAway) == 1 {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(hdr)
+ }
+
+ // Allocate a new stream
+ stream := newStream(s, id, streamSYNReceived)
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+
+ // Check if stream already exists
+ if _, ok := s.streams[id]; ok {
+ s.logger.Printf(&#34;[ERR] yamux: duplicate stream declared&#34;)
+ if sendErr := s.sendNoWait(s.goAway(goAwayProtoErr)); sendErr != nil {
+ s.logger.Printf(&#34;[WARN] yamux: failed to send go away: %v&#34;, sendErr)
+ }
+ return ErrDuplicateStream
+ }
+
+ // Register the stream
+ s.streams[id] = stream
+
+ // Check if we&#39;ve exceeded the backlog
+ select {
+ case s.acceptCh &lt;- stream:
+ return nil
+ default:
+ // Backlog exceeded! RST the stream
+ s.logger.Printf(&#34;[WARN] yamux: backlog exceeded, forcing connection reset&#34;)
+ delete(s.streams, id)
+ stream.sendHdr.encode(typeWindowUpdate, flagRST, id, 0)
+ return s.sendNoWait(stream.sendHdr)
+ }
+}
+
+// closeStream is used to close a stream once both sides have
+// issued a close. If there was an in-flight SYN and the stream
+// was not yet established, then this will give the credit back.
+func (s *Session) closeStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ select {
+ case &lt;-s.synCh:
+ default:
+ s.logger.Printf(&#34;[ERR] yamux: SYN tracking out of sync&#34;)
+ }
+ }
+ delete(s.streams, id)
+ s.streamLock.Unlock()
+}
+
+// establishStream is used to mark a stream that was in the
+// SYN Sent state as established.
+func (s *Session) establishStream(id uint32) {
+ s.streamLock.Lock()
+ if _, ok := s.inflight[id]; ok {
+ delete(s.inflight, id)
+ } else {
+ s.logger.Printf(&#34;[ERR] yamux: established stream without inflight SYN (no tracking entry)&#34;)
+ }
+ select {
+ case &lt;-s.synCh:
+ default:
+ s.logger.Printf(&#34;[ERR] yamux: established stream without inflight SYN (didn&#39;t have semaphore)&#34;)
+ }
+ s.streamLock.Unlock()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+session_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;bytes&#34;
+ &#34;fmt&#34;
+ &#34;io&#34;
+ &#34;io/ioutil&#34;
+ &#34;log&#34;
+ &#34;reflect&#34;
+ &#34;runtime&#34;
+ &#34;strings&#34;
+ &#34;sync&#34;
+ &#34;testing&#34;
+ &#34;time&#34;
+)
+
+type logCapture struct{ bytes.Buffer }
+
+func (l *logCapture) logs() []string {
+ return strings.Split(strings.TrimSpace(l.String()), &#34;\n&#34;)
+}
+
+func (l *logCapture) match(expect []string) bool {
+ return reflect.DeepEqual(l.logs(), expect)
+}
+
+func captureLogs(s *Session) *logCapture {
+ buf := new(logCapture)
+ s.logger = log.New(buf, &#34;&#34;, 0)
+ return buf
+}
+
+type pipeConn struct {
+ reader *io.PipeReader
+ writer *io.PipeWriter
+ writeBlocker sync.Mutex
+}
+
+func (p *pipeConn) Read(b []byte) (int, error) {
+ return p.reader.Read(b)
+}
+
+func (p *pipeConn) Write(b []byte) (int, error) {
+ p.writeBlocker.Lock()
+ defer p.writeBlocker.Unlock()
+ return p.writer.Write(b)
+}
+
+func (p *pipeConn) Close() error {
+ p.reader.Close()
+ return p.writer.Close()
+}
+
+func testConn() (io.ReadWriteCloser, io.ReadWriteCloser) {
+ read1, write1 := io.Pipe()
+ read2, write2 := io.Pipe()
+ conn1 := &amp;pipeConn{reader: read1, writer: write2}
+ conn2 := &amp;pipeConn{reader: read2, writer: write1}
+ return conn1, conn2
+}
+
+func testConf() *Config {
+ conf := DefaultConfig()
+ conf.AcceptBacklog = 64
+ conf.KeepAliveInterval = 100 * time.Millisecond
+ conf.ConnectionWriteTimeout = 250 * time.Millisecond
+ return conf
+}
+
+func testConfNoKeepAlive() *Config {
+ conf := testConf()
+ conf.EnableKeepAlive = false
+ return conf
+}
+
+func testClientServer() (*Session, *Session) {
+ return testClientServerConfig(testConf())
+}
+
+func testClientServerConfig(conf *Config) (*Session, *Session) {
+ conn1, conn2 := testConn()
+ client, _ := Client(conn1, conf)
+ server, _ := Server(conn2, conf)
+ return client, server
+}
+
+func TestPing(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ rtt, err := client.Ping()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if rtt == 0 {
+ t.Fatalf(&#34;bad: %v&#34;, rtt)
+ }
+
+ rtt, err = server.Ping()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if rtt == 0 {
+ t.Fatalf(&#34;bad: %v&#34;, rtt)
+ }
+}
+
+func TestPing_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh &lt;- err
+ }()
+
+ select {
+ case err := &lt;-errCh:
+ if err != ErrTimeout {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ case &lt;-time.After(client.config.ConnectionWriteTimeout * 2):
+ t.Fatalf(&#34;failed to timeout within expected %v&#34;, client.config.ConnectionWriteTimeout)
+ }
+
+ // Verify that we recover, even if we gave up
+ clientConn.writeBlocker.Unlock()
+
+ go func() {
+ _, err := server.Ping() // Ping via the server session
+ errCh &lt;- err
+ }()
+
+ select {
+ case err := &lt;-errCh:
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ case &lt;-time.After(client.config.ConnectionWriteTimeout):
+ t.Fatalf(&#34;timeout&#34;)
+ }
+}
+
+func TestCloseBeforeAck(t *testing.T) {
+ cfg := testConf()
+ cfg.AcceptBacklog = 8
+ client, server := testClientServerConfig(cfg)
+
+ defer client.Close()
+ defer server.Close()
+
+ for i := 0; i &lt; 8; i&#43;&#43; {
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ for i := 0; i &lt; 8; i&#43;&#43; {
+ s, err := server.AcceptStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ s, err := client.OpenStream()
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.Close()
+ }()
+
+ select {
+ case &lt;-done:
+ case &lt;-time.After(time.Second * 5):
+ t.Fatal(&#34;timed out trying to open stream&#34;)
+ }
+}
+
+func TestAccept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if client.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+
+ wg := &amp;sync.WaitGroup{}
+ wg.Add(4)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf(&#34;bad: %v&#34;, id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if id := stream.StreamID(); id != 2 {
+ t.Fatalf(&#34;bad: %v&#34;, id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if id := stream.StreamID(); id != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, id)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+
+ select {
+ case &lt;-doneCh:
+ case &lt;-time.After(time.Second):
+ panic(&#34;timeout&#34;)
+ }
+}
+
+func TestClose_closeTimeout(t *testing.T) {
+ conf := testConf()
+ conf.StreamCloseTimeout = 10 * time.Millisecond
+ client, server := testClientServerConfig(conf)
+ defer client.Close()
+ defer server.Close()
+
+ if client.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+
+ wg := &amp;sync.WaitGroup{}
+ wg.Add(2)
+
+ // Open a stream on the client but only close it on the server.
+ // We want to see if the stream ever gets cleaned up on the client.
+
+ var clientStream *Stream
+ go func() {
+ defer wg.Done()
+ var err error
+ clientStream, err = client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+
+ select {
+ case &lt;-doneCh:
+ case &lt;-time.After(time.Second):
+ panic(&#34;timeout&#34;)
+ }
+
+ // We should have zero streams after our timeout period
+ time.Sleep(100 * time.Millisecond)
+
+ if v := server.NumStreams(); v &gt; 0 {
+ t.Fatalf(&#34;should have zero streams: %d&#34;, v)
+ }
+ if v := client.NumStreams(); v &gt; 0 {
+ t.Fatalf(&#34;should have zero streams: %d&#34;, v)
+ }
+
+ if _, err := clientStream.Write([]byte(&#34;hello&#34;)); err == nil {
+ t.Fatal(&#34;should error on write&#34;)
+ } else if err.Error() != &#34;connection reset&#34; {
+ t.Fatalf(&#34;expected connection reset, got %q&#34;, err)
+ }
+}
+
+func TestNonNilInterface(t *testing.T) {
+ _, server := testClientServer()
+ server.Close()
+
+ conn, err := server.Accept()
+ if err != nil &amp;&amp; conn != nil {
+ t.Error(&#34;bad: accept should return a connection of nil value&#34;)
+ }
+
+ conn, err = server.Open()
+ if err != nil &amp;&amp; conn != nil {
+ t.Error(&#34;bad: open should return a connection of nil value&#34;)
+ }
+}
+
+func TestSendData_Small(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &amp;sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ if server.NumStreams() != 1 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+
+ buf := make([]byte, 4)
+ for i := 0; i &lt; 1000; i&#43;&#43; {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;short read: %d&#34;, n)
+ }
+ if string(buf) != &#34;test&#34; {
+ t.Fatalf(&#34;bad: %s&#34;, buf)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ if client.NumStreams() != 1 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+
+ for i := 0; i &lt; 1000; i&#43;&#43; {
+ n, err := stream.Write([]byte(&#34;test&#34;))
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;short write %d&#34;, n)
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case &lt;-doneCh:
+ case &lt;-time.After(time.Second):
+ panic(&#34;timeout&#34;)
+ }
+
+ if client.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+ if server.NumStreams() != 0 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+}
+
+func TestSendData_Large(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ const (
+ sendSize = 250 * 1024 * 1024
+ recvSize = 4 * 1024
+ )
+
+ data := make([]byte, sendSize)
+ for idx := range data {
+ data[idx] = byte(idx % 256)
+ }
+
+ wg := &amp;sync.WaitGroup{}
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ var sz int
+ buf := make([]byte, recvSize)
+ for i := 0; i &lt; sendSize/recvSize; i&#43;&#43; {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != recvSize {
+ t.Fatalf(&#34;short read: %d&#34;, n)
+ }
+ sz &#43;= n
+ for idx := range buf {
+ if buf[idx] != byte(idx%256) {
+ t.Fatalf(&#34;bad: %v %v %v&#34;, i, idx, buf[idx])
+ }
+ }
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ t.Logf(&#34;cap=%d, n=%d\n&#34;, stream.recvBuf.Cap(), sz)
+ }()
+
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ n, err := stream.Write(data)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != len(data) {
+ t.Fatalf(&#34;short write %d&#34;, n)
+ }
+
+ if err := stream.Close(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case &lt;-doneCh:
+ case &lt;-time.After(5 * time.Second):
+ panic(&#34;timeout&#34;)
+ }
+}
+
+func TestGoAway(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ if err := server.GoAway(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ _, err := client.Open()
+ if err != ErrRemoteGoAway {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+}
+
+func TestManyStreams(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &amp;sync.WaitGroup{}
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 512)
+ for {
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n == 0 {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ msg := fmt.Sprintf(&#34;%08d&#34;, i)
+ for i := 0; i &lt; 1000; i&#43;&#43; {
+ n, err := stream.Write([]byte(msg))
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != len(msg) {
+ t.Fatalf(&#34;short write %d&#34;, n)
+ }
+ }
+ }
+
+ for i := 0; i &lt; 50; i&#43;&#43; {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestManyStreams_PingPong(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ wg := &amp;sync.WaitGroup{}
+
+ ping := []byte(&#34;ping&#34;)
+ pong := []byte(&#34;pong&#34;)
+
+ acceptor := func(i int) {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for {
+ // Read the &#39;ping&#39;
+ n, err := stream.Read(buf)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if !bytes.Equal(buf, ping) {
+ t.Fatalf(&#34;bad: %s&#34;, buf)
+ }
+
+ // Shrink the internal buffer!
+ stream.Shrink()
+
+ // Write out the &#39;pong&#39;
+ n, err = stream.Write(pong)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ }
+ sender := func(i int) {
+ defer wg.Done()
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ for i := 0; i &lt; 1000; i&#43;&#43; {
+ // Send the &#39;ping&#39;
+ n, err := stream.Write(ping)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;short write %d&#34;, n)
+ }
+
+ // Read the &#39;pong&#39;
+ n, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 4 {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if !bytes.Equal(buf, pong) {
+ t.Fatalf(&#34;bad: %s&#34;, buf)
+ }
+
+ // Shrink the buffer
+ stream.Shrink()
+ }
+ }
+
+ for i := 0; i &lt; 50; i&#43;&#43; {
+ wg.Add(2)
+ go acceptor(i)
+ go sender(i)
+ }
+
+ wg.Wait()
+}
+
+func TestHalfClose(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if _, err = stream.Write([]byte(&#34;a&#34;)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ stream2.Close() // Half close
+
+ buf := make([]byte, 4)
+ n, err := stream2.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 1 {
+ t.Fatalf(&#34;bad: %v&#34;, n)
+ }
+
+ // Send more
+ if _, err = stream.Write([]byte(&#34;bcd&#34;)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ stream.Close()
+
+ // Read after close
+ n, err = stream2.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 3 {
+ t.Fatalf(&#34;bad: %v&#34;, n)
+ }
+
+ // EOF after close
+ n, err = stream2.Read(buf)
+ if err != io.EOF {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 0 {
+ t.Fatalf(&#34;bad: %v&#34;, n)
+ }
+}
+
+func TestReadDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ buf := make([]byte, 4)
+ if _, err := stream.Read(buf); err != ErrTimeout {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+}
+
+func TestReadDeadline_BlockedRead(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream2.Close()
+
+ // Start a read that will block
+ errCh := make(chan error, 1)
+ go func() {
+ buf := make([]byte, 4)
+ _, err := stream.Read(buf)
+ errCh &lt;- err
+ close(errCh)
+ }()
+
+ // Wait to ensure the read has started.
+ time.Sleep(5 * time.Millisecond)
+
+ // Update the read deadline
+ if err := stream.SetReadDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ select {
+ case &lt;-time.After(100 * time.Millisecond):
+ t.Fatal(&#34;expected read timeout&#34;)
+ case err := &lt;-errCh:
+ if err != ErrTimeout {
+ t.Fatalf(&#34;expected ErrTimeout; got %v&#34;, err)
+ }
+ }
+}
+
+func TestWriteDeadline(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream2.Close()
+
+ if err := stream.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ buf := make([]byte, 512)
+ for i := 0; i &lt; int(initialStreamWindow); i&#43;&#43; {
+ _, err := stream.Write(buf)
+ if err != nil &amp;&amp; err == ErrTimeout {
+ return
+ } else if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ t.Fatalf(&#34;Expected timeout&#34;)
+}
+
+func TestWriteDeadline_BlockedWrite(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream2.Close()
+
+ // Start a goroutine making writes that will block
+ errCh := make(chan error, 1)
+ go func() {
+ buf := make([]byte, 512)
+ for i := 0; i &lt; int(initialStreamWindow); i&#43;&#43; {
+ _, err := stream.Write(buf)
+ if err == nil {
+ continue
+ }
+
+ errCh &lt;- err
+ close(errCh)
+ return
+ }
+
+ close(errCh)
+ }()
+
+ // Wait to ensure the write has started.
+ time.Sleep(5 * time.Millisecond)
+
+ // Update the write deadline
+ if err := stream.SetWriteDeadline(time.Now().Add(5 * time.Millisecond)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ select {
+ case &lt;-time.After(1 * time.Second):
+ t.Fatal(&#34;expected write timeout&#34;)
+ case err := &lt;-errCh:
+ if err != ErrTimeout {
+ t.Fatalf(&#34;expected ErrTimeout; got %v&#34;, err)
+ }
+ }
+}
+
+func TestBacklogExceeded(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ // Fill the backlog
+ max := client.config.AcceptBacklog
+ for i := 0; i &lt; max; i&#43;&#43; {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte(&#34;foo&#34;)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+
+ // Attempt to open a new stream
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := client.Open()
+ errCh &lt;- err
+ }()
+
+ // Shutdown the server
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ server.Close()
+ }()
+
+ select {
+ case err := &lt;-errCh:
+ if err == nil {
+ t.Fatalf(&#34;open should fail&#34;)
+ }
+ case &lt;-time.After(time.Second):
+ t.Fatalf(&#34;timeout&#34;)
+ }
+}
+
+func TestKeepAlive(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ time.Sleep(200 * time.Millisecond)
+
+ // Ping value should increase
+ client.pingLock.Lock()
+ defer client.pingLock.Unlock()
+ if client.pingID == 0 {
+ t.Fatalf(&#34;should ping&#34;)
+ }
+
+ server.pingLock.Lock()
+ defer server.pingLock.Unlock()
+ if server.pingID == 0 {
+ t.Fatalf(&#34;should ping&#34;)
+ }
+}
+
+func TestKeepAlive_Timeout(t *testing.T) {
+ conn1, conn2 := testConn()
+
+ clientConf := testConf()
+ clientConf.ConnectionWriteTimeout = time.Hour // We&#39;re testing keep alives, not connection writes
+ clientConf.EnableKeepAlive = false // Just test one direction, so it&#39;s deterministic who hangs up on whom
+ client, _ := Client(conn1, clientConf)
+ defer client.Close()
+
+ server, _ := Server(conn2, testConf())
+ defer server.Close()
+
+ _ = captureLogs(client) // Client logs aren&#39;t part of the test
+ serverLogs := captureLogs(server)
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := server.Accept() // Wait until server closes
+ errCh &lt;- err
+ }()
+
+ // Prevent the client from responding
+ clientConn := client.conn.(*pipeConn)
+ clientConn.writeBlocker.Lock()
+
+ select {
+ case err := &lt;-errCh:
+ if err != ErrKeepAliveTimeout {
+ t.Fatalf(&#34;unexpected error: %v&#34;, err)
+ }
+ case &lt;-time.After(1 * time.Second):
+ t.Fatalf(&#34;timeout waiting for timeout&#34;)
+ }
+
+ if !server.IsClosed() {
+ t.Fatalf(&#34;server should have closed&#34;)
+ }
+
+ if !serverLogs.match([]string{&#34;[ERR] yamux: keepalive failed: i/o deadline reached&#34;}) {
+ t.Fatalf(&#34;server log incorect: %v&#34;, serverLogs.logs())
+ }
+}
+
+func TestLargeWindow(t *testing.T) {
+ conf := DefaultConfig()
+ conf.MaxStreamWindowSize *= 2
+
+ client, server := testClientServerConfig(conf)
+ defer client.Close()
+ defer server.Close()
+
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ stream2, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream2.Close()
+
+ stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
+ buf := make([]byte, conf.MaxStreamWindowSize)
+ n, err := stream.Write(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != len(buf) {
+ t.Fatalf(&#34;short write: %d&#34;, n)
+ }
+}
+
+type UnlimitedReader struct{}
+
+func (u *UnlimitedReader) Read(p []byte) (int, error) {
+ runtime.Gosched()
+ return len(p), nil
+}
+
+func TestSendData_VeryLarge(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ var n int64 = 1 * 1024 * 1024 * 1024
+ var workers int = 16
+
+ wg := &amp;sync.WaitGroup{}
+ wg.Add(workers * 2)
+
+ for i := 0; i &lt; workers; i&#43;&#43; {
+ go func() {
+ defer wg.Done()
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ buf := make([]byte, 4)
+ _, err = stream.Read(buf)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if !bytes.Equal(buf, []byte{0, 1, 2, 3}) {
+ t.Fatalf(&#34;bad header&#34;)
+ }
+
+ recv, err := io.Copy(ioutil.Discard, stream)
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if recv != n {
+ t.Fatalf(&#34;bad: %v&#34;, recv)
+ }
+ }()
+ }
+ for i := 0; i &lt; workers; i&#43;&#43; {
+ go func() {
+ defer wg.Done()
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ _, err = stream.Write([]byte{0, 1, 2, 3})
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+
+ unlimited := &amp;UnlimitedReader{}
+ sent, err := io.Copy(stream, io.LimitReader(unlimited, n))
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if sent != n {
+ t.Fatalf(&#34;bad: %v&#34;, sent)
+ }
+ }()
+ }
+
+ doneCh := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(doneCh)
+ }()
+ select {
+ case &lt;-doneCh:
+ case &lt;-time.After(20 * time.Second):
+ panic(&#34;timeout&#34;)
+ }
+}
+
+func TestBacklogExceeded_Accept(t *testing.T) {
+ client, server := testClientServer()
+ defer client.Close()
+ defer server.Close()
+
+ max := 5 * client.config.AcceptBacklog
+ go func() {
+ for i := 0; i &lt; max; i&#43;&#43; {
+ stream, err := server.Accept()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+ }
+ }()
+
+ // Fill the backlog
+ for i := 0; i &lt; max; i&#43;&#43; {
+ stream, err := client.Open()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ if _, err := stream.Write([]byte(&#34;foo&#34;)); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+}
+
+func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize) - 1
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ n, err := stream.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if int64(n) != flood {
+ t.Fatalf(&#34;short write: %d&#34;, n)
+ }
+ }()
+
+ // The client will open a stream, block outbound writes, and then
+ // listen to the flood from the server, which should time out since
+ // it won&#39;t be able to send the window update.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ _, err = stream.Read(make([]byte, flood))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PartialReadWindowUpdate(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ // Choose a huge flood size that we know will result in a window update.
+ flood := int64(client.config.MaxStreamWindowSize)
+ var wr *Stream
+
+ // The server will accept a new stream and then flood data to it.
+ go func() {
+ defer wg.Done()
+
+ var err error
+ wr, err = server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer wr.Close()
+
+ if wr.sendWindow != client.config.MaxStreamWindowSize {
+ t.Fatalf(&#34;sendWindow: exp=%d, got=%d&#34;, client.config.MaxStreamWindowSize, wr.sendWindow)
+ }
+
+ n, err := wr.Write(make([]byte, flood))
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if int64(n) != flood {
+ t.Fatalf(&#34;short write: %d&#34;, n)
+ }
+ if wr.sendWindow != 0 {
+ t.Fatalf(&#34;sendWindow: exp=%d, got=%d&#34;, 0, wr.sendWindow)
+ }
+ }()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ wg.Wait()
+
+ _, err = stream.Read(make([]byte, flood/2&#43;1))
+
+ if exp := uint32(flood/2 &#43; 1); wr.sendWindow != exp {
+ t.Errorf(&#34;sendWindow: exp=%d, got=%d&#34;, exp, wr.sendWindow)
+ }
+}
+
+func TestSession_sendNoWait_Timeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we&#39;ll
+ // probe sendNoWait once it gets into that state.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, flagACK, 0, 0)
+ for {
+ err = client.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_PingOfDeath(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ var doPingOfDeath sync.Mutex
+ doPingOfDeath.Lock()
+
+ // This is used later to block outbound writes.
+ conn := server.conn.(*pipeConn)
+
+ // The server will accept a stream, block outbound writes, and then
+ // flood its send channel so that no more headers can be queued.
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ conn.writeBlocker.Lock()
+ for {
+ hdr := header(make([]byte, headerSize))
+ hdr.encode(typePing, 0, 0, 0)
+ err = server.sendNoWait(hdr)
+ if err == nil {
+ continue
+ } else if err == ErrConnectionWriteTimeout {
+ break
+ } else {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }
+
+ doPingOfDeath.Unlock()
+ }()
+
+ // The client will open a stream and then send the server a ping once it
+ // can no longer write. This makes sure the server doesn&#39;t deadlock reads
+ // while trying to reply to the ping with no ability to write.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ // This ping will never unblock because the ping id will never
+ // show up in a response.
+ doPingOfDeath.Lock()
+ go func() { client.Ping() }()
+
+ // Wait for a while to make sure the previous ping times out,
+ // then turn writes back on and make sure a ping works again.
+ time.Sleep(2 * server.config.ConnectionWriteTimeout)
+ conn.writeBlocker.Unlock()
+ if _, err = client.Ping(); err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSession_ConnectionWriteTimeout(t *testing.T) {
+ client, server := testClientServerConfig(testConfNoKeepAlive())
+ defer client.Close()
+ defer server.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ stream, err := server.AcceptStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+ }()
+
+ // The client will open the stream and then block outbound writes, we&#39;ll
+ // tee up a write and make sure it eventually times out.
+ go func() {
+ defer wg.Done()
+
+ stream, err := client.OpenStream()
+ if err != nil {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ defer stream.Close()
+
+ conn := client.conn.(*pipeConn)
+ conn.writeBlocker.Lock()
+
+ // Since the write goroutine is blocked then this will return a
+ // timeout since it can&#39;t get feedback about whether the write
+ // worked.
+ n, err := stream.Write([]byte(&#34;hello&#34;))
+ if err != ErrConnectionWriteTimeout {
+ t.Fatalf(&#34;err: %v&#34;, err)
+ }
+ if n != 0 {
+ t.Fatalf(&#34;lied about writes: %d&#34;, n)
+ }
+ }()
+
+ wg.Wait()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+spec.md - github.com/hashicorp/yamux
+# Specification
+
+We use this document to detail the internal specification of Yamux.
+This is used both as a guide for implementing Yamux, but also for
+alternative interoperable libraries to be built.
+
+# Framing
+
+Yamux uses a streaming connection underneath, but imposes a message
+framing so that it can be shared between many logical streams. Each
+frame contains a header like:
+
+* Version (8 bits)
+* Type (8 bits)
+* Flags (16 bits)
+* StreamID (32 bits)
+* Length (32 bits)
+
+This means that each header has a 12 byte overhead.
+All fields are encoded in network order (big endian).
+Each field is described below:
+
+## Version Field
+
+The version field is used for future backward compatibility. At the
+current time, the field is always set to 0, to indicate the initial
+version.
+
+## Type Field
+
+The type field is used to switch the frame message type. The following
+message types are supported:
+
+* 0x0 Data - Used to transmit data. May transmit zero length payloads
+ depending on the flags.
+
+* 0x1 Window Update - Used to updated the senders receive window size.
+ This is used to implement per-session flow control.
+
+* 0x2 Ping - Used to measure RTT. It can also be used to heart-beat
+ and do keep-alives over TCP.
+
+* 0x3 Go Away - Used to close a session.
+
+## Flag Field
+
+The flags field is used to provide additional information related
+to the message type. The following flags are supported:
+
+* 0x1 SYN - Signals the start of a new stream. May be sent with a data or
+ window update message. Also sent with a ping to indicate outbound.
+
+* 0x2 ACK - Acknowledges the start of a new stream. May be sent with a data
+ or window update message. Also sent with a ping to indicate response.
+
+* 0x4 FIN - Performs a half-close of a stream. May be sent with a data
+ message or window update.
+
+* 0x8 RST - Reset a stream immediately. May be sent with a data or
+ window update message.
+
+## StreamID Field
+
+The StreamID field is used to identify the logical stream the frame
+is addressing. The client side should use odd ID&#39;s, and the server even.
+This prevents any collisions. Additionally, the 0 ID is reserved to represent
+the session.
+
+Both Ping and Go Away messages should always use the 0 StreamID.
+
+## Length Field
+
+The meaning of the length field depends on the message type:
+
+* Data - provides the length of bytes following the header
+* Window update - provides a delta update to the window size
+* Ping - Contains an opaque value, echoed back
+* Go Away - Contains an error code
+
+# Message Flow
+
+There is no explicit connection setup, as Yamux relies on an underlying
+transport to be provided. However, there is a distinction between client
+and server side of the connection.
+
+## Opening a stream
+
+To open a stream, an initial data or window update frame is sent
+with a new StreamID. The SYN flag should be set to signal a new stream.
+
+The receiver must then reply with either a data or window update frame
+with the StreamID along with the ACK flag to accept the stream or with
+the RST flag to reject the stream.
+
+Because we are relying on the reliable stream underneath, a connection
+can begin sending data once the SYN flag is sent. The corresponding
+ACK does not need to be received. This is particularly well suited
+for an RPC system where a client wants to open a stream and immediately
+fire a request without waiting for the RTT of the ACK.
+
+This does introduce the possibility of a connection being rejected
+after data has been sent already. This is a slight semantic difference
+from TCP, where the conection cannot be refused after it is opened.
+Clients should be prepared to handle this by checking for an error
+that indicates a RST was received.
+
+## Closing a stream
+
+To close a stream, either side sends a data or window update frame
+along with the FIN flag. This does a half-close indicating the sender
+will send no further data.
+
+Once both sides have closed the connection, the stream is closed.
+
+Alternatively, if an error occurs, the RST flag can be used to
+hard close a stream immediately.
+
+## Flow Control
+
+When Yamux is initially starts each stream with a 256KB window size.
+There is no window size for the session.
+
+To prevent the streams from stalling, window update frames should be
+sent regularly. Yamux can be configured to provide a larger limit for
+windows sizes. Both sides assume the initial 256KB window, but can
+immediately send a window update as part of the SYN/ACK indicating a
+larger window.
+
+Both sides should track the number of bytes sent in Data frames
+only, as only they are tracked as part of the window size.
+
+## Session termination
+
+When a session is being terminated, the Go Away message should
+be sent. The Length should be set to one of the following to
+provide an error code:
+
+* 0x0 Normal termination
+* 0x1 Protocol error
+* 0x2 Internal error
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+stream.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;bytes&#34;
+ &#34;io&#34;
+ &#34;sync&#34;
+ &#34;sync/atomic&#34;
+ &#34;time&#34;
+)
+
+type streamState int
+
+const (
+ streamInit streamState = iota
+ streamSYNSent
+ streamSYNReceived
+ streamEstablished
+ streamLocalClose
+ streamRemoteClose
+ streamClosed
+ streamReset
+)
+
+// Stream is used to represent a logical stream
+// within a session.
+type Stream struct {
+ recvWindow uint32
+ sendWindow uint32
+
+ id uint32
+ session *Session
+
+ state streamState
+ stateLock sync.Mutex
+
+ recvBuf *bytes.Buffer
+ recvLock sync.Mutex
+
+ controlHdr header
+ controlErr chan error
+ controlHdrLock sync.Mutex
+
+ sendHdr header
+ sendErr chan error
+ sendLock sync.Mutex
+
+ recvNotifyCh chan struct{}
+ sendNotifyCh chan struct{}
+
+ readDeadline atomic.Value // time.Time
+ writeDeadline atomic.Value // time.Time
+
+ // closeTimer is set with stateLock held to honor the StreamCloseTimeout
+ // setting on Session.
+ closeTimer *time.Timer
+}
+
+// newStream is used to construct a new stream within
+// a given session for an ID
+func newStream(session *Session, id uint32, state streamState) *Stream {
+ s := &amp;Stream{
+ id: id,
+ session: session,
+ state: state,
+ controlHdr: header(make([]byte, headerSize)),
+ controlErr: make(chan error, 1),
+ sendHdr: header(make([]byte, headerSize)),
+ sendErr: make(chan error, 1),
+ recvWindow: initialStreamWindow,
+ sendWindow: initialStreamWindow,
+ recvNotifyCh: make(chan struct{}, 1),
+ sendNotifyCh: make(chan struct{}, 1),
+ }
+ s.readDeadline.Store(time.Time{})
+ s.writeDeadline.Store(time.Time{})
+ return s
+}
+
+// Session returns the associated stream session
+func (s *Stream) Session() *Session {
+ return s.session
+}
+
+// StreamID returns the ID of this stream
+func (s *Stream) StreamID() uint32 {
+ return s.id
+}
+
+// Read is used to read from the stream
+func (s *Stream) Read(b []byte) (n int, err error) {
+ defer asyncNotify(s.recvNotifyCh)
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamRemoteClose:
+ fallthrough
+ case streamClosed:
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ s.stateLock.Unlock()
+ return 0, io.EOF
+ }
+ s.recvLock.Unlock()
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ s.recvLock.Lock()
+ if s.recvBuf == nil || s.recvBuf.Len() == 0 {
+ s.recvLock.Unlock()
+ goto WAIT
+ }
+
+ // Read any bytes
+ n, _ = s.recvBuf.Read(b)
+ s.recvLock.Unlock()
+
+ // Send a window update potentially
+ err = s.sendWindowUpdate()
+ return n, err
+
+WAIT:
+ var timeout &lt;-chan time.Time
+ var timer *time.Timer
+ readDeadline := s.readDeadline.Load().(time.Time)
+ if !readDeadline.IsZero() {
+ delay := readDeadline.Sub(time.Now())
+ timer = time.NewTimer(delay)
+ timeout = timer.C
+ }
+ select {
+ case &lt;-s.recvNotifyCh:
+ if timer != nil {
+ timer.Stop()
+ }
+ goto START
+ case &lt;-timeout:
+ return 0, ErrTimeout
+ }
+}
+
+// Write is used to write to the stream
+func (s *Stream) Write(b []byte) (n int, err error) {
+ s.sendLock.Lock()
+ defer s.sendLock.Unlock()
+ total := 0
+ for total &lt; len(b) {
+ n, err := s.write(b[total:])
+ total &#43;= n
+ if err != nil {
+ return total, err
+ }
+ }
+ return total, nil
+}
+
+// write is used to write to the stream, may return on
+// a short write.
+func (s *Stream) write(b []byte) (n int, err error) {
+ var flags uint16
+ var max uint32
+ var body io.Reader
+START:
+ s.stateLock.Lock()
+ switch s.state {
+ case streamLocalClose:
+ fallthrough
+ case streamClosed:
+ s.stateLock.Unlock()
+ return 0, ErrStreamClosed
+ case streamReset:
+ s.stateLock.Unlock()
+ return 0, ErrConnectionReset
+ }
+ s.stateLock.Unlock()
+
+ // If there is no data available, block
+ window := atomic.LoadUint32(&amp;s.sendWindow)
+ if window == 0 {
+ goto WAIT
+ }
+
+ // Determine the flags if any
+ flags = s.sendFlags()
+
+ // Send up to our send window
+ max = min(window, uint32(len(b)))
+ body = bytes.NewReader(b[:max])
+
+ // Send the header
+ s.sendHdr.encode(typeData, flags, s.id, max)
+ if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
+ return 0, err
+ }
+
+ // Reduce our send window
+ atomic.AddUint32(&amp;s.sendWindow, ^uint32(max-1))
+
+ // Unlock
+ return int(max), err
+
+WAIT:
+ var timeout &lt;-chan time.Time
+ writeDeadline := s.writeDeadline.Load().(time.Time)
+ if !writeDeadline.IsZero() {
+ delay := writeDeadline.Sub(time.Now())
+ timeout = time.After(delay)
+ }
+ select {
+ case &lt;-s.sendNotifyCh:
+ goto START
+ case &lt;-timeout:
+ return 0, ErrTimeout
+ }
+ return 0, nil
+}
+
+// sendFlags determines any flags that are appropriate
+// based on the current stream state
+func (s *Stream) sendFlags() uint16 {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+ var flags uint16
+ switch s.state {
+ case streamInit:
+ flags |= flagSYN
+ s.state = streamSYNSent
+ case streamSYNReceived:
+ flags |= flagACK
+ s.state = streamEstablished
+ }
+ return flags
+}
+
+// sendWindowUpdate potentially sends a window update enabling
+// further writes to take place. Must be invoked with the lock.
+func (s *Stream) sendWindowUpdate() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ // Determine the delta update
+ max := s.session.config.MaxStreamWindowSize
+ var bufLen uint32
+ s.recvLock.Lock()
+ if s.recvBuf != nil {
+ bufLen = uint32(s.recvBuf.Len())
+ }
+ delta := (max - bufLen) - s.recvWindow
+
+ // Determine the flags if any
+ flags := s.sendFlags()
+
+ // Check if we can omit the update
+ if delta &lt; (max/2) &amp;&amp; flags == 0 {
+ s.recvLock.Unlock()
+ return nil
+ }
+
+ // Update our window
+ s.recvWindow &#43;= delta
+ s.recvLock.Unlock()
+
+ // Send the header
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// sendClose is used to send a FIN
+func (s *Stream) sendClose() error {
+ s.controlHdrLock.Lock()
+ defer s.controlHdrLock.Unlock()
+
+ flags := s.sendFlags()
+ flags |= flagFIN
+ s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
+ if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Close is used to close the stream
+func (s *Stream) Close() error {
+ closeStream := false
+ s.stateLock.Lock()
+ switch s.state {
+ // Opened means we need to signal a close
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamLocalClose
+ goto SEND_CLOSE
+
+ case streamLocalClose:
+ case streamRemoteClose:
+ s.state = streamClosed
+ closeStream = true
+ goto SEND_CLOSE
+
+ case streamClosed:
+ case streamReset:
+ default:
+ panic(&#34;unhandled state&#34;)
+ }
+ s.stateLock.Unlock()
+ return nil
+SEND_CLOSE:
+ // This shouldn&#39;t happen (the more realistic scenario to cancel the
+ // timer is via processFlags) but just in case this ever happens, we
+ // cancel the timer to prevent dangling timers.
+ if s.closeTimer != nil {
+ s.closeTimer.Stop()
+ s.closeTimer = nil
+ }
+
+ // If we have a StreamCloseTimeout set we start the timeout timer.
+ // We do this only if we&#39;re not already closing the stream since that
+ // means this was a graceful close.
+ //
+ // This prevents memory leaks if one side (this side) closes and the
+ // remote side poorly behaves and never responds with a FIN to complete
+ // the close. After the specified timeout, we clean our resources up no
+ // matter what.
+ if !closeStream &amp;&amp; s.session.config.StreamCloseTimeout &gt; 0 {
+ s.closeTimer = time.AfterFunc(
+ s.session.config.StreamCloseTimeout, s.closeTimeout)
+ }
+
+ s.stateLock.Unlock()
+ s.sendClose()
+ s.notifyWaiting()
+ if closeStream {
+ s.session.closeStream(s.id)
+ }
+ return nil
+}
+
+// closeTimeout is called after StreamCloseTimeout during a close to
+// close this stream.
+func (s *Stream) closeTimeout() {
+ // Close our side forcibly
+ s.forceClose()
+
+ // Free the stream from the session map
+ s.session.closeStream(s.id)
+
+ // Send a RST so the remote side closes too.
+ s.sendLock.Lock()
+ defer s.sendLock.Unlock()
+ s.sendHdr.encode(typeWindowUpdate, flagRST, s.id, 0)
+ s.session.sendNoWait(s.sendHdr)
+}
+
+// forceClose is used for when the session is exiting
+func (s *Stream) forceClose() {
+ s.stateLock.Lock()
+ s.state = streamClosed
+ s.stateLock.Unlock()
+ s.notifyWaiting()
+}
+
+// processFlags is used to update the state of the stream
+// based on set flags, if any. Lock must be held
+func (s *Stream) processFlags(flags uint16) error {
+ s.stateLock.Lock()
+ defer s.stateLock.Unlock()
+
+ // Close the stream without holding the state lock
+ closeStream := false
+ defer func() {
+ if closeStream {
+ if s.closeTimer != nil {
+ // Stop our close timeout timer since we gracefully closed
+ s.closeTimer.Stop()
+ }
+
+ s.session.closeStream(s.id)
+ }
+ }()
+
+ if flags&amp;flagACK == flagACK {
+ if s.state == streamSYNSent {
+ s.state = streamEstablished
+ }
+ s.session.establishStream(s.id)
+ }
+ if flags&amp;flagFIN == flagFIN {
+ switch s.state {
+ case streamSYNSent:
+ fallthrough
+ case streamSYNReceived:
+ fallthrough
+ case streamEstablished:
+ s.state = streamRemoteClose
+ s.notifyWaiting()
+ case streamLocalClose:
+ s.state = streamClosed
+ closeStream = true
+ s.notifyWaiting()
+ default:
+ s.session.logger.Printf(&#34;[ERR] yamux: unexpected FIN flag in state %d&#34;, s.state)
+ return ErrUnexpectedFlag
+ }
+ }
+ if flags&amp;flagRST == flagRST {
+ s.state = streamReset
+ closeStream = true
+ s.notifyWaiting()
+ }
+ return nil
+}
+
+// notifyWaiting notifies all the waiting channels
+func (s *Stream) notifyWaiting() {
+ asyncNotify(s.recvNotifyCh)
+ asyncNotify(s.sendNotifyCh)
+}
+
+// incrSendWindow updates the size of our send window
+func (s *Stream) incrSendWindow(hdr header, flags uint16) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Increase window, unblock a sender
+ atomic.AddUint32(&amp;s.sendWindow, hdr.Length())
+ asyncNotify(s.sendNotifyCh)
+ return nil
+}
+
+// readData is used to handle a data frame
+func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
+ if err := s.processFlags(flags); err != nil {
+ return err
+ }
+
+ // Check that our recv window is not exceeded
+ length := hdr.Length()
+ if length == 0 {
+ return nil
+ }
+
+ // Wrap in a limited reader
+ conn = &amp;io.LimitedReader{R: conn, N: int64(length)}
+
+ // Copy into buffer
+ s.recvLock.Lock()
+
+ if length &gt; s.recvWindow {
+ s.session.logger.Printf(&#34;[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)&#34;, s.id, s.recvWindow, length)
+ return ErrRecvWindowExceeded
+ }
+
+ if s.recvBuf == nil {
+ // Allocate the receive buffer just-in-time to fit the full data frame.
+ // This way we can read in the whole packet without further allocations.
+ s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
+ }
+ if _, err := io.Copy(s.recvBuf, conn); err != nil {
+ s.session.logger.Printf(&#34;[ERR] yamux: Failed to read stream data: %v&#34;, err)
+ s.recvLock.Unlock()
+ return err
+ }
+
+ // Decrement the receive window
+ s.recvWindow -= length
+ s.recvLock.Unlock()
+
+ // Unblock any readers
+ asyncNotify(s.recvNotifyCh)
+ return nil
+}
+
+// SetDeadline sets the read and write deadlines
+func (s *Stream) SetDeadline(t time.Time) error {
+ if err := s.SetReadDeadline(t); err != nil {
+ return err
+ }
+ if err := s.SetWriteDeadline(t); err != nil {
+ return err
+ }
+ return nil
+}
+
+// SetReadDeadline sets the deadline for blocked and future Read calls.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+ s.readDeadline.Store(t)
+ asyncNotify(s.recvNotifyCh)
+ return nil
+}
+
+// SetWriteDeadline sets the deadline for blocked and future Write calls
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+ s.writeDeadline.Store(t)
+ asyncNotify(s.sendNotifyCh)
+ return nil
+}
+
+// Shrink is used to compact the amount of buffers utilized
+// This is useful when using Yamux in a connection pool to reduce
+// the idle memory utilization.
+func (s *Stream) Shrink() {
+ s.recvLock.Lock()
+ if s.recvBuf != nil &amp;&amp; s.recvBuf.Len() == 0 {
+ s.recvBuf = nil
+ }
+ s.recvLock.Unlock()
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+util.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;sync&#34;
+ &#34;time&#34;
+)
+
+var (
+ timerPool = &amp;sync.Pool{
+ New: func() interface{} {
+ timer := time.NewTimer(time.Hour * 1e6)
+ timer.Stop()
+ return timer
+ },
+ }
+)
+
+// asyncSendErr is used to try an async send of an error
+func asyncSendErr(ch chan error, err error) {
+ if ch == nil {
+ return
+ }
+ select {
+ case ch &lt;- err:
+ default:
+ }
+}
+
+// asyncNotify is used to signal a waiting goroutine
+func asyncNotify(ch chan struct{}) {
+ select {
+ case ch &lt;- struct{}{}:
+ default:
+ }
+}
+
+// min computes the minimum of two values
+func min(a, b uint32) uint32 {
+ if a &lt; b {
+ return a
+ }
+ return b
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+util_test.go - github.com/hashicorp/yamux
+package yamux
+
+import (
+ &#34;testing&#34;
+)
+
+func TestAsyncSendErr(t *testing.T) {
+ ch := make(chan error)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case &lt;-ch:
+ t.Fatalf(&#34;should not get&#34;)
+ default:
+ }
+
+ ch = make(chan error, 1)
+ asyncSendErr(ch, ErrTimeout)
+ select {
+ case &lt;-ch:
+ default:
+ t.Fatalf(&#34;should get&#34;)
+ }
+}
+
+func TestAsyncNotify(t *testing.T) {
+ ch := make(chan struct{})
+ asyncNotify(ch)
+ select {
+ case &lt;-ch:
+ t.Fatalf(&#34;should not get&#34;)
+ default:
+ }
+
+ ch = make(chan struct{}, 1)
+ asyncNotify(ch)
+ select {
+ case &lt;-ch:
+ default:
+ t.Fatalf(&#34;should get&#34;)
+ }
+}
+
+func TestMin(t *testing.T) {
+ if min(1, 2) != 1 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+ if min(2, 1) != 1 {
+ t.Fatalf(&#34;bad&#34;)
+ }
+}
+
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LICENSE - github.com/jcmturner/gofork
Copyright (c) 2009 The Go Authors. All rights reserved.
diff --git a/go.mod b/go.mod
index f9c2d9ae9..9b67ddc36 100644
--- a/go.mod
+++ b/go.mod
@@ -23,6 +23,7 @@ require (
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/hashicorp/golang-lru v0.5.4
+ github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864
github.com/kelseyhightower/envconfig v1.3.0
github.com/lib/pq v1.2.0
github.com/libgit2/git2go/v31 v31.4.12
diff --git a/go.sum b/go.sum
index 34b473729..ee105d591 100644
--- a/go.sum
+++ b/go.sum
@@ -196,6 +196,8 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
+github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864 h1:Y4V+SFe7d3iH+9pJCoeWIOS5/xBJIFsltS7E+KJSsJY=
+github.com/hashicorp/yamux v0.0.0-20210316155119-a95892c5f864/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
@@ -304,14 +306,10 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/otiai10/copy v1.0.1 h1:gtBjD8aq4nychvRZ2CyJvFWAw0aja+VHazDdruZKGZA=
github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
-github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
-github.com/otiai10/copy v1.0.1/go.mod h1:8bMCJrAqOtN/d9oyh5HR7HhLQMvcGMpGdwRDYsfOCHc=
github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE=
github.com/otiai10/curr v1.0.0 h1:TJIWdbX0B+kpNagQrjgq8bCMrbhiuX73M2XwgtDMoOI=
github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs=
github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
-github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
-github.com/otiai10/mint v1.2.3/go.mod h1:YnfyPNhBvnY8bW4SGQHCs/aAFhkgySlMZbrF5U0bOVw=
github.com/otiai10/mint v1.3.0 h1:Ady6MKVezQwHBkGzLFbrsywyp09Ah7rkmfjV3Bcr5uc=
github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo=
github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o=
@@ -360,7 +358,6 @@ github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBa
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
-github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -423,7 +420,6 @@ gitlab.com/gitlab-org/gitaly v1.68.0/go.mod h1:/pCsB918Zu5wFchZ9hLYin9WkJ2yQqdVN
gitlab.com/gitlab-org/gitlab-shell v1.9.8-0.20201117050822-3f9890ef73dc h1:ENBJqb2gK3cZFHe8dbEM/TPXCwkAxDuxGK255ux4XPg=
gitlab.com/gitlab-org/gitlab-shell v1.9.8-0.20201117050822-3f9890ef73dc/go.mod h1:5QSTbpAHY2v0iIH5uHh2KA9w7sPUqPmnLjDApI/sv1U=
gitlab.com/gitlab-org/labkit v0.0.0-20190221122536-0c3fc7cdd57c/go.mod h1:rYhLgfrbEcyfinG+R3EvKu6bZSsmwQqcXzLfHWSfUKM=
-gitlab.com/gitlab-org/labkit v0.0.0-20190221122536-0c3fc7cdd57c/go.mod h1:rYhLgfrbEcyfinG+R3EvKu6bZSsmwQqcXzLfHWSfUKM=
gitlab.com/gitlab-org/labkit v0.0.0-20200908084045-45895e129029/go.mod h1:SNfxkfUwVNECgtmluVayv0GWFgEjjBs5AzgsowPQuo0=
gitlab.com/gitlab-org/labkit v1.0.0 h1:t2Wr8ygtvHfXAMlCkoEdk5pdb5Gy1IYdr41H7t4kAYw=
gitlab.com/gitlab-org/labkit v1.0.0/go.mod h1:nohrYTSLDnZix0ebXZrbZJjymRar8HeV2roWL5/jw2U=
diff --git a/internal/backchannel/backchannel.go b/internal/backchannel/backchannel.go
new file mode 100644
index 000000000..b22fc71e0
--- /dev/null
+++ b/internal/backchannel/backchannel.go
@@ -0,0 +1,60 @@
+// Package backchannel implements connection multiplexing that allows for invoking
+// gRPC methods from the server to the client.
+//
+// gRPC allows only for invoking RPCs from client to the server. Invoking
+// RPCs from the server to the client can be useful in some cases such as
+// tunneling through firewalls. While implementing such a use case would be
+// possible with plain bidirectional streams, the approach has various limitations
+// that force additional work on the user. All messages in a single stream are ordered
+// and processed sequentially. If concurrency is desired, this would require the user
+// to implement their own concurrency handling. Request routing and cancellations would also
+// have to be implemented separately on top of the bidirectional stream.
+//
+// To do away with these problems, this package provides a multiplexed transport for running two
+// independent gRPC sessions on a single connection. This allows for dialing back to the client from
+// the server to establish another gRPC session where the server and client roles are switched.
+//
+// The server side supports clients that are unaware of the multiplexing. The server peeks the incoming
+// network stream to see if it starts with the magic bytes that indicate a multiplexing aware client.
+// If the magic bytes are present, the server initiates the multiplexing session and dials back to the client
+// over the already established network connection. If the magic bytes are not present, the server restores the
+// the bytes back into the original network stream and handles it without a multiplexing session.
+//
+// Usage:
+// 1. Implement a ServerFactory, which is simply a function that returns a Server that can serve on the backchannel
+// connection. Plug in the ClientHandshake returned by the ServerFactory.ClientHandshaker via grpc.WithTransportCredentials.
+// This ensures all connections established by gRPC work with a multiplexing session and have a backchannel Server serving.
+// 2. Configure the ServerHandshake on the server side by passing it into the gRPC server via the grpc.Creds option.
+// The ServerHandshake method is called on each newly established connection. It peeks the network stream to see if a
+// multiplexing session should be initiated. If so, it also dials back to the client's backchannel server. Server
+// makes the backchannel connection's available later via the Registry's Backchannel method. The ID of the
+// peer associated with the current RPC handler can be fetched via GetPeerID. The returned ID can be used
+// to access the correct backchannel connection from the Registry.
+package backchannel
+
+import (
+ "io"
+ "net"
+
+ "github.com/hashicorp/yamux"
+)
+
+// magicBytes are sent by the client to server to identify as a multiplexing aware client.
+var magicBytes = []byte("backchannel")
+
+// muxConfig returns a new config to use with the multiplexing session.
+func muxConfig(logger io.Writer) *yamux.Config {
+ cfg := yamux.DefaultConfig()
+ cfg.LogOutput = logger
+ return cfg
+}
+
+// connCloser wraps a net.Conn and calls the provided close function instead when Close
+// is called.
+type connCloser struct {
+ net.Conn
+ close func() error
+}
+
+// Close calls the provided close function.
+func (cc connCloser) Close() error { return cc.close() }
diff --git a/internal/backchannel/backchannel_example_test.go b/internal/backchannel/backchannel_example_test.go
new file mode 100644
index 000000000..20387b2de
--- /dev/null
+++ b/internal/backchannel/backchannel_example_test.go
@@ -0,0 +1,139 @@
+package backchannel_test
+
+import (
+ "context"
+ "fmt"
+ "net"
+
+ "github.com/sirupsen/logrus"
+ "gitlab.com/gitlab-org/gitaly/internal/backchannel"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+)
+
+func Example() {
+ // Open the server's listener.
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ fmt.Printf("failed to start listener: %v", err)
+ return
+ }
+
+ // Registry is for storing the open backchannels. It should be passed into the ServerHandshaker
+ // which creates the backchannel connections and adds them to the registry. The RPC handlers
+ // can use the registry to access available backchannels by their peer ID.
+ registry := backchannel.NewRegistry()
+
+ logger := logrus.NewEntry(logrus.New())
+
+ // ServerHandshaker initiates the multiplexing session on the server side. Once that is done,
+ // it creates the backchannel connection and stores it into the registry. For each connection,
+ // the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a
+ // backchannel connection.
+ handshaker := backchannel.NewServerHandshaker(logger, backchannel.Insecure(), registry)
+
+ // Create the server
+ srv := grpc.NewServer(
+ grpc.Creds(handshaker),
+ grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
+ fmt.Println("Gitaly received a transactional mutator")
+
+ backchannelID, err := backchannel.GetPeerID(stream.Context())
+ if err == backchannel.ErrNonMultiplexedConnection {
+ // This call is from a client that is not multiplexing aware. Client is not
+ // Praefect, so no need to perform voting. The client could be for example
+ // GitLab calling Gitaly directly.
+ fmt.Println("Gitaly responding to a non-multiplexed client")
+ return stream.SendMsg(&gitalypb.CreateBranchResponse{})
+ } else if err != nil {
+ return fmt.Errorf("get peer id: %w", err)
+ }
+
+ backchannelConn, err := registry.Backchannel(backchannelID)
+ if err != nil {
+ return fmt.Errorf("get backchannel: %w", err)
+ }
+
+ fmt.Println("Gitaly sending vote to Praefect via backchannel")
+ if err := backchannelConn.Invoke(
+ stream.Context(), "/Praefect/VoteTransaction",
+ &gitalypb.VoteTransactionRequest{}, &gitalypb.VoteTransactionResponse{},
+ ); err != nil {
+ return fmt.Errorf("invoke backchannel: %w", err)
+ }
+ fmt.Println("Gitaly received vote response via backchannel")
+
+ fmt.Println("Gitaly responding to the transactional mutator")
+ return stream.SendMsg(&gitalypb.CreateBranchResponse{})
+ }),
+ )
+ defer srv.Stop()
+
+ // Start the server
+ go func() {
+ if err := srv.Serve(ln); err != nil {
+ fmt.Printf("failed to serve: %v", err)
+ }
+ }()
+
+ fmt.Printf("Invoke with a multiplexed client:\n\n")
+ if err := invokeWithMuxedClient(logger, ln.Addr().String()); err != nil {
+ fmt.Printf("failed to invoke with muxed client: %v", err)
+ return
+ }
+
+ fmt.Printf("\nInvoke with a non-multiplexed client:\n\n")
+ if err := invokeWithNormalClient(ln.Addr().String()); err != nil {
+ fmt.Printf("failed to invoke with non-muxed client: %v", err)
+ return
+ }
+ // Output:
+ // Invoke with a multiplexed client:
+ //
+ // Gitaly received a transactional mutator
+ // Gitaly sending vote to Praefect via backchannel
+ // Praefect received vote via backchannel
+ // Praefect responding via backchannel
+ // Gitaly received vote response via backchannel
+ // Gitaly responding to the transactional mutator
+ //
+ // Invoke with a non-multiplexed client:
+ //
+ // Gitaly received a transactional mutator
+ // Gitaly responding to a non-multiplexed client
+}
+
+func invokeWithMuxedClient(logger *logrus.Entry, address string) error {
+ // serverFactory gets called on each established connection. The Server it returns
+ // is started on Praefect's end of the connection, which Gitaly can call.
+ serverFactory := backchannel.ServerFactory(func() backchannel.Server {
+ return grpc.NewServer(grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
+ fmt.Println("Praefect received vote via backchannel")
+ fmt.Println("Praefect responding via backchannel")
+ return stream.SendMsg(&gitalypb.VoteTransactionResponse{})
+ }))
+ })
+
+ return invokeWithOpts(address, grpc.WithTransportCredentials(serverFactory.ClientHandshaker(logger, backchannel.Insecure())))
+}
+
+func invokeWithNormalClient(address string) error {
+ return invokeWithOpts(address, grpc.WithInsecure())
+}
+
+func invokeWithOpts(address string, opts ...grpc.DialOption) error {
+ clientConn, err := grpc.Dial(address, opts...)
+ if err != nil {
+ return fmt.Errorf("dial server: %w", err)
+ }
+
+ if err := clientConn.Invoke(context.Background(), "/Gitaly/Mutator", &gitalypb.CreateBranchRequest{}, &gitalypb.CreateBranchResponse{}); err != nil {
+ return fmt.Errorf("call server: %w", err)
+ }
+
+ if err := clientConn.Close(); err != nil {
+ return fmt.Errorf("close clientConn: %w", err)
+ }
+
+ return nil
+}
diff --git a/internal/backchannel/backchannel_test.go b/internal/backchannel/backchannel_test.go
new file mode 100644
index 000000000..6b93153cb
--- /dev/null
+++ b/internal/backchannel/backchannel_test.go
@@ -0,0 +1,223 @@
+package backchannel
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/gitaly/internal/testhelper"
+ "gitlab.com/gitlab-org/gitaly/proto/go/gitalypb"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+type mockTransactionServer struct {
+ voteTransactionFunc func(context.Context, *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error)
+ *gitalypb.UnimplementedRefTransactionServer
+}
+
+func (m mockTransactionServer) VoteTransaction(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ return m.voteTransactionFunc(ctx, req)
+}
+
+func TestBackchannel_concurrentRequestsFromMultipleClients(t *testing.T) {
+ registry := NewRegistry()
+ handshaker := NewServerHandshaker(testhelper.DiscardTestEntry(t), Insecure(), registry)
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ require.NoError(t, err)
+
+ errNonMultiplexed := status.Error(codes.FailedPrecondition, ErrNonMultiplexedConnection.Error())
+ srv := grpc.NewServer(grpc.Creds(handshaker))
+
+ gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{
+ voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ peerID, err := GetPeerID(ctx)
+ if err == ErrNonMultiplexedConnection {
+ return nil, errNonMultiplexed
+ }
+ assert.NoError(t, err)
+
+ cc, err := registry.Backchannel(peerID)
+ if !assert.NoError(t, err) {
+ return nil, err
+ }
+
+ return gitalypb.NewRefTransactionClient(cc).VoteTransaction(ctx, req)
+ },
+ })
+
+ defer srv.Stop()
+ go srv.Serve(ln)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ start := make(chan struct{})
+
+ // Create 25 multiplexed clients and non-multiplexed clients that launch requests
+ // concurrently.
+ var wg sync.WaitGroup
+ for i := uint64(0); i < 25; i++ {
+ i := i
+ wg.Add(2)
+
+ go func() {
+ defer wg.Done()
+
+ <-start
+ client, err := grpc.Dial(ln.Addr().String(), grpc.WithInsecure())
+ if !assert.NoError(t, err) {
+ return
+ }
+
+ resp, err := gitalypb.NewRefTransactionClient(client).VoteTransaction(ctx, &gitalypb.VoteTransactionRequest{})
+ assert.Equal(t, err, errNonMultiplexed)
+ assert.Nil(t, resp)
+
+ assert.NoError(t, client.Close())
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ expectedErr := status.Error(codes.Internal, fmt.Sprintf("multiplexed %d", i))
+ serverFactory := ServerFactory(func() Server {
+ srv := grpc.NewServer()
+ gitalypb.RegisterRefTransactionServer(srv, mockTransactionServer{
+ voteTransactionFunc: func(ctx context.Context, req *gitalypb.VoteTransactionRequest) (*gitalypb.VoteTransactionResponse, error) {
+ assert.Equal(t, &gitalypb.VoteTransactionRequest{TransactionId: i}, req)
+ return nil, expectedErr
+ },
+ })
+
+ return srv
+ })
+
+ <-start
+ client, err := grpc.Dial(ln.Addr().String(),
+ grpc.WithTransportCredentials(serverFactory.ClientHandshaker(testhelper.DiscardTestEntry(t), Insecure())),
+ )
+ if !assert.NoError(t, err) {
+ return
+ }
+
+ // Run two invocations concurrently on each multiplexed client to sanity check
+ // the routing works with multiple requests from a connection.
+ var invocations sync.WaitGroup
+ for invocation := 0; invocation < 2; invocation++ {
+ invocations.Add(1)
+ go func() {
+ defer invocations.Done()
+ resp, err := gitalypb.NewRefTransactionClient(client).VoteTransaction(ctx, &gitalypb.VoteTransactionRequest{TransactionId: i})
+ assert.Equal(t, err, expectedErr)
+ assert.Nil(t, resp)
+ }()
+ }
+
+ invocations.Wait()
+ assert.NoError(t, client.Close())
+ }()
+ }
+
+ // Establish the connection and fire the requests.
+ close(start)
+
+ // Wait for the clients to finish their calls and close their connections.
+ wg.Wait()
+}
+
+type mockSSHService struct {
+ sshUploadPackFunc func(gitalypb.SSHService_SSHUploadPackServer) error
+ *gitalypb.UnimplementedSSHServiceServer
+}
+
+func (m mockSSHService) SSHUploadPack(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ return m.sshUploadPackFunc(stream)
+}
+
+func Benchmark(b *testing.B) {
+ for _, tc := range []struct {
+ desc string
+ multiplexed bool
+ }{
+ {desc: "multiplexed", multiplexed: true},
+ {desc: "normal"},
+ } {
+ b.Run(tc.desc, func(b *testing.B) {
+ for _, messageSize := range []int64{
+ 1024,
+ 1024 * 1024,
+ 3 * 1024 * 1024,
+ } {
+ b.Run(fmt.Sprintf("message size %dkb", messageSize/1024), func(b *testing.B) {
+ var serverOpts []grpc.ServerOption
+ if tc.multiplexed {
+ serverOpts = []grpc.ServerOption{
+ grpc.Creds(NewServerHandshaker(testhelper.DiscardTestEntry(b), Insecure(), NewRegistry())),
+ }
+ }
+
+ srv := grpc.NewServer(serverOpts...)
+ gitalypb.RegisterSSHServiceServer(srv, mockSSHService{
+ sshUploadPackFunc: func(stream gitalypb.SSHService_SSHUploadPackServer) error {
+ for {
+ _, err := stream.Recv()
+ if err != nil {
+ assert.Equal(b, io.EOF, err)
+ return nil
+ }
+ }
+ },
+ })
+
+ ln, err := net.Listen("tcp", "localhost:0")
+ require.NoError(b, err)
+
+ defer srv.Stop()
+ go srv.Serve(ln)
+
+ ctx, cancel := testhelper.Context()
+ defer cancel()
+
+ opts := []grpc.DialOption{grpc.WithBlock(), grpc.WithInsecure()}
+ if tc.multiplexed {
+ nopServer := ServerFactory(func() Server { return grpc.NewServer() })
+ opts = []grpc.DialOption{
+ grpc.WithBlock(),
+ grpc.WithTransportCredentials(nopServer.ClientHandshaker(
+ testhelper.DiscardTestEntry(b), Insecure(),
+ )),
+ }
+ }
+
+ cc, err := grpc.DialContext(ctx, ln.Addr().String(), opts...)
+ require.NoError(b, err)
+
+ defer cc.Close()
+
+ client, err := gitalypb.NewSSHServiceClient(cc).SSHUploadPack(ctx)
+ require.NoError(b, err)
+
+ request := &gitalypb.SSHUploadPackRequest{Stdin: make([]byte, messageSize)}
+ b.SetBytes(messageSize)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ require.NoError(b, client.Send(request))
+ }
+
+ require.NoError(b, client.CloseSend())
+ _, err = client.Recv()
+ require.Equal(b, io.EOF, err)
+ })
+ }
+ })
+ }
+}
diff --git a/internal/backchannel/client.go b/internal/backchannel/client.go
new file mode 100644
index 000000000..a844d6d7f
--- /dev/null
+++ b/internal/backchannel/client.go
@@ -0,0 +1,125 @@
+package backchannel
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+
+ "github.com/hashicorp/yamux"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc/credentials"
+)
+
+// Server is the interface of a backchannel server.
+type Server interface {
+ // Serve starts serving on the listener.
+ Serve(net.Listener) error
+ // Stops the server and closes all connections.
+ Stop()
+}
+
+// ServerFactory returns the server that should serve on the backchannel.
+// Each invocation should return a new server as the servers get stopped when
+// a backchannel closes.
+type ServerFactory func() Server
+
+// ClientHandshaker returns TransportCredentials that perform the client side multiplexing handshake and
+// start the backchannel Server on the established connections. The provided logger is used to log multiplexing
+// errors and the transport credentials are used to intiliaze the connection prior to the multiplexing.
+func (sf ServerFactory) ClientHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials) credentials.TransportCredentials {
+ return clientHandshaker{TransportCredentials: tc, serverFactory: sf, logger: logger}
+}
+
+type clientHandshaker struct {
+ credentials.TransportCredentials
+ serverFactory ServerFactory
+ logger *logrus.Entry
+}
+
+func (ch clientHandshaker) ClientHandshake(ctx context.Context, serverName string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := ch.TransportCredentials.ClientHandshake(ctx, serverName, conn)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ clientStream, err := ch.serve(ctx, conn)
+ if err != nil {
+ return nil, nil, fmt.Errorf("serve: %w", err)
+ }
+
+ return clientStream, authInfo, nil
+}
+
+func (ch clientHandshaker) serve(ctx context.Context, conn net.Conn) (net.Conn, error) {
+ deadline := time.Time{}
+ if dl, ok := ctx.Deadline(); ok {
+ deadline = dl
+ }
+
+ // gRPC expects the ClientHandshaker implementation to respect the deadline set in the context.
+ if err := conn.SetDeadline(deadline); err != nil {
+ return nil, fmt.Errorf("set connection deadline: %w", err)
+ }
+
+ defer func() {
+ // The deadline has to be cleared after the muxing session is established as we are not
+ // returning the Conn itself but the stream, thus gRPC can't clear the deadline we set
+ // on the Conn.
+ if err := conn.SetDeadline(time.Time{}); err != nil {
+ ch.logger.WithError(err).Error("remove connection deadline")
+ }
+ }()
+
+ // Write the magic bytes on the connection so the server knows we're about to initiate
+ // a multiplexing session.
+ if _, err := conn.Write(magicBytes); err != nil {
+ return nil, fmt.Errorf("write backchannel magic bytes: %w", err)
+ }
+
+ logger := ch.logger.WriterLevel(logrus.ErrorLevel)
+
+ // Initiate the multiplexing session.
+ muxSession, err := yamux.Client(conn, muxConfig(logger))
+ if err != nil {
+ logger.Close()
+ return nil, fmt.Errorf("open multiplexing session: %w", err)
+ }
+
+ go func() {
+ <-muxSession.CloseChan()
+ logger.Close()
+ }()
+
+ // Initiate the stream to the server. This is used by the client's gRPC session.
+ clientToServer, err := muxSession.Open()
+ if err != nil {
+ return nil, fmt.Errorf("open client stream: %w", err)
+ }
+
+ // Run the backchannel server.
+ server := ch.serverFactory()
+ serveErr := make(chan error, 1)
+ go func() { serveErr <- server.Serve(muxSession) }()
+
+ return connCloser{
+ Conn: clientToServer,
+ close: func() error {
+ // Stop closes the listener, which is the muxing session. Closing the
+ // muxing session closes the underlying network connection.
+ //
+ // There's no sense in doing a graceful shutdown. The connection is being closed,
+ // it would no longer receive a response from the server.
+ server.Stop()
+ // Serve returns a non-nil error if it returned before Stop was called. If the error
+ // is non-nil, it indicates a serving failure prior to calling Stop.
+ return <-serveErr
+ }}, nil
+}
+
+func (ch clientHandshaker) Clone() credentials.TransportCredentials {
+ return clientHandshaker{
+ TransportCredentials: ch.TransportCredentials.Clone(),
+ serverFactory: ch.serverFactory,
+ }
+}
diff --git a/internal/backchannel/insecure.go b/internal/backchannel/insecure.go
new file mode 100644
index 000000000..678a90527
--- /dev/null
+++ b/internal/backchannel/insecure.go
@@ -0,0 +1,40 @@
+package backchannel
+
+import (
+ "context"
+ "net"
+
+ "google.golang.org/grpc/credentials"
+)
+
+type insecureAuthInfo struct{ credentials.CommonAuthInfo }
+
+func (insecureAuthInfo) AuthType() string { return "insecure" }
+
+type insecure struct{}
+
+func (insecure) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ return conn, insecureAuthInfo{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
+}
+
+func (insecure) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ return conn, insecureAuthInfo{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
+}
+
+func (insecure) Info() credentials.ProtocolInfo {
+ return credentials.ProtocolInfo{SecurityProtocol: "insecure"}
+}
+
+func (insecure) Clone() credentials.TransportCredentials { return Insecure() }
+
+func (insecure) OverrideServerName(string) error { return nil }
+
+// Insecure can be used in place of transport credentials when no transport security is configured.
+// Its handshakes simply return the passed in connection.
+//
+// Similar credentials are already implemented in gRPC:
+// https://github.com/grpc/grpc-go/blob/702608ffae4d03a6821b96d3e2311973d34b96dc/credentials/insecure/insecure.go
+// We've reimplemented these here as upgrading our gRPC version was very involved. Once
+// we've upgrade to a version that contains the insecure credentials, this implementation can be removed and
+// substituted by the official implementation.
+func Insecure() credentials.TransportCredentials { return insecure{} }
diff --git a/internal/backchannel/registry.go b/internal/backchannel/registry.go
new file mode 100644
index 000000000..407d770f9
--- /dev/null
+++ b/internal/backchannel/registry.go
@@ -0,0 +1,51 @@
+package backchannel
+
+import (
+ "fmt"
+ "sync"
+
+ "google.golang.org/grpc"
+)
+
+// ID is a monotonically increasing number that uniquely identifies a peer connection.
+type ID uint64
+
+// Registry is a thread safe registry for backchannels. It enables accessing the backchannels via a
+// unique ID.
+type Registry struct {
+ m sync.RWMutex
+ currentID ID
+ backchannels map[ID]*grpc.ClientConn
+}
+
+// NewRegistry returns a new Registry.
+func NewRegistry() *Registry { return &Registry{backchannels: map[ID]*grpc.ClientConn{}} }
+
+// Backchannel returns a backchannel for the ID. Returns an error if no backchannel is registered
+// for the ID.
+func (r *Registry) Backchannel(id ID) (*grpc.ClientConn, error) {
+ r.m.RLock()
+ defer r.m.RUnlock()
+ backchannel, ok := r.backchannels[id]
+ if !ok {
+ return nil, fmt.Errorf("no backchannel for peer %d", id)
+ }
+
+ return backchannel, nil
+}
+
+// RegisterBackchannel registers a new backchannel and returns its unique ID.
+func (r *Registry) RegisterBackchannel(conn *grpc.ClientConn) ID {
+ r.m.Lock()
+ defer r.m.Unlock()
+ r.currentID++
+ r.backchannels[r.currentID] = conn
+ return r.currentID
+}
+
+// RemoveBackchannel removes a backchannel from the registry.
+func (r *Registry) RemoveBackchannel(id ID) {
+ r.m.Lock()
+ defer r.m.Unlock()
+ delete(r.backchannels, id)
+}
diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go
new file mode 100644
index 000000000..44a769ca9
--- /dev/null
+++ b/internal/backchannel/server.go
@@ -0,0 +1,146 @@
+package backchannel
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+
+ "github.com/hashicorp/yamux"
+ "github.com/sirupsen/logrus"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/peer"
+)
+
+// ErrNonMultiplexedConnection is returned when attempting to get the peer id of a non-multiplexed
+// connection.
+var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection")
+
+// authInfoWrapper is used to pass the peer id through the context to the RPC handlers.
+type authInfoWrapper struct {
+ id ID
+ credentials.AuthInfo
+}
+
+func (w authInfoWrapper) peerID() ID { return w.id }
+
+// GetPeerID gets the ID of the current peer connection.
+func GetPeerID(ctx context.Context) (ID, error) {
+ peerInfo, ok := peer.FromContext(ctx)
+ if !ok {
+ return 0, errors.New("no peer info in context")
+ }
+
+ wrapper, ok := peerInfo.AuthInfo.(interface{ peerID() ID })
+ if !ok {
+ return 0, ErrNonMultiplexedConnection
+ }
+
+ return wrapper.peerID(), nil
+}
+
+// ServerHandshaker implements the server side handshake of the multiplexed connection.
+type ServerHandshaker struct {
+ registry *Registry
+ logger *logrus.Entry
+ credentials.TransportCredentials
+}
+
+// NewServerHandshaker returns a new server side implementation of the backchannel.
+func NewServerHandshaker(logger *logrus.Entry, tc credentials.TransportCredentials, reg *Registry) credentials.TransportCredentials {
+ return ServerHandshaker{
+ TransportCredentials: tc,
+ registry: reg,
+ logger: logger,
+ }
+}
+
+// restoredConn allows for restoring the connection's stream after peeking it. If the connection
+// was not multiplexed, the peeked bytes are restored back into the stream.
+type restoredConn struct {
+ net.Conn
+ reader io.Reader
+}
+
+func (rc *restoredConn) Read(b []byte) (int, error) { return rc.reader.Read(b) }
+
+// ServerHandshake peeks the connection to determine whether the client supports establishing a
+// backchannel by multiplexing the network connection. If so, it establishes a gRPC ClientConn back
+// to the client and stores it's ID in the AuthInfo where it can be later accessed by the RPC handlers.
+// gRPC sets an IO timeout on the connection before calling ServerHandshake, so we don't have to handle
+// timeouts separately.
+func (s ServerHandshaker) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ conn, authInfo, err := s.TransportCredentials.ServerHandshake(conn)
+ if err != nil {
+ return nil, nil, fmt.Errorf("wrapped server handshake: %w", err)
+ }
+
+ peeked, err := ioutil.ReadAll(io.LimitReader(conn, int64(len(magicBytes))))
+ if err != nil {
+ return nil, nil, fmt.Errorf("peek network stream: %w", err)
+ }
+
+ if !bytes.Equal(peeked, magicBytes) {
+ // If the client connection is not multiplexed, restore the peeked bytes back into the stream.
+ // We also set a 0 peer ID in the authInfo to indicate that the server handshake was attempted
+ // but this was not a multiplexed connection.
+ return &restoredConn{
+ Conn: conn,
+ reader: io.MultiReader(bytes.NewReader(peeked), conn),
+ }, authInfo, nil
+ }
+
+ // It is not necessary to clean up any of the multiplexing-related sessions on errors as the
+ // gRPC server closes the conn if there is an error, which closes the multiplexing
+ // session as well.
+
+ logger := s.logger.WriterLevel(logrus.ErrorLevel)
+
+ // Open the server side of the multiplexing session.
+ muxSession, err := yamux.Server(conn, muxConfig(logger))
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("create multiplexing session: %w", err)
+ }
+
+ // Accept the client's stream. This is the client's gRPC session to the server.
+ clientToServerStream, err := muxSession.Accept()
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("accept client's stream: %w", err)
+ }
+
+ // The address does not actually matter but we set it so clientConn.Target returns a meaningful value.
+ // WithInsecure is used as the multiplexer operates within a TLS session already if one is configured.
+ backchannelConn, err := grpc.Dial(
+ "multiplexed/"+conn.RemoteAddr().String(),
+ grpc.WithInsecure(),
+ grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return muxSession.Open() }),
+ )
+ if err != nil {
+ logger.Close()
+ return nil, nil, fmt.Errorf("dial backchannel: %w", err)
+ }
+
+ id := s.registry.RegisterBackchannel(backchannelConn)
+ // The returned connection must close the underlying network connection, we redirect the close
+ // to the muxSession which also closes the underlying connection.
+ return connCloser{
+ Conn: clientToServerStream,
+ close: func() error {
+ s.registry.RemoveBackchannel(id)
+ backchannelConn.Close()
+ muxSession.Close()
+ logger.Close()
+ return nil
+ },
+ },
+ authInfoWrapper{
+ id: id,
+ AuthInfo: authInfo,
+ }, nil
+}