diff options
Diffstat (limited to 'src/database/sql/sql.go')
| -rw-r--r-- | src/database/sql/sql.go | 106 |
1 files changed, 75 insertions, 31 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index adf964992d..04986a28ea 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1554,19 +1554,6 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { tx.closemu.RLock() defer tx.closemu.RUnlock() - // TODO(bradfitz): We could be more efficient here and either - // provide a method to take an existing Stmt (created on - // perhaps a different Conn), and re-create it on this Conn if - // necessary. Or, better: keep a map in DB of query string to - // Stmts, and have Stmt.Execute do the right thing and - // re-prepare if the Conn in use doesn't have that prepared - // statement. But we'll want to avoid caching the statement - // in the case where we only call conn.Prepare implicitly - // (such as in db.Exec or tx.Exec), but the caller package - // can't be holding a reference to the returned statement. - // Perhaps just looking at the reference count (by noting - // Stmt.Close) would be enough. We might also want a finalizer - // on Stmt to drop the reference count. dc, err := tx.grabConn(ctx) if err != nil { return nil, err @@ -1621,11 +1608,6 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { tx.closemu.RLock() defer tx.closemu.RUnlock() - // TODO(bradfitz): optimize this. Currently this re-prepares - // each time. This is fine for now to illustrate the API but - // we should really cache already-prepared statements - // per-Conn. See also the big comment in Tx.Prepare. - if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } @@ -1634,9 +1616,45 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { return &Stmt{stickyErr: err} } var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) - }) + var parentStmt *Stmt + stmt.mu.Lock() + if stmt.closed || stmt.tx != nil { + // If the statement has been closed or already belongs to a + // transaction, we can't reuse it in this connection. + // Since tx.StmtContext should never need to be called with a + // Stmt already belonging to tx, we ignore this edge case and + // re-prepare the statement in this case. No need to add + // code-complexity for this. + stmt.mu.Unlock() + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + } else { + stmt.removeClosedStmtLocked() + // See if the statement has already been prepared on this connection, + // and reuse it if possible. + for _, v := range stmt.css { + if v.dc == dc { + si = v.ds.si + break + } + } + + stmt.mu.Unlock() + + if si == nil { + cs, err := stmt.prepareOnConnLocked(ctx, dc) + if err != nil { + return &Stmt{stickyErr: err} + } + si = cs.si + } + parentStmt = stmt + } + txs := &Stmt{ db: tx.db, tx: tx, @@ -1644,8 +1662,11 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { Locker: dc, si: si, }, - query: stmt.query, - stickyErr: err, + parentStmt: parentStmt, + query: stmt.query, + } + if parentStmt != nil { + tx.db.addDep(parentStmt, txs) } tx.stmts.Lock() tx.stmts.v = append(tx.stmts.v, txs) @@ -1769,13 +1790,21 @@ type Stmt struct { tx *Tx txds *driverStmt + // parentStmt is set when a transaction-specific statement + // is requested from an identical statement prepared on the same + // conn. parentStmt is used to track the dependency of this statement + // on its originating ("parent") statement so that parentStmt may + // be closed by the user without them having to know whether or not + // any transactions are still using it. + parentStmt *Stmt + mu sync.Mutex // protects the rest of the fields closed bool // css is a list of underlying driver statement interfaces // that are valid on particular connections. This is only // used if tx == nil and one is found that has idle - // connections. If tx != nil, txsi is always used. + // connections. If tx != nil, txds is always used. css []connStmt // lastNumClosed is copied from db.numClosed when Stmt is created @@ -1916,18 +1945,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e // No luck; we need to prepare the statement on this connection withLock(dc, func() { - ds, err = dc.prepareLocked(ctx, s.query) + ds, err = s.prepareOnConnLocked(ctx, dc) }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } + + return dc, dc.releaseConn, ds, nil +} + +// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of +// open connStmt on the statement. It assumes the caller is holding the lock on dc. +func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { + si, err := dc.prepareLocked(ctx, s.query) + if err != nil { + return nil, err + } + cs := connStmt{dc, si} s.mu.Lock() - cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - - return dc, dc.releaseConn, ds, nil + return cs.ds, nil } // QueryContext executes a prepared query statement with the given arguments @@ -2056,11 +2095,16 @@ func (s *Stmt) Close() error { s.closed = true s.mu.Unlock() - if s.tx != nil { - return s.txds.Close() + if s.tx == nil { + return s.db.removeDep(s, s) } - return s.db.removeDep(s, s) + if s.parentStmt != nil { + // If parentStmt is set, we must not close s.txds since it's stored + // in the css array of the parentStmt. + return s.db.removeDep(s.parentStmt, s) + } + return s.txds.Close() } func (s *Stmt) finalClose() error { |
