aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2020-01-15 19:27:32 +0000
committerBrad Fitzpatrick <bradfitz@golang.org>2020-04-20 20:33:36 +0000
commit40a144b94f14c3f3fbe06e097a236a5543ada57f (patch)
tree1cb2036d542b8f7bc260c32b7cabfb9dd5461e9e /src
parent7004be998ba708f027c95615d7885efeb9f73e75 (diff)
downloadgo-40a144b94f14c3f3fbe06e097a236a5543ada57f.tar.xz
crypto/tls: add Dialer
Fixes #18482 Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f Reviewed-on: https://go-review.googlesource.com/c/go/+/214977 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Ian Lance Taylor <iant@golang.org>
Diffstat (limited to 'src')
-rw-r--r--src/crypto/tls/conn.go8
-rw-r--r--src/crypto/tls/tls.go89
-rw-r--r--src/crypto/tls/tls_test.go42
-rw-r--r--src/go/build/deps_test.go2
4 files changed, 127 insertions, 14 deletions
diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go
index 6bda73e085..bf2111cb97 100644
--- a/src/crypto/tls/conn.go
+++ b/src/crypto/tls/conn.go
@@ -1334,8 +1334,12 @@ func (c *Conn) closeNotify() error {
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
-// Most uses of this package need not call Handshake
-// explicitly: the first Read or Write will call it automatically.
+//
+// Most uses of this package need not call Handshake explicitly: the
+// first Read or Write will call it automatically.
+//
+// For control over canceling or setting a timeout on a handshake, use
+// the Dialer's DialContext method.
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
diff --git a/src/crypto/tls/tls.go b/src/crypto/tls/tls.go
index d98abdaea1..36d98d39eb 100644
--- a/src/crypto/tls/tls.go
+++ b/src/crypto/tls/tls.go
@@ -13,6 +13,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
+ return dial(context.Background(), dialer, network, addr, config)
+}
+
+func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
- timeout := dialer.Timeout
+ timeout := netDialer.Timeout
- if !dialer.Deadline.IsZero() {
- deadlineTimeout := time.Until(dialer.Deadline)
+ if !netDialer.Deadline.IsZero() {
+ deadlineTimeout := time.Until(netDialer.Deadline)
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
- var errChannel chan error
-
+ // hsErrCh is non-nil if we might not wait for Handshake to complete.
+ var hsErrCh chan error
+ if timeout != 0 || ctx.Done() != nil {
+ hsErrCh = make(chan error, 2)
+ }
if timeout != 0 {
- errChannel = make(chan error, 2)
timer := time.AfterFunc(timeout, func() {
- errChannel <- timeoutError{}
+ hsErrCh <- timeoutError{}
})
defer timer.Stop()
}
- rawConn, err := dialer.Dial(network, addr)
+ rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
@@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
conn := Client(rawConn, config)
- if timeout == 0 {
+ if hsErrCh == nil {
err = conn.Handshake()
} else {
go func() {
- errChannel <- conn.Handshake()
+ hsErrCh <- conn.Handshake()
}()
- err = <-errChannel
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ case err = <-hsErrCh:
+ if err != nil {
+ // If the error was due to the context
+ // closing, prefer the context's error, rather
+ // than some random network teardown error.
+ if e := ctx.Err(); e != nil {
+ err = e
+ }
+ }
+ }
}
if err != nil {
@@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
+// Dialer dials TLS connections given a configuration and a Dialer for the
+// underlying connection.
+type Dialer struct {
+ // NetDialer is the optional dialer to use for the TLS connections'
+ // underlying TCP connections.
+ // A nil NetDialer is equivalent to the net.Dialer zero value.
+ NetDialer *net.Dialer
+
+ // Config is the TLS configuration to use for new connections.
+ // A nil configuration is equivalent to the zero
+ // configuration; see the documentation of Config for the
+ // defaults.
+ Config *Config
+}
+
+// Dial connects to the given network address and initiates a TLS
+// handshake, returning the resulting TLS connection.
+//
+// The returned Conn, if any, will always be of type *Conn.
+func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
+ return d.DialContext(context.Background(), network, addr)
+}
+
+func (d *Dialer) netDialer() *net.Dialer {
+ if d.NetDialer != nil {
+ return d.NetDialer
+ }
+ return new(net.Dialer)
+}
+
+// Dial connects to the given network address and initiates a TLS
+// handshake, returning the resulting TLS connection.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// The returned Conn, if any, will always be of type *Conn.
+func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+ c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
+ if err != nil {
+ // Don't return c (a typed nil) in an interface.
+ return nil, err
+ }
+ return c, nil
+}
+
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to
diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go
index 42fd5e1b8c..85005d4950 100644
--- a/src/crypto/tls/tls_test.go
+++ b/src/crypto/tls/tls_test.go
@@ -6,6 +6,7 @@ package tls
import (
"bytes"
+ "context"
"crypto"
"crypto/x509"
"encoding/json"
@@ -272,6 +273,47 @@ func TestDeadlineOnWrite(t *testing.T) {
}
}
+type readerFunc func([]byte) (int, error)
+
+func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
+
+// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
+// (The other cases are all handled by the existing dial tests in this package, which
+// all also flow through the same code shared code paths)
+func TestDialer(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ unblockServer := make(chan struct{}) // close-only
+ defer close(unblockServer)
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ <-unblockServer
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ d := Dialer{Config: &Config{
+ Rand: readerFunc(func(b []byte) (n int, err error) {
+ // By the time crypto/tls wants randomness, that means it has a TCP
+ // connection, so we're past the Dialer's dial and now blocked
+ // in a handshake. Cancel our context and see if we get unstuck.
+ // (Our TCP listener above never reads or writes, so the Handshake
+ // would otherwise be stuck forever)
+ cancel()
+ return len(b), nil
+ }),
+ ServerName: "foo",
+ }}
+ _, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+ if err != context.Canceled {
+ t.Errorf("err = %v; want context.Canceled", err)
+ }
+}
+
func isTimeoutError(err error) bool {
if ne, ok := err.(net.Error); ok {
return ne.Timeout()
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index 24f79cabb3..fad165cf60 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -408,7 +408,7 @@ var pkgDeps = map[string][]string{
// SSL/TLS.
"crypto/tls": {
"L4", "CRYPTO-MATH", "OS", "golang.org/x/crypto/cryptobyte", "golang.org/x/crypto/hkdf",
- "container/list", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
+ "container/list", "context", "crypto/x509", "encoding/pem", "net", "syscall", "crypto/ed25519",
},
"crypto/x509": {
"L4", "CRYPTO-MATH", "OS", "CGO", "crypto/ed25519",