aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMichael Anthony Knyszek <mknyszek@google.com>2024-05-31 20:22:32 +0000
committerMichael Knyszek <mknyszek@google.com>2024-06-07 19:09:18 +0000
commit9d2aeae72d34880510c7221f35ab61171cec1ffd (patch)
treec2519cdfebe24e6f3b731df878e24c178c9cc633 /src
parent1471978bacf91bd9de273e6f265ed31def83711a (diff)
downloadgo-9d2aeae72d34880510c7221f35ab61171cec1ffd.tar.xz
iter: propagate runtime.Goexit from iterator passed to Pull
This change propagates a runtime.Goexit initiated by the iterator into the caller of next and/or stop. Fixes #67712. Change-Id: I5bb8d22f749fce39ce4f587148c5fc71aee2af65 Reviewed-on: https://go-review.googlesource.com/c/go/+/589137 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Austin Clements <austin@google.com> Reviewed-by: David Chase <drchase@google.com>
Diffstat (limited to 'src')
-rw-r--r--src/iter/iter.go55
-rw-r--r--src/iter/pull_test.go89
-rw-r--r--src/runtime/coro.go2
3 files changed, 133 insertions, 13 deletions
diff --git a/src/iter/iter.go b/src/iter/iter.go
index 3e93f3bdb7..2ce129bb49 100644
--- a/src/iter/iter.go
+++ b/src/iter/iter.go
@@ -8,6 +8,7 @@ package iter
import (
"internal/race"
+ "runtime"
"unsafe"
)
@@ -56,6 +57,7 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
yieldNext bool
racer int
panicValue any
+ seqDone bool // to detect Goexit
)
c := newcoro(func(c *coro) {
race.Acquire(unsafe.Pointer(&racer))
@@ -76,15 +78,17 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
// Recover and propagate panics from seq.
defer func() {
if p := recover(); p != nil {
- done = true // Invalidate iterator.
panicValue = p
+ } else if !seqDone {
+ panicValue = goexitPanicValue
}
+ done = true // Invalidate iterator
race.Release(unsafe.Pointer(&racer))
}()
seq(yield)
var v0 V
v, ok = v0, false
- done = true
+ seqDone = true
})
next = func() (v1 V, ok1 bool) {
race.Write(unsafe.Pointer(&racer)) // detect races
@@ -100,9 +104,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
- // Propagate panics from seq.
+ // Propagate panics and goexits from seq.
if panicValue != nil {
- panic(panicValue)
+ if panicValue == goexitPanicValue {
+ // Propagate runtime.Goexit from seq.
+ runtime.Goexit()
+ } else {
+ panic(panicValue)
+ }
}
return v, ok
}
@@ -115,9 +124,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
- // Propagate panics from seq.
+ // Propagate panics and goexits from seq.
if panicValue != nil {
- panic(panicValue)
+ if panicValue == goexitPanicValue {
+ // Propagate runtime.Goexit from seq.
+ runtime.Goexit()
+ } else {
+ panic(panicValue)
+ }
}
}
}
@@ -152,6 +166,7 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
yieldNext bool
racer int
panicValue any
+ seqDone bool
)
c := newcoro(func(c *coro) {
race.Acquire(unsafe.Pointer(&racer))
@@ -172,16 +187,18 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
// Recover and propagate panics from seq.
defer func() {
if p := recover(); p != nil {
- done = true // Invalidate iterator.
panicValue = p
+ } else if !seqDone {
+ panicValue = goexitPanicValue
}
+ done = true // Invalidate iterator.
race.Release(unsafe.Pointer(&racer))
}()
seq(yield)
var k0 K
var v0 V
k, v, ok = k0, v0, false
- done = true
+ seqDone = true
})
next = func() (k1 K, v1 V, ok1 bool) {
race.Write(unsafe.Pointer(&racer)) // detect races
@@ -197,9 +214,14 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
- // Propagate panics from seq.
+ // Propagate panics and goexits from seq.
if panicValue != nil {
- panic(panicValue)
+ if panicValue == goexitPanicValue {
+ // Propagate runtime.Goexit from seq.
+ runtime.Goexit()
+ } else {
+ panic(panicValue)
+ }
}
return k, v, ok
}
@@ -212,11 +234,20 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
- // Propagate panics from seq.
+ // Propagate panics and goexits from seq.
if panicValue != nil {
- panic(panicValue)
+ if panicValue == goexitPanicValue {
+ // Propagate runtime.Goexit from seq.
+ runtime.Goexit()
+ } else {
+ panic(panicValue)
+ }
}
}
}
return next, stop
}
+
+// goexitPanicValue is a sentinel value indicating that an iterator
+// exited via runtime.Goexit.
+var goexitPanicValue any = new(int)
diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go
index 09f2270fa1..0d3f5ab26b 100644
--- a/src/iter/pull_test.go
+++ b/src/iter/pull_test.go
@@ -320,3 +320,92 @@ func panicsWith(v any, f func()) (panicked bool) {
f()
return
}
+
+func TestPullGoexit(t *testing.T) {
+ t.Run("next", func(t *testing.T) {
+ var next func() (int, bool)
+ var stop func()
+ if !goexits(t, func() {
+ next, stop = Pull(goexitSeq())
+ next()
+ }) {
+ t.Fatal("failed to Goexit from next")
+ }
+ if x, ok := next(); x != 0 || ok {
+ t.Fatal("iterator returned valid value after Goexit")
+ }
+ stop()
+ })
+ t.Run("stop", func(t *testing.T) {
+ var next func() (int, bool)
+ var stop func()
+ if !goexits(t, func() {
+ next, stop = Pull(goexitSeq())
+ stop()
+ }) {
+ t.Fatal("failed to Goexit from stop")
+ }
+ if x, ok := next(); x != 0 || ok {
+ t.Fatal("iterator returned valid value after Goexit")
+ }
+ stop()
+ })
+}
+
+func goexitSeq() Seq[int] {
+ return func(yield func(int) bool) {
+ runtime.Goexit()
+ }
+}
+
+func TestPull2Goexit(t *testing.T) {
+ t.Run("next", func(t *testing.T) {
+ var next func() (int, int, bool)
+ var stop func()
+ if !goexits(t, func() {
+ next, stop = Pull2(goexitSeq2())
+ next()
+ }) {
+ t.Fatal("failed to Goexit from next")
+ }
+ if x, y, ok := next(); x != 0 || y != 0 || ok {
+ t.Fatal("iterator returned valid value after Goexit")
+ }
+ stop()
+ })
+ t.Run("stop", func(t *testing.T) {
+ var next func() (int, int, bool)
+ var stop func()
+ if !goexits(t, func() {
+ next, stop = Pull2(goexitSeq2())
+ stop()
+ }) {
+ t.Fatal("failed to Goexit from stop")
+ }
+ if x, y, ok := next(); x != 0 || y != 0 || ok {
+ t.Fatal("iterator returned valid value after Goexit")
+ }
+ stop()
+ })
+}
+
+func goexitSeq2() Seq2[int, int] {
+ return func(yield func(int, int) bool) {
+ runtime.Goexit()
+ }
+}
+
+func goexits(t *testing.T, f func()) bool {
+ t.Helper()
+
+ exit := make(chan bool)
+ go func() {
+ cleanExit := false
+ defer func() {
+ exit <- recover() == nil && !cleanExit
+ }()
+ f()
+ cleanExit = true
+ }()
+ return <-exit
+}
diff --git a/src/runtime/coro.go b/src/runtime/coro.go
index 3d39d13493..30ada455e4 100644
--- a/src/runtime/coro.go
+++ b/src/runtime/coro.go
@@ -68,8 +68,8 @@ func corostart() {
c := gp.coroarg
gp.coroarg = nil
+ defer coroexit(c)
c.f(c)
- coroexit(c)
}
// coroexit is like coroswitch but closes the coro