aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql')
-rw-r--r--src/database/sql/fakedb_test.go13
-rw-r--r--src/database/sql/sql.go72
-rw-r--r--src/database/sql/sql_test.go58
3 files changed, 141 insertions, 2 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 2fe5ea42da..cfeb3b3437 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
)
@@ -90,6 +91,8 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
type fakeDB struct {
name string
+ useRawBytes atomic.Bool
+
mu sync.Mutex
tables map[string]*table
badConn bool
@@ -697,6 +700,8 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
switch cmd {
case "WIPE":
// Nothing
+ case "USE_RAWBYTES":
+ c.db.useRawBytes.Store(true)
case "SELECT":
stmt, err = c.prepareSelect(stmt, parts)
case "CREATE":
@@ -800,6 +805,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
case "WIPE":
db.wipe()
return driver.ResultNoRows, nil
+ case "USE_RAWBYTES":
+ s.c.db.useRawBytes.Store(true)
+ return driver.ResultNoRows, nil
case "CREATE":
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
return nil, err
@@ -929,6 +937,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
txStatus = "transaction"
}
cursor := &rowsCursor{
+ db: s.c.db,
parentMem: s.c,
posRow: -1,
rows: [][]*row{
@@ -1025,6 +1034,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
}
cursor := &rowsCursor{
+ db: s.c.db,
parentMem: s.c,
posRow: -1,
rows: setMRows,
@@ -1067,6 +1077,7 @@ func (tx *fakeTx) Rollback() error {
}
type rowsCursor struct {
+ db *fakeDB
parentMem memToucher
cols [][]string
colType [][]string
@@ -1141,7 +1152,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
// messing up conversions or doing them differently.
dest[i] = v
- if bs, ok := v.([]byte); ok {
+ if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
if rc.bytesClone == nil {
rc.bytesClone = make(map[*byte][]byte)
}
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 8dd48107a6..3db387e841 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -2893,6 +2893,8 @@ type Rows struct {
cancel func() // called when Rows is closed, may be nil.
closeStmt *driverStmt // if non-nil, statement to Close on close
+ contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt
+
// closemu prevents Rows from closing while there
// is an active streaming result. It is held for read during non-close operations
// and exclusively during close.
@@ -2905,6 +2907,15 @@ type Rows struct {
// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
+
+ // closemuScanHold is whether the previous call to Scan kept closemu RLock'ed
+ // without unlocking it. It does that when the user passes a *RawBytes scan
+ // target. In that case, we need to prevent awaitDone from closing the Rows
+ // while the user's still using the memory. See go.dev/issue/60304.
+ //
+ // It is only used by Scan, Next, and NextResultSet which are expected
+ // not to be called concurrently.
+ closemuScanHold bool
}
// lasterrOrErrLocked returns either lasterr or the provided err.
@@ -2942,7 +2953,11 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) {
}
select {
case <-ctx.Done():
+ err := ctx.Err()
+ rs.contextDone.Store(&err)
case <-txctxDone:
+ err := txctx.Err()
+ rs.contextDone.Store(&err)
}
rs.close(ctx.Err())
}
@@ -2954,6 +2969,15 @@ func (rs *Rows) awaitDone(ctx, txctx context.Context) {
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
+ // If the user's calling Next, they're done with their previous row's Scan
+ // results (any RawBytes memory), so we can release the read lock that would
+ // be preventing awaitDone from calling close.
+ rs.closemuRUnlockIfHeldByScan()
+
+ if rs.contextDone.Load() != nil {
+ return false
+ }
+
var doClose, ok bool
withLock(rs.closemu.RLocker(), func() {
doClose, ok = rs.nextLocked()
@@ -3008,6 +3032,11 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
// scanning. If there are further result sets they may not have rows in the result
// set.
func (rs *Rows) NextResultSet() bool {
+ // If the user's calling NextResultSet, they're done with their previous
+ // row's Scan results (any RawBytes memory), so we can release the read lock
+ // that would be preventing awaitDone from calling close.
+ rs.closemuRUnlockIfHeldByScan()
+
var doClose bool
defer func() {
if doClose {
@@ -3044,6 +3073,10 @@ func (rs *Rows) NextResultSet() bool {
// Err returns the error, if any, that was encountered during iteration.
// Err may be called after an explicit or implicit Close.
func (rs *Rows) Err() error {
+ if errp := rs.contextDone.Load(); errp != nil {
+ return *errp
+ }
+
rs.closemu.RLock()
defer rs.closemu.RUnlock()
return rs.lasterrOrErrLocked(nil)
@@ -3237,6 +3270,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
// If any of the first arguments implementing Scanner returns an error,
// that error will be wrapped in the returned error.
func (rs *Rows) Scan(dest ...any) error {
+ if rs.closemuScanHold {
+ // This should only be possible if the user calls Scan twice in a row
+ // without calling Next.
+ return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
+ }
rs.closemu.RLock()
if rs.lasterr != nil && rs.lasterr != io.EOF {
@@ -3248,23 +3286,50 @@ func (rs *Rows) Scan(dest ...any) error {
rs.closemu.RUnlock()
return err
}
- rs.closemu.RUnlock()
+
+ if scanArgsContainRawBytes(dest) {
+ rs.closemuScanHold = true
+ } else {
+ rs.closemu.RUnlock()
+ }
if rs.lastcols == nil {
+ rs.closemuRUnlockIfHeldByScan()
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
+ rs.closemuRUnlockIfHeldByScan()
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
+
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
+ rs.closemuRUnlockIfHeldByScan()
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}
+// closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
+// call to Scan with *RawBytes.
+func (rs *Rows) closemuRUnlockIfHeldByScan() {
+ if rs.closemuScanHold {
+ rs.closemuScanHold = false
+ rs.closemu.RUnlock()
+ }
+}
+
+func scanArgsContainRawBytes(args []any) bool {
+ for _, a := range args {
+ if _, ok := a.(*RawBytes); ok {
+ return true
+ }
+ }
+ return false
+}
+
// rowsCloseHook returns a function so tests may install the
// hook through a test only mutex.
var rowsCloseHook = func() func(*Rows, *error) { return nil }
@@ -3274,6 +3339,11 @@ var rowsCloseHook = func() func(*Rows, *error) { return nil }
// the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
+ // If the user's calling Close, they're done with their previous row's Scan
+ // results (any RawBytes memory), so we can release the read lock that would
+ // be preventing awaitDone from calling the unexported close before we do so.
+ rs.closemuRUnlockIfHeldByScan()
+
return rs.close(nil)
}
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 2b3d76f513..29a6709f23 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -4385,6 +4385,64 @@ func TestRowsScanProperlyWrapsErrors(t *testing.T) {
}
}
+// From go.dev/issue/60304
+func TestContextCancelDuringRawBytesScan(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ if _, err := db.Exec("USE_RAWBYTES"); err != nil {
+ t.Fatal(err)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ r, err := db.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ numRows := 0
+ var sink byte
+ for r.Next() {
+ numRows++
+ var s RawBytes
+ err = r.Scan(&s)
+ if !r.closemuScanHold {
+ t.Errorf("expected closemu to be held")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("read %q", s)
+ if numRows == 2 {
+ cancel() // invalidate the context, which used to call close asynchronously
+ }
+ for _, b := range s { // some operation reading from the raw memory
+ sink += b
+ }
+ }
+ if r.closemuScanHold {
+ t.Errorf("closemu held; should not be")
+ }
+
+ // There are 3 rows. We canceled after reading 2 so we expect either
+ // 2 or 3 depending on how the awaitDone goroutine schedules.
+ switch numRows {
+ case 0, 1:
+ t.Errorf("got %d rows; want 2+", numRows)
+ case 2:
+ if err := r.Err(); err != context.Canceled {
+ t.Errorf("unexpected error: %v (%T)", err, err)
+ }
+ default:
+ // Made it to the end. This is rare, but fine. Permit it.
+ }
+
+ if err := r.Close(); err != nil {
+ t.Fatal(err)
+ }
+}
+
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}