diff --git a/spanner/client_test.go b/spanner/client_test.go index 86c5a0d8317..3e79710f1b3 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -2198,6 +2198,7 @@ func TestClient_DecodeCustomFieldType(t *testing.T) { defer iter.Stop() var results []typesTable + var lenientResults []typesTable for { row, err := iter.Next() if err == iterator.Done { @@ -2212,9 +2213,15 @@ func TestClient_DecodeCustomFieldType(t *testing.T) { t.Fatalf("failed to convert a row to a struct: %v", err) } results = append(results, d) + + var d2 typesTable + if err := row.ToStructLenient(&d2); err != nil { + t.Fatalf("failed to convert a row to a struct: %v", err) + } + lenientResults = append(lenientResults, d2) } - if len(results) > 1 { + if len(results) > 1 || len(lenientResults) > 1 { t.Fatalf("mismatch length of array: got %v, want 1", results) } @@ -2228,7 +2235,11 @@ func TestClient_DecodeCustomFieldType(t *testing.T) { } got := results[0] if !testEqual(got, want) { - t.Fatalf("mismatch result: got %v, want %v", got, want) + t.Fatalf("mismatch result from ToStruct: got %v, want %v", got, want) + } + got = lenientResults[0] + if !testEqual(got, want) { + t.Fatalf("mismatch result from ToStructLenient: got %v, want %v", got, want) } } diff --git a/spanner/examples_test.go b/spanner/examples_test.go index a1fedbbcc20..e2e6ee23cfb 100644 --- a/spanner/examples_test.go +++ b/spanner/examples_test.go @@ -404,6 +404,30 @@ func ExampleRow_ToStruct() { fmt.Println(acct) } +func ExampleRow_ToStructLenient() { + ctx := context.Background() + client, err := spanner.NewClient(ctx, myDB) + if err != nil { + // TODO: Handle error. + } + row, err := client.Single().ReadRow(ctx, "Accounts", spanner.Key{"alice"}, []string{"accountID", "name", "balance"}) + if err != nil { + // TODO: Handle error. + } + + type Account struct { + Name string + Balance int64 + NickName string + } + + var acct Account + if err := row.ToStructLenient(&acct); err != nil { + // TODO: Handle error. + } + fmt.Println(acct) +} + func ExampleReadOnlyTransaction_Read() { ctx := context.Background() client, err := spanner.NewClient(ctx, myDB) diff --git a/spanner/row.go b/spanner/row.go index c5c37430f2e..1aefb4a50ff 100644 --- a/spanner/row.go +++ b/spanner/row.go @@ -288,6 +288,10 @@ func errToStructArgType(p interface{}) error { // 2. Otherwise, if the name of a field matches the name of a column (ignoring case), // decode the column into the field. // +// 3. The number of columns in the row must match the number of exported fields in the struct. +// There must be exactly one match for each column in the row. The method will return an error +// if a column in the row cannot be assigned to a field in the struct. +// // The fields of the destination struct can be of any type that is acceptable // to spanner.Row.Column. // @@ -311,5 +315,48 @@ func (r *Row) ToStruct(p interface{}) error { &sppb.StructType{Fields: r.fields}, &proto3.ListValue{Values: r.vals}, p, + false, + ) +} + +// ToStructLenient fetches the columns in a row into the fields of a struct. +// The rules for mapping a row's columns into a struct's exported fields +// are: +// +// 1. If a field has a `spanner: "column_name"` tag, then decode column +// 'column_name' into the field. A special case is the `spanner: "-"` +// tag, which instructs ToStruct to ignore the field during decoding. +// +// 2. Otherwise, if the name of a field matches the name of a column (ignoring case), +// decode the column into the field. +// +// 3. The number of columns in the row and exported fields in the struct do not need to match. +// Any field in the struct that cannot not be assigned a value from the row is assigned its default value. +// Any column in the row that does not have a corresponding field in the struct is ignored. +// +// The fields of the destination struct can be of any type that is acceptable +// to spanner.Row.Column. +// +// Slice and pointer fields will be set to nil if the source column is NULL, and a +// non-nil value if the column is not NULL. To decode NULL values of other types, use +// one of the spanner.NullXXX types as the type of the destination field. +// +// If ToStructLenient returns an error, the contents of p are undefined. Some fields may +// have been successfully populated, while others were not; you should not use any of +// the fields. +func (r *Row) ToStructLenient(p interface{}) error { + // Check if p is a pointer to a struct + if t := reflect.TypeOf(p); t == nil || t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct { + return errToStructArgType(p) + } + if len(r.vals) != len(r.fields) { + return errFieldsMismatchVals(r) + } + // Call decodeStruct directly to decode the row as a typed proto.ListValue. + return decodeStruct( + &sppb.StructType{Fields: r.fields}, + &proto3.ListValue{Values: r.vals}, + p, + true, ) } diff --git a/spanner/row_test.go b/spanner/row_test.go index 0f7ca970c52..89a6e91bb62 100644 --- a/spanner/row_test.go +++ b/spanner/row_test.go @@ -339,22 +339,28 @@ func TestNilDst(t *testing.T) { })(nil)), }, } { - if gotErr := test.r.Column(0, test.dst); !testEqual(gotErr, test.wantErr) { - t.Errorf("%v: test.r.Column() returns error %v, want %v", i, gotErr, test.wantErr) - } - if gotErr := test.r.ColumnByName("Col0", test.dst); !testEqual(gotErr, test.wantErr) { - t.Errorf("%v: test.r.ColumnByName() returns error %v, want %v", i, gotErr, test.wantErr) - } - // Row.Columns(T) should return nil on T == nil, otherwise, it should return test.wantErr. - wantColumnsErr := test.wantErr - if test.dst == nil { - wantColumnsErr = nil - } - if gotErr := test.r.Columns(test.dst); !testEqual(gotErr, wantColumnsErr) { - t.Errorf("%v: test.r.Columns() returns error %v, want %v", i, gotErr, wantColumnsErr) - } - if gotErr := test.r.ToStruct(test.structDst); !testEqual(gotErr, test.wantToStructErr) { - t.Errorf("%v: test.r.ToStruct() returns error %v, want %v", i, gotErr, test.wantToStructErr) + for j, toStuct := range []func(ptr interface{}) error{test.r.ToStruct, test.r.ToStructLenient} { + if gotErr := test.r.Column(0, test.dst); !testEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.r.Column() returns error %v, want %v", i, gotErr, test.wantErr) + } + if gotErr := test.r.ColumnByName("Col0", test.dst); !testEqual(gotErr, test.wantErr) { + t.Errorf("%v: test.r.ColumnByName() returns error %v, want %v", i, gotErr, test.wantErr) + } + // Row.Columns(T) should return nil on T == nil, otherwise, it should return test.wantErr. + wantColumnsErr := test.wantErr + if test.dst == nil { + wantColumnsErr = nil + } + if gotErr := test.r.Columns(test.dst); !testEqual(gotErr, wantColumnsErr) { + t.Errorf("%v: test.r.Columns() returns error %v, want %v", i, gotErr, wantColumnsErr) + } + if gotErr := toStuct(test.structDst); !testEqual(gotErr, test.wantToStructErr) { + if j == 0 { + t.Errorf("%v: test.r.ToStruct() returns error %v, want %v", i, gotErr, test.wantToStructErr) + } else { + t.Errorf("%v: test.r.ToStructLenient() returns error %v, want %v", i, gotErr, test.wantToStructErr) + } + } } } } @@ -526,6 +532,46 @@ func TestInvalidColumnRequest(t *testing.T) { {Name: "", Type: stringType()}, }}, 0), }, + { + "Call ToStructLenient on row with duplicated column names", + func() error { + s := &struct { + Val string + }{} + r := &Row{ + []*sppb.StructType_Field{ + {Name: "Val", Type: stringType()}, + {Name: "Val", Type: stringType()}, + }, + []*proto3.Value{stringProto("value1"), stringProto("value2")}, + } + return r.ToStructLenient(s) + }, + errDupSpannerField("Val", &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "Val", Type: stringType()}, + {Name: "Val", Type: stringType()}, + }, + }), + }, + { + "Call ToStructLenient on a row with unnamed field", + func() error { + s := &struct { + Val string + }{} + r := &Row{ + []*sppb.StructType_Field{ + {Name: "", Type: stringType()}, + }, + []*proto3.Value{stringProto("value1")}, + } + return r.ToStructLenient(s) + }, + errUnnamedField(&sppb.StructType{Fields: []*sppb.StructType_Field{ + {Name: "", Type: stringType()}, + }}, 0), + }, } { if gotErr := test.f(); !testEqual(gotErr, test.wantErr) { t.Errorf("%v: test.f() returns error %v, want %v", test.desc, gotErr, test.wantErr) @@ -536,22 +582,37 @@ func TestInvalidColumnRequest(t *testing.T) { // Test decoding the row with row.ToStruct into an invalid destination. func TestToStructInvalidDst(t *testing.T) { for _, test := range []struct { - desc string - dst interface{} - wantErr error + desc string + dst interface{} + wantErr error + toStruct func(ptr interface{}) error }{ { - "Decode row as STRUCT into int32", + "row.ToStruct(): Decode row as STRUCT into int32", + proto.Int(1), + errToStructArgType(proto.Int(1)), + row.ToStruct, + }, + { + "ToStructLenient(): Decode row as STRUCT into int32", proto.Int(1), errToStructArgType(proto.Int(1)), + row.ToStructLenient, }, { - "Decode row as STRUCT to nil Go struct", + "row.ToStruct(): Decode row as STRUCT to nil Go struct", (*struct{})(nil), errNilDst((*struct{})(nil)), + row.ToStruct, }, { - "Decode row as STRUCT to Go struct with duplicated fields for the PK column", + "row.ToStructLenient(): Decode row as STRUCT to nil Go struct", + (*struct{})(nil), + errNilDst((*struct{})(nil)), + row.ToStructLenient, + }, + { + "row.ToStruct(): Decode row as STRUCT to Go struct with duplicated fields for the PK column", &struct { PK1 string `spanner:"STRING"` PK2 string `spanner:"STRING"` @@ -560,26 +621,58 @@ func TestToStructInvalidDst(t *testing.T) { PK1 string `spanner:"STRING"` PK2 string `spanner:"STRING"` }{}, "STRING"), + row.ToStruct, }, { - "Decode row as STRUCT to Go struct with no field for the PK column", + "row.ToStructLenient(): Decode row as STRUCT to Go struct with duplicated fields for the PK column", + &struct { + PK1 string `spanner:"STRING"` + PK2 string `spanner:"STRING"` + }{}, + errDupGoField(&struct { + PK1 string `spanner:"STRING"` + PK2 string `spanner:"STRING"` + }{}, "STRING"), + row.ToStructLenient, + }, + { + "row.ToStruct(): Decode row as STRUCT to Go struct with no field for the PK column", &struct { PK1 string `spanner:"_STRING"` }{}, errNoOrDupGoField(&struct { PK1 string `spanner:"_STRING"` }{}, "STRING"), + row.ToStruct, + }, + { + "row.ToStructLenient(): Decode row as STRUCT to Go struct with no field for the PK column", + &struct { + PK1 string `spanner:"_STRING"` + }{}, + nil, + row.ToStructLenient, + }, + { + "row.ToStruct(): Decode row as STRUCT to Go struct with wrong type for the PK column", + &struct { + PK1 int64 `spanner:"STRING"` + }{}, + errDecodeStructField(&sppb.StructType{Fields: row.fields}, "STRING", + errTypeMismatch(sppb.TypeCode_STRING, sppb.TypeCode_TYPE_CODE_UNSPECIFIED, proto.Int64(0))), + row.ToStruct, }, { - "Decode row as STRUCT to Go struct with wrong type for the PK column", + "row.ToStructLenient(): Decode row as STRUCT to Go struct with wrong type for the PK column", &struct { PK1 int64 `spanner:"STRING"` }{}, errDecodeStructField(&sppb.StructType{Fields: row.fields}, "STRING", errTypeMismatch(sppb.TypeCode_STRING, sppb.TypeCode_TYPE_CODE_UNSPECIFIED, proto.Int64(0))), + row.ToStructLenient, }, } { - if gotErr := row.ToStruct(test.dst); !testEqual(gotErr, test.wantErr) { + if gotErr := test.toStruct(test.dst); !testEqual(gotErr, test.wantErr) { t.Errorf("%v: decoding:\ngot %v\nwant %v", test.desc, gotErr, test.wantErr) } } @@ -1408,115 +1501,122 @@ func TestBrokenRow(t *testing.T) { // Test Row.ToStruct(). func TestToStruct(t *testing.T) { - s := []struct { - // STRING / STRING ARRAY - PrimaryKey string `spanner:"STRING"` - NullString NullString `spanner:"NULL_STRING"` - StringArray []NullString `spanner:"STRING_ARRAY"` - NullStringArray []NullString `spanner:"NULL_STRING_ARRAY"` - // BYTES / BYTES ARRAY - Bytes []byte `spanner:"BYTES"` - NullBytes []byte `spanner:"NULL_BYTES"` - BytesArray [][]byte `spanner:"BYTES_ARRAY"` - NullBytesArray [][]byte `spanner:"NULL_BYTES_ARRAY"` - // INT64 / INT64 ARRAY - Int64 int64 `spanner:"INT64"` - NullInt64 NullInt64 `spanner:"NULL_INT64"` - Int64Array []NullInt64 `spanner:"INT64_ARRAY"` - NullInt64Array []NullInt64 `spanner:"NULL_INT64_ARRAY"` - // BOOL / BOOL ARRAY - Bool bool `spanner:"BOOL"` - NullBool NullBool `spanner:"NULL_BOOL"` - BoolArray []NullBool `spanner:"BOOL_ARRAY"` - NullBoolArray []NullBool `spanner:"NULL_BOOL_ARRAY"` - // FLOAT64 / FLOAT64 ARRAY - Float64 float64 `spanner:"FLOAT64"` - NullFloat64 NullFloat64 `spanner:"NULL_FLOAT64"` - Float64Array []NullFloat64 `spanner:"FLOAT64_ARRAY"` - NullFloat64Array []NullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` - // TIMESTAMP / TIMESTAMP ARRAY - Timestamp time.Time `spanner:"TIMESTAMP"` - NullTimestamp NullTime `spanner:"NULL_TIMESTAMP"` - TimestampArray []NullTime `spanner:"TIMESTAMP_ARRAY"` - NullTimestampArray []NullTime `spanner:"NULL_TIMESTAMP_ARRAY"` - // DATE / DATE ARRAY - Date civil.Date `spanner:"DATE"` - NullDate NullDate `spanner:"NULL_DATE"` - DateArray []NullDate `spanner:"DATE_ARRAY"` - NullDateArray []NullDate `spanner:"NULL_DATE_ARRAY"` - // STRUCT ARRAY - StructArray []*struct { - Col1 int64 - Col2 float64 - Col3 string - } `spanner:"STRUCT_ARRAY"` - NullStructArray []*struct { - Col1 int64 - Col2 float64 - Col3 string - } `spanner:"NULL_STRUCT_ARRAY"` - }{ - {}, // got - { + for i, toStuct := range []func(ptr interface{}) error{row.ToStruct, row.ToStructLenient} { + s := []struct { // STRING / STRING ARRAY - "value", - NullString{}, - []NullString{{"value1", true}, {}, {"value3", true}}, - []NullString(nil), + PrimaryKey string `spanner:"STRING"` + NullString NullString `spanner:"NULL_STRING"` + StringArray []NullString `spanner:"STRING_ARRAY"` + NullStringArray []NullString `spanner:"NULL_STRING_ARRAY"` // BYTES / BYTES ARRAY - []byte("value"), - []byte(nil), - [][]byte{[]byte("value1"), nil, []byte("value3")}, - [][]byte(nil), + Bytes []byte `spanner:"BYTES"` + NullBytes []byte `spanner:"NULL_BYTES"` + BytesArray [][]byte `spanner:"BYTES_ARRAY"` + NullBytesArray [][]byte `spanner:"NULL_BYTES_ARRAY"` // INT64 / INT64 ARRAY - int64(17), - NullInt64{}, - []NullInt64{{int64(1), true}, {int64(2), true}, {}}, - []NullInt64(nil), + Int64 int64 `spanner:"INT64"` + NullInt64 NullInt64 `spanner:"NULL_INT64"` + Int64Array []NullInt64 `spanner:"INT64_ARRAY"` + NullInt64Array []NullInt64 `spanner:"NULL_INT64_ARRAY"` // BOOL / BOOL ARRAY - true, - NullBool{}, - []NullBool{{}, {true, true}, {false, true}}, - []NullBool(nil), + Bool bool `spanner:"BOOL"` + NullBool NullBool `spanner:"NULL_BOOL"` + BoolArray []NullBool `spanner:"BOOL_ARRAY"` + NullBoolArray []NullBool `spanner:"NULL_BOOL_ARRAY"` // FLOAT64 / FLOAT64 ARRAY - 1.7, - NullFloat64{}, - []NullFloat64{{}, {}, {1.7, true}}, - []NullFloat64(nil), + Float64 float64 `spanner:"FLOAT64"` + NullFloat64 NullFloat64 `spanner:"NULL_FLOAT64"` + Float64Array []NullFloat64 `spanner:"FLOAT64_ARRAY"` + NullFloat64Array []NullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` // TIMESTAMP / TIMESTAMP ARRAY - tm, - NullTime{}, - []NullTime{{}, {tm, true}}, - []NullTime(nil), + Timestamp time.Time `spanner:"TIMESTAMP"` + NullTimestamp NullTime `spanner:"NULL_TIMESTAMP"` + TimestampArray []NullTime `spanner:"TIMESTAMP_ARRAY"` + NullTimestampArray []NullTime `spanner:"NULL_TIMESTAMP_ARRAY"` // DATE / DATE ARRAY - dt, - NullDate{}, - []NullDate{{}, {dt, true}}, - []NullDate(nil), + Date civil.Date `spanner:"DATE"` + NullDate NullDate `spanner:"NULL_DATE"` + DateArray []NullDate `spanner:"DATE_ARRAY"` + NullDateArray []NullDate `spanner:"NULL_DATE_ARRAY"` + // STRUCT ARRAY - []*struct { + StructArray []*struct { Col1 int64 Col2 float64 Col3 string - }{ - nil, - - {3, 33.3, "three"}, - nil, - }, - []*struct { + } `spanner:"STRUCT_ARRAY"` + NullStructArray []*struct { Col1 int64 Col2 float64 Col3 string - }(nil), - }, // want - } - err := row.ToStruct(&s[0]) - if err != nil { - t.Errorf("row.ToStruct() returns error: %v, want nil", err) - } else if !testEqual(s[0], s[1]) { - t.Errorf("row.ToStruct() fetches struct %v, want %v", s[0], s[1]) + } `spanner:"NULL_STRUCT_ARRAY"` + }{ + {}, // got + { + // STRING / STRING ARRAY + "value", + NullString{}, + []NullString{{"value1", true}, {}, {"value3", true}}, + []NullString(nil), + // BYTES / BYTES ARRAY + []byte("value"), + []byte(nil), + [][]byte{[]byte("value1"), nil, []byte("value3")}, + [][]byte(nil), + // INT64 / INT64 ARRAY + int64(17), + NullInt64{}, + []NullInt64{{int64(1), true}, {int64(2), true}, {}}, + []NullInt64(nil), + // BOOL / BOOL ARRAY + true, + NullBool{}, + []NullBool{{}, {true, true}, {false, true}}, + []NullBool(nil), + // FLOAT64 / FLOAT64 ARRAY + 1.7, + NullFloat64{}, + []NullFloat64{{}, {}, {1.7, true}}, + []NullFloat64(nil), + // TIMESTAMP / TIMESTAMP ARRAY + tm, + NullTime{}, + []NullTime{{}, {tm, true}}, + []NullTime(nil), + // DATE / DATE ARRAY + dt, + NullDate{}, + []NullDate{{}, {dt, true}}, + []NullDate(nil), + // STRUCT ARRAY + []*struct { + Col1 int64 + Col2 float64 + Col3 string + }{ + nil, + + {3, 33.3, "three"}, + nil, + }, + []*struct { + Col1 int64 + Col2 float64 + Col3 string + }(nil), + }, // want + } + funcName := "row.ToStruct()" + if i != 0 { + funcName = "row.ToStructLenient" + } + err := toStuct(&s[0]) + if err != nil { + t.Errorf("%s returns error: %v, want nil", funcName, err) + } else if !testEqual(s[0], s[1]) { + t.Errorf("%s fetches struct %v, want %v", funcName, s[0], s[1]) + } } } @@ -1536,115 +1636,121 @@ func TestToStructWithCustomTypes(t *testing.T) { type CustomDate civil.Date type CustomNullDate NullDate - s := []struct { - // STRING / STRING ARRAY - PrimaryKey CustomString `spanner:"STRING"` - NullString CustomNullString `spanner:"NULL_STRING"` - StringArray []CustomNullString `spanner:"STRING_ARRAY"` - NullStringArray []CustomNullString `spanner:"NULL_STRING_ARRAY"` - // BYTES / BYTES ARRAY - Bytes CustomBytes `spanner:"BYTES"` - NullBytes CustomBytes `spanner:"NULL_BYTES"` - BytesArray []CustomBytes `spanner:"BYTES_ARRAY"` - NullBytesArray []CustomBytes `spanner:"NULL_BYTES_ARRAY"` - // INT64 / INT64 ARRAY - Int64 CustomInt64 `spanner:"INT64"` - NullInt64 CustomNullInt64 `spanner:"NULL_INT64"` - Int64Array []CustomNullInt64 `spanner:"INT64_ARRAY"` - NullInt64Array []CustomNullInt64 `spanner:"NULL_INT64_ARRAY"` - // BOOL / BOOL ARRAY - Bool CustomBool `spanner:"BOOL"` - NullBool CustomNullBool `spanner:"NULL_BOOL"` - BoolArray []CustomNullBool `spanner:"BOOL_ARRAY"` - NullBoolArray []CustomNullBool `spanner:"NULL_BOOL_ARRAY"` - // FLOAT64 / FLOAT64 ARRAY - Float64 CustomFloat64 `spanner:"FLOAT64"` - NullFloat64 CustomNullFloat64 `spanner:"NULL_FLOAT64"` - Float64Array []CustomNullFloat64 `spanner:"FLOAT64_ARRAY"` - NullFloat64Array []CustomNullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` - // TIMESTAMP / TIMESTAMP ARRAY - Timestamp CustomTime `spanner:"TIMESTAMP"` - NullTimestamp CustomNullTime `spanner:"NULL_TIMESTAMP"` - TimestampArray []CustomNullTime `spanner:"TIMESTAMP_ARRAY"` - NullTimestampArray []CustomNullTime `spanner:"NULL_TIMESTAMP_ARRAY"` - // DATE / DATE ARRAY - Date CustomDate `spanner:"DATE"` - NullDate CustomNullDate `spanner:"NULL_DATE"` - DateArray []CustomNullDate `spanner:"DATE_ARRAY"` - NullDateArray []CustomNullDate `spanner:"NULL_DATE_ARRAY"` - - // STRUCT ARRAY - StructArray []*struct { - Col1 CustomInt64 - Col2 CustomFloat64 - Col3 CustomString - } `spanner:"STRUCT_ARRAY"` - NullStructArray []*struct { - Col1 CustomInt64 - Col2 CustomFloat64 - Col3 CustomString - } `spanner:"NULL_STRUCT_ARRAY"` - }{ - {}, // got - { + for i, toStuct := range []func(ptr interface{}) error{row.ToStruct, row.ToStructLenient} { + s := []struct { // STRING / STRING ARRAY - "value", - CustomNullString{}, - []CustomNullString{{"value1", true}, {}, {"value3", true}}, - []CustomNullString(nil), + PrimaryKey CustomString `spanner:"STRING"` + NullString CustomNullString `spanner:"NULL_STRING"` + StringArray []CustomNullString `spanner:"STRING_ARRAY"` + NullStringArray []CustomNullString `spanner:"NULL_STRING_ARRAY"` // BYTES / BYTES ARRAY - CustomBytes("value"), - CustomBytes(nil), - []CustomBytes{[]byte("value1"), nil, []byte("value3")}, - []CustomBytes(nil), + Bytes CustomBytes `spanner:"BYTES"` + NullBytes CustomBytes `spanner:"NULL_BYTES"` + BytesArray []CustomBytes `spanner:"BYTES_ARRAY"` + NullBytesArray []CustomBytes `spanner:"NULL_BYTES_ARRAY"` // INT64 / INT64 ARRAY - CustomInt64(17), - CustomNullInt64{}, - []CustomNullInt64{{int64(1), true}, {int64(2), true}, {}}, - []CustomNullInt64(nil), + Int64 CustomInt64 `spanner:"INT64"` + NullInt64 CustomNullInt64 `spanner:"NULL_INT64"` + Int64Array []CustomNullInt64 `spanner:"INT64_ARRAY"` + NullInt64Array []CustomNullInt64 `spanner:"NULL_INT64_ARRAY"` // BOOL / BOOL ARRAY - true, - CustomNullBool{}, - []CustomNullBool{{}, {true, true}, {false, true}}, - []CustomNullBool(nil), + Bool CustomBool `spanner:"BOOL"` + NullBool CustomNullBool `spanner:"NULL_BOOL"` + BoolArray []CustomNullBool `spanner:"BOOL_ARRAY"` + NullBoolArray []CustomNullBool `spanner:"NULL_BOOL_ARRAY"` // FLOAT64 / FLOAT64 ARRAY - 1.7, - CustomNullFloat64{}, - []CustomNullFloat64{{}, {}, {1.7, true}}, - []CustomNullFloat64(nil), + Float64 CustomFloat64 `spanner:"FLOAT64"` + NullFloat64 CustomNullFloat64 `spanner:"NULL_FLOAT64"` + Float64Array []CustomNullFloat64 `spanner:"FLOAT64_ARRAY"` + NullFloat64Array []CustomNullFloat64 `spanner:"NULL_FLOAT64_ARRAY"` // TIMESTAMP / TIMESTAMP ARRAY - CustomTime(tm), - CustomNullTime{}, - []CustomNullTime{{}, {tm, true}}, - []CustomNullTime(nil), + Timestamp CustomTime `spanner:"TIMESTAMP"` + NullTimestamp CustomNullTime `spanner:"NULL_TIMESTAMP"` + TimestampArray []CustomNullTime `spanner:"TIMESTAMP_ARRAY"` + NullTimestampArray []CustomNullTime `spanner:"NULL_TIMESTAMP_ARRAY"` // DATE / DATE ARRAY - CustomDate(dt), - CustomNullDate{}, - []CustomNullDate{{}, {dt, true}}, - []CustomNullDate(nil), + Date CustomDate `spanner:"DATE"` + NullDate CustomNullDate `spanner:"NULL_DATE"` + DateArray []CustomNullDate `spanner:"DATE_ARRAY"` + NullDateArray []CustomNullDate `spanner:"NULL_DATE_ARRAY"` + // STRUCT ARRAY - []*struct { + StructArray []*struct { Col1 CustomInt64 Col2 CustomFloat64 Col3 CustomString - }{ - nil, - - {3, 33.3, "three"}, - nil, - }, - []*struct { + } `spanner:"STRUCT_ARRAY"` + NullStructArray []*struct { Col1 CustomInt64 Col2 CustomFloat64 Col3 CustomString - }(nil), - }, // want - } - err := row.ToStruct(&s[0]) - if err != nil { - t.Errorf("row.ToStruct() returns error: %v, want nil", err) - } else if !testutil.Equal(s[0], s[1], cmp.AllowUnexported(CustomTime{})) { - t.Errorf("row.ToStruct() fetches struct %v, want %v", s[0], s[1]) + } `spanner:"NULL_STRUCT_ARRAY"` + }{ + {}, // got + { + // STRING / STRING ARRAY + "value", + CustomNullString{}, + []CustomNullString{{"value1", true}, {}, {"value3", true}}, + []CustomNullString(nil), + // BYTES / BYTES ARRAY + CustomBytes("value"), + CustomBytes(nil), + []CustomBytes{[]byte("value1"), nil, []byte("value3")}, + []CustomBytes(nil), + // INT64 / INT64 ARRAY + CustomInt64(17), + CustomNullInt64{}, + []CustomNullInt64{{int64(1), true}, {int64(2), true}, {}}, + []CustomNullInt64(nil), + // BOOL / BOOL ARRAY + true, + CustomNullBool{}, + []CustomNullBool{{}, {true, true}, {false, true}}, + []CustomNullBool(nil), + // FLOAT64 / FLOAT64 ARRAY + 1.7, + CustomNullFloat64{}, + []CustomNullFloat64{{}, {}, {1.7, true}}, + []CustomNullFloat64(nil), + // TIMESTAMP / TIMESTAMP ARRAY + CustomTime(tm), + CustomNullTime{}, + []CustomNullTime{{}, {tm, true}}, + []CustomNullTime(nil), + // DATE / DATE ARRAY + CustomDate(dt), + CustomNullDate{}, + []CustomNullDate{{}, {dt, true}}, + []CustomNullDate(nil), + // STRUCT ARRAY + []*struct { + Col1 CustomInt64 + Col2 CustomFloat64 + Col3 CustomString + }{ + nil, + + {3, 33.3, "three"}, + nil, + }, + []*struct { + Col1 CustomInt64 + Col2 CustomFloat64 + Col3 CustomString + }(nil), + }, // want + } + funcName := "row.ToStruct()" + if i != 0 { + funcName = "row.ToStructLenient()" + } + err := toStuct(&s[0]) + if err != nil { + t.Errorf("%s returns error: %v, want nil", funcName, err) + } else if !testutil.Equal(s[0], s[1], cmp.AllowUnexported(CustomTime{})) { + t.Errorf("%s fetches struct %v, want %v", funcName, s[0], s[1]) + } } } @@ -1674,6 +1780,77 @@ func TestToStructEmbedded(t *testing.T) { if !testEqual(got, want) { t.Errorf("got %+v, want %+v", got, want) } + + var gotLenient S2 + if err := r.ToStructLenient(&gotLenient); err != nil { + t.Fatal(err) + } + if !testEqual(gotLenient, want) { + t.Errorf("gotLenient %+v, want %+v", got, want) + } +} + +func TestToStructWithUnEqualFields(t *testing.T) { + type ( + extraField struct { + F1 string + F2 string + F3 string + } + lessField struct { + F1 string + F3 string + } + ) + testCases := []struct { + name string + dst interface{} + want interface{} + row Row + }{ + { + name: "destination struct has extra field", + dst: &extraField{}, + row: Row{ + []*sppb.StructType_Field{ + {Name: "F1", Type: stringType()}, + {Name: "F2", Type: stringType()}, + }, + []*proto3.Value{ + stringProto("v1"), + stringProto("v2"), + }, + }, + want: &extraField{F1: "v1", F2: "v2", F3: ""}, + }, + { + name: "destination struct has less field", + dst: &lessField{}, + row: Row{ + []*sppb.StructType_Field{ + {Name: "F1", Type: stringType()}, + {Name: "F2", Type: stringType()}, + {Name: "F3", Type: stringType()}, + }, + []*proto3.Value{ + stringProto("v1"), + stringProto("v2"), + stringProto("v3"), + }, + }, + want: &lessField{F1: "v1", F3: "v3"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := tc.row.ToStructLenient(tc.dst); err != nil { + t.Fatal(err) + } + if !testEqual(tc.dst, tc.want) { + t.Errorf("got %+v, want %+v", tc.dst, tc.want) + } + }) + } } func TestRowToString(t *testing.T) { diff --git a/spanner/value.go b/spanner/value.go index 55fc4c270bf..1c0a7ddf30c 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -2871,6 +2871,11 @@ func errNilSpannerStructType() error { return spannerErrorf(codes.FailedPrecondition, "unexpected nil StructType in decoding Cloud Spanner STRUCT") } +// errDupGoField returns error for duplicated Go STRUCT field names +func errDupGoField(s interface{}, name string) error { + return spannerErrorf(codes.InvalidArgument, "Go struct %+v(type %T) has duplicate fields for GO STRUCT field %s", s, s, name) +} + // errUnnamedField returns error for decoding a Cloud Spanner STRUCT with // unnamed field into a Go struct. func errUnnamedField(ty *sppb.StructType, i int) error { @@ -2905,7 +2910,7 @@ func errDecodeStructField(ty *sppb.StructType, f string, err error) error { // decodeStruct decodes proto3.ListValue pb into struct referenced by pointer // ptr, according to // the structural information given in sppb.StructType ty. -func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) error { +func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}, lenient bool) error { if reflect.ValueOf(ptr).IsNil() { return errNilDst(ptr) } @@ -2921,6 +2926,15 @@ func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) er if err != nil { return ToSpannerError(err) } + // return error if lenient is true and destination has duplicate exported columns + if lenient { + fieldNames := getAllFieldNames(v) + for _, f := range fieldNames { + if fields.Match(f) == nil { + return errDupGoField(ptr, f) + } + } + } seen := map[string]bool{} for i, f := range ty.Fields { if f.Name == "" { @@ -2928,6 +2942,9 @@ func decodeStruct(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{}) er } sf := fields.Match(f.Name) if sf == nil { + if lenient { + continue + } return errNoOrDupGoField(ptr, f.Name) } if seen[f.Name] { @@ -2986,7 +3003,7 @@ func decodeStructArray(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{ return errDecodeArrayElement(i, pv, "STRUCT", err) } // Decode proto3.ListValue l into struct referenced by s.Interface(). - if err = decodeStruct(ty, l, s.Interface()); err != nil { + if err = decodeStruct(ty, l, s.Interface(), false); err != nil { return errDecodeArrayElement(i, pv, "STRUCT", err) } // Append the decoded struct back into the slice. @@ -2995,6 +3012,37 @@ func decodeStructArray(ty *sppb.StructType, pb *proto3.ListValue, ptr interface{ return nil } +func getAllFieldNames(v reflect.Value) []string { + var names []string + typeOfT := v.Type() + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + fieldType := typeOfT.Field(i) + exported := (fieldType.PkgPath == "") + // If a named field is unexported, ignore it. An anonymous + // unexported field is processed, because it may contain + // exported fields, which are visible. + if !exported && !fieldType.Anonymous { + continue + } + if f.Kind() == reflect.Struct { + if fieldType.Anonymous { + names = append(names, getAllFieldNames(reflect.ValueOf(f.Interface()))...) + } + continue + } + name, keep, _, _ := spannerTagParser(fieldType.Tag) + if !keep { + continue + } + if name == "" { + name = fieldType.Name + } + names = append(names, name) + } + return names +} + // errEncoderUnsupportedType returns error for not being able to encode a value // of certain type. func errEncoderUnsupportedType(v interface{}) error { diff --git a/spanner/value_test.go b/spanner/value_test.go index b6f93b259ea..aa3627556e8 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -2049,38 +2049,74 @@ func TestDecodeStruct(t *testing.T) { ) for _, test := range []struct { - desc string - ptr interface{} - want interface{} - fail bool + desc string + lenient bool + ptr interface{} + want interface{} + fail bool }{ { - desc: "decode to S1", - ptr: &s1, - want: &S1{ID: "id", Time: t1}, + desc: "decode to S1 with lenient enabled", + ptr: &s1, + want: &S1{ID: "id", Time: t1}, + lenient: true, + }, + { + desc: "decode to S1 with lenient disabled", + ptr: &s1, + want: &S1{ID: "id", Time: t1}, + lenient: false, + }, + { + desc: "decode to S2 with lenient enabled", + ptr: &s2, + fail: true, + lenient: true, + }, + { + desc: "decode to S2 with lenient disabled", + ptr: &s2, + fail: true, + lenient: false, + }, + { + desc: "decode to S3 with lenient enabled", + ptr: &s3, + want: &S3{ID: CustomString("id"), Time: CustomTime(t1)}, + lenient: true, + }, + { + desc: "decode to S3 with lenient disabled", + ptr: &s3, + want: &S3{ID: CustomString("id"), Time: CustomTime(t1)}, + lenient: false, }, { - desc: "decode to S2", - ptr: &s2, - fail: true, + desc: "decode to S4 with lenient enabled", + ptr: &s4, + fail: true, + lenient: true, }, { - desc: "decode to S3", - ptr: &s3, - want: &S3{ID: CustomString("id"), Time: CustomTime(t1)}, + desc: "decode to S4 with lenient disabled", + ptr: &s4, + fail: true, + lenient: false, }, { - desc: "decode to S4", - ptr: &s4, - fail: true, + desc: "decode to S5 with lenient enabled", + ptr: &s5, + want: &S5{NullString: NullString{}, Time: CustomTime(t1)}, + lenient: true, }, { - desc: "decode to S5", - ptr: &s5, - fail: true, + desc: "decode to S5 with lenient disabled", + ptr: &s5, + fail: true, + lenient: false, }, } { - err := decodeStruct(stype, lv, test.ptr) + err := decodeStruct(stype, lv, test.ptr, test.lenient) if (err != nil) != test.fail { t.Errorf("%s: got error %v, wanted fail: %v", test.desc, err, test.fail) } @@ -2195,7 +2231,7 @@ func TestDecodeStructWithPointers(t *testing.T) { want: &S1{Str: nil, Int: nil, Bool: nil, Float: nil, Time: nil, Date: nil, StrArray: nil, IntArray: nil, BoolArray: nil, FloatArray: nil, TimeArray: nil, DateArray: nil}, }, } { - err := decodeStruct(stype, lv[i], test.ptr) + err := decodeStruct(stype, lv[i], test.ptr, false) if (err != nil) != test.fail { t.Errorf("%s: got error %v, wanted fail: %v", test.desc, err, test.fail) }