aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--Makefile18
-rw-r--r--ssh/keys.go6
-rw-r--r--ssh/keys_test.go41
-rw-r--r--ssh/knownhosts/db.go68
-rw-r--r--ssh/session.go30
6 files changed, 151 insertions, 14 deletions
diff --git a/.gitignore b/.gitignore
index 5a9d62e..a7b710b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,4 @@
# Add no patterns to .gitignore except for files generated by the build.
last-change
+cover.html
+cover.out
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..5cb209c
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,18 @@
+COVER_OUT:=cover.out
+COVER_HTML:=cover.html
+
+.PHONY: test
+test:
+ CGO_ENABLED=1 go test -failfast -timeout=5m -race -coverprofile=$(COVER_OUT) ./...
+ go tool cover -html=$(COVER_OUT) -o $(COVER_HTML)
+
+.PHONY: lint
+lint:
+ -fieldalignment ./...
+ -shadow ./...
+ -golangci-lint run \
+ --presets bugs,metalinter,performance,unused \
+ --disable exhaustive \
+ --disable musttag \
+ --disable bodyclose \
+ ./...
diff --git a/ssh/keys.go b/ssh/keys.go
index 47a0753..18851e7 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -1271,6 +1271,12 @@ func (*PassphraseMissingError) Error() string {
return "ssh: this private key is passphrase protected"
}
+// Is return true if the target is an instance of PassphraseMissingError.
+func (errPassMissing *PassphraseMissingError) Is(target error) (ok bool) {
+ _, ok = target.(*PassphraseMissingError)
+ return ok
+}
+
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports
// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH
// formats. If the private key is encrypted, it will return a PassphraseMissingError.
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index a1165ec..ed5bb1a 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -272,18 +272,18 @@ func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
}
func TestParseEncryptedPrivateKeysWithUnsupportedCiphers(t *testing.T) {
- for _, tt := range testdata.UnsupportedCipherData {
- t.Run(tt.Name, func(t *testing.T){
- _, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
- if err == nil {
- t.Fatalf("expected 'unknown cipher' error for %q, got nil", tt.Name)
- // If this cipher is now supported, remove it from testdata.UnsupportedCipherData
- }
- if !strings.Contains(err.Error(), "unknown cipher") {
- t.Errorf("wanted 'unknown cipher' error, got %v", err.Error())
- }
- })
- }
+ for _, tt := range testdata.UnsupportedCipherData {
+ t.Run(tt.Name, func(t *testing.T) {
+ _, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
+ if err == nil {
+ t.Fatalf("expected 'unknown cipher' error for %q, got nil", tt.Name)
+ // If this cipher is now supported, remove it from testdata.UnsupportedCipherData
+ }
+ if !strings.Contains(err.Error(), "unknown cipher") {
+ t.Errorf("wanted 'unknown cipher' error, got %v", err.Error())
+ }
+ })
+ }
}
func TestParseEncryptedPrivateKeysWithIncorrectPassphrase(t *testing.T) {
@@ -863,3 +863,20 @@ cLYUOHfQDw==
t.Fatal("parsing an SSH certificate using another certificate as signature key succeeded; expected failure")
}
}
+
+func TestPassphraseMissingErrorIs(t *testing.T) {
+ var (
+ errPassMissing = &PassphraseMissingError{}
+
+ err error
+ )
+
+ _, err = ParseRawPrivateKey(testdata.PEMEncryptedKeys[0].PEMBytes)
+ if err == nil {
+ t.Fatalf(`got error nil, want %T`, errPassMissing)
+ }
+
+ if !errors.Is(err, errPassMissing) {
+ t.Fatalf(`got error %T, want %T `, err, errPassMissing)
+ }
+}
diff --git a/ssh/knownhosts/db.go b/ssh/knownhosts/db.go
new file mode 100644
index 0000000..9d81617
--- /dev/null
+++ b/ssh/knownhosts/db.go
@@ -0,0 +1,68 @@
+// Copyright 2026 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package knownhosts
+
+import (
+ "fmt"
+ "net"
+ "os"
+
+ "golang.org/x/crypto/ssh"
+)
+
+type DB interface {
+ // HostKeyAlgorithms takes an address and returns a list of matching key types.
+ HostKeyAlgorithms(address string) ([]string, error)
+
+ // HostKeyCallback is knownhosts.New without the DB initialization.
+ HostKeyCallback() ssh.HostKeyCallback
+}
+
+// NewDB creates a new known_hosts database from the files given and returns
+// it.
+func NewDB(files ...string) (DB, error) {
+ logp := `NewDB`
+ db := newHostKeyDB()
+ for _, fn := range files {
+ f, err := os.Open(fn)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+ defer f.Close()
+ err = db.Read(f, fn)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+ }
+ return db, nil
+}
+
+// HostKeyAlgorithms returns a list of host key algorithms associated
+// with the given address.
+func (db *hostKeyDB) HostKeyAlgorithms(address string) (knownTypes []string, err error) {
+ logp := `HostKeyAlgorithms`
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, fmt.Errorf(`%s: %w`, logp, err)
+ }
+
+ hostToCheck := addr{host, port}
+ for _, l := range db.lines {
+ if l.match(hostToCheck) {
+ knownTypes = append(knownTypes, l.knownKey.Key.Type())
+ }
+ }
+ return knownTypes, nil
+}
+
+// HostKeyCallback is the way to get the ssh.HostKeyCallback if you have used
+// NewDB.
+func (db *hostKeyDB) HostKeyCallback() ssh.HostKeyCallback {
+ var certChecker ssh.CertChecker
+ certChecker.IsHostAuthority = db.IsHostAuthority
+ certChecker.IsRevoked = db.IsRevoked
+ certChecker.HostKeyFallback = db.check
+ return certChecker.CheckHostKey
+}
diff --git a/ssh/session.go b/ssh/session.go
index acef622..f940686 100644
--- a/ssh/session.go
+++ b/ssh/session.go
@@ -9,6 +9,7 @@ package ssh
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error {
return s.Wait()
}
+// RunWithContext similar to [Session.Run] but with context.
+func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) {
+ err = s.Start(cmd)
+ if err != nil {
+ return err
+ }
+
+ return s.waitWithContext(ctx)
+}
+
// Output runs cmd on the remote host and returns its standard output.
func (s *Session) Output(cmd string) ([]byte, error) {
if s.Stdout != nil {
@@ -397,10 +408,22 @@ func (s *Session) start() error {
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Wait() error {
+ return s.waitWithContext(context.Background())
+}
+
+// waitWithContext wait for remote command to exit or terminate with SIGKILL
+// when its receive context cancellation.
+func (s *Session) waitWithContext(ctx context.Context) (err error) {
if !s.started {
return errors.New("ssh: session not started")
}
- waitErr := <-s.exitStatus
+
+ var waitErr error
+ select {
+ case <-ctx.Done():
+ waitErr = s.Signal(SIGKILL)
+ case waitErr = <-s.exitStatus:
+ }
if s.stdinPipeWriter != nil {
s.stdinPipeWriter.Close()
@@ -414,7 +437,10 @@ func (s *Session) Wait() error {
if waitErr != nil {
return waitErr
}
- return copyError
+ if copyError != nil {
+ return copyError
+ }
+ return context.Cause(ctx)
}
func (s *Session) wait(reqs <-chan *Request) error {