diff options
Diffstat (limited to 'src/database/sql/convert.go')
| -rw-r--r-- | src/database/sql/convert.go | 201 |
1 files changed, 148 insertions, 53 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index 630a585ab2..4983181fe7 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -12,6 +12,7 @@ import ( "fmt" "reflect" "strconv" + "sync" "time" "unicode" "unicode/utf8" @@ -37,86 +38,180 @@ func validateNamedValueName(name string) error { return fmt.Errorf("name %q does not begin with a letter", name) } +func driverNumInput(ds *driverStmt) int { + ds.Lock() + defer ds.Unlock() // in case NumInput panics + return ds.si.NumInput() +} + +// ccChecker wraps the driver.ColumnConverter and allows it to be used +// as if it were a NamedValueChecker. If the driver ColumnConverter +// is not present then the NamedValueChecker will return driver.ErrSkip. +type ccChecker struct { + sync.Locker + cci driver.ColumnConverter + want int +} + +func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { + if c.cci == nil { + return driver.ErrSkip + } + // The column converter shouldn't be called on any index + // it isn't expecting. The final error will be thrown + // in the argument converter loop. + index := nv.Ordinal - 1 + if c.want <= index { + return nil + } + + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if vr, ok := nv.Value.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return err + } + if !driver.IsValue(sv) { + return fmt.Errorf("non-subset type %T returned from Value", sv) + } + nv.Value = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + var err error + arg := nv.Value + c.Lock() + nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) + c.Unlock() + if err != nil { + return err + } + if !driver.IsValue(nv.Value) { + return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) + } + return nil +} + +// defaultCheckNamedValue wraps the default ColumnConverter to have the same +// function signature as the CheckNamedValue in the driver.NamedValueChecker +// interface. +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + // driverArgs converts arguments from callers of Stmt.Exec and // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { +func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + want := -1 + var si driver.Stmt + var cc ccChecker if ds != nil { si = ds.si + want = driverNumInput(ds) + cc.Locker = ds.Locker + cc.want = want } - cc, ok := si.(driver.ColumnConverter) - // Normal path, for a driver.Stmt that is not a ColumnConverter. + // Check all types of interfaces from the start. + // Drivers may opt to use the NamedValueChecker for special + // argument types, then return driver.ErrSkip to pass it along + // to the column converter. + nvc, ok := si.(driver.NamedValueChecker) if !ok { - for n, arg := range args { - var err error - nv := &nvargs[n] - nv.Ordinal = n + 1 - if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { - return nil, err - } - arg = np.Value - nvargs[n].Name = np.Name - } - nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg) - - if err != nil { - return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err) - } - } - return nvargs, nil + nvc, ok = ci.(driver.NamedValueChecker) + } + cci, ok := si.(driver.ColumnConverter) + if ok { + cc.cci = cci } - // Let the Stmt convert its own arguments. - for n, arg := range args { + // Loop through all the arguments, checking each one. + // If no error is returned simply increment the index + // and continue. However if driver.ErrRemoveArgument + // is returned the argument is not included in the query + // argument list. + var err error + var n int + for _, arg := range args { nv := &nvargs[n] - nv.Ordinal = n + 1 if np, ok := arg.(NamedArg); ok { - if err := validateNamedValueName(np.Name); err != nil { + if err = validateNamedValueName(np.Name); err != nil { return nil, err } arg = np.Value nv.Name = np.Name } - // First, see if the value itself knows how to convert - // itself to a driver type. For example, a NullString - // struct changing into a string or nil. - if vr, ok := arg.(driver.Valuer); ok { - sv, err := callValuerValue(vr) - if err != nil { - return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err) - } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv) - } - arg = sv + nv.Ordinal = n + 1 + nv.Value = arg + + // Checking sequence has four routes: + // A: 1. Default + // B: 1. NamedValueChecker 2. Column Converter 3. Default + // C: 1. NamedValueChecker 3. Default + // D: 1. Column Converter 2. Default + // + // The only time a Column Converter is called is first + // or after NamedValueConverter. If first it is handled before + // the nextCheck label. Thus for repeats tries only when the + // NamedValueConverter is selected should the Column Converter + // be used in the retry. + checker := defaultCheckNamedValue + nextCC := false + switch { + case nvc != nil: + nextCC = cci != nil + checker = nvc.CheckNamedValue + case cci != nil: + checker = cc.CheckNamedValue } - // Second, ask the column to sanity check itself. For - // example, drivers might use this to make sure that - // an int64 values being inserted into a 16-bit - // integer field is in range (before getting - // truncated), or that a nil can't go into a NOT NULL - // column before going across the network to get the - // same error. - var err error - ds.Lock() - nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg) - ds.Unlock() - if err != nil { + nextCheck: + err = checker(nv) + switch err { + case nil: + n++ + continue + case driver.ErrRemoveArgument: + nvargs = nvargs[:len(nvargs)-1] + continue + case driver.ErrSkip: + if nextCC { + nextCC = false + checker = cc.CheckNamedValue + } else { + checker = defaultCheckNamedValue + } + goto nextCheck + default: return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) } - if !driver.IsValue(nv.Value) { - return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T", - describeNamedValue(nv), arg, nv.Value) - } + } + + // Check the length of arguments after convertion to allow for omitted + // arguments. + if want != -1 && len(nvargs) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) } return nvargs, nil + } // convertAssign copies to dest the value in src, converting it if possible. |
