aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql')
-rw-r--r--src/database/sql/convert.go15
-rw-r--r--src/database/sql/convert_test.go2
-rw-r--r--src/database/sql/fakedb_test.go1
-rw-r--r--src/database/sql/sql.go70
-rw-r--r--src/database/sql/sql_test.go6
5 files changed, 55 insertions, 39 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go
index c349a96edf..b44bed559d 100644
--- a/src/database/sql/convert.go
+++ b/src/database/sql/convert.go
@@ -12,7 +12,6 @@ import (
"fmt"
"reflect"
"strconv"
- "sync"
"time"
"unicode"
"unicode/utf8"
@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name)
}
-func driverNumInput(ds *driverStmt) int {
- ds.Lock()
- defer ds.Unlock() // in case NumInput panics
- return ds.si.NumInput()
-}
-
// ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct {
- sync.Locker
cci driver.ColumnConverter
want int
}
@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
// same error.
var err error
arg := nv.Value
- c.Lock()
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
- c.Unlock()
if err != nil {
return err
}
@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of
@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
var cc ccChecker
if ds != nil {
si = ds.si
- want = driverNumInput(ds)
- cc.Locker = ds.Locker
+ want = ds.si.NumInput()
cc.want = want
}
diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go
index 35dbab3339..c198177760 100644
--- a/src/database/sql/convert_test.go
+++ b/src/database/sql/convert_test.go
@@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) {
}
for i, tt := range tests {
ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
- got, err := driverArgs(nil, ds, tt.args)
+ got, err := driverArgsConnLocked(nil, ds, tt.args)
if err != nil {
t.Errorf("test[%d]: %v", i, err)
continue
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 070b783453..31e22a7a74 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -1005,6 +1005,7 @@ type rowsCursor struct {
}
func (rc *rowsCursor) touchMem() {
+ rc.parentMem.touchMem()
rc.line++
}
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index c17b2b543b..be73b5e372 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
}
if ok {
var nvdargs []driver.NamedValue
- nvdargs, err = driverArgs(dc.ci, nil, args)
- if err != nil {
- return nil, err
- }
var resi driver.Result
withLock(dc, func() {
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
})
if err != driver.ErrSkip {
@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
queryer, ok = dc.ci.(driver.Queryer)
}
if ok {
- nvdargs, err := driverArgs(dc.ci, nil, args)
- if err != nil {
- releaseConn(err)
- return nil, err
- }
+ var nvdargs []driver.NamedValue
var rowsi driver.Rows
+ var err error
withLock(dc, func() {
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.mu.Unlock()
if si == nil {
- cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ withLock(dc, func() {
+ var ds *driverStmt
+ ds, err = stmt.prepareOnConnLocked(ctx, dc)
+ si = ds.si
+ })
if err != nil {
return &Stmt{stickyErr: err}
}
- si = cs.si
}
parentStmt = stmt
}
@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
}
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
- dargs, err := driverArgs(ci, ds, args)
+ ds.Lock()
+ defer ds.Unlock()
+
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
-
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
}
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
- var want int
- withLock(ds, func() {
- want = ds.si.NumInput()
- })
+ ds.Lock()
+ defer ds.Unlock()
+
+ want := ds.si.NumInput()
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(ci, ds, args)
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
-
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
+
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
doClose = true
return false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
return rs.rowsi.Columns(), nil
}
@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
- return rowsColumnInfoSetup(rs.rowsi), nil
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
-func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index dead273503..f7b7d988e1 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) {
// In the test, a context is canceled while the query is in process so
// the internal rollback will run concurrently with the explicitly called
// Tx.Rollback.
+//
+// The addition of calling rows.Next also tests
+// Issue 21117.
func TestIssue18429(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) {
// reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
if rows != nil {
+ // Call Next to test Issue 21117 and check for races.
+ for rows.Next() {
+ }
rows.Close()
}
// This call will race with the context cancel rollback to complete