diff options
| author | Daniel Theophanes <kardianos@gmail.com> | 2016-10-03 09:49:25 -0700 |
|---|---|---|
| committer | Brad Fitzpatrick <bradfitz@golang.org> | 2016-10-17 07:56:35 +0000 |
| commit | 707a83341b8c7973f4e0fce731fa279c618f233b (patch) | |
| tree | 3436b372be1d863a3e5f2f2e837fbabe577252c8 /src/database/sql/fakedb_test.go | |
| parent | 99df54f19696e26bea8d6a052d8d91ddb1e4ea65 (diff) | |
| download | go-707a83341b8c7973f4e0fce731fa279c618f233b.tar.xz | |
database/sql: add option to use named parameter in query arguments
Modify the new Context methods to take a name-value driver struct.
This will require more modifications to drivers to use, but will
reduce the overall number of structures that need to be maintained
over time.
Fixes #12381
Change-Id: I30747533ce418a1be5991a0c8767a26e8451adbd
Reviewed-on: https://go-review.googlesource.com/30166
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Diffstat (limited to 'src/database/sql/fakedb_test.go')
| -rw-r--r-- | src/database/sql/fakedb_test.go | 83 |
1 files changed, 65 insertions, 18 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index aaa13a6799..07f50196a5 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -5,6 +5,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -32,6 +33,7 @@ var _ = log.Printf // where types are: "string", [u]int{8,16,32,64}, "bool" // INSERT|<tablename>|col=val,col2=val2,col3=? // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? +// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 // // Any of these can be preceded by PANIC|<method>|, to cause the // named method on fakeStmt to panic. @@ -103,6 +105,12 @@ type fakeTx struct { c *fakeConn } +type boundCol struct { + Column string + Placeholder string + Ordinal int +} + type fakeStmt struct { c *fakeConn q string // just for debugging @@ -120,7 +128,7 @@ type fakeStmt struct { colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) placeholders int // used by INSERT/SELECT: number of ? params - whereCol []string // used by SELECT (all placeholders) + whereCol []boundCol // used by SELECT (all placeholders) placeholderConverter []driver.ValueConverter // used by INSERT } @@ -339,18 +347,23 @@ func (c *fakeConn) Close() (err error) { return nil } -func checkSubsetTypes(args []driver.Value) error { - for n, arg := range args { - switch arg.(type) { +func checkSubsetTypes(args []driver.NamedValue) error { + for _, arg := range args { + switch arg.Value.(type) { case int64, float64, bool, nil, []byte, string, time.Time: default: - return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) } } return nil } func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // Ensure that ExecContext is called if available. + panic("ExecContext was not called.") +} + +func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -363,6 +376,11 @@ func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error } func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // Ensure that ExecContext is called if available. + panic("QueryContext was not called.") +} + +func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { // This is an optional interface, but it's implemented here // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't @@ -403,13 +421,13 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err stmt.Close() return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) } - if value != "?" { + if !strings.HasPrefix(value, "?") { stmt.Close() return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", stmt.table, column) } - stmt.whereCol = append(stmt.whereCol, column) stmt.placeholders++ + stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) } return stmt, nil } @@ -454,7 +472,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err } stmt.colName = append(stmt.colName, column) - if value != "?" { + if !strings.HasPrefix(value, "?") { var subsetVal interface{} // Convert to driver subset type switch ctype { @@ -477,7 +495,7 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err } else { stmt.placeholders++ stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) - stmt.colValue = append(stmt.colValue, "?") + stmt.colValue = append(stmt.colValue, value) } } return stmt, nil @@ -580,6 +598,9 @@ var errClosed = errors.New("fakedb: statement has been closed") var hookExecBadConn func() bool func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + panic("Using ExecContext") +} +func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { if s.panic == "Exec" { panic(s.panic) } @@ -620,7 +641,7 @@ func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { // When doInsert is true, add the row to the table. // When doInsert is false do prep-work and error checking, but don't // actually add the row to the table. -func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { +func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { db := s.c.db if len(args) != s.placeholders { panic("error in pkg db; should only get here if size is correct") @@ -646,8 +667,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) } var val interface{} - if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { - val = args[argPos] + if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { + if strvalue == "?" { + val = args[argPos].Value + } else { + // Assign value from argument placeholder name. + for _, a := range args { + if a.Name == strvalue { + val = a.Value + break + } + } + } argPos++ } else { val = s.colValue[n] @@ -667,6 +698,10 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result var hookQueryBadConn func() bool func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + panic("Use QueryContext") +} + +func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { if s.panic == "Query" { panic(s.panic) } @@ -700,9 +735,9 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { } if s.table == "magicquery" { - if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { - if args[0] == "sleep" { - time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { + if args[0].Value == "sleep" { + time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) } } } @@ -725,8 +760,8 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { // Process the where clause, skipping non-match rows. This is lazy // and just uses fmt.Sprintf("%v") to test equality. Good enough // for test code. - for widx, wcol := range s.whereCol { - idx := t.columnIndex(wcol) + for _, wcol := range s.whereCol { + idx := t.columnIndex(wcol.Column) if idx == -1 { t.mu.Unlock() return nil, fmt.Errorf("db: invalid where clause column %q", wcol) @@ -736,7 +771,19 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { // lazy hack to avoid sprintf %v on a []byte tcol = string(bs) } - if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + var argValue interface{} + if wcol.Placeholder == "?" { + argValue = args[wcol.Ordinal-1].Value + } else { + // Assign arg value from placeholder name. + for _, a := range args { + if a.Name == wcol.Placeholder { + argValue = a.Value + break + } + } + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { continue rows } } |
