diff options
Diffstat (limited to 'src/database/sql')
| -rw-r--r-- | src/database/sql/ctxutil.go | 24 | ||||
| -rw-r--r-- | src/database/sql/sql.go | 45 | ||||
| -rw-r--r-- | src/database/sql/sql_test.go | 58 |
3 files changed, 119 insertions, 8 deletions
diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go index 65e1652657..e1d4c03c9a 100644 --- a/src/database/sql/ctxutil.go +++ b/src/database/sql/ctxutil.go @@ -14,6 +14,10 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver if ciCtx, is := ci.(driver.ConnPrepareContext); is { return ciCtx.PrepareContext(ctx, query) } + if ctx.Done() == context.Background().Done() { + return ci.Prepare(query) + } + type R struct { err error panic interface{} @@ -50,6 +54,10 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg if execerCtx, is := execer.(driver.ExecerContext); is { return execerCtx.ExecContext(ctx, query, dargs) } + if ctx.Done() == context.Background().Done() { + return execer.Exec(query, dargs) + } + type R struct { err error panic interface{} @@ -86,6 +94,10 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d if queryerCtx, is := queryer.(driver.QueryerContext); is { return queryerCtx.QueryContext(ctx, query, dargs) } + if ctx.Done() == context.Background().Done() { + return queryer.Query(query, dargs) + } + type R struct { err error panic interface{} @@ -122,6 +134,10 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value if siCtx, is := si.(driver.StmtExecContext); is { return siCtx.ExecContext(ctx, dargs) } + if ctx.Done() == context.Background().Done() { + return si.Exec(dargs) + } + type R struct { err error panic interface{} @@ -158,6 +174,10 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Valu if siCtx, is := si.(driver.StmtQueryContext); is { return siCtx.QueryContext(ctx, dargs) } + if ctx.Done() == context.Background().Done() { + return si.Query(dargs) + } + type R struct { err error panic interface{} @@ -196,6 +216,10 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { if ciCtx, is := ci.(driver.ConnBeginContext); is { return ciCtx.BeginContext(ctx) } + if ctx.Done() == context.Background().Done() { + return ci.Begin() + } + // TODO(kardianos): check the transaction level in ctx. If set and non-default // then return an error here as the BeginContext driver value is not supported. diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 4c44e2b6f4..f56c71a638 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -974,7 +974,8 @@ const maxBadConnRetries = 2 // returned statement. // The caller must call the statement's Close method // when the statement is no longer needed. -// Context is for the preparation of the statment, not for the execution of +// +// The provided context is for the preparation of the statment, not for the execution of // the statement. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt @@ -1148,6 +1149,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er releaseConn: releaseConn, rowsi: rowsi, } + rows.initContextClose(ctx) return rows, nil } } @@ -1180,6 +1182,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er rowsi: rowsi, closeStmt: si, } + rows.initContextClose(ctx) return rows, nil } @@ -1364,7 +1367,8 @@ func (tx *Tx) Rollback() error { // be used once the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. -// Context will be used for the preparation of the context, not +// +// The provided context will be used for the preparation of the context, not // for the execution of the returned statement. The returned statement // will run in the transaction context. func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { @@ -1759,6 +1763,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er rowsi: rowsi, // releaseConn set below } + rows.initContextClose(ctx) s.db.addDep(s, rows) rows.releaseConn = func(err error) { releaseConn(err) @@ -1899,12 +1904,30 @@ type Rows struct { releaseConn func(error) rowsi driver.Rows - closed bool + // closed value is 1 when the Rows is closed. + // Use atomic operations on value when checking value. + closed int32 + ctxClose chan struct{} // closed when Rows is closed, may be null. lastcols []driver.Value lasterr error // non-nil only if closed is true closeStmt driver.Stmt // if non-nil, statement to Close on close } +func (rs *Rows) initContextClose(ctx context.Context) { + if ctx.Done() == context.Background().Done() { + return + } + + rs.ctxClose = make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + rs.Close() + case <-rs.ctxClose: + } + }() +} + // Next prepares the next result row for reading with the Scan method. It // returns true on success, or false if there is no next result row or an error // happened while preparing it. Err should be consulted to distinguish between @@ -1912,7 +1935,7 @@ type Rows struct { // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { - if rs.closed { + if rs.isClosed() { return false } if rs.lastcols == nil { @@ -1939,7 +1962,7 @@ func (rs *Rows) Err() error { // Columns returns an error if the rows are closed, or if the rows // are from QueryRow and there was a deferred error. func (rs *Rows) Columns() ([]string, error) { - if rs.closed { + if rs.isClosed() { return nil, errors.New("sql: Rows are closed") } if rs.rowsi == nil { @@ -2000,7 +2023,7 @@ func (rs *Rows) Columns() ([]string, error) { // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. func (rs *Rows) Scan(dest ...interface{}) error { - if rs.closed { + if rs.isClosed() { return errors.New("sql: Rows are closed") } if rs.lastcols == nil { @@ -2020,14 +2043,20 @@ func (rs *Rows) Scan(dest ...interface{}) error { var rowsCloseHook func(*Rows, *error) +func (rs *Rows) isClosed() bool { + return atomic.LoadInt32(&rs.closed) != 0 +} + // Close closes the Rows, preventing further enumeration. If Next returns // false, the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { - if rs.closed { + if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) { return nil } - rs.closed = true + if rs.ctxClose != nil { + close(rs.ctxClose) + } err := rs.rowsi.Close() if fn := rowsCloseHook; fn != nil { fn(rs, &err) diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 9fcb2e38c1..ca14af79e7 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -261,6 +261,64 @@ func TestQuery(t *testing.T) { } } +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + time.Sleep(10 * time.Millisecond) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err == nil { + t.Fatal("expected an error on last scan") + } + got = append(got, r) + index++ + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) |
