Skip to content

Commit

Permalink
fix(spanner): allow decoding null values to spanner.Decoder (#4558)
Browse files Browse the repository at this point in the history
Allow NULL values from the database to be passed in to the DecodeSpanner method
of a struct that implements spanner.Decoder.

Fixes #4552
  • Loading branch information
olavloite committed Aug 5, 2021
1 parent ed73554 commit 45ddaca
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
22 changes: 20 additions & 2 deletions spanner/value.go
Expand Up @@ -1330,7 +1330,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error {
// Check if the pointer is a custom type that implements spanner.Decoder
// interface.
if decodedVal, ok := ptr.(Decoder); ok {
x, err := getGenericValue(v)
x, err := getGenericValue(t, v)
if err != nil {
return err
}
Expand Down Expand Up @@ -1909,19 +1909,37 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) {
}

// getGenericValue returns the interface{} value encoded in proto3.Value.
func getGenericValue(v *proto3.Value) (interface{}, error) {
func getGenericValue(t *sppb.Type, v *proto3.Value) (interface{}, error) {
switch x := v.GetKind().(type) {
case *proto3.Value_NumberValue:
return x.NumberValue, nil
case *proto3.Value_BoolValue:
return x.BoolValue, nil
case *proto3.Value_StringValue:
return x.StringValue, nil
case *proto3.Value_NullValue:
return getTypedNil(t)
default:
return 0, errSrcVal(v, "Number, Bool, String")
}
}

func getTypedNil(t *sppb.Type) (interface{}, error) {
switch t.Code {
case sppb.TypeCode_FLOAT64:
var f *float64
return f, nil
case sppb.TypeCode_BOOL:
var b *bool
return b, nil
default:
// The encoding for most types is string, except for the ones listed
// above.
var s *string
return s, nil
}
}

// errUnexpectedNumericStr returns error for decoder getting an unexpected
// string for representing special numeric values.
func errUnexpectedNumericStr(s string) error {
Expand Down
31 changes: 30 additions & 1 deletion spanner/value_test.go
Expand Up @@ -179,6 +179,21 @@ func (c *customStructToDate) DecodeSpanner(val interface{}) (err error) {
return nil
}

type customStructToNull struct {
val interface{}
}

func (c customStructToNull) EncodeSpanner() (interface{}, error) {
return c.val, nil
}

func (c *customStructToNull) DecodeSpanner(val interface{}) (err error) {
if reflect.ValueOf(val).IsNil() {
return nil
}
return fmt.Errorf("val mismatch: expected nil, got %v", val)
}

// Test encoding Values.
func TestEncodeValue(t *testing.T) {
type CustomString string
Expand Down Expand Up @@ -384,6 +399,14 @@ func TestEncodeValue(t *testing.T) {
{customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"},
{customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"},
{customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"},
{customStructToNull{val: bNilPtr}, nullProto(), tBool, "a struct to null bool"},
{customStructToNull{val: []byte(nil)}, nullProto(), tBytes, "a struct to null bytes"},
{customStructToNull{val: sNilPtr}, nullProto(), tString, "a struct to null string"},
{customStructToNull{val: iNilPtr}, nullProto(), tInt, "a struct to null int"},
{customStructToNull{val: fNilPtr}, nullProto(), tFloat, "a struct to null float"},
{customStructToNull{val: numNilPtr}, nullProto(), tNumeric, "a struct to null numeric"},
{customStructToNull{val: dNilPtr}, nullProto(), tDate, "a struct to null date"},
{customStructToNull{val: tNilPtr}, nullProto(), tTime, "a struct to null timestamp"},
// CUSTOM NUMERIC / CUSTOM NUMERIC ARRAY
{CustomNumeric(*numValuePtr), numericProto(numValuePtr), tNumeric, "CustomNumeric"},
{CustomNullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "CustomNullNumeric with value"},
Expand Down Expand Up @@ -1632,6 +1655,12 @@ func TestDecodeValue(t *testing.T) {
{desc: "decode BYTES to CustomStructToBytes", proto: bytesProto([]byte("AB")), protoType: bytesType(), want: customStructToBytes{[]byte("A"), []byte("B")}},
{desc: "decode TIMESTAMP to CustomStructToTime", proto: timeProto(t1), protoType: timeType(), want: customStructToTime{"A", "B"}},
{desc: "decode DATE to CustomStructToDate", proto: dateProto(d1), protoType: dateType(), want: customStructToDate{"A", "B"}},
{desc: "decode NULL bool to CustomStructToNull", proto: nullProto(), protoType: boolType(), want: customStructToNull{}},
{desc: "decode NULL float to CustomStructToNull", proto: nullProto(), protoType: floatType(), want: customStructToNull{}},
{desc: "decode NULL string to CustomStructToNull", proto: nullProto(), protoType: stringType(), want: customStructToNull{}},
{desc: "decode NULL array of bool to CustomStructToNull", proto: nullProto(), protoType: listType(boolType()), want: customStructToNull{}},
{desc: "decode NULL array of float to CustomStructToNull", proto: nullProto(), protoType: listType(floatType()), want: customStructToNull{}},
{desc: "decode NULL array of string to CustomStructToNull", proto: nullProto(), protoType: listType(stringType()), want: customStructToNull{}},
} {
gotp := reflect.New(reflect.TypeOf(test.want))
v := gotp.Interface()
Expand Down Expand Up @@ -1665,7 +1694,7 @@ func TestDecodeValue(t *testing.T) {
continue
}
got := reflect.Indirect(gotp).Interface()
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{})) {
if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) {
t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want)
}
}
Expand Down

0 comments on commit 45ddaca

Please sign in to comment.