aboutsummaryrefslogtreecommitdiff
path: root/ssh/session_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/session_test.go')
-rw-r--r--ssh/session_test.go71
1 files changed, 59 insertions, 12 deletions
diff --git a/ssh/session_test.go b/ssh/session_test.go
index 521677f..807a913 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -13,6 +13,7 @@ import (
"io"
"math/rand"
"net"
+ "sync"
"testing"
"golang.org/x/crypto/ssh/terminal"
@@ -27,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client {
t.Fatalf("netPipe: %v", err)
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
- defer c1.Close()
+ defer func() {
+ c1.Close()
+ wg.Done()
+ }()
conf := ServerConfig{
NoClientAuth: true,
}
@@ -39,7 +46,11 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Unable to handshake: %v", err)
return
}
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for newCh := range chans {
if newCh.ChannelType() != "session" {
@@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client {
t.Errorf("Accept: %v", err)
continue
}
+ wg.Add(1)
go func() {
handler(ch, inReqs, t)
+ wg.Done()
}()
}
if err := conn.Wait(); err != io.EOF {
@@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) {
t.Fatal(err)
}
defer session.Close()
- result := make(chan []byte)
+ serverStdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("StdinPipe failed: %v", err)
+ }
+
+ result := make(chan []byte)
go func() {
defer close(result)
echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
@@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) {
result <- echoedBuf.Bytes()
}()
- serverStdin, err := session.StdinPipe()
- if err != nil {
- t.Fatalf("StdinPipe failed: %v", err)
- }
written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
if err != nil {
t.Errorf("failed to copy origBuf to serverStdin: %v", err)
@@ -648,29 +662,44 @@ func TestSessionID(t *testing.T) {
User: "user",
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+
srvErrCh := make(chan error, 1)
+ wg.Add(1)
go func() {
+ defer wg.Done()
conn, chans, reqs, err := NewServerConn(c1, serverConf)
srvErrCh <- err
if err != nil {
return
}
serverID <- conn.SessionID()
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}
}()
cliErrCh := make(chan error, 1)
+ wg.Add(1)
go func() {
+ defer wg.Done()
conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
cliErrCh <- err
if err != nil {
return
}
clientID <- conn.SessionID()
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}
@@ -738,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) {
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
@@ -751,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close()
defer c2.Close()
- go NewServerConn(c1, serverConf)
+ wg.Add(1)
+ go func() {
+ NewServerConn(c1, serverConf)
+ wg.Done()
+ }()
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
@@ -785,7 +820,11 @@ func TestHostKeyAlgorithms(t *testing.T) {
defer c1.Close()
defer c2.Close()
- go NewServerConn(c1, serverConf)
+ wg.Add(1)
+ go func() {
+ NewServerConn(c1, serverConf)
+ wg.Done()
+ }()
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
@@ -818,14 +857,22 @@ func TestServerClientAuthCallback(t *testing.T) {
User: someUsername,
}
+ var wg sync.WaitGroup
+ t.Cleanup(wg.Wait)
+ wg.Add(1)
go func() {
+ defer wg.Done()
_, chans, reqs, err := NewServerConn(c1, serverConf)
if err != nil {
t.Errorf("server handshake: %v", err)
userCh <- "error"
return
}
- go DiscardRequests(reqs)
+ wg.Add(1)
+ go func() {
+ DiscardRequests(reqs)
+ wg.Done()
+ }()
for ch := range chans {
ch.Reject(Prohibited, "")
}