aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCuong Manh Le <cuong.manhle.vn@gmail.com>2024-01-02 22:28:49 +0700
committerGopher Robot <gobot@golang.org>2024-01-08 16:00:53 +0000
commit881869dde0ddddf37151137421cd53d0c537671e (patch)
treeb1f853e5a7b591ecb989efe1b66888a4b9552ad5
parent1ae729e6d34040a84da8ef2fc0b9781efe9b0d95 (diff)
downloadgo-881869dde0ddddf37151137421cd53d0c537671e.tar.xz
cmd/compile: handle defined iter func type correctly
Fixed #64930 Change-Id: I916de7f97116fb20cb2f3f0b425ac34409afd494 Reviewed-on: https://go-review.googlesource.com/c/go/+/553436 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Reviewed-by: Matthew Dempsky <mdempsky@google.com> Auto-Submit: Cuong Manh Le <cuong.manhle.vn@gmail.com>
-rw-r--r--src/cmd/compile/internal/rangefunc/rewrite.go2
-rw-r--r--test/range4.go25
2 files changed, 26 insertions, 1 deletions
diff --git a/src/cmd/compile/internal/rangefunc/rewrite.go b/src/cmd/compile/internal/rangefunc/rewrite.go
index 7475c570aa..d439412ea8 100644
--- a/src/cmd/compile/internal/rangefunc/rewrite.go
+++ b/src/cmd/compile/internal/rangefunc/rewrite.go
@@ -934,7 +934,7 @@ func (r *rewriter) endLoop(loop *forLoop) {
if rfunc.Params().Len() != 1 {
base.Fatalf("invalid typecheck of range func")
}
- ftyp := rfunc.Params().At(0).Type().(*types2.Signature) // func(...) bool
+ ftyp := types2.CoreType(rfunc.Params().At(0).Type()).(*types2.Signature) // func(...) bool
if ftyp.Results().Len() != 1 {
base.Fatalf("invalid typecheck of range func")
}
diff --git a/test/range4.go b/test/range4.go
index 696b205ab7..0b051f6d3c 100644
--- a/test/range4.go
+++ b/test/range4.go
@@ -311,6 +311,30 @@ func testcalls() {
}
}
+type iter3YieldFunc func(int, int) bool
+
+func iter3(list ...int) func(iter3YieldFunc) {
+ return func(yield iter3YieldFunc) {
+ for k, v := range list {
+ if !yield(k, v) {
+ return
+ }
+ }
+ }
+}
+
+func testcalls1() {
+ ncalls := 0
+ for k, v := range iter3(1, 2, 3) {
+ _, _ = k, v
+ ncalls++
+ }
+ if ncalls != 3 {
+ println("wrong number of calls:", ncalls, "!= 3")
+ panic("fail")
+ }
+}
+
func main() {
testfunc0()
testfunc1()
@@ -323,4 +347,5 @@ func main() {
testfunc8()
testfunc9()
testcalls()
+ testcalls1()
}