Skip to content

Commit

Permalink
feat: support NUMERIC as key (#3627)
Browse files Browse the repository at this point in the history
Co-authored-by: skuruppu <skuruppu@google.com>
  • Loading branch information
olavloite and skuruppu committed Jan 29, 2021
1 parent 8617812 commit 84dd78c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
25 changes: 25 additions & 0 deletions spanner/client_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"io"
"math/big"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -2267,3 +2268,27 @@ func TestClient_DoForEachRow_ShouldEndSpanWithQueryError(t *testing.T) {
t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, codes.InvalidArgument)
}
}

func TestClient_Single_Read_WithNumericKey(t *testing.T) {
t.Parallel()

_, client, teardown := setupMockedTestServer(t)
defer teardown()
ctx := context.Background()
iter := client.Single().Read(ctx, "Albums", KeySets(Key{*big.NewRat(1, 1)}), []string{"SingerId", "AlbumId", "AlbumTitle"})
defer iter.Stop()
rowCount := int64(0)
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
t.Fatal(err)
}
rowCount++
}
if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
}
}
7 changes: 5 additions & 2 deletions spanner/key.go
Expand Up @@ -19,6 +19,7 @@ package spanner
import (
"bytes"
"fmt"
"math/big"
"time"

"cloud.google.com/go/civil"
Expand Down Expand Up @@ -84,7 +85,7 @@ func keyPartValue(part interface{}) (pb *proto3.Value, err error) {
pb, _, err = encodeValue(int64(v))
case float32:
pb, _, err = encodeValue(float64(v))
case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate:
case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate, big.Rat, NullNumeric:
pb, _, err = encodeValue(v)
case Encoder:
part, err = v.EncodeSpanner()
Expand Down Expand Up @@ -150,7 +151,7 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) {
} else {
fmt.Fprint(b, nullString)
}
case NullInt64, NullFloat64, NullBool:
case NullInt64, NullFloat64, NullBool, NullNumeric:
// The above types implement fmt.Stringer.
fmt.Fprintf(b, "%s", v)
case NullString, NullDate, NullTime:
Expand All @@ -164,6 +165,8 @@ func (key Key) elemString(b *bytes.Buffer, part interface{}) {
fmt.Fprintf(b, "%q", v)
case time.Time:
fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano))
case big.Rat:
fmt.Fprintf(b, "%v", NumericString(&v))
case Encoder:
var err error
part, err = v.EncodeSpanner()
Expand Down
16 changes: 16 additions & 0 deletions spanner/key_test.go
Expand Up @@ -18,6 +18,7 @@ package spanner

import (
"errors"
"math/big"
"testing"
"time"

Expand Down Expand Up @@ -132,6 +133,11 @@ func TestKey(t *testing.T) {
wantProto: listValueProto(stringProto("2016-11-15")),
wantStr: `("2016-11-15")`,
},
{
k: Key{*big.NewRat(1, 1)},
wantProto: listValueProto(stringProto("1.000000000")),
wantStr: `(1.000000000)`,
},
{
k: Key{[]byte("value")},
wantProto: listValueProto(bytesProto([]byte("value"))),
Expand Down Expand Up @@ -203,6 +209,16 @@ func TestKey(t *testing.T) {
wantProto: listValueProto(stringProto("1"), nullProto(), stringProto("value"), floatProto(1.5), boolProto(true)),
wantStr: `(1,<null>,"value",1.5,true)`,
},
{
k: Key{NullNumeric{*big.NewRat(2, 3), true}},
wantProto: listValueProto(stringProto("0.666666667")),
wantStr: "(0.666666667)",
},
{
k: Key{NullNumeric{big.Rat{}, false}},
wantProto: listValueProto(nullProto()),
wantStr: "(<null>)",
},
{
k: Key{customKeyToString("value")},
wantProto: listValueProto(stringProto("value")),
Expand Down

0 comments on commit 84dd78c

Please sign in to comment.