aboutsummaryrefslogtreecommitdiff
path: root/src/database/sql
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/sql')
-rw-r--r--src/database/sql/sql.go35
-rw-r--r--src/database/sql/sql_test.go149
2 files changed, 170 insertions, 14 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
index 4be450ca87..c8ec91c1ec 100644
--- a/src/database/sql/sql.go
+++ b/src/database/sql/sql.go
@@ -202,8 +202,9 @@ func (ns *NullString) Scan(value any) error {
ns.String, ns.Valid = "", false
return nil
}
- ns.Valid = true
- return convertAssign(&ns.String, value)
+ err := convertAssign(&ns.String, value)
+ ns.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -228,8 +229,9 @@ func (n *NullInt64) Scan(value any) error {
n.Int64, n.Valid = 0, false
return nil
}
- n.Valid = true
- return convertAssign(&n.Int64, value)
+ err := convertAssign(&n.Int64, value)
+ n.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -254,8 +256,9 @@ func (n *NullInt32) Scan(value any) error {
n.Int32, n.Valid = 0, false
return nil
}
- n.Valid = true
- return convertAssign(&n.Int32, value)
+ err := convertAssign(&n.Int32, value)
+ n.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -334,8 +337,9 @@ func (n *NullFloat64) Scan(value any) error {
n.Float64, n.Valid = 0, false
return nil
}
- n.Valid = true
- return convertAssign(&n.Float64, value)
+ err := convertAssign(&n.Float64, value)
+ n.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -360,8 +364,9 @@ func (n *NullBool) Scan(value any) error {
n.Bool, n.Valid = false, false
return nil
}
- n.Valid = true
- return convertAssign(&n.Bool, value)
+ err := convertAssign(&n.Bool, value)
+ n.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -386,8 +391,9 @@ func (n *NullTime) Scan(value any) error {
n.Time, n.Valid = time.Time{}, false
return nil
}
- n.Valid = true
- return convertAssign(&n.Time, value)
+ err := convertAssign(&n.Time, value)
+ n.Valid = err == nil
+ return err
}
// Value implements the [driver.Valuer] interface.
@@ -422,8 +428,9 @@ func (n *Null[T]) Scan(value any) error {
n.V, n.Valid = *new(T), false
return nil
}
- n.Valid = true
- return convertAssign(&n.V, value)
+ err := convertAssign(&n.V, value)
+ n.Valid = err == nil
+ return err
}
func (n Null[T]) Value() (driver.Value, error) {
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
index e8a6560097..5f093a2d6d 100644
--- a/src/database/sql/sql_test.go
+++ b/src/database/sql/sql_test.go
@@ -5086,3 +5086,152 @@ type unknownInputsValueConverter struct{}
func (unknownInputsValueConverter) ConvertValue(v any) (driver.Value, error) {
return "string", nil
}
+
+func TestNullTypeScanErrorConsistency(t *testing.T) {
+ // Issue #45662: Null* types should have Valid=false when Scan returns an error.
+ // Previously, Valid was set to true before convertAssign was called,
+ // so if conversion failed, Valid would still be true despite the error.
+
+ tests := []struct {
+ name string
+ scanner Scanner
+ input any
+ wantErr bool
+ }{
+ {
+ name: "NullInt32 with invalid input",
+ scanner: &NullInt32{},
+ input: []byte("not_a_number"),
+ wantErr: true,
+ },
+ {
+ name: "NullInt64 with invalid input",
+ scanner: &NullInt64{},
+ input: []byte("not_a_number"),
+ wantErr: true,
+ },
+ {
+ name: "NullFloat64 with invalid input",
+ scanner: &NullFloat64{},
+ input: []byte("not_a_float"),
+ wantErr: true,
+ },
+ {
+ name: "NullBool with invalid input",
+ scanner: &NullBool{},
+ input: []byte("not_a_bool"),
+ wantErr: true,
+ },
+ // Valid cases should still work
+ {
+ name: "NullInt32 with valid input",
+ scanner: &NullInt32{},
+ input: int64(42),
+ wantErr: false,
+ },
+ {
+ name: "NullInt64 with valid input",
+ scanner: &NullInt64{},
+ input: int64(42),
+ wantErr: false,
+ },
+ {
+ name: "NullFloat64 with valid input",
+ scanner: &NullFloat64{},
+ input: float64(3.14),
+ wantErr: false,
+ },
+ {
+ name: "NullBool with valid input",
+ scanner: &NullBool{},
+ input: true,
+ wantErr: false,
+ },
+ {
+ name: "NullString with valid input",
+ scanner: &NullString{},
+ input: "hello",
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := tt.scanner.Scan(tt.input)
+
+ // Check that error matches expectation
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Scan() error = %v, wantErr %v", err, tt.wantErr)
+ }
+
+ // The key invariant: Valid should be the opposite of whether we got an error
+ // (assuming non-nil input)
+ var valid bool
+ switch s := tt.scanner.(type) {
+ case *NullInt32:
+ valid = s.Valid
+ case *NullInt64:
+ valid = s.Valid
+ case *NullFloat64:
+ valid = s.Valid
+ case *NullBool:
+ valid = s.Valid
+ case *NullString:
+ valid = s.Valid
+ case *NullTime:
+ valid = s.Valid
+ }
+
+ if err != nil && valid {
+ t.Errorf("Scan() returned error but Valid=true; want Valid=false when err!=nil")
+ }
+ if err == nil && !valid {
+ t.Errorf("Scan() returned nil error but Valid=false; want Valid=true when err==nil")
+ }
+ })
+ }
+}
+
+// TestNullTypeScanNil verifies that scanning nil sets Valid=false without error.
+func TestNullTypeScanNil(t *testing.T) {
+ tests := []struct {
+ name string
+ scanner Scanner
+ }{
+ {"NullString", &NullString{String: "preset", Valid: true}},
+ {"NullInt64", &NullInt64{Int64: 42, Valid: true}},
+ {"NullInt32", &NullInt32{Int32: 42, Valid: true}},
+ {"NullFloat64", &NullFloat64{Float64: 3.14, Valid: true}},
+ {"NullBool", &NullBool{Bool: true, Valid: true}},
+ {"NullTime", &NullTime{Time: time.Now(), Valid: true}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := tt.scanner.Scan(nil)
+ if err != nil {
+ t.Errorf("Scan(nil) error = %v; want nil", err)
+ }
+
+ var valid bool
+ switch s := tt.scanner.(type) {
+ case *NullString:
+ valid = s.Valid
+ case *NullInt64:
+ valid = s.Valid
+ case *NullInt32:
+ valid = s.Valid
+ case *NullFloat64:
+ valid = s.Valid
+ case *NullBool:
+ valid = s.Valid
+ case *NullTime:
+ valid = s.Valid
+ }
+
+ if valid {
+ t.Errorf("Scan(nil) left Valid=true; want Valid=false")
+ }
+ })
+ }
+}