aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ssh/server.go27
-rw-r--r--ssh/server_test.go374
2 files changed, 401 insertions, 0 deletions
diff --git a/ssh/server.go b/ssh/server.go
index f06c08c..064dcba 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -44,6 +44,9 @@ type Permissions struct {
// pass data from the authentication callbacks to the server
// application layer.
Extensions map[string]string
+
+ // ExtraData allows to store user defined data.
+ ExtraData map[any]any
}
type GSSAPIWithMICConfig struct {
@@ -127,6 +130,21 @@ type ServerConfig struct {
// Permissions.Extensions entry.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
+ // VerifiedPublicKeyCallback, if non-nil, is called after a client
+ // successfully confirms having control over a key that was previously
+ // approved by PublicKeyCallback. The permissions object passed to the
+ // callback is the one returned by PublicKeyCallback for the given public
+ // key and its ownership is transferred to the callback. The returned
+ // Permissions object can be the same object, optionally modified, or a
+ // completely new object. If VerifiedPublicKeyCallback is non-nil,
+ // PublicKeyCallback is not allowed to return a PartialSuccessError, which
+ // can instead be returned by VerifiedPublicKeyCallback.
+ //
+ // VerifiedPublicKeyCallback does not affect which authentication methods
+ // are included in the list of methods that can be attempted by the client.
+ VerifiedPublicKeyCallback func(conn ConnMetadata, key PublicKey, permissions *Permissions,
+ signatureAlgorithm string) (*Permissions, error)
+
// KeyboardInteractiveCallback, if non-nil, is called when
// keyboard-interactive authentication is selected (RFC
// 4256). The client object's Challenge function should be
@@ -653,6 +671,9 @@ userAuthLoop:
candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
+ if isPartialSuccessError && config.VerifiedPublicKeyCallback != nil {
+ return nil, errors.New("ssh: invalid library usage: PublicKeyCallback must not return partial success when VerifiedPublicKeyCallback is defined")
+ }
if (candidate.result == nil || isPartialSuccessError) &&
candidate.perms != nil &&
@@ -723,6 +744,12 @@ userAuthLoop:
authErr = candidate.result
perms = candidate.perms
+ if authErr == nil && config.VerifiedPublicKeyCallback != nil {
+ // Only call VerifiedPublicKeyCallback after the key has been accepted
+ // and successfully verified. If authErr is non-nil, the key is not
+ // considered verified and the callback must not run.
+ perms, authErr = config.VerifiedPublicKeyCallback(s, pubKey, perms, algo)
+ }
}
case "gssapi-with-mic":
if authConfig.GSSAPIWithMICConfig == nil {
diff --git a/ssh/server_test.go b/ssh/server_test.go
index 5bd18db..e48f7e3 100644
--- a/ssh/server_test.go
+++ b/ssh/server_test.go
@@ -434,6 +434,380 @@ func TestPreAuthConnAndBanners(t *testing.T) {
}
}
+func TestVerifiedPublicKeyCallback(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ extraKey := "extra"
+ extraDataString := "just a string"
+
+ serverConf := &ServerConfig{
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ if permissions != nil && permissions.ExtraData != nil {
+ if !reflect.DeepEqual(map[any]any{extraKey: extraDataString}, permissions.ExtraData) {
+ t.Errorf("expected extra data: %v; got: %v", extraDataString, permissions.ExtraData)
+ }
+ } else {
+ t.Error("expected extra data is missing")
+ }
+ if signatureAlgorithm != KeyAlgoRSASHA256 {
+ t.Errorf("expected signature algorithm: %q; got: %q", KeyAlgoRSASHA256, signatureAlgorithm)
+ }
+ return permissions, nil
+ },
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ return &Permissions{ExtraData: map[any]any{extraKey: extraDataString}}, nil
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ conn, _, _, err := NewServerConn(c1, serverConf)
+ if err != nil {
+ t.Errorf("unexpected server error: %v", err)
+ }
+ if !reflect.DeepEqual(map[any]any{extraKey: extraDataString}, conn.Permissions.ExtraData) {
+ t.Errorf("expected extra data: %v; got: %v", extraDataString, conn.Permissions.ExtraData)
+ }
+ }()
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-done
+}
+
+func TestVerifiedPublicCallbackPartialSuccess(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, &PartialSuccessError{
+ Next: ServerAuthCallbacks{
+ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+ if string(password) == clientPassword {
+ return nil, nil
+ }
+ return nil, nil
+ },
+ },
+ }
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ Password(clientPassword),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ go NewServerConn(c1, serverConf)
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err != nil {
+ t.Fatalf("client login error: %s", err)
+ }
+}
+
+func TestVerifiedPublicKeyCallbackPwdAndKey(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+ if string(password) == clientPassword {
+ return nil, &PartialSuccessError{
+ Next: ServerAuthCallbacks{
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ },
+ }
+ }
+ return nil, errors.New("invalid credentials")
+
+ },
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ Password(clientPassword),
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ go NewServerConn(c1, serverConf)
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err != nil {
+ t.Fatalf("client login error: %s", err)
+ }
+}
+
+func TestVerifiedPubKeyCallbackAuthMethods(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+ return nil, nil
+ },
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, nil
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ go NewServerConn(c1, serverConf)
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err == nil {
+ t.Fatal("client login succeed with only VerifiedPublicKeyCallback defined")
+ }
+}
+
+func TestVerifiedPubKeyCallbackError(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ return nil, nil
+ },
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ return nil, errors.New("invalid credentials")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ go NewServerConn(c1, serverConf)
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err == nil {
+ t.Fatal("client login succeed with VerifiedPublicKeyCallback returning an error")
+ }
+}
+
+func TestVerifiedPublicCallbackPartialSuccessBadUsage(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ // Returning PartialSuccessError is not permitted when
+ // VerifiedPublicKeyCallback is defined. This callback is
+ // invoked for both query requests and real authentications,
+ // while VerifiedPublicKeyCallback is only triggered if the
+ // client has proven control of the key.
+ return nil, &PartialSuccessError{
+ Next: ServerAuthCallbacks{
+ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+ if string(password) == clientPassword {
+ return nil, nil
+ }
+ return nil, nil
+ },
+ },
+ }
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) {
+ return nil, &PartialSuccessError{
+ Next: ServerAuthCallbacks{
+ PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
+ if string(password) == clientPassword {
+ return nil, nil
+ }
+ return nil, nil
+ },
+ },
+ }
+ }
+ return nil, errors.New("invalid credentials")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ Password(clientPassword),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ go NewServerConn(c1, serverConf)
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err == nil {
+ t.Fatal("authentication suceeded with PartialSuccess returned from PublicKeyCallback and VerifiedPublicKeyCallback defined")
+ }
+}
+
+func TestVerifiedPublicKeyCallbackOnError(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ var verifiedCallbackCalled bool
+
+ serverConf := &ServerConfig{
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ verifiedCallbackCalled = true
+ return nil, nil
+ },
+ PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
+ return nil, errors.New("invalid key")
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ NewServerConn(c1, serverConf)
+ }()
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err == nil {
+ t.Fatal("authentication should fail")
+ }
+ <-done
+ if verifiedCallbackCalled {
+ t.Error("VerifiedPublicKeyCallback called after PublicKeyCallback returned an error")
+ }
+}
+
+func TestVerifiedPublicKeyCallbackOnly(t *testing.T) {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ serverConf := &ServerConfig{
+ VerifiedPublicKeyCallback: func(conn ConnMetadata, key PublicKey, permissions *Permissions, signatureAlgorithm string) (*Permissions, error) {
+ return nil, nil
+ },
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ NewServerConn(c1, serverConf)
+ }()
+
+ clientConf := ClientConfig{
+ User: "user",
+ Auth: []AuthMethod{
+ PublicKeys(testSigners["rsa"]),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err == nil {
+ t.Fatal("authentication suceeded with only VerifiedPublicKeyCallback defined")
+ }
+ <-done
+}
+
type markerConn struct {
closed uint32
used uint32