aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/sql.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/sql.go')
-rw-r--r--src/database/sql/sql.go106
1 files changed, 75 insertions, 31 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index adf964992d..04986a28ea 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -1554,19 +1554,6 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
- // TODO(bradfitz): We could be more efficient here and either
- // provide a method to take an existing Stmt (created on
- // perhaps a different Conn), and re-create it on this Conn if
- // necessary. Or, better: keep a map in DB of query string to
- // Stmts, and have Stmt.Execute do the right thing and
- // re-prepare if the Conn in use doesn't have that prepared
- // statement. But we'll want to avoid caching the statement
- // in the case where we only call conn.Prepare implicitly
- // (such as in db.Exec or tx.Exec), but the caller package
- // can't be holding a reference to the returned statement.
- // Perhaps just looking at the reference count (by noting
- // Stmt.Close) would be enough. We might also want a finalizer
- // on Stmt to drop the reference count.
dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
@@ -1621,11 +1608,6 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
tx.closemu.RLock()
defer tx.closemu.RUnlock()
- // TODO(bradfitz): optimize this. Currently this re-prepares
- // each time. This is fine for now to illustrate the API but
- // we should really cache already-prepared statements
- // per-Conn. See also the big comment in Tx.Prepare.
-
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
@@ -1634,9 +1616,45 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
return &Stmt{stickyErr: err}
}
var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
- })
+ var parentStmt *Stmt
+ stmt.mu.Lock()
+ if stmt.closed || stmt.tx != nil {
+ // If the statement has been closed or already belongs to a
+ // transaction, we can't reuse it in this connection.
+ // Since tx.StmtContext should never need to be called with a
+ // Stmt already belonging to tx, we ignore this edge case and
+ // re-prepare the statement in this case. No need to add
+ // code-complexity for this.
+ stmt.mu.Unlock()
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ } else {
+ stmt.removeClosedStmtLocked()
+ // See if the statement has already been prepared on this connection,
+ // and reuse it if possible.
+ for _, v := range stmt.css {
+ if v.dc == dc {
+ si = v.ds.si
+ break
+ }
+ }
+
+ stmt.mu.Unlock()
+
+ if si == nil {
+ cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ si = cs.si
+ }
+ parentStmt = stmt
+ }
+
txs := &Stmt{
db: tx.db,
tx: tx,
@@ -1644,8 +1662,11 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
Locker: dc,
si: si,
},
- query: stmt.query,
- stickyErr: err,
+ parentStmt: parentStmt,
+ query: stmt.query,
+ }
+ if parentStmt != nil {
+ tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
@@ -1769,13 +1790,21 @@ type Stmt struct {
tx *Tx
txds *driverStmt
+ // parentStmt is set when a transaction-specific statement
+ // is requested from an identical statement prepared on the same
+ // conn. parentStmt is used to track the dependency of this statement
+ // on its originating ("parent") statement so that parentStmt may
+ // be closed by the user without them having to know whether or not
+ // any transactions are still using it.
+ parentStmt *Stmt
+
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
// used if tx == nil and one is found that has idle
- // connections. If tx != nil, txsi is always used.
+ // connections. If tx != nil, txds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
@@ -1916,18 +1945,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- ds, err = dc.prepareLocked(ctx, s.query)
+ ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
+
+ return dc, dc.releaseConn, ds, nil
+}
+
+// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
+// open connStmt on the statement. It assumes the caller is holding the lock on dc.
+func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
+ si, err := dc.prepareLocked(ctx, s.query)
+ if err != nil {
+ return nil, err
+ }
+ cs := connStmt{dc, si}
s.mu.Lock()
- cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
-
- return dc, dc.releaseConn, ds, nil
+ return cs.ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
@@ -2056,11 +2095,16 @@ func (s *Stmt) Close() error {
s.closed = true
s.mu.Unlock()
- if s.tx != nil {
- return s.txds.Close()
+ if s.tx == nil {
+ return s.db.removeDep(s, s)
}
- return s.db.removeDep(s, s)
+ if s.parentStmt != nil {
+ // If parentStmt is set, we must not close s.txds since it's stored
+ // in the css array of the parentStmt.
+ return s.db.removeDep(s.parentStmt, s)
+ }
+ return s.txds.Close()
}
func (s *Stmt) finalClose() error {