diff options
Diffstat (limited to 'src/database/sql/fakedb_test.go')
| -rw-r--r-- | src/database/sql/fakedb_test.go | 31 |
1 files changed, 22 insertions, 9 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 4b15f5bec7..1c95c35a68 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -58,9 +58,10 @@ type fakeDriver struct { type fakeDB struct { name string - mu sync.Mutex - tables map[string]*table - badConn bool + mu sync.Mutex + tables map[string]*table + badConn bool + allowAny bool } type table struct { @@ -352,12 +353,14 @@ func (c *fakeConn) Close() (err error) { return nil } -func checkSubsetTypes(args []driver.NamedValue) error { +func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { for _, arg := range args { switch arg.Value.(type) { case int64, float64, bool, nil, []byte, string, time.Time: default: - return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + if !allowAny { + return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + } } } return nil @@ -373,7 +376,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver. // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -390,7 +393,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver // just to check that all the args are of the proper types. // ErrSkip is returned so the caller acts as if we didn't // implement this at all. - err := checkSubsetTypes(args) + err := checkSubsetTypes(c.db.allowAny, args) if err != nil { return nil, err } @@ -642,7 +645,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } @@ -753,7 +756,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( return nil, driver.ErrBadConn } - err := checkSubsetTypes(args) + err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { return nil, err } @@ -1004,6 +1007,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { return fmt.Sprintf("%v", v), nil } +type anyTypeConverter struct{} + +func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) { + return v, nil +} + func converterForType(typ string) driver.ValueConverter { switch typ { case "bool": @@ -1030,6 +1039,8 @@ func converterForType(typ string) driver.ValueConverter { return driver.Null{Converter: driver.DefaultParameterConverter} case "datetime": return driver.DefaultParameterConverter + case "any": + return anyTypeConverter{} } panic("invalid fakedb column type of " + typ) } @@ -1056,6 +1067,8 @@ func colTypeToReflectType(typ string) reflect.Type { return reflect.TypeOf(NullFloat64{}) case "datetime": return reflect.TypeOf(time.Time{}) + case "any": + return reflect.TypeOf(new(interface{})).Elem() } panic("invalid fakedb column type of " + typ) } |
