diff options
Diffstat (limited to 'src/database/sql/sql.go')
| -rw-r--r-- | src/database/sql/sql.go | 67 |
1 files changed, 49 insertions, 18 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 2a9ae0b95a..4ef0fa7221 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn return nil, errDBClosed } // Check if the context is expired. - if err := ctx.Err(); err != nil { + select { + default: + case <-ctx.Done(): db.mu.Unlock() - return nil, err + return nil, ctx.Err() } lifetime := db.maxLifetime @@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row { // BeginContext starts a transaction. // +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginContext is canceled. +// // An isolation level may be set by setting the value in the context // before calling this. If a non-default isolation level is used // that the driver doesn't support an error will be returned. Different drivers @@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er dc: dc, txi: txi, cancel: cancel, + ctx: ctx, } - go func() { + go func(tx *Tx) { select { - case <-ctx.Done(): - if !tx.done { - tx.Rollback() + case <-tx.ctx.Done(): + if !tx.isDone() { + // Discard and close the connection used to ensure the transaction + // is closed and the resources are released. + tx.rollback(true) } } - }() + }(tx) return tx, nil } @@ -1370,10 +1380,11 @@ type Tx struct { dc *driverConn txi driver.Tx - // done transitions from false to true exactly once, on Commit + // done transitions from 0 to 1 exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. - done bool + // Use atomic operations on value when checking value. + done int32 // All Stmts prepared for this transaction. These will be closed after the // transaction has been committed or rolled back. @@ -1384,6 +1395,13 @@ type Tx struct { // cancel is called after done transitions from false to true. cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +func (tx *Tx) isDone() bool { + return atomic.LoadInt32(&tx.done) != 0 } // ErrTxDone is returned by any operation that is performed on a transaction @@ -1391,10 +1409,9 @@ type Tx struct { var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") func (tx *Tx) close(err error) { - if tx.done { + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { panic("double close") // internal error } - tx.done = true tx.db.putConn(tx.dc, err) tx.cancel() tx.dc = nil @@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) { } func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { - if tx.done { + if tx.isDone() { return nil, ErrTxDone } return tx.dc, nil @@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() { // Commit commits the transaction. func (tx *Tx) Commit() error { - if tx.done { + select { + default: + case <-tx.ctx.Done(): + return tx.ctx.Err() + } + if tx.isDone() { return ErrTxDone } var err error @@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error { return err } -// Rollback aborts the transaction. -func (tx *Tx) Rollback() error { - if tx.done { +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if tx.isDone() { return ErrTxDone } var err error @@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error { if err != driver.ErrBadConn { tx.closePrepared() } + if discardConn { + err = driver.ErrBadConn + } tx.close(err) return err } +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + // Prepare creates a prepared statement for use within a transaction. // // The returned statement operates within the transaction and will be closed @@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var si driver.Stmt withLock(dc, func() { - si, err = dc.ci.Prepare(query) + si, err = ctxDriverPrepare(ctx, dc.ci, query) }) if err != nil { return nil, err @@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { } var si driver.Stmt withLock(dc, func() { - si, err = dc.ci.Prepare(stmt.query) + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) }) txs := &Stmt{ db: tx.db, |
