diff options
Diffstat (limited to 'ssh/mux_test.go')
| -rw-r--r-- | ssh/mux_test.go | 91 |
1 files changed, 49 insertions, 42 deletions
diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 1db3be5..eae637d 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -10,7 +10,6 @@ import ( "io" "sync" "testing" - "time" ) func muxPair() (*mux, *mux) { @@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) { magic := "hello world" magicExt := "hello stderr" + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() _, err := s.Write([]byte(magic)) if err != nil { t.Errorf("Write: %v", err) @@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } writer.Write(make([]byte, 1)) - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() @@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) { if _, err := reader.SendRequest("hello", true, nil); err == nil { t.Errorf("SendRequest succeeded.") } - <-wDone } func TestMuxChannelCloseWriteUnblock(t *testing.T) { @@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() reader.Close() - <-wDone } func TestMuxConnectionCloseWriteUnblock(t *testing.T) { @@ -206,20 +211,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() mux.Close() - <-wDone } func TestMuxReject(t *testing.T) { @@ -227,7 +233,12 @@ func TestMuxReject(t *testing.T) { defer server.Close() defer client.Close() + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() + ch, ok := <-server.incomingChannels if !ok { t.Error("cannot accept channel") @@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) { var received int var wg sync.WaitGroup + t.Cleanup(wg.Wait) wg.Add(1) go func() { for r := range server.incomingRequests { @@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) { } if ok { t.Errorf("SendRequest(no): %v", ok) - } client.Close() @@ -389,13 +400,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) { // Wait for the server to send the keepalive message and receive back a // response. - select { - case err := <-kDone: - if err != nil { - t.Fatal(err) - } - case <-time.After(10 * time.Second): - t.Fatalf("server never received ack") + if err := <-kDone; err != nil { + t.Fatal(err) } // Confirm client hasn't closed. @@ -403,13 +409,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) { t.Fatalf("failed to send keepalive: %v", err) } - select { - case err := <-kDone: - if err != nil { - t.Fatal(err) - } - case <-time.After(10 * time.Second): - t.Fatalf("server never shut down") + // Wait for the server to shut down. + if err := <-kDone; err != nil { + t.Fatal(err) } } @@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) { defer ch.Close() // Wait for the server to close the channel and send the keepalive. - select { - case <-kDone: - case <-time.After(10 * time.Second): - t.Fatalf("server never received ack") - } + <-kDone // Make sure the channel closed. if _, ok := <-ch.incomingRequests; ok { @@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) { t.Fatalf("failed to send keepalive: %v", err) } - select { - case <-kDone: - case <-time.After(10 * time.Second): - t.Fatalf("server never shut down") - } + // Wait for the server to shut down. + <-kDone } func TestMuxGlobalRequest(t *testing.T) { + var sawPeek bool + var wg sync.WaitGroup + defer func() { + wg.Wait() + if !sawPeek { + t.Errorf("never saw 'peek' request") + } + }() + clientMux, serverMux := muxPair() defer serverMux.Close() defer clientMux.Close() - var seen bool + wg.Add(1) go func() { + defer wg.Done() for r := range serverMux.incomingRequests { - seen = seen || r.Type == "peek" + sawPeek = sawPeek || r.Type == "peek" if r.WantReply { err := r.Reply(r.Type == "yes", append([]byte(r.Type), r.Payload...)) @@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) { t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", ok, data, err) } - - if !seen { - t.Errorf("never saw 'peek' request") - } } func TestMuxGlobalRequestUnblock(t *testing.T) { @@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) { t.Errorf("could not send packet") } - go a.SendRequest("hello", false, nil) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + a.SendRequest("hello", false, nil) + wg.Done() + }() _, ok := <-b.incomingRequests if ok { |
