diff options
| author | Mateusz Poliwczak <mpoliwczak34@gmail.com> | 2024-02-21 17:18:43 +0000 |
|---|---|---|
| committer | Gopher Robot <gobot@golang.org> | 2024-02-21 18:47:20 +0000 |
| commit | c07b9b00361b9a99fc64ffe36f897d24954f99cf (patch) | |
| tree | 694ee8150c33fa3acafc85b6c3071395f81383d2 /src | |
| parent | cd17032721425f1e131f3c19801c6e66972d4b6c (diff) | |
| download | go-c07b9b00361b9a99fc64ffe36f897d24954f99cf.tar.xz | |
net: support context cancellation in acquireThread
acquireThread is already waiting on a channel, so
it can be easily wired up to support context cancellation.
This change will make sure that contexts that are
cancelled at the acquireThread stage (when the limit of
threads is reached) do not queue unnecessarily and cause
an unnecessary cgo call that will be soon aborted by
the doBlockingWithCtx function.
Updates #63978
Change-Id: I8ae4debd51995637567d8f51c6f1ed60f23d6c0c
GitHub-Last-Rev: 4189b9faf07c073a2ca440becee07b6aa9c4e795
GitHub-Pull-Request: golang/go#63985
Reviewed-on: https://go-review.googlesource.com/c/go/+/539360
Auto-Submit: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Commit-Queue: Ian Lance Taylor <iant@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/net/cgo_unix.go | 38 | ||||
| -rw-r--r-- | src/net/lookup_windows.go | 83 | ||||
| -rw-r--r-- | src/net/net.go | 9 |
3 files changed, 97 insertions, 33 deletions
diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index 9879315019..82ec4441fc 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -40,8 +40,20 @@ func (eai addrinfoErrno) isAddrinfoErrno() {} // doBlockingWithCtx executes a blocking function in a separate goroutine when the provided // context is cancellable. It is intended for use with calls that don't support context // cancellation (cgo, syscalls). blocking func may still be running after this function finishes. -func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) { +// For the duration of the execution of the blocking function, the thread is 'acquired' using [acquireThread], +// blocking might not be executed when the context gets cancelled early. +func doBlockingWithCtx[T any](ctx context.Context, lookupName string, blocking func() (T, error)) (T, error) { + if err := acquireThread(ctx); err != nil { + var zero T + return zero, &DNSError{ + Name: lookupName, + Err: mapErr(err).Error(), + IsTimeout: err == context.DeadlineExceeded, + } + } + if ctx.Done() == nil { + defer releaseThread() return blocking() } @@ -52,6 +64,7 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) ( res := make(chan result, 1) go func() { + defer releaseThread() var r result r.res, r.err = blocking() res <- r @@ -62,7 +75,11 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) ( return r.res, r.err case <-ctx.Done(): var zero T - return zero, mapErr(ctx.Err()) + return zero, &DNSError{ + Name: lookupName, + Err: mapErr(ctx.Err()).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } } } @@ -97,7 +114,7 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err *_C_ai_family(&hints) = _C_AF_INET6 } - return doBlockingWithCtx(ctx, func() (int, error) { + return doBlockingWithCtx(ctx, network+"/"+service, func() (int, error) { return cgoLookupServicePort(&hints, network, service) }) } @@ -146,9 +163,6 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p } func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { - acquireThread() - defer releaseThread() - var hints _C_struct_addrinfo *_C_ai_flags(&hints) = cgoAddrInfoFlags *_C_ai_socktype(&hints) = _C_SOCK_STREAM @@ -213,7 +227,7 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { } func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error) { - return doBlockingWithCtx(ctx, func() ([]IPAddr, error) { + return doBlockingWithCtx(ctx, name, func() ([]IPAddr, error) { return cgoLookupHostIP(network, name) }) } @@ -241,15 +255,12 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error) return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr} } - return doBlockingWithCtx(ctx, func() ([]string, error) { + return doBlockingWithCtx(ctx, addr, func() ([]string, error) { return cgoLookupAddrPTR(addr, sa, salen) }) } func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) { - acquireThread() - defer releaseThread() - var gerrno int var b []byte for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 { @@ -310,15 +321,12 @@ func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, // resSearch will make a call to the 'res_nsearch' routine in the C library // and parse the output as a slice of DNS resources. func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) { + return doBlockingWithCtx(ctx, hostname, func() ([]dnsmessage.Resource, error) { return cgoResSearch(hostname, rtype, class) }) } func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - acquireThread() - defer releaseThread() - resStateSize := unsafe.Sizeof(_C_struct___res_state{}) var state *_C_struct___res_state if resStateSize > 0 { diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 3048f3269b..946622761c 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -54,7 +54,10 @@ func lookupProtocol(ctx context.Context, name string) (int, error) { } ch := make(chan result) // unbuffered go func() { - acquireThread() + if err := acquireThread(ctx); err != nil { + ch <- result{err: mapErr(err)} + return + } defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -111,7 +114,13 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr } getaddr := func() ([]IPAddr, error) { - acquireThread() + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() hints := syscall.AddrinfoW{ Family: family, @@ -200,8 +209,14 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int return lookupPortMap(network, service) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return 0, &DNSError{ + Name: network + "/" + service, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var hints syscall.AddrinfoW @@ -263,8 +278,14 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) return r.goLookupCNAME(ctx, name, order, conf) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return "", &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil) @@ -288,8 +309,14 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( if systemConf().mustUseGoResolver(r) { return r.goLookupSRV(ctx, service, proto, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return "", nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var target string if service == "" && proto == "" { @@ -318,8 +345,14 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { if systemConf().mustUseGoResolver(r) { return r.goLookupMX(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil) @@ -342,8 +375,14 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { if systemConf().mustUseGoResolver(r) { return r.goLookupNS(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil) @@ -365,8 +404,14 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) if systemConf().mustUseGoResolver(r) { return r.goLookupTXT(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil) @@ -393,8 +438,14 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error return r.goLookupPTR(ctx, addr, order, conf) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: addr, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() arpa, err := reverseaddr(addr) if err != nil { diff --git a/src/net/net.go b/src/net/net.go index 387f2bb14d..b5f7303db3 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -727,11 +727,16 @@ var threadLimit chan struct{} var threadOnce sync.Once -func acquireThread() { +func acquireThread(ctx context.Context) error { threadOnce.Do(func() { threadLimit = make(chan struct{}, concurrentThreadsLimit()) }) - threadLimit <- struct{}{} + select { + case threadLimit <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } } func releaseThread() { |
