diff --git a/spanner/value.go b/spanner/value.go index 2a8225db37f..55fc4c270bf 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -18,6 +18,8 @@ package spanner import ( "bytes" + "database/sql" + "database/sql/driver" "encoding/base64" "encoding/json" "fmt" @@ -210,6 +212,43 @@ func (n *NullInt64) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Int64, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullInt64) Scan(value interface{}) error { + if value == nil { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullInt64: %v", p) + case *int64: + n.Int64 = *p + case int64: + n.Int64 = p + case *NullInt64: + n.Int64 = p.Int64 + n.Valid = p.Valid + case NullInt64: + n.Int64 = p.Int64 + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullInt64) GormDataType() string { + return "INT64" +} + // NullString represents a Cloud Spanner STRING that may be NULL. type NullString struct { StringVal string // StringVal contains the value when it is non-NULL, and an empty string when NULL. @@ -256,6 +295,43 @@ func (n *NullString) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullString) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.StringVal, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullString) Scan(value interface{}) error { + if value == nil { + n.StringVal, n.Valid = "", false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullString: %v", p) + case *string: + n.StringVal = *p + case string: + n.StringVal = p + case *NullString: + n.StringVal = p.StringVal + n.Valid = p.Valid + case NullString: + n.StringVal = p.StringVal + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullString) GormDataType() string { + return "STRING(MAX)" +} + // NullFloat64 represents a Cloud Spanner FLOAT64 that may be NULL. type NullFloat64 struct { Float64 float64 // Float64 contains the value when it is non-NULL, and zero when NULL. @@ -302,6 +378,43 @@ func (n *NullFloat64) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullFloat64) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Float64, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullFloat64) Scan(value interface{}) error { + if value == nil { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullFloat64: %v", p) + case *float64: + n.Float64 = *p + case float64: + n.Float64 = p + case *NullFloat64: + n.Float64 = p.Float64 + n.Valid = p.Valid + case NullFloat64: + n.Float64 = p.Float64 + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullFloat64) GormDataType() string { + return "FLOAT64" +} + // NullBool represents a Cloud Spanner BOOL that may be NULL. type NullBool struct { Bool bool // Bool contains the value when it is non-NULL, and false when NULL. @@ -348,6 +461,43 @@ func (n *NullBool) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullBool) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Bool, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullBool) Scan(value interface{}) error { + if value == nil { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullBool: %v", p) + case *bool: + n.Bool = *p + case bool: + n.Bool = p + case *NullBool: + n.Bool = p.Bool + n.Valid = p.Valid + case NullBool: + n.Bool = p.Bool + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullBool) GormDataType() string { + return "BOOL" +} + // NullTime represents a Cloud Spanner TIMESTAMP that may be null. type NullTime struct { Time time.Time // Time contains the value when it is non-NULL, and a zero time.Time when NULL. @@ -399,6 +549,43 @@ func (n *NullTime) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullTime) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Time, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullTime) Scan(value interface{}) error { + if value == nil { + n.Time, n.Valid = time.Time{}, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullTime: %v", p) + case *time.Time: + n.Time = *p + case time.Time: + n.Time = p + case *NullTime: + n.Time = p.Time + n.Valid = p.Valid + case NullTime: + n.Time = p.Time + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullTime) GormDataType() string { + return "TIMESTAMP" +} + // NullDate represents a Cloud Spanner DATE that may be null. type NullDate struct { Date civil.Date // Date contains the value when it is non-NULL, and a zero civil.Date when NULL. @@ -450,6 +637,43 @@ func (n *NullDate) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullDate) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Date, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullDate) Scan(value interface{}) error { + if value == nil { + n.Date, n.Valid = civil.Date{}, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullDate: %v", p) + case *civil.Date: + n.Date = *p + case civil.Date: + n.Date = p + case *NullDate: + n.Date = p.Date + n.Valid = p.Valid + case NullDate: + n.Date = p.Date + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullDate) GormDataType() string { + return "DATE" +} + // NullNumeric represents a Cloud Spanner Numeric that may be NULL. type NullNumeric struct { Numeric big.Rat // Numeric contains the value when it is non-NULL, and a zero big.Rat when NULL. @@ -501,10 +725,52 @@ func (n *NullNumeric) UnmarshalJSON(payload []byte) error { return nil } +// Value implements the driver.Valuer interface. +func (n NullNumeric) Value() (driver.Value, error) { + if n.IsNull() { + return nil, nil + } + return n.Numeric, nil +} + +// Scan implements the sql.Scanner interface. +func (n *NullNumeric) Scan(value interface{}) error { + if value == nil { + n.Numeric, n.Valid = big.Rat{}, false + return nil + } + n.Valid = true + switch p := value.(type) { + default: + return spannerErrorf(codes.InvalidArgument, "invalid type for NullNumeric: %v", p) + case *big.Rat: + n.Numeric = *p + case big.Rat: + n.Numeric = p + case *NullNumeric: + n.Numeric = p.Numeric + n.Valid = p.Valid + case NullNumeric: + n.Numeric = p.Numeric + n.Valid = p.Valid + } + return nil +} + +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullNumeric) GormDataType() string { + return "NUMERIC" +} + // NullJSON represents a Cloud Spanner JSON that may be NULL. // // This type must always be used when encoding values to a JSON column in Cloud // Spanner. +// +// NullJSON does not implement the driver.Valuer and sql.Scanner interfaces, as +// the underlying value can be anything. This means that the type NullJSON must +// also be used when calling sql.Row#Scan(dest ...interface{}) for a JSON +// column. type NullJSON struct { Value interface{} // Val contains the value when it is non-NULL, and nil when NULL. Valid bool // Valid is true if Json is not NULL. @@ -554,6 +820,11 @@ func (n *NullJSON) UnmarshalJSON(payload []byte) error { return nil } +// GormDataType is used by gorm to determine the default data type for fields with this type. +func (n NullJSON) GormDataType() string { + return "JSON" +} + // NullRow represents a Cloud Spanner STRUCT that may be NULL. // See also the document for Row. // Note that NullRow is not a valid Cloud Spanner column Type. @@ -700,7 +971,12 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return err } *p = x - case *NullString, **string: + case *NullString, **string, *sql.NullString: + // Most Null* types are automatically supported for both spanner.Null* and sql.Null* types, except for + // NullString, and we need to add explicit support for it here. The reason that the other types are + // automatically supported is that they use the same field names (e.g. spanner.NullBool and sql.NullBool both + // contain the fields Valid and Bool). spanner.NullString has a field StringVal, sql.NullString has a field + // String. if p == nil { return errNilDst(p) } @@ -713,6 +989,8 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { *sp = NullString{} case **string: *sp = nil + case *sql.NullString: + *sp = sql.NullString{} } break } @@ -726,6 +1004,9 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { sp.StringVal = x case **string: *sp = &x + case *sql.NullString: + sp.Valid = true + sp.String = x } case *[]NullString, *[]*string: if p == nil { @@ -2737,6 +3018,11 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { return encodeValue(v.StringVal) } pt = stringType() + case sql.NullString: + if v.Valid { + return encodeValue(v.String) + } + pt = stringType() case []string: if v != nil { pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] })