aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/sql.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/sql.go')
-rw-r--r--src/database/sql/sql.go67
1 files changed, 49 insertions, 18 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 2a9ae0b95a..4ef0fa7221 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -875,9 +875,11 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
return nil, errDBClosed
}
// Check if the context is expired.
- if err := ctx.Err(); err != nil {
+ select {
+ default:
+ case <-ctx.Done():
db.mu.Unlock()
- return nil, err
+ return nil, ctx.Err()
}
lifetime := db.maxLifetime
@@ -1288,6 +1290,11 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
// BeginContext starts a transaction.
//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginContext is canceled.
+//
// An isolation level may be set by setting the value in the context
// before calling this. If a non-default isolation level is used
// that the driver doesn't support an error will be returned. Different drivers
@@ -1335,15 +1342,18 @@ func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, er
dc: dc,
txi: txi,
cancel: cancel,
+ ctx: ctx,
}
- go func() {
+ go func(tx *Tx) {
select {
- case <-ctx.Done():
- if !tx.done {
- tx.Rollback()
+ case <-tx.ctx.Done():
+ if !tx.isDone() {
+ // Discard and close the connection used to ensure the transaction
+ // is closed and the resources are released.
+ tx.rollback(true)
}
}
- }()
+ }(tx)
return tx, nil
}
@@ -1370,10 +1380,11 @@ type Tx struct {
dc *driverConn
txi driver.Tx
- // done transitions from false to true exactly once, on Commit
+ // done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
- done bool
+ // Use atomic operations on value when checking value.
+ done int32
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
@@ -1384,6 +1395,13 @@ type Tx struct {
// cancel is called after done transitions from false to true.
cancel func()
+
+ // ctx lives for the life of the transaction.
+ ctx context.Context
+}
+
+func (tx *Tx) isDone() bool {
+ return atomic.LoadInt32(&tx.done) != 0
}
// ErrTxDone is returned by any operation that is performed on a transaction
@@ -1391,10 +1409,9 @@ type Tx struct {
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
func (tx *Tx) close(err error) {
- if tx.done {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
panic("double close") // internal error
}
- tx.done = true
tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
@@ -1402,7 +1419,7 @@ func (tx *Tx) close(err error) {
}
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
- if tx.done {
+ if tx.isDone() {
return nil, ErrTxDone
}
return tx.dc, nil
@@ -1419,7 +1436,12 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if tx.done {
+ select {
+ default:
+ case <-tx.ctx.Done():
+ return tx.ctx.Err()
+ }
+ if tx.isDone() {
return ErrTxDone
}
var err error
@@ -1433,9 +1455,10 @@ func (tx *Tx) Commit() error {
return err
}
-// Rollback aborts the transaction.
-func (tx *Tx) Rollback() error {
- if tx.done {
+// rollback aborts the transaction and optionally forces the pool to discard
+// the connection.
+func (tx *Tx) rollback(discardConn bool) error {
+ if tx.isDone() {
return ErrTxDone
}
var err error
@@ -1445,10 +1468,18 @@ func (tx *Tx) Rollback() error {
if err != driver.ErrBadConn {
tx.closePrepared()
}
+ if discardConn {
+ err = driver.ErrBadConn
+ }
tx.close(err)
return err
}
+// Rollback aborts the transaction.
+func (tx *Tx) Rollback() error {
+ return tx.rollback(false)
+}
+
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
@@ -1480,7 +1511,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
@@ -1538,7 +1569,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
}
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(stmt.query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
txs := &Stmt{
db: tx.db,