diff options
Diffstat (limited to 'src/database/sql')
| -rw-r--r-- | src/database/sql/sql.go | 19 | ||||
| -rw-r--r-- | src/database/sql/sql_test.go | 19 |
2 files changed, 36 insertions, 2 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 3db387e841..a77d63dc5e 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -2916,6 +2916,12 @@ type Rows struct { // It is only used by Scan, Next, and NextResultSet which are expected // not to be called concurrently. closemuScanHold bool + + // hitEOF is whether Next hit the end of the rows without + // encountering an error. It's set in Next before + // returning. It's only used by Next and Err which are + // expected not to be called concurrently. + hitEOF bool } // lasterrOrErrLocked returns either lasterr or the provided err. @@ -2985,6 +2991,9 @@ func (rs *Rows) Next() bool { if doClose { rs.Close() } + if doClose && !ok { + rs.hitEOF = true + } return ok } @@ -3073,8 +3082,14 @@ func (rs *Rows) NextResultSet() bool { // Err returns the error, if any, that was encountered during iteration. // Err may be called after an explicit or implicit Close. func (rs *Rows) Err() error { - if errp := rs.contextDone.Load(); errp != nil { - return *errp + // Return any context error that might've happened during row iteration, + // but only if we haven't reported the final Next() = false after rows + // are done, in which case the user might've canceled their own context + // before calling Rows.Err. + if !rs.hitEOF { + if errp := rs.contextDone.Load(); errp != nil { + return *errp + } } rs.closemu.RLock() diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 29a6709f23..4f2a2d83ef 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -4443,6 +4443,25 @@ func TestContextCancelDuringRawBytesScan(t *testing.T) { } } +func TestContextCancelBetweenNextAndErr(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := db.QueryContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + for r.Next() { + } + cancel() // wake up the awaitDone goroutine + time.Sleep(10 * time.Millisecond) // increase odds of seeing failure + if err := r.Err(); err != nil { + t.Fatal(err) + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{} |
