diff options
Diffstat (limited to 'src/database/sql/sql.go')
| -rw-r--r-- | src/database/sql/sql.go | 105 |
1 files changed, 95 insertions, 10 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 7c35710688..c17b2b543b 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -334,6 +334,7 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} + resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -341,6 +342,8 @@ type DB struct { maxOpen int // <= 0 means unlimited maxLifetime time.Duration // maximum amount of time a connection may be reused cleanerCh chan struct{} + + stop func() // stop cancels the connection opener and the session resetter. } // connReuseStrategy determines how (*DB).conn returns database connections. @@ -368,6 +371,7 @@ type driverConn struct { closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool + lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -376,7 +380,7 @@ type driverConn struct { } func (dc *driverConn) releaseConn(err error) { - dc.db.putConn(dc, err) + dc.db.putConn(dc, err, true) } func (dc *driverConn) removeOpenStmt(ds *driverStmt) { @@ -417,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } +// resetSession resets the connection session and sets the lastErr +// that is checked before returning the connection to another query. +// +// resetSession assumes that the embedded mutex is locked when the connection +// was returned to the pool. This unlocks the mutex. +func (dc *driverConn) resetSession(ctx context.Context) { + defer dc.Unlock() // In case of panic. + if dc.closed { // Check if the database has been closed. + return + } + dc.lastErr = dc.ci.(driver.ResetSessioner).ResetSession(ctx) +} + // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -604,14 +621,18 @@ func (t dsnConnector) Driver() driver.Driver { // function should be called just once. It is rarely necessary to // close a DB. func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) db := &DB{ connector: c, openerCh: make(chan struct{}, connectionRequestQueueSize), + resetterCh: make(chan *driverConn, 50), lastPut: make(map[*driverConn]string), connRequests: make(map[uint64]chan connRequest), + stop: cancel, } - go db.connectionOpener() + go db.connectionOpener(ctx) + go db.connectionResetter(ctx) return db } @@ -693,7 +714,6 @@ func (db *DB) Close() error { db.mu.Unlock() return nil } - close(db.openerCh) if db.cleanerCh != nil { close(db.cleanerCh) } @@ -714,6 +734,7 @@ func (db *DB) Close() error { err = err1 } } + db.stop() return err } @@ -901,18 +922,39 @@ func (db *DB) maybeOpenNewConnections() { } // Runs in a separate goroutine, opens new connections when requested. -func (db *DB) connectionOpener() { - for range db.openerCh { - db.openNewConnection() +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// connectionResetter runs in a separate goroutine to reset connections async +// to exported API. +func (db *DB) connectionResetter(ctx context.Context) { + for { + select { + case <-ctx.Done(): + for dc := range db.resetterCh { + dc.Unlock() + } + return + case dc := <-db.resetterCh: + dc.resetSession(ctx) + } } } // Open one new connection -func (db *DB) openNewConnection() { +func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent // on db.openerCh. This function must execute db.numOpen-- if the // connection fails or is closed before returning. - ci, err := db.connector.Connect(context.Background()) + ci, err := db.connector.Connect(ctx) db.mu.Lock() defer db.mu.Unlock() if db.closed { @@ -987,6 +1029,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } + // Lock around reading lastErr to ensure the session resetter finished. + conn.Lock() + err := conn.lastErr + conn.Unlock() + if err == driver.ErrBadConn { + conn.Close() + return nil, driver.ErrBadConn + } return conn, nil } @@ -1012,7 +1062,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn default: case ret, ok := <-req: if ok { - db.putConn(ret.conn, ret.err) + db.putConn(ret.conn, ret.err, false) } } return nil, ctx.Err() @@ -1024,6 +1074,17 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn ret.conn.Close() return nil, driver.ErrBadConn } + if ret.conn == nil { + return nil, ret.err + } + // Lock around reading lastErr to ensure the session resetter finished. + ret.conn.Lock() + err := ret.conn.lastErr + ret.conn.Unlock() + if err == driver.ErrBadConn { + ret.conn.Close() + return nil, driver.ErrBadConn + } return ret.conn, ret.err } } @@ -1079,7 +1140,7 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(dc *driverConn, err error) { +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { db.mu.Lock() if !dc.inUse { if debugGetPut { @@ -1110,11 +1171,35 @@ func (db *DB) putConn(dc *driverConn, err error) { if putConnHook != nil { putConnHook(db, dc) } + if resetSession { + if _, resetSession = dc.ci.(driver.ResetSessioner); resetSession { + // Lock the driverConn here so it isn't released until + // the connection is reset. + // The lock must be taken before the connection is put into + // the pool to prevent it from being taken out before it is reset. + dc.Lock() + } + } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { + if resetSession { + dc.Unlock() + } dc.Close() + return + } + if !resetSession { + return + } + select { + default: + // If the resetterCh is blocking then mark the connection + // as bad and continue on. + dc.lastErr = driver.ErrBadConn + dc.Unlock() + case db.resetterCh <- dc: } } |
