aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/fakedb_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql/fakedb_test.go')
-rw-r--r--src/database/sql/fakedb_test.go59
1 files changed, 50 insertions, 9 deletions
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
index 07f50196a5..c42f23208f 100644
--- a/src/database/sql/fakedb_test.go
+++ b/src/database/sql/fakedb_test.go
@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
+ "reflect"
"sort"
"strconv"
"strings"
@@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
stmt.table = parts[0]
+
stmt.colName = strings.Split(parts[1], ",")
for n, colspec := range strings.Split(parts[2], ",") {
if colspec == "" {
@@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
setMRows := make([][]*row, 0, 1)
setColumns := make([][]string, 0, 1)
+ setColType := make([][]string, 0, 1)
for {
db.mu.Lock()
@@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
mrows = append(mrows, mrow)
}
+ var colType []string
+ for _, column := range s.colName {
+ colType = append(colType, t.coltype[t.columnIndex(column)])
+ }
+
t.mu.Unlock()
setMRows = append(setMRows, mrows)
setColumns = append(setColumns, s.colName)
+ setColType = append(setColType, colType)
if s.next == nil {
break
@@ -806,10 +815,11 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
}
cursor := &rowsCursor{
- posRow: -1,
- rows: setMRows,
- cols: setColumns,
- errPos: -1,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
+ colType: setColType,
+ errPos: -1,
}
return cursor, nil
}
@@ -844,11 +854,12 @@ func (tx *fakeTx) Rollback() error {
}
type rowsCursor struct {
- cols [][]string
- posSet int
- posRow int
- rows [][]*row
- closed bool
+ cols [][]string
+ colType [][]string
+ posSet int
+ posRow int
+ rows [][]*row
+ closed bool
// errPos and err are for making Next return early with error.
errPos int
@@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string {
return rc.cols[rc.posSet]
}
+func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
+ return colTypeToReflectType(rc.colType[rc.posSet][index])
+}
+
var rowsCursorNextHook func(dest []driver.Value) error
func (rc *rowsCursor) Next(dest []driver.Value) error {
@@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter {
}
panic("invalid fakedb column type of " + typ)
}
+
+func colTypeToReflectType(typ string) reflect.Type {
+ switch typ {
+ case "bool":
+ return reflect.TypeOf(false)
+ case "nullbool":
+ return reflect.TypeOf(NullBool{})
+ case "int32":
+ return reflect.TypeOf(int32(0))
+ case "string":
+ return reflect.TypeOf("")
+ case "nullstring":
+ return reflect.TypeOf(NullString{})
+ case "int64":
+ return reflect.TypeOf(int64(0))
+ case "nullint64":
+ return reflect.TypeOf(NullInt64{})
+ case "float64":
+ return reflect.TypeOf(float64(0))
+ case "nullfloat64":
+ return reflect.TypeOf(NullFloat64{})
+ case "datetime":
+ return reflect.TypeOf(time.Time{})
+ }
+ panic("invalid fakedb column type of " + typ)
+}