diff --git a/spanner/client_test.go b/spanner/client_test.go index 4fc6219a0f4..cd29a79367c 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "math/big" "os" "strings" "testing" @@ -2267,3 +2268,27 @@ func TestClient_DoForEachRow_ShouldEndSpanWithQueryError(t *testing.T) { t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, codes.InvalidArgument) } } + +func TestClient_Single_Read_WithNumericKey(t *testing.T) { + t.Parallel() + + _, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + iter := client.Single().Read(ctx, "Albums", KeySets(Key{*big.NewRat(1, 1)}), []string{"SingerId", "AlbumId", "AlbumTitle"}) + defer iter.Stop() + rowCount := int64(0) + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + rowCount++ + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) + } +} diff --git a/spanner/key.go b/spanner/key.go index d9819013bf4..aa876ec3d85 100644 --- a/spanner/key.go +++ b/spanner/key.go @@ -19,6 +19,7 @@ package spanner import ( "bytes" "fmt" + "math/big" "time" "cloud.google.com/go/civil" @@ -84,7 +85,7 @@ func keyPartValue(part interface{}) (pb *proto3.Value, err error) { pb, _, err = encodeValue(int64(v)) case float32: pb, _, err = encodeValue(float64(v)) - case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate: + case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric: pb, _, err = encodeValue(v) case Encoder: part, err = v.EncodeSpanner() @@ -150,7 +151,7 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) { } else { fmt.Fprint(b, nullString) } - case NullInt64, NullFloat64, NullBool: + case NullInt64, NullFloat64, NullBool, NullNumeric: // The above types implement fmt.Stringer. fmt.Fprintf(b, "%s", v) case NullString, NullDate, NullTime: @@ -164,6 +165,8 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) { fmt.Fprintf(b, "%q", v) case time.Time: fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano)) + case big.Rat: + fmt.Fprintf(b, "%v", NumericString(&v)) case Encoder: var err error part, err = v.EncodeSpanner() diff --git a/spanner/key_test.go b/spanner/key_test.go index 0d5723e17a2..da103fde95e 100644 --- a/spanner/key_test.go +++ b/spanner/key_test.go @@ -18,6 +18,7 @@ package spanner import ( "errors" + "math/big" "testing" "time" @@ -132,6 +133,11 @@ func TestKey(t *testing.T) { wantProto: listValueProto(stringProto("2016-11-15")), wantStr: `("2016-11-15")`, }, + { + k: Key{*big.NewRat(1, 1)}, + wantProto: listValueProto(stringProto("1.000000000")), + wantStr: `(1.000000000)`, + }, { k: Key{[]byte("value")}, wantProto: listValueProto(bytesProto([]byte("value"))), @@ -203,6 +209,16 @@ func TestKey(t *testing.T) { wantProto: listValueProto(stringProto("1"), nullProto(), stringProto("value"), floatProto(1.5), boolProto(true)), wantStr: `(1,,"value",1.5,true)`, }, + { + k: Key{NullNumeric{*big.NewRat(2, 3), true}}, + wantProto: listValueProto(stringProto("0.666666667")), + wantStr: "(0.666666667)", + }, + { + k: Key{NullNumeric{big.Rat{}, false}}, + wantProto: listValueProto(nullProto()), + wantStr: "()", + }, { k: Key{customKeyToString("value")}, wantProto: listValueProto(stringProto("value")),