aboutsummaryrefslogtreecommitdiff
path: root/ssh
diff options
context:
space:
mode:
Diffstat (limited to 'ssh')
-rw-r--r--ssh/channel.go28
-rw-r--r--ssh/mempipe_test.go20
-rw-r--r--ssh/mux_test.go71
3 files changed, 108 insertions, 11 deletions
diff --git a/ssh/channel.go b/ssh/channel.go
index c0834c0..cc0bb7a 100644
--- a/ssh/channel.go
+++ b/ssh/channel.go
@@ -187,9 +187,11 @@ type channel struct {
pending *buffer
extPending *buffer
- // windowMu protects myWindow, the flow-control window.
- windowMu sync.Mutex
- myWindow uint32
+ // windowMu protects myWindow, the flow-control window, and myConsumed,
+ // the number of bytes consumed since we last increased myWindow
+ windowMu sync.Mutex
+ myWindow uint32
+ myConsumed uint32
// writeMu serializes calls to mux.conn.writePacket() and
// protects sentClose and packetPool. This mutex must be
@@ -332,14 +334,24 @@ func (ch *channel) handleData(packet []byte) error {
return nil
}
-func (c *channel) adjustWindow(n uint32) error {
+func (c *channel) adjustWindow(adj uint32) error {
c.windowMu.Lock()
- // Since myWindow is managed on our side, and can never exceed
- // the initial window setting, we don't worry about overflow.
- c.myWindow += uint32(n)
+ // Since myConsumed and myWindow are managed on our side, and can never
+ // exceed the initial window setting, we don't worry about overflow.
+ c.myConsumed += adj
+ var sendAdj uint32
+ if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
+ (c.myWindow < channelWindowSize/2) {
+ sendAdj = c.myConsumed
+ c.myConsumed = 0
+ c.myWindow += sendAdj
+ }
c.windowMu.Unlock()
+ if sendAdj == 0 {
+ return nil
+ }
return c.sendMessage(windowAdjustMsg{
- AdditionalBytes: uint32(n),
+ AdditionalBytes: sendAdj,
})
}
diff --git a/ssh/mempipe_test.go b/ssh/mempipe_test.go
index 8697cd6..f27339c 100644
--- a/ssh/mempipe_test.go
+++ b/ssh/mempipe_test.go
@@ -13,9 +13,10 @@ import (
// An in-memory packetConn. It is safe to call Close and writePacket
// from different goroutines.
type memTransport struct {
- eof bool
- pending [][]byte
- write *memTransport
+ eof bool
+ pending [][]byte
+ write *memTransport
+ writeCount uint64
sync.Mutex
*sync.Cond
}
@@ -63,9 +64,16 @@ func (t *memTransport) writePacket(p []byte) error {
copy(c, p)
t.write.pending = append(t.write.pending, c)
t.write.Cond.Signal()
+ t.writeCount++
return nil
}
+func (t *memTransport) getWriteCount() uint64 {
+ t.write.Lock()
+ defer t.write.Unlock()
+ return t.writeCount
+}
+
func memPipe() (a, b packetConn) {
t1 := memTransport{}
t2 := memTransport{}
@@ -81,6 +89,9 @@ func TestMemPipe(t *testing.T) {
if err := a.writePacket([]byte{42}); err != nil {
t.Fatalf("writePacket: %v", err)
}
+ if wc := a.(*memTransport).getWriteCount(); wc != 1 {
+ t.Fatalf("got %v, want 1", wc)
+ }
if err := a.Close(); err != nil {
t.Fatal("Close: ", err)
}
@@ -95,6 +106,9 @@ func TestMemPipe(t *testing.T) {
if err != io.EOF {
t.Fatalf("got %v, %v, want EOF", p, err)
}
+ if wc := b.(*memTransport).getWriteCount(); wc != 0 {
+ t.Fatalf("got %v, want 0", wc)
+ }
}
func TestDoubleClose(t *testing.T) {
diff --git a/ssh/mux_test.go b/ssh/mux_test.go
index eae637d..21f0ac3 100644
--- a/ssh/mux_test.go
+++ b/ssh/mux_test.go
@@ -182,6 +182,40 @@ func TestMuxChannelOverflow(t *testing.T) {
}
}
+func TestMuxChannelReadUnblock(t *testing.T) {
+ reader, writer, mux := channelPair(t)
+ defer reader.Close()
+ defer writer.Close()
+ defer mux.Close()
+
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
+ t.Errorf("could not fill window: %v", err)
+ }
+ if _, err := writer.Write(make([]byte, 1)); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ writer.Close()
+ }()
+
+ writer.remoteWin.waitWriterBlocked()
+
+ buf := make([]byte, 32768)
+ for {
+ _, err := reader.Read(buf)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+ }
+}
+
func TestMuxChannelCloseWriteUnblock(t *testing.T) {
reader, writer, mux := channelPair(t)
defer reader.Close()
@@ -754,6 +788,43 @@ func TestMuxMaxPacketSize(t *testing.T) {
}
}
+func TestMuxChannelWindowDeferredUpdates(t *testing.T) {
+ s, c, mux := channelPair(t)
+ cTransport := mux.conn.(*memTransport)
+ defer s.Close()
+ defer c.Close()
+ defer mux.Close()
+
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+
+ data := make([]byte, 1024)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ _, err := s.Write(data)
+ if err != nil {
+ t.Errorf("Write: %v", err)
+ return
+ }
+ }()
+ cWritesInit := cTransport.getWriteCount()
+ buf := make([]byte, 1)
+ for i := 0; i < len(data); i++ {
+ n, err := c.Read(buf)
+ if n != len(buf) || err != nil {
+ t.Fatalf("Read: %v, %v", n, err)
+ }
+ }
+ cWrites := cTransport.getWriteCount() - cWritesInit
+ // reading 1 KiB should not cause any window updates to be sent, but allow
+ // for some unexpected writes
+ if cWrites > 30 {
+ t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites)
+ }
+}
+
// Don't ship code with debug=true.
func TestDebug(t *testing.T) {
if debugMux {