aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/sql.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/sql.go')
-rw-r--r--src/database/sql/sql.go163
1 files changed, 87 insertions, 76 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index a549e859a4..a02aa35b7b 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -330,7 +330,7 @@ type driverConn struct {
ci driver.Conn
closed bool
finalClosed bool // ci.Close has been called
- openStmt map[driver.Stmt]bool
+ openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
@@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err)
}
-func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
- delete(dc.openStmt, si)
+ delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
@@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
- if err == nil {
- // Track each driverConn's open statements, so we can close them
- // before closing the conn.
- //
- // TODO(bradfitz): let drivers opt out of caring about
- // stmt closes if the conn is about to close anyway? For now
- // do the safe thing, in case stmts need to be closed.
- //
- // TODO(bradfitz): after Go 1.2, closing driver.Stmts
- // should be moved to driverStmt, using unique
- // *driverStmts everywhere (including from
- // *Stmt.connStmt, instead of returning a
- // driver.Stmt), using driverStmt as a pointer
- // everywhere, and making it a finalCloser.
- if dc.openStmt == nil {
- dc.openStmt = make(map[driver.Stmt]bool)
- }
- dc.openStmt[si] = true
+ if err != nil {
+ return nil, err
+ }
+
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[*driverStmt]bool)
}
- return si, err
+ ds := &driverStmt{Locker: dc, si: si}
+ dc.openStmt[ds] = true
+
+ return ds, nil
}
// the dc.db's Mutex is held.
@@ -409,17 +404,24 @@ func (dc *driverConn) Close() error {
func (dc *driverConn) finalClose() error {
var err error
- withLock(dc, func() {
- defer func() { // In case si.Close panics.
- dc.openStmt = nil
- dc.finalClosed = true
- err = dc.ci.Close()
- dc.ci = nil
- }()
- for si := range dc.openStmt {
- si.Close()
+ // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+ // before calling close on each stmt.
+ var openStmt []*driverStmt
+ withLock(dc, func() {
+ openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+ for ds := range dc.openStmt {
+ openStmt = append(openStmt, ds)
}
+ dc.openStmt = nil
+ })
+ for _, ds := range openStmt {
+ ds.Close()
+ }
+ withLock(dc, func() {
+ dc.finalClosed = true
+ err = dc.ci.Close()
+ dc.ci = nil
})
dc.db.mu.Lock()
@@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error {
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
+ closed bool
+ closeErr error // return value of previous Close call
}
+// Close ensures dirver.Stmt is only closed once any always returns the same
+// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
- return ds.si.Close()
+ if ds.closed {
+ return ds.closeErr
+ }
+ ds.closed = true
+ ds.closeErr = ds.si.Close()
+ return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
@@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
-// noteUnusedDriverStatement notes that si is no longer used and should
+// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
-func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
- si.Close()
+ ds.Close()
})
} else {
c.Lock()
- defer c.Unlock()
- if !c.finalClosed {
- si.Close()
+ fc := c.finalClosed
+ c.Unlock()
+ if !fc {
+ ds.Close()
}
}
}
@@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil {
return nil, err
}
- var si driver.Stmt
+ var ds *driverStmt
withLock(dc, func() {
- si, err = dc.prepareLocked(ctx, query)
+ ds, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
@@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
stmt := &Stmt{
db: db,
query: query,
- css: []connStmt{{dc, si}},
+ css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
}
db.addDep(stmt, stmt)
@@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
- return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
+ return resultFromStatement(ctx, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
@@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
return nil, err
}
- ds := driverStmt{dc, si}
+ ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
- withLock(dc, func() {
- si.Close()
- })
+ ds.Close()
releaseConn(err)
return nil, err
}
@@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
- closeStmt: si,
+ closeStmt: ds,
}
rows.initContextClose(ctx)
return rows, nil
@@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
stmt := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
txs := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
@@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
- return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, ds, args...)
}
// Exec executes a query that doesn't return rows.
@@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
- si driver.Stmt
+ ds *driverStmt
}
// Stmt is a prepared statement.
@@ -1650,7 +1662,7 @@ type Stmt struct {
// If in a transaction, else both nil:
tx *Tx
- txsi *driverStmt
+ txds *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
@@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt(ctx)
+ _, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
return nil, err
}
- res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
@@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
-func driverNumInput(ds driverStmt) int {
+func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
@@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
@@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
@@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
return
}
releaseConn = func(error) {}
- return ci, releaseConn, s.txsi.si, nil
+ return ci, releaseConn, s.txds, nil
}
s.removeClosedStmtLocked()
@@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
- return dc, dc.releaseConn, v.si, nil
+ return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- si, err = dc.prepareLocked(ctx, s.query)
+ ds, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
s.mu.Lock()
- cs := connStmt{dc, si}
+ cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
- return dc, dc.releaseConn, si, nil
+ return dc, dc.releaseConn, ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
@@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt(ctx)
+ dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
@@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return nil, err
}
- rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
@@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
-func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
+func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
@@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{})
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
@@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
+ s.mu.Unlock()
if s.tx != nil {
- defer s.mu.Unlock()
- return s.txsi.Close()
+ return s.txds.Close()
}
- s.mu.Unlock()
return s.db.removeDep(s, s)
}
@@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error {
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.dc, v.si)
- v.dc.removeOpenStmt(v.si)
+ s.db.noteUnusedDriverStatement(v.dc, v.ds)
+ v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
@@ -1985,7 +1996,7 @@ type Rows struct {
ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
- closeStmt driver.Stmt // if non-nil, statement to Close on close
+ closeStmt *driverStmt // if non-nil, statement to Close on close
}
func (rs *Rows) initContextClose(ctx context.Context) {