aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2024-09-17 11:31:20 -0700
committerGopher Robot <gobot@golang.org>2025-01-18 11:27:23 -0800
commita8ea4be81f0769fd5857e087083cbb6d3cb9f196 (patch)
tree1bbf654295f9b382970bec65995791e0183d7df1
parent71d3a4cfdb0360795ce5f2d7041e01823fd22eb6 (diff)
downloadgo-x-crypto-a8ea4be81f0769fd5857e087083cbb6d3cb9f196.tar.xz
ssh: add ServerConfig.PreAuthConnCallback, ServerPreAuthConn (banner) interface
Fixes golang/go#68688 Change-Id: Id5f72b32c61c9383a26ec182339486a432c7cdf5 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/613856 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Nicola Murino <nicola.murino@gmail.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> Reviewed-by: Nicola Murino <nicola.murino@gmail.com> Reviewed-by: Roland Shoemaker <roland@golang.org>
-rw-r--r--ssh/handshake.go14
-rw-r--r--ssh/server.go50
-rw-r--r--ssh/server_test.go86
3 files changed, 135 insertions, 15 deletions
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 56cdc7c..fef687d 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -80,6 +80,7 @@ type handshakeTransport struct {
pendingPackets [][]byte // Used when a key exchange is in progress.
writePacketsLeft uint32
writeBytesLeft int64
+ userAuthComplete bool // whether the user authentication phase is complete
// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
@@ -552,16 +553,25 @@ func (t *handshakeTransport) sendKexInit() error {
return nil
}
+var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase")
+
func (t *handshakeTransport) writePacket(p []byte) error {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
switch p[0] {
case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys")
+ case msgUserAuthBanner:
+ if t.userAuthComplete {
+ return errSendBannerPhase
+ }
+ case msgUserAuthSuccess:
+ t.userAuthComplete = true
}
- t.mu.Lock()
- defer t.mu.Unlock()
if t.writeError != nil {
return t.writeError
}
diff --git a/ssh/server.go b/ssh/server.go
index 5b5ccd9..1839ddc 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -59,6 +59,27 @@ type GSSAPIWithMICConfig struct {
Server GSSAPIServer
}
+// SendAuthBanner implements [ServerPreAuthConn].
+func (s *connection) SendAuthBanner(msg string) error {
+ return s.transport.writePacket(Marshal(&userAuthBannerMsg{
+ Message: msg,
+ }))
+}
+
+func (*connection) unexportedMethodForFutureProofing() {}
+
+// ServerPreAuthConn is the interface available on an incoming server
+// connection before authentication has completed.
+type ServerPreAuthConn interface {
+ unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB
+
+ ConnMetadata
+
+ // SendAuthBanner sends a banner message to the client.
+ // It returns an error once the authentication phase has ended.
+ SendAuthBanner(string) error
+}
+
// ServerConfig holds server specific configuration data.
type ServerConfig struct {
// Config contains configuration shared between client and server.
@@ -118,6 +139,12 @@ type ServerConfig struct {
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)
+ // PreAuthConnCallback, if non-nil, is called upon receiving a new connection
+ // before any authentication has started. The provided ServerPreAuthConn
+ // can be used at any time before authentication is complete, including
+ // after this callback has returned.
+ PreAuthConnCallback func(ServerPreAuthConn)
+
// ServerVersion is the version identification string to announce in
// the public handshake.
// If empty, a reasonable default is used.
@@ -488,6 +515,10 @@ func (b *BannerError) Error() string {
}
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
+ if config.PreAuthConnCallback != nil {
+ config.PreAuthConnCallback(s)
+ }
+
sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions
@@ -495,7 +526,7 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
authFailures := 0
noneAuthCount := 0
var authErrs []error
- var displayedBanner bool
+ var calledBannerCallback bool
partialSuccessReturned := false
// Set the initial authentication callbacks from the config. They can be
// changed if a PartialSuccessError is returned.
@@ -542,14 +573,10 @@ userAuthLoop:
s.user = userAuthReq.User
- if !displayedBanner && config.BannerCallback != nil {
- displayedBanner = true
- msg := config.BannerCallback(s)
- if msg != "" {
- bannerMsg := &userAuthBannerMsg{
- Message: msg,
- }
- if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+ if !calledBannerCallback && config.BannerCallback != nil {
+ calledBannerCallback = true
+ if msg := config.BannerCallback(s); msg != "" {
+ if err := s.SendAuthBanner(msg); err != nil {
return nil, err
}
}
@@ -762,10 +789,7 @@ userAuthLoop:
var bannerErr *BannerError
if errors.As(authErr, &bannerErr) {
if bannerErr.Message != "" {
- bannerMsg := &userAuthBannerMsg{
- Message: bannerErr.Message,
- }
- if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
+ if err := s.SendAuthBanner(bannerErr.Message); err != nil {
return nil, err
}
}
diff --git a/ssh/server_test.go b/ssh/server_test.go
index ba1bd10..c2b24f4 100644
--- a/ssh/server_test.go
+++ b/ssh/server_test.go
@@ -348,6 +348,92 @@ func TestPublicKeyCallbackLastSeen(t *testing.T) {
}
}
+func TestPreAuthConnAndBanners(t *testing.T) {
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ authConnc := make(chan ServerPreAuthConn, 1)
+ serverConfig := &ServerConfig{
+ PreAuthConnCallback: func(c ServerPreAuthConn) {
+ t.Logf("got ServerPreAuthConn: %v", c)
+ authConnc <- c // for use later in the test
+ for _, s := range []string{"hello1", "hello2"} {
+ if err := c.SendAuthBanner(s); err != nil {
+ t.Errorf("failed to send banner %q: %v", s, err)
+ }
+ }
+ // Now start a goroutine to spam SendAuthBanner in hopes
+ // of hitting a race.
+ go func() {
+ for {
+ select {
+ case <-testDone:
+ return
+ default:
+ if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase {
+ t.Errorf("unexpected error from SendAuthBanner: %v", err)
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+ }
+ }()
+ },
+ NoClientAuth: true,
+ NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
+ t.Logf("got NoClientAuthCallback")
+ return &Permissions{}, nil
+ },
+ }
+ serverConfig.AddHostKey(testSigners["rsa"])
+
+ var banners []string
+ clientConfig := &ClientConfig{
+ User: "test",
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ BannerCallback: func(msg string) error {
+ if msg != "attempted-race" {
+ banners = append(banners, msg)
+ }
+ return nil
+ },
+ }
+
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+ go newServer(c1, serverConfig)
+ c, _, _, err := NewClientConn(c2, "", clientConfig)
+ if err != nil {
+ t.Fatalf("client connection failed: %v", err)
+ }
+ defer c.Close()
+
+ wantBanners := []string{
+ "hello1",
+ "hello2",
+ }
+ if !reflect.DeepEqual(banners, wantBanners) {
+ t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
+ }
+
+ // Now that we're authenticated, verify that use of SendBanner
+ // is an error.
+ var bc ServerPreAuthConn
+ select {
+ case bc = <-authConnc:
+ default:
+ t.Fatal("expected ServerPreAuthConn")
+ }
+ if err := bc.SendAuthBanner("wrong-phase"); err == nil {
+ t.Error("unexpected success of SendAuthBanner after authentication")
+ } else if err != errSendBannerPhase {
+ t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase)
+ }
+}
+
type markerConn struct {
closed uint32
used uint32