diff options
| author | Dave Cheney <dave@cheney.net> | 2012-07-21 14:19:30 +1000 |
|---|---|---|
| committer | Dave Cheney <dave@cheney.net> | 2012-07-21 14:19:30 +1000 |
| commit | e751d5236aebf3b015b9a54aadf6af1807d3956e (patch) | |
| tree | 90cc52377fcffbe37822c7bd64d697aaa0fd7f9a /ssh/server_test.go | |
| parent | d1bf83abcbf80a644ecdaabdecf1f429b9c8f63a (diff) | |
| download | go-x-crypto-e751d5236aebf3b015b9a54aadf6af1807d3956e.tar.xz | |
go.crypto/ssh: improve TestServerWindow robustness
Fix a few resource leaks and prevent the test from
hanging if an error occurs reading from the remote
server.
R=agl, gustav.paul, kardianos
CC=golang-dev
https://golang.org/cl/6423065
Diffstat (limited to 'ssh/server_test.go')
| -rw-r--r-- | ssh/server_test.go | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/ssh/server_test.go b/ssh/server_test.go index 898b356..18dbff1 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -17,18 +17,18 @@ const windowTestBytes = 16000 * 200 // CopyNRandomly copies n bytes from src to dst. It uses a variable, and random, // buffer size to exercise more code paths. -func CopyNRandomly(dst io.Writer, src io.Reader, n int64) (written int64, err error) { +func CopyNRandomly(dst io.Writer, src io.Reader, n int) (written int, err error) { buf := make([]byte, 32*1024) for written < n { l := (rand.Intn(30) + 1) * 1024 - if d := n - written; d < int64(l) { - l = int(d) + if d := n - written; d < l { + l = d } nr, er := src.Read(buf[0:l]) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) if nw > 0 { - written += int64(nw) + written += nw } if ew != nil { err = ew @@ -75,6 +75,7 @@ func runSSHClient(t *testing.T, addr string) { // Read back the data from the server. go func() { defer session.Close() + defer close(wait) serverStdout, err := session.StdoutPipe() if err != nil { t.Fatal(err) @@ -87,7 +88,6 @@ func runSSHClient(t *testing.T, addr string) { if n != windowTestBytes { t.Fatalf("Read only %d bytes from server, expected %d", n, windowTestBytes) } - wait <- true }() serverStdin, err := session.StdinPipe() @@ -126,11 +126,10 @@ func startSSHServer(t *testing.T) (addr string) { } addr = listener.Addr().String() - go func() { + defer listener.Close() for { sConn, err := listener.Accept() - err = sConn.Handshake() if err != nil { if err != io.EOF { @@ -147,6 +146,7 @@ func startSSHServer(t *testing.T) (addr string) { } func connRun(t *testing.T, sConn *ServerConn) { + defer sConn.Close() for { channel, err := sConn.Accept() if err != nil { @@ -167,7 +167,6 @@ func connRun(t *testing.T, sConn *ServerConn) { go func() { defer channel.Close() - n, err := CopyNRandomly(channel, channel, windowTestBytes) if err != nil && err != io.EOF { if err == io.ErrShortWrite { |
