diff options
| author | Michael Anthony Knyszek <mknyszek@google.com> | 2024-05-31 20:22:32 +0000 |
|---|---|---|
| committer | Michael Knyszek <mknyszek@google.com> | 2024-06-07 19:09:18 +0000 |
| commit | 9d2aeae72d34880510c7221f35ab61171cec1ffd (patch) | |
| tree | c2519cdfebe24e6f3b731df878e24c178c9cc633 /src | |
| parent | 1471978bacf91bd9de273e6f265ed31def83711a (diff) | |
| download | go-9d2aeae72d34880510c7221f35ab61171cec1ffd.tar.xz | |
iter: propagate runtime.Goexit from iterator passed to Pull
This change propagates a runtime.Goexit initiated by the iterator into
the caller of next and/or stop.
Fixes #67712.
Change-Id: I5bb8d22f749fce39ce4f587148c5fc71aee2af65
Reviewed-on: https://go-review.googlesource.com/c/go/+/589137
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Austin Clements <austin@google.com>
Reviewed-by: David Chase <drchase@google.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/iter/iter.go | 55 | ||||
| -rw-r--r-- | src/iter/pull_test.go | 89 | ||||
| -rw-r--r-- | src/runtime/coro.go | 2 |
3 files changed, 133 insertions, 13 deletions
diff --git a/src/iter/iter.go b/src/iter/iter.go index 3e93f3bdb7..2ce129bb49 100644 --- a/src/iter/iter.go +++ b/src/iter/iter.go @@ -8,6 +8,7 @@ package iter import ( "internal/race" + "runtime" "unsafe" ) @@ -56,6 +57,7 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { yieldNext bool racer int panicValue any + seqDone bool // to detect Goexit ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -76,15 +78,17 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { // Recover and propagate panics from seq. defer func() { if p := recover(); p != nil { - done = true // Invalidate iterator. panicValue = p + } else if !seqDone { + panicValue = goexitPanicValue } + done = true // Invalidate iterator race.Release(unsafe.Pointer(&racer)) }() seq(yield) var v0 V v, ok = v0, false - done = true + seqDone = true }) next = func() (v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races @@ -100,9 +104,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } return v, ok } @@ -115,9 +124,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } } } @@ -152,6 +166,7 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { yieldNext bool racer int panicValue any + seqDone bool ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -172,16 +187,18 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { // Recover and propagate panics from seq. defer func() { if p := recover(); p != nil { - done = true // Invalidate iterator. panicValue = p + } else if !seqDone { + panicValue = goexitPanicValue } + done = true // Invalidate iterator. race.Release(unsafe.Pointer(&racer)) }() seq(yield) var k0 K var v0 V k, v, ok = k0, v0, false - done = true + seqDone = true }) next = func() (k1 K, v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races @@ -197,9 +214,14 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } return k, v, ok } @@ -212,11 +234,20 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) - // Propagate panics from seq. + // Propagate panics and goexits from seq. if panicValue != nil { - panic(panicValue) + if panicValue == goexitPanicValue { + // Propagate runtime.Goexit from seq. + runtime.Goexit() + } else { + panic(panicValue) + } } } } return next, stop } + +// goexitPanicValue is a sentinel value indicating that an iterator +// exited via runtime.Goexit. +var goexitPanicValue any = new(int) diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go index 09f2270fa1..0d3f5ab26b 100644 --- a/src/iter/pull_test.go +++ b/src/iter/pull_test.go @@ -320,3 +320,92 @@ func panicsWith(v any, f func()) (panicked bool) { f() return } + +func TestPullGoexit(t *testing.T) { + t.Run("next", func(t *testing.T) { + var next func() (int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull(goexitSeq()) + next() + }) { + t.Fatal("failed to Goexit from next") + } + if x, ok := next(); x != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) + t.Run("stop", func(t *testing.T) { + var next func() (int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull(goexitSeq()) + stop() + }) { + t.Fatal("failed to Goexit from stop") + } + if x, ok := next(); x != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) +} + +func goexitSeq() Seq[int] { + return func(yield func(int) bool) { + runtime.Goexit() + } +} + +func TestPull2Goexit(t *testing.T) { + t.Run("next", func(t *testing.T) { + var next func() (int, int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull2(goexitSeq2()) + next() + }) { + t.Fatal("failed to Goexit from next") + } + if x, y, ok := next(); x != 0 || y != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) + t.Run("stop", func(t *testing.T) { + var next func() (int, int, bool) + var stop func() + if !goexits(t, func() { + next, stop = Pull2(goexitSeq2()) + stop() + }) { + t.Fatal("failed to Goexit from stop") + } + if x, y, ok := next(); x != 0 || y != 0 || ok { + t.Fatal("iterator returned valid value after Goexit") + } + stop() + }) +} + +func goexitSeq2() Seq2[int, int] { + return func(yield func(int, int) bool) { + runtime.Goexit() + } +} + +func goexits(t *testing.T, f func()) bool { + t.Helper() + + exit := make(chan bool) + go func() { + cleanExit := false + defer func() { + exit <- recover() == nil && !cleanExit + }() + f() + cleanExit = true + }() + return <-exit +} diff --git a/src/runtime/coro.go b/src/runtime/coro.go index 3d39d13493..30ada455e4 100644 --- a/src/runtime/coro.go +++ b/src/runtime/coro.go @@ -68,8 +68,8 @@ func corostart() { c := gp.coroarg gp.coroarg = nil + defer coroexit(c) c.f(c) - coroexit(c) } // coroexit is like coroswitch but closes the coro |
