From 65e2b7a804ce25942210f8db0e9552db9d7d6ff5 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Tue, 19 Nov 2019 19:04:01 -0500 Subject: internal/database, internal/testing/dbtest: site-agnostic DB functionality Extract into a separate package the core functionality from internal/postgres that doesn't depend on our particular schema. This makes it available for other uses, like devtools commands and etl autocomplete. Do the same for testing functionality. We now have three packages where before we had only one: - internal/postgres: discovery-specific DB operations and test support - internal/database: discovery-agnostic DB operations - internal/testing/dbtest: discovery-agnostic DB test support Change-Id: I54c59aee328dae71ba6c77170a72e7a83da7c785 Reviewed-on: https://team-review.git.corp.google.com/c/golang/discovery/+/602327 Reviewed-by: Robert Findley --- internal/database/database_test.go | 185 +++++++++++++++++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 internal/database/database_test.go (limited to 'internal/database/database_test.go') diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 00000000..ef72ace7 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,185 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package database + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + "golang.org/x/discovery/internal/testing/dbtest" +) + +const testTimeout = 5 * time.Second + +var testDB *DB + +func TestMain(m *testing.M) { + const dbName = "discovery_postgres_test" + + if err := dbtest.CreateDBIfNotExists(dbName); err != nil { + log.Fatal(err) + } + var err error + testDB, err = Open("postgres", dbtest.DBConnURI(dbName)) + if err != nil { + log.Fatal(err) + } + code := m.Run() + if err := testDB.Close(); err != nil { + log.Fatal(err) + } + os.Exit(code) +} + +func TestBulkInsert(t *testing.T) { + table := "test_bulk_insert" + + for _, tc := range []struct { + name string + columns []string + values []interface{} + conflictAction string + wantErr bool + wantCount int + }{ + { + + name: "test-one-row", + columns: []string{"colA"}, + values: []interface{}{"valueA"}, + wantCount: 1, + }, + { + + name: "test-multiple-rows", + columns: []string{"colA"}, + values: []interface{}{"valueA1", "valueA2", "valueA3"}, + wantCount: 3, + }, + { + + name: "test-invalid-column-name", + columns: []string{"invalid_col"}, + values: []interface{}{"valueA"}, + wantErr: true, + }, + { + + name: "test-mismatch-num-cols-and-vals", + columns: []string{"colA", "colB"}, + values: []interface{}{"valueA1", "valueB1", "valueA2"}, + wantErr: true, + }, + { + + name: "test-conflict-no-action-true", + columns: []string{"colA"}, + values: []interface{}{"valueA", "valueA"}, + conflictAction: OnConflictDoNothing, + wantCount: 1, + }, + { + + name: "test-conflict-no-action-false", + columns: []string{"colA"}, + values: []interface{}{"valueA", "valueA"}, + wantErr: true, + }, + { + + // This should execute the statement + // INSERT INTO series (path) VALUES ('''); TRUNCATE series CASCADE;)'); + // which will insert a row with path value: + // '); TRUNCATE series CASCADE;) + // Rather than the statement + // INSERT INTO series (path) VALUES (''); TRUNCATE series CASCADE;)); + // which would truncate most tables in the database. + name: "test-sql-injection", + columns: []string{"colA"}, + values: []interface{}{fmt.Sprintf("''); TRUNCATE %s CASCADE;))", table)}, + conflictAction: OnConflictDoNothing, + wantCount: 1, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + createQuery := fmt.Sprintf(`CREATE TABLE %s ( + colA TEXT NOT NULL, + colB TEXT, + PRIMARY KEY (colA) + );`, table) + if _, err := testDB.Exec(ctx, createQuery); err != nil { + t.Fatal(err) + } + defer func() { + dropTableQuery := fmt.Sprintf("DROP TABLE %s;", table) + if _, err := testDB.Exec(ctx, dropTableQuery); err != nil { + t.Fatal(err) + } + }() + + if err := testDB.Transact(func(tx *sql.Tx) error { + return BulkInsert(ctx, tx, table, tc.columns, tc.values, tc.conflictAction) + }); tc.wantErr && err == nil || !tc.wantErr && err != nil { + t.Errorf("testDB.Transact: %v | wantErr = %t", err, tc.wantErr) + } + + if tc.wantCount != 0 { + var count int + query := "SELECT COUNT(*) FROM " + table + row := testDB.QueryRow(ctx, query) + err := row.Scan(&count) + if err != nil { + t.Fatalf("testDB.queryRow(%q): %v", query, err) + } + if count != tc.wantCount { + t.Errorf("testDB.queryRow(%q) = %d; want = %d", query, count, tc.wantCount) + } + } + }) + } +} + +func TestLargeBulkInsert(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + if _, err := testDB.Exec(ctx, `CREATE TEMPORARY TABLE test_large_bulk (i BIGINT);`); err != nil { + t.Fatal(err) + } + const size = 150000 + vals := make([]interface{}, size) + for i := 0; i < size; i++ { + vals[i] = i + 1 + } + if err := testDB.Transact(func(tx *sql.Tx) error { + return BulkInsert(ctx, tx, "test_large_bulk", []string{"i"}, vals, "") + }); err != nil { + t.Fatal(err) + } + rows, err := testDB.Query(ctx, `SELECT i FROM test_large_bulk;`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + sum := int64(0) + for rows.Next() { + var i int64 + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + sum += i + } + var want int64 = size * (size + 1) / 2 + if sum != want { + t.Errorf("sum = %d, want %d", sum, want) + } +} -- cgit v1.3