aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/fakedb_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/fakedb_test.go')
-rw-r--r--src/database/sql/fakedb_test.go62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 4dcd096ca4..070b783453 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -55,6 +55,22 @@ type fakeDriver struct {
dbs map[string]*fakeDB
}
+type fakeConnector struct {
+ name string
+
+ waiter func(context.Context)
+}
+
+func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
+ conn, err := fdriver.Open(c.name)
+ conn.(*fakeConn).waiter = c.waiter
+ return conn, err
+}
+
+func (c *fakeConnector) Driver() driver.Driver {
+ return fdriver
+}
+
type fakeDB struct {
name string
@@ -107,6 +123,16 @@ type fakeConn struct {
// bad connection tests; see isBad()
bad bool
stickyBad bool
+
+ skipDirtySession bool // tests that use Conn should set this to true.
+
+ // dirtySession tests ResetSession, true if a query has executed
+ // until ResetSession is called.
+ dirtySession bool
+
+ // The waiter is called before each query. May be used in place of the "WAIT"
+ // directive.
+ waiter func(context.Context)
}
func (c *fakeConn) touchMem() {
@@ -298,6 +324,9 @@ func (c *fakeConn) isBad() bool {
if c.stickyBad {
return true
} else if c.bad {
+ if c.db == nil {
+ return false
+ }
// alternate between bad conn and not bad conn
c.db.badConn = !c.db.badConn
return c.db.badConn
@@ -306,6 +335,21 @@ func (c *fakeConn) isBad() bool {
}
}
+func (c *fakeConn) isDirtyAndMark() bool {
+ if c.skipDirtySession {
+ return false
+ }
+ if c.currTx != nil {
+ c.dirtySession = true
+ return false
+ }
+ if c.dirtySession {
+ return true
+ }
+ c.dirtySession = true
+ return false
+}
+
func (c *fakeConn) Begin() (driver.Tx, error) {
if c.isBad() {
return nil, driver.ErrBadConn
@@ -337,6 +381,14 @@ func setStrictFakeConnClose(t *testing.T) {
testStrictClose = t
}
+func (c *fakeConn) ResetSession(ctx context.Context) error {
+ c.dirtySession = false
+ if c.isBad() {
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
func (c *fakeConn) Close() (err error) {
drv := fdriver.(*fakeDriver)
defer func() {
@@ -572,6 +624,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
stmt.cmd = cmd
parts = parts[1:]
+ if c.waiter != nil {
+ c.waiter(ctx)
+ }
+
if stmt.wait > 0 {
wait := time.NewTimer(stmt.wait)
select {
@@ -662,6 +718,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
@@ -774,6 +833,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {