aboutsummaryrefslogtreecommitdiff
path: root/src/net/dial.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/dial.go')
-rw-r--r--src/net/dial.go297
1 files changed, 165 insertions, 132 deletions
diff --git a/src/net/dial.go b/src/net/dial.go
index 22992d5b7a..3443161004 100644
--- a/src/net/dial.go
+++ b/src/net/dial.go
@@ -5,7 +5,7 @@
package net
import (
- "runtime"
+ "context"
"time"
)
@@ -61,21 +61,34 @@ type Dialer struct {
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
+ //
+ // Deprecated: Use DialContext instead.
Cancel <-chan struct{}
}
-// Return either now+Timeout or Deadline, whichever comes first.
-// Or zero, if neither is set.
-func (d *Dialer) deadline(now time.Time) time.Time {
- if d.Timeout == 0 {
- return d.Deadline
+func minNonzeroTime(a, b time.Time) time.Time {
+ if a.IsZero() {
+ return b
}
- timeoutDeadline := now.Add(d.Timeout)
- if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) {
- return timeoutDeadline
- } else {
- return d.Deadline
+ if b.IsZero() || a.Before(b) {
+ return a
}
+ return b
+}
+
+// deadline returns the earliest of:
+// - now+Timeout
+// - d.Deadline
+// - the context's deadline
+// Or zero, if none of Timeout, Deadline, or context's deadline is set.
+func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
+ if d.Timeout != 0 { // including negative, for historical reasons
+ earliest = now.Add(d.Timeout)
+ }
+ if d, ok := ctx.Deadline(); ok {
+ earliest = minNonzeroTime(earliest, d)
+ }
+ return minNonzeroTime(earliest, d.Deadline)
}
// partialDeadline returns the deadline to use for a single address,
@@ -110,7 +123,7 @@ func (d *Dialer) fallbackDelay() time.Duration {
}
}
-func parseNetwork(net string) (afnet string, proto int, err error) {
+func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) {
i := last(net, ':')
if i < 0 { // no colon
switch net {
@@ -129,7 +142,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
protostr := net[i+1:]
proto, i, ok := dtoi(protostr, 0)
if !ok || i != len(protostr) {
- proto, err = lookupProtocol(protostr)
+ proto, err = lookupProtocol(ctx, protostr)
if err != nil {
return "", 0, err
}
@@ -142,8 +155,8 @@ func parseNetwork(net string) (afnet string, proto int, err error) {
// resolverAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
-func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) {
- afnet, _, err := parseNetwork(network)
+func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
+ afnet, _, err := parseNetwork(ctx, network)
if err != nil {
return nil, err
}
@@ -152,6 +165,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
}
switch afnet {
case "unix", "unixgram", "unixpacket":
+ // TODO(bradfitz): push down context
addr, err := ResolveUnixAddr(afnet, addr)
if err != nil {
return nil, err
@@ -161,7 +175,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a
}
return addrList{addr}, nil
}
- addrs, err := internetAddrList(afnet, addr, deadline)
+ addrs, err := internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
@@ -253,11 +267,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
return d.Dial(network, address)
}
-// dialContext holds common state for all dial operations.
-type dialContext struct {
+// dialParam contains a Dial's parameters and configuration.
+type dialParam struct {
Dialer
network, address string
- finalDeadline time.Time
}
// Dial connects to the address on the named network.
@@ -265,161 +278,182 @@ type dialContext struct {
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) Dial(network, address string) (Conn, error) {
- finalDeadline := d.deadline(time.Now())
- addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline)
+ return d.DialContext(context.Background(), network, address)
+}
+
+// DialContext connects to the address on the named network using
+// the provided context.
+//
+// The provided Context must be non-nil.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ deadline := d.deadline(ctx, time.Now())
+ if !deadline.IsZero() {
+ if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+ subCtx, cancel := context.WithDeadline(ctx, deadline)
+ defer cancel()
+ ctx = subCtx
+ }
+ }
+ if oldCancel := d.Cancel; oldCancel != nil {
+ subCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ select {
+ case <-oldCancel:
+ cancel()
+ case <-subCtx.Done():
+ }
+ }()
+ ctx = subCtx
+ }
+
+ addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
- ctx := &dialContext{
- Dialer: *d,
- network: network,
- address: address,
- finalDeadline: finalDeadline,
+ dp := &dialParam{
+ Dialer: *d,
+ network: network,
+ address: address,
}
- // DualStack mode requires that dialTCP support cancelation. This is
- // not available on plan9 (golang.org/issue/11225), so we ignore it.
var primaries, fallbacks addrList
- if d.DualStack && network == "tcp" && runtime.GOOS != "plan9" {
+ if d.DualStack && network == "tcp" {
primaries, fallbacks = addrs.partition(isIPv4)
} else {
primaries = addrs
}
var c Conn
- if len(fallbacks) == 0 {
- // dialParallel can accept an empty fallbacks list,
- // but this shortcut avoids the goroutine/channel overhead.
- c, err = dialSerial(ctx, primaries, ctx.Cancel)
+ if len(fallbacks) > 0 {
+ c, err = dialParallel(ctx, dp, primaries, fallbacks)
} else {
- c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel)
+ c, err = dialSerial(ctx, dp, primaries)
+ }
+ if err != nil {
+ return nil, err
}
- if d.KeepAlive > 0 && err == nil {
- if tc, ok := c.(*TCPConn); ok {
- setKeepAlive(tc.fd, true)
- setKeepAlivePeriod(tc.fd, d.KeepAlive)
- testHookSetKeepAlive()
- }
+ if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 {
+ setKeepAlive(tc.fd, true)
+ setKeepAlivePeriod(tc.fd, d.KeepAlive)
+ testHookSetKeepAlive()
}
- return c, err
+ return c, nil
}
// dialParallel races two copies of dialSerial, giving the first a
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
-func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) {
- results := make(chan dialResult, 2)
- cancel := make(chan struct{})
+func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
+ if len(fallbacks) == 0 {
+ return dialSerial(ctx, dp, primaries)
+ }
- // Spawn the primary racer.
- go dialSerialAsync(ctx, primaries, nil, cancel, results)
+ returned := make(chan struct{})
+ defer close(returned)
- // Spawn the fallback racer.
- fallbackTimer := time.NewTimer(ctx.fallbackDelay())
- go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results)
+ type dialResult struct {
+ Conn
+ error
+ primary bool
+ done bool
+ }
+ results := make(chan dialResult) // unbuffered
- // Wait for both racers to succeed or fail.
- var primaryResult, fallbackResult dialResult
- for !primaryResult.done || !fallbackResult.done {
+ startRacer := func(ctx context.Context, primary bool) {
+ ras := primaries
+ if !primary {
+ ras = fallbacks
+ }
+ c, err := dialSerial(ctx, dp, ras)
select {
- case <-userCancel:
- // Forward an external cancelation request.
- if cancel != nil {
- close(cancel)
- cancel = nil
+ case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
+ case <-returned:
+ if c != nil {
+ c.Close()
}
- userCancel = nil
+ }
+ }
+
+ var primary, fallback dialResult
+
+ // Start the main racer.
+ primaryCtx, primaryCancel := context.WithCancel(ctx)
+ defer primaryCancel()
+ go startRacer(primaryCtx, true)
+
+ // Start the timer for the fallback racer.
+ fallbackTimer := time.NewTimer(dp.fallbackDelay())
+ defer fallbackTimer.Stop()
+
+ for {
+ select {
+ case <-fallbackTimer.C:
+ fallbackCtx, fallbackCancel := context.WithCancel(ctx)
+ defer fallbackCancel()
+ go startRacer(fallbackCtx, false)
+
case res := <-results:
- // Drop the result into its assigned bucket.
+ if res.error == nil {
+ return res.Conn, nil
+ }
if res.primary {
- primaryResult = res
+ primary = res
} else {
- fallbackResult = res
+ fallback = res
}
- // On success, cancel the other racer (if one exists.)
- if res.error == nil && cancel != nil {
- close(cancel)
- cancel = nil
+ if primary.done && fallback.done {
+ return nil, primary.error
}
- // If the fallbackTimer was pending, then either we've canceled the
- // fallback because we no longer want it, or we haven't canceled yet
- // and therefore want it to wake up immediately.
- if fallbackTimer.Stop() && cancel != nil {
+ if res.primary && fallbackTimer.Stop() {
+ // If we were able to stop the timer, that means it
+ // was running (hadn't yet started the fallback), but
+ // we just got an error on the primary path, so start
+ // the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
}
-
- // Return, in order of preference:
- // 1. The primary connection (but close the other if we got both.)
- // 2. The fallback connection.
- // 3. The primary error.
- if primaryResult.error == nil {
- if fallbackResult.error == nil {
- fallbackResult.Conn.Close()
- }
- return primaryResult.Conn, nil
- } else if fallbackResult.error == nil {
- return fallbackResult.Conn, nil
- } else {
- return nil, primaryResult.error
- }
-}
-
-type dialResult struct {
- Conn
- error
- primary bool
- done bool
-}
-
-// dialSerialAsync runs dialSerial after some delay, and returns the
-// resulting connection through a channel. When racing two connections,
-// the primary goroutine uses a nil timer to omit the delay.
-func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) {
- if timer != nil {
- // We're in the fallback goroutine; sleep before connecting.
- select {
- case <-timer.C:
- case <-cancel:
- // dialSerial will immediately return errCanceled in this case.
- }
- }
- c, err := dialSerial(ctx, ras, cancel)
- results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true}
}
// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
-func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) {
+func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.
for i, ra := range ras {
select {
- case <-cancel:
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled}
+ case <-ctx.Done():
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}
- partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i)
+ deadline, _ := ctx.Deadline()
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
if err != nil {
// Ran out of time.
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err}
+ firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
}
break
}
-
- // If this dial is canceled, the implementation is expected to complete
- // quickly, but it's still possible that we could return a spurious Conn,
- // which the caller must Close.
- dialer := func(d time.Time) (Conn, error) {
- return dialSingle(ctx, ra, d, cancel)
+ dialCtx := ctx
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
}
- c, err := dial(ctx.network, ra, dialer, partialDeadline)
+
+ c, err := dialSingle(dialCtx, dp, ra)
if err == nil {
return c, nil
}
@@ -429,34 +463,33 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e
}
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress}
+ firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}
// dialSingle attempts to establish and returns a single connection to
-// the destination address. This must be called through the OS-specific
-// dial function, because some OSes don't implement the deadline feature.
-func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) {
- la := ctx.LocalAddr
+// the destination address.
+func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
+ la := dp.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
- c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel)
+ c, err = dialTCP(ctx, dp.network, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
- c, err = dialUDP(ctx.network, la, ra, deadline)
+ c, err = dialUDP(ctx, dp.network, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
- c, err = dialIP(ctx.network, la, ra, deadline)
+ c, err = dialIP(ctx, dp.network, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
- c, err = dialUnix(ctx.network, la, ra, deadline)
+ c, err = dialUnix(ctx, dp.network, la, ra)
default:
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}}
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
}
if err != nil {
- return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
+ return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
@@ -469,7 +502,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str
// instead of just the interface with the given host address.
// See Dial for more details about address syntax.
func Listen(net, laddr string) (Listener, error) {
- addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
+ addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
@@ -496,7 +529,7 @@ func Listen(net, laddr string) (Listener, error) {
// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
func ListenPacket(net, laddr string) (PacketConn, error) {
- addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline)
+ addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}