From 45ddaca606a372d9293bf2e2b3dc6d4398166c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 5 Aug 2021 10:25:29 +0200 Subject: [PATCH] fix(spanner): allow decoding null values to spanner.Decoder (#4558) Allow NULL values from the database to be passed in to the DecodeSpanner method of a struct that implements spanner.Decoder. Fixes #4552 --- spanner/value.go | 22 ++++++++++++++++++++-- spanner/value_test.go | 31 ++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/spanner/value.go b/spanner/value.go index 7a2a4033192..4f0ca8cfc37 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -1330,7 +1330,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { // Check if the pointer is a custom type that implements spanner.Decoder // interface. if decodedVal, ok := ptr.(Decoder); ok { - x, err := getGenericValue(v) + x, err := getGenericValue(t, v) if err != nil { return err } @@ -1909,7 +1909,7 @@ func getListValue(v *proto3.Value) (*proto3.ListValue, error) { } // getGenericValue returns the interface{} value encoded in proto3.Value. -func getGenericValue(v *proto3.Value) (interface{}, error) { +func getGenericValue(t *sppb.Type, v *proto3.Value) (interface{}, error) { switch x := v.GetKind().(type) { case *proto3.Value_NumberValue: return x.NumberValue, nil @@ -1917,11 +1917,29 @@ func getGenericValue(v *proto3.Value) (interface{}, error) { return x.BoolValue, nil case *proto3.Value_StringValue: return x.StringValue, nil + case *proto3.Value_NullValue: + return getTypedNil(t) default: return 0, errSrcVal(v, "Number, Bool, String") } } +func getTypedNil(t *sppb.Type) (interface{}, error) { + switch t.Code { + case sppb.TypeCode_FLOAT64: + var f *float64 + return f, nil + case sppb.TypeCode_BOOL: + var b *bool + return b, nil + default: + // The encoding for most types is string, except for the ones listed + // above. + var s *string + return s, nil + } +} + // errUnexpectedNumericStr returns error for decoder getting an unexpected // string for representing special numeric values. func errUnexpectedNumericStr(s string) error { diff --git a/spanner/value_test.go b/spanner/value_test.go index 1d40fe43f89..5dccbe9d738 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -179,6 +179,21 @@ func (c *customStructToDate) DecodeSpanner(val interface{}) (err error) { return nil } +type customStructToNull struct { + val interface{} +} + +func (c customStructToNull) EncodeSpanner() (interface{}, error) { + return c.val, nil +} + +func (c *customStructToNull) DecodeSpanner(val interface{}) (err error) { + if reflect.ValueOf(val).IsNil() { + return nil + } + return fmt.Errorf("val mismatch: expected nil, got %v", val) +} + // Test encoding Values. func TestEncodeValue(t *testing.T) { type CustomString string @@ -384,6 +399,14 @@ func TestEncodeValue(t *testing.T) { {customStructToBytes{[]byte("A"), []byte("B")}, bytesProto([]byte("AB")), tBytes, "a struct to bytes"}, {customStructToTime{"A", "B"}, timeProto(tValue), tTime, "a struct to time"}, {customStructToDate{"A", "B"}, dateProto(dValue), tDate, "a struct to date"}, + {customStructToNull{val: bNilPtr}, nullProto(), tBool, "a struct to null bool"}, + {customStructToNull{val: []byte(nil)}, nullProto(), tBytes, "a struct to null bytes"}, + {customStructToNull{val: sNilPtr}, nullProto(), tString, "a struct to null string"}, + {customStructToNull{val: iNilPtr}, nullProto(), tInt, "a struct to null int"}, + {customStructToNull{val: fNilPtr}, nullProto(), tFloat, "a struct to null float"}, + {customStructToNull{val: numNilPtr}, nullProto(), tNumeric, "a struct to null numeric"}, + {customStructToNull{val: dNilPtr}, nullProto(), tDate, "a struct to null date"}, + {customStructToNull{val: tNilPtr}, nullProto(), tTime, "a struct to null timestamp"}, // CUSTOM NUMERIC / CUSTOM NUMERIC ARRAY {CustomNumeric(*numValuePtr), numericProto(numValuePtr), tNumeric, "CustomNumeric"}, {CustomNullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "CustomNullNumeric with value"}, @@ -1632,6 +1655,12 @@ func TestDecodeValue(t *testing.T) { {desc: "decode BYTES to CustomStructToBytes", proto: bytesProto([]byte("AB")), protoType: bytesType(), want: customStructToBytes{[]byte("A"), []byte("B")}}, {desc: "decode TIMESTAMP to CustomStructToTime", proto: timeProto(t1), protoType: timeType(), want: customStructToTime{"A", "B"}}, {desc: "decode DATE to CustomStructToDate", proto: dateProto(d1), protoType: dateType(), want: customStructToDate{"A", "B"}}, + {desc: "decode NULL bool to CustomStructToNull", proto: nullProto(), protoType: boolType(), want: customStructToNull{}}, + {desc: "decode NULL float to CustomStructToNull", proto: nullProto(), protoType: floatType(), want: customStructToNull{}}, + {desc: "decode NULL string to CustomStructToNull", proto: nullProto(), protoType: stringType(), want: customStructToNull{}}, + {desc: "decode NULL array of bool to CustomStructToNull", proto: nullProto(), protoType: listType(boolType()), want: customStructToNull{}}, + {desc: "decode NULL array of float to CustomStructToNull", proto: nullProto(), protoType: listType(floatType()), want: customStructToNull{}}, + {desc: "decode NULL array of string to CustomStructToNull", proto: nullProto(), protoType: listType(stringType()), want: customStructToNull{}}, } { gotp := reflect.New(reflect.TypeOf(test.want)) v := gotp.Interface() @@ -1665,7 +1694,7 @@ func TestDecodeValue(t *testing.T) { continue } got := reflect.Indirect(gotp).Interface() - if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{})) { + if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) { t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want) } }