From d2df8498f366669acbae24f38e3683b3acdab102 Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Wed, 28 Sep 2016 12:51:39 -0700 Subject: database/sql: close Rows when context is cancelled To prevent leaking connections, close any open Rows when the context is cancelled. Also enforce context cancel while reading rows off of the wire. Change-Id: I62237ecdb7d250d6734f6ce3d2b0bcb16dc6fda7 Reviewed-on: https://go-review.googlesource.com/29957 Reviewed-by: Brad Fitzpatrick --- src/database/sql/sql_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) (limited to 'src/database/sql/sql_test.go') diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 9fcb2e38c1..ca14af79e7 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -261,6 +261,64 @@ func TestQuery(t *testing.T) { } } +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + time.Sleep(10 * time.Millisecond) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err == nil { + t.Fatal("expected an error on last scan") + } + got = append(got, r) + index++ + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) -- cgit v1.3