diff options
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | Makefile | 18 | ||||
| -rw-r--r-- | ssh/keys.go | 6 | ||||
| -rw-r--r-- | ssh/keys_test.go | 41 | ||||
| -rw-r--r-- | ssh/knownhosts/db.go | 68 | ||||
| -rw-r--r-- | ssh/session.go | 30 |
6 files changed, 151 insertions, 14 deletions
@@ -1,2 +1,4 @@ # Add no patterns to .gitignore except for files generated by the build. last-change +cover.html +cover.out diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5cb209c --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +COVER_OUT:=cover.out +COVER_HTML:=cover.html + +.PHONY: test +test: + CGO_ENABLED=1 go test -failfast -timeout=5m -race -coverprofile=$(COVER_OUT) ./... + go tool cover -html=$(COVER_OUT) -o $(COVER_HTML) + +.PHONY: lint +lint: + -fieldalignment ./... + -shadow ./... + -golangci-lint run \ + --presets bugs,metalinter,performance,unused \ + --disable exhaustive \ + --disable musttag \ + --disable bodyclose \ + ./... diff --git a/ssh/keys.go b/ssh/keys.go index 47a0753..18851e7 100644 --- a/ssh/keys.go +++ b/ssh/keys.go @@ -1271,6 +1271,12 @@ func (*PassphraseMissingError) Error() string { return "ssh: this private key is passphrase protected" } +// Is return true if the target is an instance of PassphraseMissingError. +func (errPassMissing *PassphraseMissingError) Is(target error) (ok bool) { + _, ok = target.(*PassphraseMissingError) + return ok +} + // ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports // RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH // formats. If the private key is encrypted, it will return a PassphraseMissingError. diff --git a/ssh/keys_test.go b/ssh/keys_test.go index a1165ec..ed5bb1a 100644 --- a/ssh/keys_test.go +++ b/ssh/keys_test.go @@ -272,18 +272,18 @@ func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) { } func TestParseEncryptedPrivateKeysWithUnsupportedCiphers(t *testing.T) { - for _, tt := range testdata.UnsupportedCipherData { - t.Run(tt.Name, func(t *testing.T){ - _, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey)) - if err == nil { - t.Fatalf("expected 'unknown cipher' error for %q, got nil", tt.Name) - // If this cipher is now supported, remove it from testdata.UnsupportedCipherData - } - if !strings.Contains(err.Error(), "unknown cipher") { - t.Errorf("wanted 'unknown cipher' error, got %v", err.Error()) - } - }) - } + for _, tt := range testdata.UnsupportedCipherData { + t.Run(tt.Name, func(t *testing.T) { + _, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey)) + if err == nil { + t.Fatalf("expected 'unknown cipher' error for %q, got nil", tt.Name) + // If this cipher is now supported, remove it from testdata.UnsupportedCipherData + } + if !strings.Contains(err.Error(), "unknown cipher") { + t.Errorf("wanted 'unknown cipher' error, got %v", err.Error()) + } + }) + } } func TestParseEncryptedPrivateKeysWithIncorrectPassphrase(t *testing.T) { @@ -863,3 +863,20 @@ cLYUOHfQDw== t.Fatal("parsing an SSH certificate using another certificate as signature key succeeded; expected failure") } } + +func TestPassphraseMissingErrorIs(t *testing.T) { + var ( + errPassMissing = &PassphraseMissingError{} + + err error + ) + + _, err = ParseRawPrivateKey(testdata.PEMEncryptedKeys[0].PEMBytes) + if err == nil { + t.Fatalf(`got error nil, want %T`, errPassMissing) + } + + if !errors.Is(err, errPassMissing) { + t.Fatalf(`got error %T, want %T `, err, errPassMissing) + } +} diff --git a/ssh/knownhosts/db.go b/ssh/knownhosts/db.go new file mode 100644 index 0000000..9d81617 --- /dev/null +++ b/ssh/knownhosts/db.go @@ -0,0 +1,68 @@ +// Copyright 2026 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package knownhosts + +import ( + "fmt" + "net" + "os" + + "golang.org/x/crypto/ssh" +) + +type DB interface { + // HostKeyAlgorithms takes an address and returns a list of matching key types. + HostKeyAlgorithms(address string) ([]string, error) + + // HostKeyCallback is knownhosts.New without the DB initialization. + HostKeyCallback() ssh.HostKeyCallback +} + +// NewDB creates a new known_hosts database from the files given and returns +// it. +func NewDB(files ...string) (DB, error) { + logp := `NewDB` + db := newHostKeyDB() + for _, fn := range files { + f, err := os.Open(fn) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + defer f.Close() + err = db.Read(f, fn) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + } + return db, nil +} + +// HostKeyAlgorithms returns a list of host key algorithms associated +// with the given address. +func (db *hostKeyDB) HostKeyAlgorithms(address string) (knownTypes []string, err error) { + logp := `HostKeyAlgorithms` + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf(`%s: %w`, logp, err) + } + + hostToCheck := addr{host, port} + for _, l := range db.lines { + if l.match(hostToCheck) { + knownTypes = append(knownTypes, l.knownKey.Key.Type()) + } + } + return knownTypes, nil +} + +// HostKeyCallback is the way to get the ssh.HostKeyCallback if you have used +// NewDB. +func (db *hostKeyDB) HostKeyCallback() ssh.HostKeyCallback { + var certChecker ssh.CertChecker + certChecker.IsHostAuthority = db.IsHostAuthority + certChecker.IsRevoked = db.IsRevoked + certChecker.HostKeyFallback = db.check + return certChecker.CheckHostKey +} diff --git a/ssh/session.go b/ssh/session.go index acef622..f940686 100644 --- a/ssh/session.go +++ b/ssh/session.go @@ -9,6 +9,7 @@ package ssh import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -314,6 +315,16 @@ func (s *Session) Run(cmd string) error { return s.Wait() } +// RunWithContext similar to [Session.Run] but with context. +func (s *Session) RunWithContext(ctx context.Context, cmd string) (err error) { + err = s.Start(cmd) + if err != nil { + return err + } + + return s.waitWithContext(ctx) +} + // Output runs cmd on the remote host and returns its standard output. func (s *Session) Output(cmd string) ([]byte, error) { if s.Stdout != nil { @@ -397,10 +408,22 @@ func (s *Session) start() error { // unsuccessfully or is interrupted by a signal, the error is of type // *ExitError. Other error types may be returned for I/O problems. func (s *Session) Wait() error { + return s.waitWithContext(context.Background()) +} + +// waitWithContext wait for remote command to exit or terminate with SIGKILL +// when its receive context cancellation. +func (s *Session) waitWithContext(ctx context.Context) (err error) { if !s.started { return errors.New("ssh: session not started") } - waitErr := <-s.exitStatus + + var waitErr error + select { + case <-ctx.Done(): + waitErr = s.Signal(SIGKILL) + case waitErr = <-s.exitStatus: + } if s.stdinPipeWriter != nil { s.stdinPipeWriter.Close() @@ -414,7 +437,10 @@ func (s *Session) Wait() error { if waitErr != nil { return waitErr } - return copyError + if copyError != nil { + return copyError + } + return context.Cause(ctx) } func (s *Session) wait(reqs <-chan *Request) error { |
