aboutsummaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
Diffstat (limited to 'src/database')
-rw-r--r--src/database/sql/driver/driver.go12
-rw-r--r--src/database/sql/fakedb_test.go62
-rw-r--r--src/database/sql/sql.go105
-rw-r--r--src/database/sql/sql_test.go63
4 files changed, 216 insertions, 26 deletions
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
index f5a2e7c16c..6113af79c5 100644
--- a/src/database/sql/driver/driver.go
+++ b/src/database/sql/driver/driver.go
@@ -222,6 +222,18 @@ type ConnBeginTx interface {
BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
}
+// ResetSessioner may be implemented by Conn to allow drivers to reset the
+// session state associated with the connection and to signal a bad connection.
+type ResetSessioner interface {
+ // ResetSession is called while a connection is in the connection
+ // pool. No queries will run on this connection until this method returns.
+ //
+ // If the connection is bad this should return driver.ErrBadConn to prevent
+ // the connection from being returned to the connection pool. Any other
+ // error will be discarded.
+ ResetSession(ctx context.Context) error
+}
+
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 4dcd096ca4..070b783453 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -55,6 +55,22 @@ type fakeDriver struct {
dbs map[string]*fakeDB
}
+type fakeConnector struct {
+ name string
+
+ waiter func(context.Context)
+}
+
+func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
+ conn, err := fdriver.Open(c.name)
+ conn.(*fakeConn).waiter = c.waiter
+ return conn, err
+}
+
+func (c *fakeConnector) Driver() driver.Driver {
+ return fdriver
+}
+
type fakeDB struct {
name string
@@ -107,6 +123,16 @@ type fakeConn struct {
// bad connection tests; see isBad()
bad bool
stickyBad bool
+
+ skipDirtySession bool // tests that use Conn should set this to true.
+
+ // dirtySession tests ResetSession, true if a query has executed
+ // until ResetSession is called.
+ dirtySession bool
+
+ // The waiter is called before each query. May be used in place of the "WAIT"
+ // directive.
+ waiter func(context.Context)
}
func (c *fakeConn) touchMem() {
@@ -298,6 +324,9 @@ func (c *fakeConn) isBad() bool {
if c.stickyBad {
return true
} else if c.bad {
+ if c.db == nil {
+ return false
+ }
// alternate between bad conn and not bad conn
c.db.badConn = !c.db.badConn
return c.db.badConn
@@ -306,6 +335,21 @@ func (c *fakeConn) isBad() bool {
}
}
+func (c *fakeConn) isDirtyAndMark() bool {
+ if c.skipDirtySession {
+ return false
+ }
+ if c.currTx != nil {
+ c.dirtySession = true
+ return false
+ }
+ if c.dirtySession {
+ return true
+ }
+ c.dirtySession = true
+ return false
+}
+
func (c *fakeConn) Begin() (driver.Tx, error) {
if c.isBad() {
return nil, driver.ErrBadConn
@@ -337,6 +381,14 @@ func setStrictFakeConnClose(t *testing.T) {
testStrictClose = t
}
+func (c *fakeConn) ResetSession(ctx context.Context) error {
+ c.dirtySession = false
+ if c.isBad() {
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
func (c *fakeConn) Close() (err error) {
drv := fdriver.(*fakeDriver)
defer func() {
@@ -572,6 +624,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
stmt.cmd = cmd
parts = parts[1:]
+ if c.waiter != nil {
+ c.waiter(ctx)
+ }
+
if stmt.wait > 0 {
wait := time.NewTimer(stmt.wait)
select {
@@ -662,6 +718,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
@@ -774,6 +833,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 7c35710688..c17b2b543b 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -334,6 +334,7 @@ type DB struct {
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
+ resetterCh chan *driverConn
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
@@ -341,6 +342,8 @@ type DB struct {
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
cleanerCh chan struct{}
+
+ stop func() // stop cancels the connection opener and the session resetter.
}
// connReuseStrategy determines how (*DB).conn returns database connections.
@@ -368,6 +371,7 @@ type driverConn struct {
closed bool
finalClosed bool // ci.Close has been called
openStmt map[*driverStmt]bool
+ lastErr error // lastError captures the result of the session resetter.
// guarded by db.mu
inUse bool
@@ -376,7 +380,7 @@ type driverConn struct {
}
func (dc *driverConn) releaseConn(err error) {
- dc.db.putConn(dc, err)
+ dc.db.putConn(dc, err, true)
}
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
@@ -417,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que
return ds, nil
}
+// resetSession resets the connection session and sets the lastErr
+// that is checked before returning the connection to another query.
+//
+// resetSession assumes that the embedded mutex is locked when the connection
+// was returned to the pool. This unlocks the mutex.
+func (dc *driverConn) resetSession(ctx context.Context) {
+ defer dc.Unlock() // In case of panic.
+ if dc.closed { // Check if the database has been closed.
+ return
+ }
+ dc.lastErr = dc.ci.(driver.ResetSessioner).ResetSession(ctx)
+}
+
// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
@@ -604,14 +621,18 @@ func (t dsnConnector) Driver() driver.Driver {
// function should be called just once. It is rarely necessary to
// close a DB.
func OpenDB(c driver.Connector) *DB {
+ ctx, cancel := context.WithCancel(context.Background())
db := &DB{
connector: c,
openerCh: make(chan struct{}, connectionRequestQueueSize),
+ resetterCh: make(chan *driverConn, 50),
lastPut: make(map[*driverConn]string),
connRequests: make(map[uint64]chan connRequest),
+ stop: cancel,
}
- go db.connectionOpener()
+ go db.connectionOpener(ctx)
+ go db.connectionResetter(ctx)
return db
}
@@ -693,7 +714,6 @@ func (db *DB) Close() error {
db.mu.Unlock()
return nil
}
- close(db.openerCh)
if db.cleanerCh != nil {
close(db.cleanerCh)
}
@@ -714,6 +734,7 @@ func (db *DB) Close() error {
err = err1
}
}
+ db.stop()
return err
}
@@ -901,18 +922,39 @@ func (db *DB) maybeOpenNewConnections() {
}
// Runs in a separate goroutine, opens new connections when requested.
-func (db *DB) connectionOpener() {
- for range db.openerCh {
- db.openNewConnection()
+func (db *DB) connectionOpener(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-db.openerCh:
+ db.openNewConnection(ctx)
+ }
+ }
+}
+
+// connectionResetter runs in a separate goroutine to reset connections async
+// to exported API.
+func (db *DB) connectionResetter(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ for dc := range db.resetterCh {
+ dc.Unlock()
+ }
+ return
+ case dc := <-db.resetterCh:
+ dc.resetSession(ctx)
+ }
}
}
// Open one new connection
-func (db *DB) openNewConnection() {
+func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnctions has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
- ci, err := db.connector.Connect(context.Background())
+ ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
@@ -987,6 +1029,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
conn.Close()
return nil, driver.ErrBadConn
}
+ // Lock around reading lastErr to ensure the session resetter finished.
+ conn.Lock()
+ err := conn.lastErr
+ conn.Unlock()
+ if err == driver.ErrBadConn {
+ conn.Close()
+ return nil, driver.ErrBadConn
+ }
return conn, nil
}
@@ -1012,7 +1062,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
default:
case ret, ok := <-req:
if ok {
- db.putConn(ret.conn, ret.err)
+ db.putConn(ret.conn, ret.err, false)
}
}
return nil, ctx.Err()
@@ -1024,6 +1074,17 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
ret.conn.Close()
return nil, driver.ErrBadConn
}
+ if ret.conn == nil {
+ return nil, ret.err
+ }
+ // Lock around reading lastErr to ensure the session resetter finished.
+ ret.conn.Lock()
+ err := ret.conn.lastErr
+ ret.conn.Unlock()
+ if err == driver.ErrBadConn {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
return ret.conn, ret.err
}
}
@@ -1079,7 +1140,7 @@ const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
-func (db *DB) putConn(dc *driverConn, err error) {
+func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
db.mu.Lock()
if !dc.inUse {
if debugGetPut {
@@ -1110,11 +1171,35 @@ func (db *DB) putConn(dc *driverConn, err error) {
if putConnHook != nil {
putConnHook(db, dc)
}
+ if resetSession {
+ if _, resetSession = dc.ci.(driver.ResetSessioner); resetSession {
+ // Lock the driverConn here so it isn't released until
+ // the connection is reset.
+ // The lock must be taken before the connection is put into
+ // the pool to prevent it from being taken out before it is reset.
+ dc.Lock()
+ }
+ }
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
+ if resetSession {
+ dc.Unlock()
+ }
dc.Close()
+ return
+ }
+ if !resetSession {
+ return
+ }
+ select {
+ default:
+ // If the resetterCh is blocking then mark the connection
+ // as bad and continue on.
+ dc.lastErr = driver.ErrBadConn
+ dc.Unlock()
+ case db.resetterCh <- dc:
}
}
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index 3551366369..dead273503 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -60,10 +60,12 @@ const fakeDBName = "foo"
var chrisBirthday = time.Unix(123456789, 0)
func newTestDB(t testing.TB, name string) *DB {
- db, err := Open("test", fakeDBName)
- if err != nil {
- t.Fatalf("Open: %v", err)
- }
+ return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
+}
+
+func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
+ fc.name = fakeDBName
+ db := OpenDB(fc)
if _, err := db.Exec("WIPE"); err != nil {
t.Fatalf("exec wipe: %v", err)
}
@@ -585,24 +587,46 @@ func TestPoolExhaustOnCancel(t *testing.T) {
if testing.Short() {
t.Skip("long test")
}
- db := newTestDB(t, "people")
- defer closeDB(t, db)
max := 3
+ var saturate, saturateDone sync.WaitGroup
+ saturate.Add(max)
+ saturateDone.Add(max)
+
+ donePing := make(chan bool)
+ state := 0
+
+ // waiter will be called for all queries, including
+ // initial setup queries. The state is only assigned when no
+ // no queries are made.
+ //
+ // Only allow the first batch of queries to finish once the
+ // second batch of Ping queries have finished.
+ waiter := func(ctx context.Context) {
+ switch state {
+ case 0:
+ // Nothing. Initial database setup.
+ case 1:
+ saturate.Done()
+ select {
+ case <-ctx.Done():
+ case <-donePing:
+ }
+ case 2:
+ }
+ }
+ db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
+ defer closeDB(t, db)
db.SetMaxOpenConns(max)
// First saturate the connection pool.
// Then start new requests for a connection that is cancelled after it is requested.
- var saturate, saturateDone sync.WaitGroup
- saturate.Add(max)
- saturateDone.Add(max)
-
+ state = 1
for i := 0; i < max; i++ {
go func() {
- saturate.Done()
- rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
+ rows, err := db.Query("SELECT|people|name,photo|")
if err != nil {
t.Fatalf("Query: %v", err)
}
@@ -612,6 +636,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
}
saturate.Wait()
+ state = 2
// Now cancel the request while it is waiting.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
@@ -628,7 +653,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
t.Fatalf("PingContext (Exhaust): %v", err)
}
}
-
+ close(donePing)
saturateDone.Wait()
// Now try to open a normal connection.
@@ -1332,6 +1357,7 @@ func TestConnQuery(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
var name string
@@ -1359,6 +1385,7 @@ func TestConnTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
tx, err := conn.BeginTx(ctx, nil)
@@ -2384,7 +2411,9 @@ func TestManyErrBadConn(t *testing.T) {
t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
}
for _, conn := range db.freeConn {
+ conn.Lock()
conn.ci.(*fakeConn).stickyBad = true
+ conn.Unlock()
}
return db
}
@@ -2474,6 +2503,7 @@ func TestManyErrBadConn(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
err = conn.Close()
if err != nil {
t.Fatal(err)
@@ -3238,9 +3268,8 @@ func TestIssue18719(t *testing.T) {
// This call will grab the connection and cancel the context
// after it has done so. Code after must deal with the canceled state.
- rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
+ _, err = tx.QueryContext(ctx, "SELECT|people|name|")
if err != nil {
- rows.Close()
t.Fatalf("expected error %v but got %v", nil, err)
}
@@ -3263,6 +3292,7 @@ func TestIssue20647(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
@@ -3567,6 +3597,8 @@ func TestQueryExecContextOnly(t *testing.T) {
t.Fatal("db.Conn", err)
}
defer conn.Close()
+ coc := conn.dc.ci.(*ctxOnlyConn)
+ coc.fc.skipDirtySession = true
_, err = conn.ExecContext(ctx, "WIPE")
if err != nil {
@@ -3599,7 +3631,6 @@ func TestQueryExecContextOnly(t *testing.T) {
t.Fatalf("expected %q, got %q", expectedValue, v1)
}
- coc := conn.dc.ci.(*ctxOnlyConn)
if !coc.execCtxCalled {
t.Error("ExecContext not called")
}