From ade8ab111315d84fa140ddde020387a78668dfa4 Mon Sep 17 00:00:00 2001 From: Hengfeng Li Date: Tue, 24 Aug 2021 14:46:41 +1000 Subject: [PATCH] feat(spanner): add support for JSON data type (#4104) * feat(spanner): add json support * Add NullJSON. * Add array and custom types support. * Add tests for json encoding and decoding for NullJSON. * Add integration test. * Only use NullJSON for encoding. * Update integration test. * Only decode Cloud Spanner JSON to NullJSON type. * Support decoding ARRAY to NullJSON. --- spanner/integration_test.go | 49 +++++++-- spanner/protoutils.go | 4 + spanner/value.go | 211 ++++++++++++++++++++++++++++++++++++ spanner/value_test.go | 90 +++++++++++++++ 4 files changed, 347 insertions(+), 7 deletions(-) diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 00351f244b7..801bb86926d 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -18,6 +18,7 @@ package spanner import ( "context" + "encoding/json" "errors" "flag" "fmt" @@ -113,6 +114,8 @@ var ( DateArray ARRAY, Timestamp TIMESTAMP, TimestampArray ARRAY, + Numeric NUMERIC, + NumericArray ARRAY ) PRIMARY KEY (RowID)`, } @@ -169,6 +172,8 @@ var ( DateArray ARRAY, Timestamp TIMESTAMP, TimestampArray ARRAY, + Numeric NUMERIC, + NumericArray ARRAY ) PRIMARY KEY (RowID)`, } @@ -1561,21 +1566,22 @@ func TestIntegration_BasicTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() stmts := singerDBStatements - stmts = []string{ - `CREATE TABLE Singers ( + if !isEmulatorEnvSet() { + stmts = []string{ + `CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), LastName STRING(1024), SingerInfo BYTES(MAX) ) PRIMARY KEY (SingerId)`, - `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`, - `CREATE TABLE Accounts ( + `CREATE INDEX SingerByName ON Singers(FirstName, LastName)`, + `CREATE TABLE Accounts ( AccountId INT64 NOT NULL, Nickname STRING(100), Balance INT64 NOT NULL, ) PRIMARY KEY (AccountId)`, - `CREATE INDEX AccountByNickname ON Accounts(Nickname) STORING (Balance)`, - `CREATE TABLE Types ( + `CREATE INDEX AccountByNickname ON Accounts(Nickname) STORING (Balance)`, + `CREATE TABLE Types ( RowID INT64 NOT NULL, String STRING(MAX), StringArray ARRAY, @@ -1592,8 +1598,11 @@ func TestIntegration_BasicTypes(t *testing.T) { Timestamp TIMESTAMP, TimestampArray ARRAY, Numeric NUMERIC, - NumericArray ARRAY + NumericArray ARRAY, + JSON JSON, + JSONArray ARRAY ) PRIMARY KEY (RowID)`, + } } client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, stmts) defer cleanup() @@ -1613,6 +1622,16 @@ func TestIntegration_BasicTypes(t *testing.T) { n1 := *n1p n2 := *n2p + type Message struct { + Name string + Body string + Time int64 + } + msg := Message{"Alice", "Hello", 1294706395881547000} + jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + var unmarshalledJSONstruct interface{} + json.Unmarshal([]byte(jsonStr), &unmarshalledJSONstruct) + tests := []struct { col string val interface{} @@ -1771,6 +1790,22 @@ func TestIntegration_BasicTypes(t *testing.T) { } } + if !isEmulatorEnvSet() { + tests = append(tests, []struct { + col string + val interface{} + want interface{} + }{ + {col: "JSON", val: NullJSON{msg, true}, want: msg}, + {col: "JSON", val: NullJSON{msg, true}, want: NullJSON{unmarshalledJSONstruct, true}}, + {col: "JSON", val: NullJSON{msg, false}}, + {col: "JSON", val: nil, want: NullJSON{}}, + {col: "JSONArray", val: []NullJSON(nil)}, + {col: "JSONArray", val: []NullJSON{}}, + {col: "JSONArray", val: []NullJSON{{msg, true}, {msg, true}, {}}}, + }...) + } + // Verify that we can insert the rows using mutations. var muts []*Mutation for i, test := range tests { diff --git a/spanner/protoutils.go b/spanner/protoutils.go index 6465b2c3342..66eea4b5c29 100644 --- a/spanner/protoutils.go +++ b/spanner/protoutils.go @@ -73,6 +73,10 @@ func numericType() *sppb.Type { return &sppb.Type{Code: sppb.TypeCode_NUMERIC} } +func jsonType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_JSON} +} + func bytesProto(b []byte) *proto3.Value { return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}} } diff --git a/spanner/value.go b/spanner/value.go index 02d67ccaa19..2a8225db37f 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -19,6 +19,7 @@ package spanner import ( "bytes" "encoding/base64" + "encoding/json" "fmt" "math" "math/big" @@ -500,6 +501,59 @@ func (n *NullNumeric) UnmarshalJSON(payload []byte) error { return nil } +// NullJSON represents a Cloud Spanner JSON that may be NULL. +// +// This type must always be used when encoding values to a JSON column in Cloud +// Spanner. +type NullJSON struct { + Value interface{} // Val contains the value when it is non-NULL, and nil when NULL. + Valid bool // Valid is true if Json is not NULL. +} + +// IsNull implements NullableValue.IsNull for NullJSON. +func (n NullJSON) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for NullJSON. +func (n NullJSON) String() string { + if !n.Valid { + return nullString + } + b, err := json.Marshal(n.Value) + if err != nil { + return fmt.Sprintf("error: %v", err) + } + return fmt.Sprintf("%v", string(b)) +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for NullJSON. +func (n NullJSON) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Value) + } + return jsonNullBytes, nil +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullJSON. +func (n *NullJSON) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.Valid = false + return nil + } + var v interface{} + err := json.Unmarshal(payload, &v) + if err != nil { + return fmt.Errorf("payload cannot be converted to a struct: got %v, err: %s", string(payload), err) + } + n.Value = v + n.Valid = true + return nil +} + // NullRow represents a Cloud Spanner STRUCT that may be NULL. // See also the document for Row. // Note that NullRow is not a valid Cloud Spanner column Type. @@ -1076,6 +1130,59 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}) error { return errUnexpectedNumericStr(x) } *p = *y + case *NullJSON: + if p == nil { + return errNilDst(p) + } + if code == sppb.TypeCode_ARRAY { + if acode != sppb.TypeCode_JSON { + return errTypeMismatch(code, acode, ptr) + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeNullJSONArrayToNullJSON(x) + if err != nil { + return err + } + *p = *y + } else { + if code != sppb.TypeCode_JSON { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = NullJSON{} + break + } + x := v.GetStringValue() + var y interface{} + err := json.Unmarshal([]byte(x), &y) + if err != nil { + return err + } + *p = NullJSON{y, true} + } + case *[]NullJSON: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_JSON { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeNullJSONArray(x) + if err != nil { + return err + } + *p = y case *NullNumeric: if p == nil { return errNilDst(p) @@ -1446,6 +1553,7 @@ const ( spannerTypeNullTime spannerTypeNullDate spannerTypeNullNumeric + spannerTypeNullJSON spannerTypeArrayOfNonNullString spannerTypeArrayOfByteArray spannerTypeArrayOfNonNullInt64 @@ -1459,6 +1567,7 @@ const ( spannerTypeArrayOfNullBool spannerTypeArrayOfNullFloat64 spannerTypeArrayOfNullNumeric + spannerTypeArrayOfNullJSON spannerTypeArrayOfNullTime spannerTypeArrayOfNullDate ) @@ -1491,6 +1600,7 @@ var typeOfNullFloat64 = reflect.TypeOf(NullFloat64{}) var typeOfNullTime = reflect.TypeOf(NullTime{}) var typeOfNullDate = reflect.TypeOf(NullDate{}) var typeOfNullNumeric = reflect.TypeOf(NullNumeric{}) +var typeOfNullJSON = reflect.TypeOf(NullJSON{}) // getDecodableSpannerType returns the corresponding decodableSpannerType of // the given pointer. @@ -1522,6 +1632,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullNumeric) { return spannerTypeNullNumeric } + if t.ConvertibleTo(typeOfNullJSON) { + return spannerTypeNullJSON + } case reflect.Struct: t := val.Type() if t.ConvertibleTo(typeOfNonNullNumeric) { @@ -1554,6 +1667,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullNumeric) { return spannerTypeNullNumeric } + if t.ConvertibleTo(typeOfNullJSON) { + return spannerTypeNullJSON + } case reflect.Slice: kind := val.Type().Elem().Kind() switch kind { @@ -1606,6 +1722,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullNumeric) { return spannerTypeArrayOfNullNumeric } + if t.ConvertibleTo(typeOfNullJSON) { + return spannerTypeArrayOfNullJSON + } case reflect.Slice: // The only array-of-array type that is supported is [][]byte. kind := val.Type().Elem().Elem().Kind() @@ -1741,6 +1860,21 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb } else { result = &NullNumeric{*y, true} } + case spannerTypeNullJSON: + if code != sppb.TypeCode_JSON { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + result = &NullJSON{} + break + } + x := v.GetStringValue() + var y interface{} + err := json.Unmarshal([]byte(x), &y) + if err != nil { + return err + } + result = &NullJSON{y, true} case spannerTypeNonNullTime, spannerTypeNullTime: var nt NullTime err := parseNullTime(v, &nt, code, isNull) @@ -1875,6 +2009,23 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb return err } result = y + case spannerTypeArrayOfNullJSON: + if acode != sppb.TypeCode_JSON { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + ptr = nil + return nil + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, jsonType(), "JSON") + if err != nil { + return err + } + result = y case spannerTypeArrayOfNonNullTime, spannerTypeArrayOfNullTime: if acode != sppb.TypeCode_TIMESTAMP { return errTypeMismatch(code, acode, ptr) @@ -2238,6 +2389,42 @@ func decodeNullNumericArray(pb *proto3.ListValue) ([]NullNumeric, error) { return a, nil } +// decodeNullJSONArray decodes proto3.ListValue pb into a NullJSON slice. +func decodeNullJSONArray(pb *proto3.ListValue) ([]NullJSON, error) { + if pb == nil { + return nil, errNilListValue("JSON") + } + a := make([]NullJSON, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, jsonType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "JSON", err) + } + } + return a, nil +} + +// decodeNullJSONArray decodes proto3.ListValue pb into a NullJSON pointer. +func decodeNullJSONArrayToNullJSON(pb *proto3.ListValue) (*NullJSON, error) { + if pb == nil { + return nil, errNilListValue("JSON") + } + strs := []string{} + for _, v := range pb.Values { + if _, ok := v.Kind.(*proto3.Value_NullValue); ok { + strs = append(strs, "null") + } else { + strs = append(strs, v.GetStringValue()) + } + } + s := fmt.Sprintf("[%s]", strings.Join(strs, ",")) + var y interface{} + err := json.Unmarshal([]byte(s), &y) + if err != nil { + return nil, err + } + return &NullJSON{y, true}, nil +} + // decodeNumericPointerArray decodes proto3.ListValue pb into a *big.Rat slice. func decodeNumericPointerArray(pb *proto3.ListValue) ([]*big.Rat, error) { if pb == nil { @@ -2747,6 +2934,23 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(numericType()) + case NullJSON: + if v.Valid { + b, err := json.Marshal(v.Value) + if err != nil { + return nil, nil, err + } + pb.Kind = stringKind(string(b)) + } + return pb, jsonType(), nil + case []NullJSON: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(jsonType()) case *big.Rat: switch LossOfPrecisionHandling { case NumericError: @@ -2936,6 +3140,8 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int destination = reflect.Indirect(reflect.New(reflect.TypeOf(big.Rat{}))) case spannerTypeNullNumeric: destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullNumeric{}))) + case spannerTypeNullJSON: + destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullJSON{}))) case spannerTypeArrayOfNonNullString: if reflect.ValueOf(v).IsNil() { return []string(nil), nil @@ -3011,6 +3217,11 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int return []NullNumeric(nil), nil } destination = reflect.MakeSlice(reflect.TypeOf([]NullNumeric{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) + case spannerTypeArrayOfNullJSON: + if reflect.ValueOf(v).IsNil() { + return []NullJSON(nil), nil + } + destination = reflect.MakeSlice(reflect.TypeOf([]NullJSON{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) default: // This should not be possible. return nil, fmt.Errorf("unknown decodable type found: %v", sourceType) diff --git a/spanner/value_test.go b/spanner/value_test.go index a689b779e64..b6f93b259ea 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -212,6 +212,21 @@ func TestEncodeValue(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate type CustomNullNumeric NullNumeric + type CustomNullJSON NullJSON + + type Message struct { + Name string + Body string + Time int64 + } + msg := Message{"Alice", "Hello", 1294706395881547000} + jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + emptyArrayJSONStr := `[]` + type PtrMessage struct { + Key *string + } + ptrMsg := PtrMessage{} + nullValueJSONStr := `{"Key":null}` sValue := "abc" var sNilPtr *string @@ -240,6 +255,7 @@ func TestEncodeValue(t *testing.T) { tTime = timeType() tDate = dateType() tNumeric = numericType() + tJSON = jsonType() ) for i, test := range []struct { in interface{} @@ -308,6 +324,13 @@ func TestEncodeValue(t *testing.T) { {[]NullNumeric{{*numValuePtr, true}, {*numValuePtr, false}}, listProto(numericProto(numValuePtr), nullProto()), listType(tNumeric), "[]NullNumeric"}, {[]*big.Rat{nil, numValuePtr}, listProto(nullProto(), numericProto(numValuePtr)), listType(tNumeric), "[]*big.Rat"}, {[]*big.Rat(nil), nullProto(), listType(tNumeric), "null []*big.Rat"}, + // JSON + {NullJSON{msg, true}, stringProto(jsonStr), tJSON, "NullJSON with value"}, + {NullJSON{msg, false}, nullProto(), tJSON, "NullJSON with null"}, + {[]NullJSON(nil), nullProto(), listType(tJSON), "null []NullJSON"}, + {[]NullJSON{{msg, true}, {msg, false}}, listProto(stringProto(jsonStr), nullProto()), listType(tJSON), "[]NullJSON"}, + {NullJSON{[]Message{}, true}, stringProto(emptyArrayJSONStr), tJSON, "a json string with empty array to NullJSON"}, + {NullJSON{ptrMsg, true}, stringProto(nullValueJSONStr), tJSON, "a json string with null value to NullJSON"}, // TIMESTAMP / TIMESTAMP ARRAY {t1, timeProto(t1), tTime, "time"}, {NullTime{t1, true}, timeProto(t1), tTime, "NullTime with value"}, @@ -419,6 +442,11 @@ func TestEncodeValue(t *testing.T) { {[]CustomNumeric{CustomNumeric(*numValuePtr), CustomNumeric(*num2ValuePtr)}, listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), listType(tNumeric), "[]CustomNumeric"}, {[]CustomNullNumeric(nil), nullProto(), listType(tNumeric), "null []CustomNullNumeric"}, {[]CustomNullNumeric{{*numValuePtr, true}, {*num2ValuePtr, false}}, listProto(numericProto(numValuePtr), nullProto()), listType(tNumeric), "[]CustomNullNumeric"}, + // CUSTOM JSON + {CustomNullJSON{msg, true}, stringProto(jsonStr), tJSON, "CustomNullJSON with value"}, + {CustomNullJSON{msg, false}, nullProto(), tJSON, "CustomNullJSON with null"}, + {[]CustomNullJSON(nil), nullProto(), listType(tJSON), "null []CustomNullJSON"}, + {[]CustomNullJSON{{msg, true}, {msg, false}}, listProto(stringProto(jsonStr), nullProto()), listType(tJSON), "[]CustomNullJSON"}, } { got, gotType, err := encodeValue(test.in) if err != nil { @@ -1301,6 +1329,21 @@ func TestDecodeValue(t *testing.T) { type CustomNullTime NullTime type CustomNullDate NullDate type CustomNullNumeric NullNumeric + type CustomNullJSON NullJSON + + jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + var unmarshalledJSONStruct interface{} + json.Unmarshal([]byte(jsonStr), &unmarshalledJSONStruct) + invalidJSONStr := `{wrong_json_string}` + emptyArrayJSONStr := `[]` + var unmarshalledEmptyJSONArray interface{} + json.Unmarshal([]byte(emptyArrayJSONStr), &unmarshalledEmptyJSONArray) + nullValueJSONStr := `{"Key":null}` + var unmarshalledStructWithNull interface{} + json.Unmarshal([]byte(nullValueJSONStr), &unmarshalledStructWithNull) + arrayJSONStr := `[{"Name":"Alice","Body":"Hello","Time":1294706395881547000},null,true]` + var unmarshalledJSONArray interface{} + json.Unmarshal([]byte(arrayJSONStr), &unmarshalledJSONArray) // Pointer values. sValue := "abc" @@ -1419,6 +1462,16 @@ func TestDecodeValue(t *testing.T) { // NUMERIC ARRAY with []*big.Rat {desc: "decode ARRAY to []*big.Rat", proto: listProto(numericProto(numValuePtr), nullProto(), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []*big.Rat{numValuePtr, nil, num2ValuePtr}}, {desc: "decode NULL to []*big.Rat", proto: nullProto(), protoType: listType(numericType()), want: []*big.Rat(nil)}, + // JSON + {desc: "decode json to NullJSON", proto: stringProto(jsonStr), protoType: jsonType(), want: NullJSON{unmarshalledJSONStruct, true}}, + {desc: "decode NULL to NullJSON", proto: nullProto(), protoType: jsonType(), want: NullJSON{}}, + {desc: "decode an invalid json string", proto: stringProto(invalidJSONStr), protoType: jsonType(), want: NullJSON{}, wantErr: true}, + {desc: "decode a json string with empty array to a NullJSON", proto: stringProto(emptyArrayJSONStr), protoType: jsonType(), want: NullJSON{unmarshalledEmptyJSONArray, true}}, + {desc: "decode a json string with null to a NullJSON", proto: stringProto(nullValueJSONStr), protoType: jsonType(), want: NullJSON{unmarshalledStructWithNull, true}}, + // JSON ARRAY with []NullJSON + {desc: "decode ARRAY to []NullJSON", proto: listProto(stringProto(jsonStr), stringProto(jsonStr), nullProto()), protoType: listType(jsonType()), want: []NullJSON{{unmarshalledJSONStruct, true}, {unmarshalledJSONStruct, true}, {}}}, + {desc: "decode ARRAY to NullJSON", proto: listProto(stringProto(jsonStr), nullProto(), stringProto("true")), protoType: listType(jsonType()), want: NullJSON{unmarshalledJSONArray, true}}, + {desc: "decode NULL to []NullJSON", proto: nullProto(), protoType: listType(jsonType()), want: []NullJSON(nil)}, // TIMESTAMP {desc: "decode TIMESTAMP to time.Time", proto: timeProto(t1), protoType: timeType(), want: t1}, {desc: "decode TIMESTAMP to NullTime", proto: timeProto(t1), protoType: timeType(), want: NullTime{t1, true}}, @@ -1629,6 +1682,7 @@ func TestDecodeValue(t *testing.T) { {desc: "decode BOOL to CustomNullBool", proto: boolProto(true), protoType: boolType(), want: CustomNullBool{true, true}}, {desc: "decode FLOAT64 to CustomNullFloat64", proto: floatProto(6.626), protoType: floatType(), want: CustomNullFloat64{6.626, true}}, {desc: "decode NUMERIC to CustomNullNumeric", proto: numericProto(numValuePtr), protoType: numericType(), want: CustomNullNumeric{*numValuePtr, true}}, + {desc: "decode JSON to CustomNullJSON", proto: stringProto(jsonStr), protoType: jsonType(), want: CustomNullJSON{unmarshalledJSONStruct, true}}, {desc: "decode TIMESTAMP to CustomNullTime", proto: timeProto(t1), protoType: timeType(), want: CustomNullTime{t1, true}}, {desc: "decode DATE to CustomNullDate", proto: dateProto(d1), protoType: dateType(), want: CustomNullDate{d1, true}}, @@ -1637,6 +1691,7 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL to CustomNullBool", proto: nullProto(), protoType: boolType(), want: CustomNullBool{}}, {desc: "decode NULL to CustomNullFloat64", proto: nullProto(), protoType: floatType(), want: CustomNullFloat64{}}, {desc: "decode NULL to CustomNullNumeric", proto: nullProto(), protoType: numericType(), want: CustomNullNumeric{}}, + {desc: "decode NULL to CustomNullJSON", proto: nullProto(), protoType: jsonType(), want: CustomNullJSON{}}, {desc: "decode NULL to CustomNullTime", proto: nullProto(), protoType: timeType(), want: CustomNullTime{}}, {desc: "decode NULL to CustomNullDate", proto: nullProto(), protoType: dateType(), want: CustomNullDate{}}, @@ -1673,6 +1728,9 @@ func TestDecodeValue(t *testing.T) { {desc: "decode ARRAY to []CustomNumeric", proto: listProto(numericProto(numValuePtr), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []CustomNumeric{CustomNumeric(*numValuePtr), CustomNumeric(*num2ValuePtr)}}, {desc: "decode NULL to []CustomNullNumeric", proto: nullProto(), protoType: listType(numericType()), want: []CustomNullNumeric(nil)}, {desc: "decode ARRAY to []CustomNullNumeric", proto: listProto(numericProto(numValuePtr), nullProto(), numericProto(num2ValuePtr)), protoType: listType(numericType()), want: []CustomNullNumeric{{*numValuePtr, true}, {}, {*num2ValuePtr, true}}}, + // JSON ARRAY + {desc: "decode NULL to []CustomNullJSON", proto: nullProto(), protoType: listType(jsonType()), want: []CustomNullJSON(nil)}, + {desc: "decode ARRAY to []CustomNullJSON", proto: listProto(stringProto(jsonStr), stringProto(jsonStr), nullProto()), protoType: listType(jsonType()), want: []CustomNullJSON{{unmarshalledJSONStruct, true}, {unmarshalledJSONStruct, true}, {}}}, // TIME ARRAY {desc: "decode NULL to []CustomTime", proto: nullProto(), protoType: listType(timeType()), want: []CustomTime(nil)}, {desc: "decode ARRAY with NULL values to []CustomTime", proto: listProto(timeProto(t1), nullProto(), timeProto(t2)), protoType: listType(timeType()), want: []CustomTime{}, wantErr: true}, @@ -2429,6 +2487,14 @@ func TestBindParamsDynamic(t *testing.T) { // Test converting nullable types to json strings. func TestJSONMarshal_NullTypes(t *testing.T) { + type Message struct { + Name string + Body string + Time int64 + } + msg := Message{"Alice", "Hello", 1294706395881547000} + jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + type testcase struct { input interface{} expect string @@ -2501,6 +2567,15 @@ func TestJSONMarshal_NullTypes(t *testing.T) { {input: NullNumeric{}, expect: "null"}, }, }, + { + "NullJSON", + []testcase{ + {input: NullJSON{msg, true}, expect: jsonStr}, + {input: &NullJSON{msg, true}, expect: jsonStr}, + {input: &NullJSON{msg, false}, expect: "null"}, + {input: NullJSON{}, expect: "null"}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2516,6 +2591,8 @@ func TestJSONMarshal_NullTypes(t *testing.T) { // Test converting json strings to nullable types. func TestJSONUnmarshal_NullTypes(t *testing.T) { + jsonStr := `{"Body":"Hello","Name":"Alice","Time":1294706395881547000}` + type testcase struct { input []byte got interface{} @@ -2599,6 +2676,16 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { {input: []byte(`"1234.123456789`), got: NullNumeric{}, isNull: true, expect: nullString, expectError: true}, }, }, + { + "NullJSON", + []testcase{ + {input: []byte(jsonStr), got: NullJSON{}, isNull: false, expect: jsonStr, expectError: false}, + {input: []byte("null"), got: NullJSON{}, isNull: true, expect: nullString, expectError: false}, + {input: nil, got: NullJSON{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(""), got: NullJSON{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(`{invalid_json_string}`), got: NullJSON{}, isNull: true, expect: nullString, expectError: true}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2624,6 +2711,9 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { case NullNumeric: err := json.Unmarshal(tc.input, &v) expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) + case NullJSON: + err := json.Unmarshal(tc.input, &v) + expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) default: t.Fatalf("Unknown type: %T", v) }