aboutsummaryrefslogtreecommitdiff
path: root/ssh/handshake_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'ssh/handshake_test.go')
-rw-r--r--ssh/handshake_test.go220
1 files changed, 220 insertions, 0 deletions
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 2bc607b..019e47f 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -539,6 +539,226 @@ func TestDisconnect(t *testing.T) {
}
}
+type mockKeyingTransport struct {
+ packetConn
+ kexInitAllowed chan struct{}
+ kexInitSent chan struct{}
+}
+
+func (n *mockKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
+ return nil
+}
+
+func (n *mockKeyingTransport) writePacket(packet []byte) error {
+ if packet[0] == msgKexInit {
+ <-n.kexInitAllowed
+ n.kexInitSent <- struct{}{}
+ }
+ return n.packetConn.writePacket(packet)
+}
+
+func (n *mockKeyingTransport) readPacket() ([]byte, error) {
+ return n.packetConn.readPacket()
+}
+
+func (n *mockKeyingTransport) setStrictMode() error { return nil }
+
+func (n *mockKeyingTransport) setInitialKEXDone() {}
+
+func TestHandshakePendingPacketsWait(t *testing.T) {
+ a, b := memPipe()
+
+ trS := &mockKeyingTransport{
+ packetConn: a,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trS.kexInitAllowed <- struct{}{}
+
+ trC := &mockKeyingTransport{
+ packetConn: b,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trC.kexInitAllowed <- struct{}{}
+
+ clientConf := &ClientConfig{
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
+
+ if err := server.waitSession(); err != nil {
+ t.Fatalf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ t.Fatalf("client.waitSession: %v", err)
+ }
+
+ <-trC.kexInitSent
+ <-trS.kexInitSent
+
+ // Allow and request new KEX server side.
+ trS.kexInitAllowed <- struct{}{}
+ server.requestKeyExchange()
+ // Wait until the KEX init is sent.
+ <-trS.kexInitSent
+ // The client is not allowed to respond to the KEX, so writes will be
+ // blocked on the server side once the packets queue is full.
+ for i := 0; i < maxPendingPackets; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := server.writePacket(p); err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }
+ // The packets queue is now full, the next write will block.
+ server.mu.Lock()
+ if len(server.pendingPackets) != maxPendingPackets {
+ t.Errorf("unexpected pending packets size; got: %d, want: %d", len(server.pendingPackets), maxPendingPackets)
+ }
+ server.mu.Unlock()
+
+ writeDone := make(chan struct{})
+ go func() {
+ defer close(writeDone)
+
+ p := []byte{msgRequestSuccess, byte(65)}
+ // This write will block until KEX completes.
+ err := server.writePacket(p)
+ if err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }()
+
+ // Consume packets on the client side
+ readDone := make(chan bool)
+ go func() {
+ defer close(readDone)
+
+ for {
+ if _, err := client.readPacket(); err != nil {
+ if err != io.EOF {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ break
+ }
+ }
+ }()
+
+ // Allow the client to reply to the KEX and so unblock the write goroutine.
+ trC.kexInitAllowed <- struct{}{}
+ <-trC.kexInitSent
+ <-writeDone
+ // Close the client to unblock the read goroutine.
+ client.Close()
+ <-readDone
+ server.Close()
+}
+
+func TestHandshakePendingPacketsError(t *testing.T) {
+ a, b := memPipe()
+
+ trS := &mockKeyingTransport{
+ packetConn: a,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trS.kexInitAllowed <- struct{}{}
+
+ trC := &mockKeyingTransport{
+ packetConn: b,
+ kexInitAllowed: make(chan struct{}, 2),
+ kexInitSent: make(chan struct{}, 2),
+ }
+ // Allow the first KEX.
+ trC.kexInitAllowed <- struct{}{}
+
+ clientConf := &ClientConfig{
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ clientConf.SetDefaults()
+
+ v := []byte("version")
+ client := newClientTransport(trC, v, v, clientConf, "addr", nil)
+
+ serverConf := &ServerConfig{}
+ serverConf.AddHostKey(testSigners["ecdsa"])
+ serverConf.AddHostKey(testSigners["rsa"])
+ serverConf.SetDefaults()
+ server := newServerTransport(trS, v, v, serverConf)
+
+ if err := server.waitSession(); err != nil {
+ t.Fatalf("server.waitSession: %v", err)
+ }
+ if err := client.waitSession(); err != nil {
+ t.Fatalf("client.waitSession: %v", err)
+ }
+
+ <-trC.kexInitSent
+ <-trS.kexInitSent
+
+ // Allow and request new KEX server side.
+ trS.kexInitAllowed <- struct{}{}
+ server.requestKeyExchange()
+ // Wait until the KEX init is sent.
+ <-trS.kexInitSent
+ // The client is not allowed to respond to the KEX, so writes will be
+ // blocked on the server side once the packets queue is full.
+ for i := 0; i < maxPendingPackets; i++ {
+ p := []byte{msgRequestSuccess, byte(i)}
+ if err := server.writePacket(p); err != nil {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }
+ // The packets queue is now full, the next write will block.
+ writeDone := make(chan struct{})
+ go func() {
+ defer close(writeDone)
+
+ p := []byte{msgRequestSuccess, byte(65)}
+ // This write will block until KEX completes.
+ err := server.writePacket(p)
+ if err != io.EOF {
+ t.Errorf("unexpected write error: %v", err)
+ }
+ }()
+
+ // Consume packets on the client side
+ readDone := make(chan bool)
+ go func() {
+ defer close(readDone)
+
+ for {
+ if _, err := client.readPacket(); err != nil {
+ if err != io.EOF {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ break
+ }
+ }
+ }()
+
+ // Close the server to unblock the write after an error
+ server.Close()
+ <-writeDone
+ // Unblock the pending write and close the client to unblock the read
+ // goroutine.
+ trC.kexInitAllowed <- struct{}{}
+ client.Close()
+ <-readDone
+}
+
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{