aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql')
-rw-r--r--src/database/sql/sql.go19
-rw-r--r--src/database/sql/sql_test.go19
2 files changed, 36 insertions, 2 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 3db387e841..a77d63dc5e 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -2916,6 +2916,12 @@ type Rows struct {
// It is only used by Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
closemuScanHold bool
+
+ // hitEOF is whether Next hit the end of the rows without
+ // encountering an error. It's set in Next before
+ // returning. It's only used by Next and Err which are
+ // expected not to be called concurrently.
+ hitEOF bool
}
// lasterrOrErrLocked returns either lasterr or the provided err.
@@ -2985,6 +2991,9 @@ func (rs *Rows) Next() bool {
if doClose {
rs.Close()
}
+ if doClose && !ok {
+ rs.hitEOF = true
+ }
return ok
}
@@ -3073,8 +3082,14 @@ func (rs *Rows) NextResultSet() bool {
// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *Rows) Err() error {
- if errp := rs.contextDone.Load(); errp != nil {
- return *errp
+ // Return any context error that might've happened during row iteration,
+ // but only if we haven't reported the final Next() = false after rows
+ // are done, in which case the user might've canceled their own context
+ // before calling Rows.Err.
+ if !rs.hitEOF {
+ if errp := rs.contextDone.Load(); errp != nil {
+ return *errp
+ }
}
rs.closemu.RLock()
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 29a6709f23..4f2a2d83ef 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -4443,6 +4443,25 @@ func TestContextCancelDuringRawBytesScan(t *testing.T) {
}
}
+func TestContextCancelBetweenNextAndErr(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ r, err := db.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for r.Next() {
+ }
+ cancel() // wake up the awaitDone goroutine
+ time.Sleep(10 * time.Millisecond) // increase odds of seeing failure
+ if err := r.Err(); err != nil {
+ t.Fatal(err)
+ }
+}
+
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}