diff options
Diffstat (limited to 'src/database/sql/fakedb_test.go')
| -rw-r--r-- | src/database/sql/fakedb_test.go | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 07f50196a5..c42f23208f 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "reflect" "sort" "strconv" "strings" @@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) } stmt.table = parts[0] + stmt.colName = strings.Split(parts[1], ",") for n, colspec := range strings.Split(parts[2], ",") { if colspec == "" { @@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( setMRows := make([][]*row, 0, 1) setColumns := make([][]string, 0, 1) + setColType := make([][]string, 0, 1) for { db.mu.Lock() @@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( mrows = append(mrows, mrow) } + var colType []string + for _, column := range s.colName { + colType = append(colType, t.coltype[t.columnIndex(column)]) + } + t.mu.Unlock() setMRows = append(setMRows, mrows) setColumns = append(setColumns, s.colName) + setColType = append(setColType, colType) if s.next == nil { break @@ -806,10 +815,11 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } cursor := &rowsCursor{ - posRow: -1, - rows: setMRows, - cols: setColumns, - errPos: -1, + posRow: -1, + rows: setMRows, + cols: setColumns, + colType: setColType, + errPos: -1, } return cursor, nil } @@ -844,11 +854,12 @@ func (tx *fakeTx) Rollback() error { } type rowsCursor struct { - cols [][]string - posSet int - posRow int - rows [][]*row - closed bool + cols [][]string + colType [][]string + posSet int + posRow int + rows [][]*row + closed bool // errPos and err are for making Next return early with error. errPos int @@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string { return rc.cols[rc.posSet] } +func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { + return colTypeToReflectType(rc.colType[rc.posSet][index]) +} + var rowsCursorNextHook func(dest []driver.Value) error func (rc *rowsCursor) Next(dest []driver.Value) error { @@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter { } panic("invalid fakedb column type of " + typ) } + +func colTypeToReflectType(typ string) reflect.Type { + switch typ { + case "bool": + return reflect.TypeOf(false) + case "nullbool": + return reflect.TypeOf(NullBool{}) + case "int32": + return reflect.TypeOf(int32(0)) + case "string": + return reflect.TypeOf("") + case "nullstring": + return reflect.TypeOf(NullString{}) + case "int64": + return reflect.TypeOf(int64(0)) + case "nullint64": + return reflect.TypeOf(NullInt64{}) + case "float64": + return reflect.TypeOf(float64(0)) + case "nullfloat64": + return reflect.TypeOf(NullFloat64{}) + case "datetime": + return reflect.TypeOf(time.Time{}) + } + panic("invalid fakedb column type of " + typ) +} |
