diff options
Diffstat (limited to 'src/net/dial.go')
| -rw-r--r-- | src/net/dial.go | 297 |
1 files changed, 165 insertions, 132 deletions
diff --git a/src/net/dial.go b/src/net/dial.go index 22992d5b7a..3443161004 100644 --- a/src/net/dial.go +++ b/src/net/dial.go @@ -5,7 +5,7 @@ package net import ( - "runtime" + "context" "time" ) @@ -61,21 +61,34 @@ type Dialer struct { // Cancel is an optional channel whose closure indicates that // the dial should be canceled. Not all types of dials support // cancelation. + // + // Deprecated: Use DialContext instead. Cancel <-chan struct{} } -// Return either now+Timeout or Deadline, whichever comes first. -// Or zero, if neither is set. -func (d *Dialer) deadline(now time.Time) time.Time { - if d.Timeout == 0 { - return d.Deadline +func minNonzeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b } - timeoutDeadline := now.Add(d.Timeout) - if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) { - return timeoutDeadline - } else { - return d.Deadline + if b.IsZero() || a.Before(b) { + return a } + return b +} + +// deadline returns the earliest of: +// - now+Timeout +// - d.Deadline +// - the context's deadline +// Or zero, if none of Timeout, Deadline, or context's deadline is set. +func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) { + if d.Timeout != 0 { // including negative, for historical reasons + earliest = now.Add(d.Timeout) + } + if d, ok := ctx.Deadline(); ok { + earliest = minNonzeroTime(earliest, d) + } + return minNonzeroTime(earliest, d.Deadline) } // partialDeadline returns the deadline to use for a single address, @@ -110,7 +123,7 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(net string) (afnet string, proto int, err error) { +func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { @@ -129,7 +142,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { protostr := net[i+1:] proto, i, ok := dtoi(protostr, 0) if !ok || i != len(protostr) { - proto, err = lookupProtocol(protostr) + proto, err = lookupProtocol(ctx, protostr) if err != nil { return "", 0, err } @@ -142,8 +155,8 @@ func parseNetwork(net string) (afnet string, proto int, err error) { // resolverAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. -func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (addrList, error) { - afnet, _, err := parseNetwork(network) +func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { + afnet, _, err := parseNetwork(ctx, network) if err != nil { return nil, err } @@ -152,6 +165,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a } switch afnet { case "unix", "unixgram", "unixpacket": + // TODO(bradfitz): push down context addr, err := ResolveUnixAddr(afnet, addr) if err != nil { return nil, err @@ -161,7 +175,7 @@ func resolveAddrList(op, network, addr string, hint Addr, deadline time.Time) (a } return addrList{addr}, nil } - addrs, err := internetAddrList(afnet, addr, deadline) + addrs, err := internetAddrList(ctx, afnet, addr) if err != nil || op != "dial" || hint == nil { return addrs, err } @@ -253,11 +267,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { return d.Dial(network, address) } -// dialContext holds common state for all dial operations. -type dialContext struct { +// dialParam contains a Dial's parameters and configuration. +type dialParam struct { Dialer network, address string - finalDeadline time.Time } // Dial connects to the address on the named network. @@ -265,161 +278,182 @@ type dialContext struct { // See func Dial for a description of the network and address // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { - finalDeadline := d.deadline(time.Now()) - addrs, err := resolveAddrList("dial", network, address, d.LocalAddr, finalDeadline) + return d.DialContext(context.Background(), network, address) +} + +// DialContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. +// +// See func Dial for a description of the network and address +// parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { + if ctx == nil { + panic("nil context") + } + deadline := d.deadline(ctx, time.Now()) + if !deadline.IsZero() { + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + subCtx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + ctx = subCtx + } + } + if oldCancel := d.Cancel; oldCancel != nil { + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-oldCancel: + cancel() + case <-subCtx.Done(): + } + }() + ctx = subCtx + } + + addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} } - ctx := &dialContext{ - Dialer: *d, - network: network, - address: address, - finalDeadline: finalDeadline, + dp := &dialParam{ + Dialer: *d, + network: network, + address: address, } - // DualStack mode requires that dialTCP support cancelation. This is - // not available on plan9 (golang.org/issue/11225), so we ignore it. var primaries, fallbacks addrList - if d.DualStack && network == "tcp" && runtime.GOOS != "plan9" { + if d.DualStack && network == "tcp" { primaries, fallbacks = addrs.partition(isIPv4) } else { primaries = addrs } var c Conn - if len(fallbacks) == 0 { - // dialParallel can accept an empty fallbacks list, - // but this shortcut avoids the goroutine/channel overhead. - c, err = dialSerial(ctx, primaries, ctx.Cancel) + if len(fallbacks) > 0 { + c, err = dialParallel(ctx, dp, primaries, fallbacks) } else { - c, err = dialParallel(ctx, primaries, fallbacks, ctx.Cancel) + c, err = dialSerial(ctx, dp, primaries) + } + if err != nil { + return nil, err } - if d.KeepAlive > 0 && err == nil { - if tc, ok := c.(*TCPConn); ok { - setKeepAlive(tc.fd, true) - setKeepAlivePeriod(tc.fd, d.KeepAlive) - testHookSetKeepAlive() - } + if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 { + setKeepAlive(tc.fd, true) + setKeepAlivePeriod(tc.fd, d.KeepAlive) + testHookSetKeepAlive() } - return c, err + return c, nil } // dialParallel races two copies of dialSerial, giving the first a // head start. It returns the first established connection and // closes the others. Otherwise it returns an error from the first // primary address. -func dialParallel(ctx *dialContext, primaries, fallbacks addrList, userCancel <-chan struct{}) (Conn, error) { - results := make(chan dialResult, 2) - cancel := make(chan struct{}) +func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) { + if len(fallbacks) == 0 { + return dialSerial(ctx, dp, primaries) + } - // Spawn the primary racer. - go dialSerialAsync(ctx, primaries, nil, cancel, results) + returned := make(chan struct{}) + defer close(returned) - // Spawn the fallback racer. - fallbackTimer := time.NewTimer(ctx.fallbackDelay()) - go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results) + type dialResult struct { + Conn + error + primary bool + done bool + } + results := make(chan dialResult) // unbuffered - // Wait for both racers to succeed or fail. - var primaryResult, fallbackResult dialResult - for !primaryResult.done || !fallbackResult.done { + startRacer := func(ctx context.Context, primary bool) { + ras := primaries + if !primary { + ras = fallbacks + } + c, err := dialSerial(ctx, dp, ras) select { - case <-userCancel: - // Forward an external cancelation request. - if cancel != nil { - close(cancel) - cancel = nil + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() } - userCancel = nil + } + } + + var primary, fallback dialResult + + // Start the main racer. + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go startRacer(primaryCtx, true) + + // Start the timer for the fallback racer. + fallbackTimer := time.NewTimer(dp.fallbackDelay()) + defer fallbackTimer.Stop() + + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go startRacer(fallbackCtx, false) + case res := <-results: - // Drop the result into its assigned bucket. + if res.error == nil { + return res.Conn, nil + } if res.primary { - primaryResult = res + primary = res } else { - fallbackResult = res + fallback = res } - // On success, cancel the other racer (if one exists.) - if res.error == nil && cancel != nil { - close(cancel) - cancel = nil + if primary.done && fallback.done { + return nil, primary.error } - // If the fallbackTimer was pending, then either we've canceled the - // fallback because we no longer want it, or we haven't canceled yet - // and therefore want it to wake up immediately. - if fallbackTimer.Stop() && cancel != nil { + if res.primary && fallbackTimer.Stop() { + // If we were able to stop the timer, that means it + // was running (hadn't yet started the fallback), but + // we just got an error on the primary path, so start + // the fallback immediately (in 0 nanoseconds). fallbackTimer.Reset(0) } } } - - // Return, in order of preference: - // 1. The primary connection (but close the other if we got both.) - // 2. The fallback connection. - // 3. The primary error. - if primaryResult.error == nil { - if fallbackResult.error == nil { - fallbackResult.Conn.Close() - } - return primaryResult.Conn, nil - } else if fallbackResult.error == nil { - return fallbackResult.Conn, nil - } else { - return nil, primaryResult.error - } -} - -type dialResult struct { - Conn - error - primary bool - done bool -} - -// dialSerialAsync runs dialSerial after some delay, and returns the -// resulting connection through a channel. When racing two connections, -// the primary goroutine uses a nil timer to omit the delay. -func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) { - if timer != nil { - // We're in the fallback goroutine; sleep before connecting. - select { - case <-timer.C: - case <-cancel: - // dialSerial will immediately return errCanceled in this case. - } - } - c, err := dialSerial(ctx, ras, cancel) - results <- dialResult{Conn: c, error: err, primary: timer == nil, done: true} } // dialSerial connects to a list of addresses in sequence, returning // either the first successful connection, or the first error. -func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) { +func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) { var firstErr error // The error from the first address is most relevant. for i, ra := range ras { select { - case <-cancel: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled} + case <-ctx.Done(): + return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())} default: } - partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i) + deadline, _ := ctx.Deadline() + partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i) if err != nil { // Ran out of time. if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err} } break } - - // If this dial is canceled, the implementation is expected to complete - // quickly, but it's still possible that we could return a spurious Conn, - // which the caller must Close. - dialer := func(d time.Time) (Conn, error) { - return dialSingle(ctx, ra, d, cancel) + dialCtx := ctx + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() } - c, err := dial(ctx.network, ra, dialer, partialDeadline) + + c, err := dialSingle(dialCtx, dp, ra) if err == nil { return c, nil } @@ -429,34 +463,33 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e } if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress} } return nil, firstErr } // dialSingle attempts to establish and returns a single connection to -// the destination address. This must be called through the OS-specific -// dial function, because some OSes don't implement the deadline feature. -func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan struct{}) (c Conn, err error) { - la := ctx.LocalAddr +// the destination address. +func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) { + la := dp.LocalAddr switch ra := ra.(type) { case *TCPAddr: la, _ := la.(*TCPAddr) - c, err = testHookDialTCP(ctx.network, la, ra, deadline, cancel) + c, err = dialTCP(ctx, dp.network, la, ra) case *UDPAddr: la, _ := la.(*UDPAddr) - c, err = dialUDP(ctx.network, la, ra, deadline) + c, err = dialUDP(ctx, dp.network, la, ra) case *IPAddr: la, _ := la.(*IPAddr) - c, err = dialIP(ctx.network, la, ra, deadline) + c, err = dialIP(ctx, dp.network, la, ra) case *UnixAddr: la, _ := la.(*UnixAddr) - c, err = dialUnix(ctx.network, la, ra, deadline) + c, err = dialUnix(ctx, dp.network, la, ra) default: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}} + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}} } if err != nil { - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer } return c, nil } @@ -469,7 +502,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time, cancel <-chan str // instead of just the interface with the given host address. // See Dial for more details about address syntax. func Listen(net, laddr string) (Listener, error) { - addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } @@ -496,7 +529,7 @@ func Listen(net, laddr string) (Listener, error) { // instead of just the interface with the given host address. // See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := resolveAddrList("listen", net, laddr, nil, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } |
