aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql/convert.go
diff options
context:
space:
mode:
authorDaniel Theophanes <kardianos@gmail.com>2017-03-23 13:17:59 -0700
committerDaniel Theophanes <kardianos@gmail.com>2017-05-18 22:22:31 +0000
commita9bf3b2e19920f6f516bc2e0211ae2e2f7ce395c (patch)
tree3ef9485c394a6713242e556086237a09d6cd92d4 /src/database/sql/convert.go
parent9044cb04f2c0379e907d0b2e944043e81888033e (diff)
downloadgo-a9bf3b2e19920f6f516bc2e0211ae2e2f7ce395c.tar.xz
database/sql: allow drivers to support custom arg types
Previously all arguments were passed through driver.IsValid. This checked arguments against a few fundamental go types and prevented others from being passed in as arguments. The new interface driver.NamedValueChecker may be implemented by both driver.Stmt and driver.Conn. This allows this new interface to completely supersede the driver.ColumnConverter interface as it can be used for checking arguments known to a prepared statement and arbitrary query arguments. The NamedValueChecker may be skipped with driver.ErrSkip after all special cases are exhausted to use the default argument converter. In addition if driver.ErrRemoveArgument is returned the argument will not be passed to the query at all, useful for passing in driver specific per-query options. Add a canonical Out argument wrapper to be passed to OUTPUT parameters. This will unify checks that need to be written in the NameValueChecker. The statement number check is also moved to the argument converter so the NamedValueChecker may remove arguments passed to the query. Fixes #13567 Fixes #18079 Updates #18417 Updates #17834 Updates #16235 Updates #13067 Updates #19797 Change-Id: I89088bd9cca4596a48bba37bfd20d987453ef237 Reviewed-on: https://go-review.googlesource.com/38533 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
Diffstat (limited to 'src/database/sql/convert.go')
-rw-r--r--src/database/sql/convert.go201
1 files changed, 148 insertions, 53 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go
index 630a585ab2..4983181fe7 100644
--- a/src/database/sql/convert.go
+++ b/src/database/sql/convert.go
@@ -12,6 +12,7 @@ import (
"fmt"
"reflect"
"strconv"
+ "sync"
"time"
"unicode"
"unicode/utf8"
@@ -37,86 +38,180 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name)
}
+func driverNumInput(ds *driverStmt) int {
+ ds.Lock()
+ defer ds.Unlock() // in case NumInput panics
+ return ds.si.NumInput()
+}
+
+// ccChecker wraps the driver.ColumnConverter and allows it to be used
+// as if it were a NamedValueChecker. If the driver ColumnConverter
+// is not present then the NamedValueChecker will return driver.ErrSkip.
+type ccChecker struct {
+ sync.Locker
+ cci driver.ColumnConverter
+ want int
+}
+
+func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.cci == nil {
+ return driver.ErrSkip
+ }
+ // The column converter shouldn't be called on any index
+ // it isn't expecting. The final error will be thrown
+ // in the argument converter loop.
+ index := nv.Ordinal - 1
+ if c.want <= index {
+ return nil
+ }
+
+ // First, see if the value itself knows how to convert
+ // itself to a driver type. For example, a NullString
+ // struct changing into a string or nil.
+ if vr, ok := nv.Value.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(sv) {
+ return fmt.Errorf("non-subset type %T returned from Value", sv)
+ }
+ nv.Value = sv
+ }
+
+ // Second, ask the column to sanity check itself. For
+ // example, drivers might use this to make sure that
+ // an int64 values being inserted into a 16-bit
+ // integer field is in range (before getting
+ // truncated), or that a nil can't go into a NOT NULL
+ // column before going across the network to get the
+ // same error.
+ var err error
+ arg := nv.Value
+ c.Lock()
+ nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
+ c.Unlock()
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(nv.Value) {
+ return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
+ }
+ return nil
+}
+
+// defaultCheckNamedValue wraps the default ColumnConverter to have the same
+// function signature as the CheckNamedValue in the driver.NamedValueChecker
+// interface.
+func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
+ nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
+ return err
+}
+
// driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
+
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ want := -1
+
var si driver.Stmt
+ var cc ccChecker
if ds != nil {
si = ds.si
+ want = driverNumInput(ds)
+ cc.Locker = ds.Locker
+ cc.want = want
}
- cc, ok := si.(driver.ColumnConverter)
- // Normal path, for a driver.Stmt that is not a ColumnConverter.
+ // Check all types of interfaces from the start.
+ // Drivers may opt to use the NamedValueChecker for special
+ // argument types, then return driver.ErrSkip to pass it along
+ // to the column converter.
+ nvc, ok := si.(driver.NamedValueChecker)
if !ok {
- for n, arg := range args {
- var err error
- nv := &nvargs[n]
- nv.Ordinal = n + 1
- if np, ok := arg.(NamedArg); ok {
- if err := validateNamedValueName(np.Name); err != nil {
- return nil, err
- }
- arg = np.Value
- nvargs[n].Name = np.Name
- }
- nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
-
- if err != nil {
- return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err)
- }
- }
- return nvargs, nil
+ nvc, ok = ci.(driver.NamedValueChecker)
+ }
+ cci, ok := si.(driver.ColumnConverter)
+ if ok {
+ cc.cci = cci
}
- // Let the Stmt convert its own arguments.
- for n, arg := range args {
+ // Loop through all the arguments, checking each one.
+ // If no error is returned simply increment the index
+ // and continue. However if driver.ErrRemoveArgument
+ // is returned the argument is not included in the query
+ // argument list.
+ var err error
+ var n int
+ for _, arg := range args {
nv := &nvargs[n]
- nv.Ordinal = n + 1
if np, ok := arg.(NamedArg); ok {
- if err := validateNamedValueName(np.Name); err != nil {
+ if err = validateNamedValueName(np.Name); err != nil {
return nil, err
}
arg = np.Value
nv.Name = np.Name
}
- // First, see if the value itself knows how to convert
- // itself to a driver type. For example, a NullString
- // struct changing into a string or nil.
- if vr, ok := arg.(driver.Valuer); ok {
- sv, err := callValuerValue(vr)
- if err != nil {
- return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err)
- }
- if !driver.IsValue(sv) {
- return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv)
- }
- arg = sv
+ nv.Ordinal = n + 1
+ nv.Value = arg
+
+ // Checking sequence has four routes:
+ // A: 1. Default
+ // B: 1. NamedValueChecker 2. Column Converter 3. Default
+ // C: 1. NamedValueChecker 3. Default
+ // D: 1. Column Converter 2. Default
+ //
+ // The only time a Column Converter is called is first
+ // or after NamedValueConverter. If first it is handled before
+ // the nextCheck label. Thus for repeats tries only when the
+ // NamedValueConverter is selected should the Column Converter
+ // be used in the retry.
+ checker := defaultCheckNamedValue
+ nextCC := false
+ switch {
+ case nvc != nil:
+ nextCC = cci != nil
+ checker = nvc.CheckNamedValue
+ case cci != nil:
+ checker = cc.CheckNamedValue
}
- // Second, ask the column to sanity check itself. For
- // example, drivers might use this to make sure that
- // an int64 values being inserted into a 16-bit
- // integer field is in range (before getting
- // truncated), or that a nil can't go into a NOT NULL
- // column before going across the network to get the
- // same error.
- var err error
- ds.Lock()
- nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg)
- ds.Unlock()
- if err != nil {
+ nextCheck:
+ err = checker(nv)
+ switch err {
+ case nil:
+ n++
+ continue
+ case driver.ErrRemoveArgument:
+ nvargs = nvargs[:len(nvargs)-1]
+ continue
+ case driver.ErrSkip:
+ if nextCC {
+ nextCC = false
+ checker = cc.CheckNamedValue
+ } else {
+ checker = defaultCheckNamedValue
+ }
+ goto nextCheck
+ default:
return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
}
- if !driver.IsValue(nv.Value) {
- return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T",
- describeNamedValue(nv), arg, nv.Value)
- }
+ }
+
+ // Check the length of arguments after convertion to allow for omitted
+ // arguments.
+ if want != -1 && len(nvargs) != want {
+ return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
}
return nvargs, nil
+
}
// convertAssign copies to dest the value in src, converting it if possible.