diff options
Diffstat (limited to 'src/database/sql')
| -rw-r--r-- | src/database/sql/sql.go | 35 | ||||
| -rw-r--r-- | src/database/sql/sql_test.go | 48 |
2 files changed, 83 insertions, 0 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 27adf69122..5c5b7dc7e9 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1792,6 +1792,8 @@ type Conn struct { done int32 } +// grabConn takes a context to implement stmtConnGrabber +// but the context is not used. func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { if atomic.LoadInt32(&c.done) != 0 { return nil, nil, ErrConnDone @@ -1856,6 +1858,39 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) return c.db.prepareDC(ctx, dc, release, c, query) } +// Raw executes f exposing the underlying driver connection for the +// duration of f. The driverConn must not be used outside of f. +// +// Once f returns and err is nil, the Conn will continue to be usable +// until Conn.Close is called. +func (c *Conn) Raw(f func(driverConn interface{}) error) (err error) { + var dc *driverConn + var release releaseConn + + // grabConn takes a context to implement stmtConnGrabber, but the context is not used. + dc, release, err = c.grabConn(nil) + if err != nil { + return + } + fPanic := true + dc.Mutex.Lock() + defer func() { + dc.Mutex.Unlock() + + // If f panics fPanic will remain true. + // Ensure an error is passed to release so the connection + // may be discarded. + if fPanic { + err = driver.ErrBadConn + } + release(err) + }() + err = f(dc.ci) + fPanic = false + + return +} + // BeginTx starts a transaction. // // The provided context is used until the transaction is committed or rolled back. diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 260374d413..a95b70cadb 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -1339,6 +1339,54 @@ func TestConnQuery(t *testing.T) { } } +func TestConnRaw(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + sawFunc := false + err = conn.Raw(func(dc interface{}) error { + sawFunc = true + if _, ok := dc.(*fakeConn); !ok { + return fmt.Errorf("got %T want *fakeConn", dc) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if !sawFunc { + t.Fatal("Raw func not called") + } + + func() { + defer func() { + x := recover() + if x == nil { + t.Fatal("expected panic") + } + conn.closemu.Lock() + closed := conn.dc == nil + conn.closemu.Unlock() + if !closed { + t.Fatal("expected connection to be closed after panic") + } + }() + err = conn.Raw(func(dc interface{}) error { + panic("Conn.Raw panic should return an error") + }) + t.Fatal("expected panic from Raw func") + }() +} + func TestCursorFake(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) |
