diff options
Diffstat (limited to 'src/pkg/database/sql/sql.go')
| -rw-r--r-- | src/pkg/database/sql/sql.go | 313 |
1 files changed, 207 insertions, 106 deletions
diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index d351fbc243..a4c410267c 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -191,12 +191,35 @@ type DB struct { dsn string mu sync.Mutex // protects following fields - outConn map[driver.Conn]bool // whether the conn is in use - freeConn []driver.Conn + outConn map[*driverConn]bool // whether the conn is in use + freeConn []*driverConn closed bool dep map[finalCloser]depSet - onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned - lastPut map[driver.Conn]string // stacktrace of last conn's put; debug only + onConnPut map[*driverConn][]func() // code (with mu held) run when conn is next returned + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only +} + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + sync.Mutex + ci driver.Conn +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt +} + +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + return ds.si.Close() } // depSet is a finalCloser's outstanding dependencies @@ -270,9 +293,9 @@ func Open(driverName, dataSourceName string) (*DB, error) { db := &DB{ driver: driveri, dsn: dataSourceName, - outConn: make(map[driver.Conn]bool), - lastPut: make(map[driver.Conn]string), - onConnPut: make(map[driver.Conn][]func()), + outConn: make(map[*driverConn]bool), + lastPut: make(map[*driverConn]string), + onConnPut: make(map[*driverConn][]func()), } return db, nil } @@ -283,11 +306,11 @@ func (db *DB) Ping() error { // TODO(bradfitz): give drivers an optional hook to implement // this in a more efficient or more reliable way, if they // have one. - c, err := db.conn() + dc, err := db.conn() if err != nil { return err } - db.putConn(c, nil) + db.putConn(dc, nil) return nil } @@ -296,8 +319,10 @@ func (db *DB) Close() error { db.mu.Lock() defer db.mu.Unlock() var err error - for _, c := range db.freeConn { - err1 := c.Close() + for _, dc := range db.freeConn { + dc.Lock() + err1 := dc.ci.Close() + dc.Unlock() if err1 != nil { err = err1 } @@ -314,8 +339,8 @@ func (db *DB) maxIdleConns() int { return defaultMaxIdleConns } -// conn returns a newly-opened or cached driver.Conn -func (db *DB) conn() (driver.Conn, error) { +// conn returns a newly-opened or cached *driverConn +func (db *DB) conn() (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() @@ -329,13 +354,16 @@ func (db *DB) conn() (driver.Conn, error) { return conn, nil } db.mu.Unlock() - conn, err := db.driver.Open(db.dsn) - if err == nil { - db.mu.Lock() - db.outConn[conn] = true - db.mu.Unlock() + + ci, err := db.driver.Open(db.dsn) + if err != nil { + return nil, err } - return conn, err + dc := &driverConn{ci: ci} + db.mu.Lock() + db.outConn[dc] = true + db.mu.Unlock() + return dc, nil } // connIfFree returns (wanted, true) if wanted is still a valid conn and @@ -343,7 +371,7 @@ func (db *DB) conn() (driver.Conn, error) { // // If wanted is valid but in use, connIfFree returns (wanted, false). // If wanted is invalid, connIfFre returns (nil, false). -func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { +func (db *DB) connIfFree(wanted *driverConn) (conn *driverConn, ok bool) { db.mu.Lock() defer db.mu.Unlock() if db.outConn[wanted] { @@ -362,12 +390,12 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { } // putConnHook is a hook for testing. -var putConnHook func(*DB, driver.Conn) +var putConnHook func(*DB, *driverConn) // noteUnusedDriverStatement notes that si is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { db.mu.Lock() defer db.mu.Unlock() if db.outConn[c] { @@ -385,24 +413,24 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(c driver.Conn, err error) { +func (db *DB) putConn(dc *driverConn, err error) { db.mu.Lock() - if !db.outConn[c] { + if !db.outConn[dc] { if debugGetPut { - fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c]) + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) } panic("sql: connection returned that was never out") } if debugGetPut { - db.lastPut[c] = stack() + db.lastPut[dc] = stack() } - delete(db.outConn, c) + delete(db.outConn, dc) - if fns, ok := db.onConnPut[c]; ok { + if fns, ok := db.onConnPut[dc]; ok { for _, fn := range fns { fn() } - delete(db.onConnPut, c) + delete(db.onConnPut, dc) } if err == driver.ErrBadConn { @@ -411,17 +439,20 @@ func (db *DB) putConn(c driver.Conn, err error) { return } if putConnHook != nil { - putConnHook(db, c) + putConnHook(db, dc) } if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { - db.freeConn = append(db.freeConn, c) + db.freeConn = append(db.freeConn, dc) db.mu.Unlock() return } // TODO: check to see if we need this Conn for any prepared // statements which are still active? db.mu.Unlock() - c.Close() + + dc.Lock() + dc.ci.Close() + dc.Unlock() } // Prepare creates a prepared statement for later queries or executions. @@ -446,22 +477,24 @@ func (db *DB) prepare(query string) (*Stmt, error) { // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { - db.putConn(ci, err) + db.putConn(dc, err) return nil, err } stmt := &Stmt{ db: db, query: query, - css: []connStmt{{ci, si}}, + css: []connStmt{{dc, si}}, } db.addDep(stmt, stmt) - db.putConn(ci, nil) + db.putConn(dc, nil) return stmt, nil } @@ -480,35 +513,39 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { } func (db *DB) exec(query string, args []interface{}) (res Result, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } defer func() { - db.putConn(ci, err) + db.putConn(dc, err) }() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err != driver.ErrSkip { if err != nil { return nil, err } - return result{resi}, nil + return driverResult{dc, resi}, nil } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() + defer withLock(dc, func() { si.Close() }) - return resultFromStatement(sti, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. @@ -538,14 +575,16 @@ func (db *DB) query(query string, args []interface{}) (*Rows, error) { // queryConn executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { - if queryer, ok := ci.(driver.Queryer); ok { +func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { + if queryer, ok := dc.ci.(driver.Queryer); ok { dargs, err := driverArgs(nil, args) if err != nil { releaseConn(err) return nil, err } + dc.Lock() rowsi, err := queryer.Query(query, dargs) + dc.Unlock() if err != driver.ErrSkip { if err != nil { releaseConn(err) @@ -555,7 +594,7 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a // with releaseConn. rows := &Rows{ db: db, - ci: ci, + dc: dc, releaseConn: releaseConn, rowsi: rowsi, } @@ -563,16 +602,21 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { releaseConn(err) return nil, err } - rowsi, err := rowsiFromStatement(sti, args...) + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { releaseConn(err) - sti.Close() + dc.Lock() + si.Close() + dc.Unlock() return nil, err } @@ -580,10 +624,10 @@ func (db *DB) queryConn(ci driver.Conn, releaseConn func(error), query string, a // with releaseConn. rows := &Rows{ db: db, - ci: ci, + dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: sti, + closeStmt: si, } return rows, nil } @@ -611,18 +655,20 @@ func (db *DB) Begin() (*Tx, error) { } func (db *DB) begin() (tx *Tx, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - txi, err := ci.Begin() + dc.Lock() + txi, err := dc.ci.Begin() + dc.Unlock() if err != nil { - db.putConn(ci, err) + db.putConn(dc, err) return nil, err } return &Tx{ db: db, - ci: ci, + dc: dc, txi: txi, }, nil } @@ -641,13 +687,15 @@ func (db *DB) Driver() driver.Driver { type Tx struct { db *DB - // ci is owned exclusively until Commit or Rollback, at which point + // dc is owned exclusively until Commit or Rollback, at which point // it's returned with putConn. - ci driver.Conn + // TODO(bradfitz): golang.org/issue/3857 + dc *driverConn txi driver.Tx // cimu is held while somebody is using ci (between grabConn // and releaseConn) + // TODO(bradfitz): golang.org/issue/3857 cimu sync.Mutex // done transitions from false to true exactly once, on Commit @@ -663,17 +711,17 @@ func (tx *Tx) close() { panic("double close") // internal error } tx.done = true - tx.db.putConn(tx.ci, nil) - tx.ci = nil + tx.db.putConn(tx.dc, nil) + tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (driver.Conn, error) { +func (tx *Tx) grabConn() (*driverConn, error) { if tx.done { return nil, ErrTxDone } tx.cimu.Lock() - return tx.ci, nil + return tx.dc, nil } func (tx *Tx) releaseConn() { @@ -686,6 +734,8 @@ func (tx *Tx) Commit() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Commit() } @@ -695,6 +745,8 @@ func (tx *Tx) Rollback() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Rollback() } @@ -718,21 +770,26 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } defer tx.releaseConn() - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } stmt := &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: query, } return stmt, nil @@ -756,16 +813,21 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return &Stmt{stickyErr: err} } defer tx.releaseConn() - si, err := ci.Prepare(stmt.query) + dc.Lock() + si, err := dc.ci.Prepare(stmt.query) + dc.Unlock() return &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: stmt.query, stickyErr: err, } @@ -774,33 +836,37 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } defer tx.releaseConn() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err == nil { - return result{resi}, nil + return driverResult{dc, resi}, nil } if err != driver.ErrSkip { return nil, err } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() + defer withLock(dc, func() { si.Close() }) - return resultFromStatement(sti, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. @@ -825,7 +891,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { // connStmt is a prepared statement on a particular connection. type connStmt struct { - ci driver.Conn + dc *driverConn si driver.Stmt } @@ -840,7 +906,7 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi driver.Stmt + txsi *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -857,39 +923,45 @@ type Stmt struct { func (s *Stmt) Exec(args ...interface{}) (Result, error) { s.closemu.RLock() defer s.closemu.RUnlock() - _, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } defer releaseConn(nil) - return resultFromStatement(si, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } -func resultFromStatement(si driver.Stmt, args ...interface{}) (Result, error) { +func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { + if want != -1 && len(args) != want { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - resi, err := si.Exec(dargs) + ds.Lock() + resi, err := ds.si.Exec(dargs) + ds.Unlock() if err != nil { return nil, err } - return result{resi}, nil + return driverResult{ds.Locker, resi}, nil } // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { if err = s.stickyErr; err != nil { return } @@ -909,7 +981,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St return } releaseConn = func(error) { s.tx.releaseConn() } - return ci, releaseConn, s.txsi, nil + return ci, releaseConn, s.txsi.si, nil } var cs connStmt @@ -917,7 +989,7 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St for _, v := range s.css { // TODO(bradfitz): lazily clean up entries in this // list with dead conns while enumerating - if _, match = s.db.connIfFree(v.ci); match { + if _, match = s.db.connIfFree(v.dc); match { cs = v break } @@ -928,11 +1000,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St // TODO(bradfitz): or wait for one? make configurable later? if !match { for i := 0; ; i++ { - ci, err := s.db.conn() + dc, err := s.db.conn() if err != nil { return nil, nil, nil, err } - si, err := ci.Prepare(s.query) + dc.Lock() + si, err := dc.ci.Prepare(s.query) + dc.Unlock() if err == driver.ErrBadConn && i < 10 { continue } @@ -940,14 +1014,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St return nil, nil, nil, err } s.mu.Lock() - cs = connStmt{ci, si} + cs = connStmt{dc, si} s.css = append(s.css, cs) s.mu.Unlock() break } } - conn := cs.ci + conn := cs.dc releaseConn = func(err error) { s.db.putConn(conn, err) } return conn, releaseConn, cs.si, nil } @@ -958,12 +1032,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { s.closemu.RLock() defer s.closemu.RUnlock() - ci, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } - rowsi, err := rowsiFromStatement(si, args...) + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { releaseConn(err) return nil, err @@ -973,7 +1048,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { // with releaseConn. rows := &Rows{ db: s.db, - ci: ci, + dc: dc, rowsi: rowsi, // releaseConn set below } @@ -985,20 +1060,26 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return rows, nil } -func rowsiFromStatement(si driver.Stmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) + if want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - rowsi, err := si.Query(dargs) + ds.Lock() + rowsi, err := ds.si.Query(dargs) + ds.Unlock() if err != nil { return nil, err } @@ -1049,7 +1130,7 @@ func (s *Stmt) Close() error { func (s *Stmt) finalClose() error { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.ci, v.si) + s.db.noteUnusedDriverStatement(v.dc, v.si) } s.css = nil return nil @@ -1070,7 +1151,7 @@ func (s *Stmt) finalClose() error { // ... type Rows struct { db *DB - ci driver.Conn // owned; must call releaseConn when closed to release + dc *driverConn // owned; must call releaseConn when closed to release releaseConn func(error) rowsi driver.Rows @@ -1243,11 +1324,31 @@ type Result interface { RowsAffected() (int64, error) } -type result struct { - driver.Result +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() } func stack() string { var buf [1024]byte return string(buf[:runtime.Stack(buf[:], false)]) } + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + fn() + lk.Unlock() +} |
