aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ssh/client.go2
-rw-r--r--ssh/client_test.go25
2 files changed, 24 insertions, 3 deletions
diff --git a/ssh/client.go b/ssh/client.go
index bdc356c..fd8c497 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -82,7 +82,7 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan
if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close()
- return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
+ return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
}
conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
diff --git a/ssh/client_test.go b/ssh/client_test.go
index c114573..2621f0e 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -7,6 +7,9 @@ package ssh
import (
"bytes"
"crypto/rand"
+ "errors"
+ "fmt"
+ "net"
"strings"
"testing"
)
@@ -207,9 +210,12 @@ func TestBannerCallback(t *testing.T) {
}
func TestNewClientConn(t *testing.T) {
+ errHostKeyMismatch := errors.New("host key mismatch")
+
for _, tt := range []struct {
- name string
- user string
+ name string
+ user string
+ simulateHostKeyMismatch HostKeyCallback
}{
{
name: "good user field for ConnMetadata",
@@ -219,6 +225,13 @@ func TestNewClientConn(t *testing.T) {
name: "empty user field for ConnMetadata",
user: "",
},
+ {
+ name: "host key mismatch",
+ user: "testuser",
+ simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error {
+ return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key)))
+ },
+ },
} {
t.Run(tt.name, func(t *testing.T) {
c1, c2, err := netPipe()
@@ -243,8 +256,16 @@ func TestNewClientConn(t *testing.T) {
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
+
+ if tt.simulateHostKeyMismatch != nil {
+ clientConf.HostKeyCallback = tt.simulateHostKeyMismatch
+ }
+
clientConn, _, _, err := NewClientConn(c2, "", clientConf)
if err != nil {
+ if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) {
+ return
+ }
t.Fatal(err)
}