aboutsummaryrefslogtreecommitdiff
path: root/internal/database/database_test.go
diff options
context:
space:
mode:
authorJonathan Amsterdam <jba@google.com>2020-04-13 14:55:37 -0400
committerJonathan Amsterdam <jba@google.com>2020-04-14 15:38:30 +0000
commit86348f6125d8422c830a3ca3c67f247962fbea1c (patch)
tree8513c06ec40404213be6b38b4629ba476d991a13 /internal/database/database_test.go
parent90861f00a20f0c21fabcc46ff145201279e67da6 (diff)
downloadgo-x-pkgsite-86348f6125d8422c830a3ca3c67f247962fbea1c.tar.xz
internal/database: use the DB for inside a transaction as well as out
The database.DB type now can represent a DB connection in the middle of a transaction. Such a DB is created only by calling DB.Transact. The resulting API is much simpler, since all the ...Tx methods disappear. Change-Id: I41afada87738e1eacdec2fcf115902edddeff867 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/716719 Reviewed-by: Julie Qiu <julieqiu@google.com>
Diffstat (limited to 'internal/database/database_test.go')
-rw-r--r--internal/database/database_test.go25
1 files changed, 20 insertions, 5 deletions
diff --git a/internal/database/database_test.go b/internal/database/database_test.go
index ef72ace7..bb7fd27c 100644
--- a/internal/database/database_test.go
+++ b/internal/database/database_test.go
@@ -6,7 +6,6 @@ package database
import (
"context"
- "database/sql"
"fmt"
"log"
"os"
@@ -127,8 +126,8 @@ func TestBulkInsert(t *testing.T) {
}
}()
- if err := testDB.Transact(func(tx *sql.Tx) error {
- return BulkInsert(ctx, tx, table, tc.columns, tc.values, tc.conflictAction)
+ if err := testDB.Transact(func(db *DB) error {
+ return db.BulkInsert(ctx, table, tc.columns, tc.values, tc.conflictAction)
}); tc.wantErr && err == nil || !tc.wantErr && err != nil {
t.Errorf("testDB.Transact: %v | wantErr = %t", err, tc.wantErr)
}
@@ -160,8 +159,8 @@ func TestLargeBulkInsert(t *testing.T) {
for i := 0; i < size; i++ {
vals[i] = i + 1
}
- if err := testDB.Transact(func(tx *sql.Tx) error {
- return BulkInsert(ctx, tx, "test_large_bulk", []string{"i"}, vals, "")
+ if err := testDB.Transact(func(db *DB) error {
+ return db.BulkInsert(ctx, "test_large_bulk", []string{"i"}, vals, "")
}); err != nil {
t.Fatal(err)
}
@@ -183,3 +182,19 @@ func TestLargeBulkInsert(t *testing.T) {
t.Errorf("sum = %d, want %d", sum, want)
}
}
+
+func TestDBAfterTransactFails(t *testing.T) {
+ var tx *DB
+ err := testDB.Transact(func(d *DB) error {
+ tx = d
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ var i int
+ err = tx.QueryRow(context.Background(), `SELECT 1`).Scan(&i)
+ if err == nil {
+ t.Fatal("got nil, want error")
+ }
+}