diff options
| author | Randy Reddig <ydnar@shaderlab.com> | 2023-06-22 16:06:05 +0000 |
|---|---|---|
| committer | Gopher Robot <gobot@golang.org> | 2023-11-27 15:35:52 +0000 |
| commit | b2d7c26edb17864f117d8b0ee73c1843bcc6090f (patch) | |
| tree | d12aab722b68a8120dbef0da54362fd210717dd6 | |
| parent | 1c17e20020f974158d1b45be166660c999d6269b (diff) | |
| download | go-x-crypto-b2d7c26edb17864f117d8b0ee73c1843bcc6090f.tar.xz | |
ssh: add (*Client).DialContext method
This change adds DialContext to ssh.Client, which opens a TCP-IP
connection tunneled over the SSH connection. This is useful for
proxying network connections, e.g. setting
(net/http.Transport).DialContext.
Fixes golang/go#20288.
Change-Id: I110494c00962424ea803065535ebe2209364ac27
GitHub-Last-Rev: 3176984a71a9a1422702e3a071340ecfff71ff62
GitHub-Pull-Request: golang/crypto#260
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/504735
Run-TryBot: Nicola Murino <nicola.murino@gmail.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Commit-Queue: Nicola Murino <nicola.murino@gmail.com>
| -rw-r--r-- | ssh/tcpip.go | 35 | ||||
| -rw-r--r-- | ssh/tcpip_test.go | 33 | ||||
| -rw-r--r-- | ssh/test/dial_unix_test.go | 7 |
3 files changed, 74 insertions, 1 deletions
diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 80d35f5..ef5059a 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -5,6 +5,7 @@ package ssh import ( + "context" "errors" "fmt" "io" @@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr { return l.laddr } +// DialContext initiates a connection to the addr from the remote host. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). func (c *Client) Dial(n, addr string) (net.Conn, error) { diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index f1265cb..4d85114 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -5,7 +5,10 @@ package ssh import ( + "context" + "net" "testing" + "time" ) func TestAutoPortListenBroken(t *testing.T) { @@ -18,3 +21,33 @@ func TestAutoPortListenBroken(t *testing.T) { t.Errorf("version %q marked as broken", works) } } + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index 0a5f5e3..8ec8d50 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -9,6 +9,7 @@ package test // direct-tcpip and direct-streamlocal functional tests import ( + "context" "fmt" "io" "net" @@ -46,7 +47,11 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) { } }() - conn, err := sshConn.Dial(n, l.Addr().String()) + ctx, cancel := context.WithCancel(context.Background()) + conn, err := sshConn.DialContext(ctx, n, l.Addr().String()) + // Canceling the context after dial should have no effect + // on the opened connection. + cancel() if err != nil { t.Fatalf("Dial: %v", err) } |
