Skip to content

Commit

Permalink
feat(spanner): add custom type support to spanner.Key (#2748)
Browse files Browse the repository at this point in the history
fixes #2747
  • Loading branch information
vvakame committed Aug 20, 2020
1 parent 8bb85ec commit f799758
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 29 deletions.
74 changes: 46 additions & 28 deletions spanner/key.go
Expand Up @@ -86,6 +86,12 @@ func keyPartValue(part interface{}) (pb *proto3.Value, err error) {
pb, _, err = encodeValue(float64(v))
case int64, float64, NullInt64, NullFloat64, bool, NullBool, []byte, string, NullString, time.Time, civil.Date, NullTime, NullDate:
pb, _, err = encodeValue(v)
case Encoder:
part, err = v.EncodeSpanner()
if err != nil {
return nil, err
}
pb, err = keyPartValue(part)
default:
return nil, errInvdKeyPartType(v)
}
Expand Down Expand Up @@ -125,38 +131,50 @@ func (key Key) String() string {
if i != 0 {
fmt.Fprint(b, ",")
}
switch v := part.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, bool:
// Use %v to print numeric types and bool.
fmt.Fprintf(b, "%v", v)
case string:
key.elemString(b, part)
}
fmt.Fprint(b, ")")
return b.String()
}

func (key Key) elemString(b *bytes.Buffer, part interface{}) {
switch v := part.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, bool:
// Use %v to print numeric types and bool.
fmt.Fprintf(b, "%v", v)
case string:
fmt.Fprintf(b, "%q", v)
case []byte:
if v != nil {
fmt.Fprintf(b, "%q", v)
case []byte:
if v != nil {
fmt.Fprintf(b, "%q", v)
} else {
fmt.Fprint(b, nullString)
}
case NullInt64, NullFloat64, NullBool:
// The above types implement fmt.Stringer.
fmt.Fprintf(b, "%s", v)
case NullString, NullDate, NullTime:
// Quote the returned string if it is not null.
if v.(NullableValue).IsNull() {
fmt.Fprintf(b, "%s", nullString)
} else {
fmt.Fprintf(b, "%q", v)
}
case civil.Date:
} else {
fmt.Fprint(b, nullString)
}
case NullInt64, NullFloat64, NullBool:
// The above types implement fmt.Stringer.
fmt.Fprintf(b, "%s", v)
case NullString, NullDate, NullTime:
// Quote the returned string if it is not null.
if v.(NullableValue).IsNull() {
fmt.Fprintf(b, "%s", nullString)
} else {
fmt.Fprintf(b, "%q", v)
case time.Time:
fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano))
default:
fmt.Fprintf(b, "%v", v)
}
case civil.Date:
fmt.Fprintf(b, "%q", v)
case time.Time:
fmt.Fprintf(b, "%q", v.Format(time.RFC3339Nano))
case Encoder:
var err error
part, err = v.EncodeSpanner()
if err != nil {
fmt.Fprintf(b, "error")
} else {
key.elemString(b, part)
}
default:
fmt.Fprintf(b, "%v", v)
}
fmt.Fprint(b, ")")
return b.String()
}

// AsPrefix returns a KeyRange for all keys where k is the prefix.
Expand Down
56 changes: 55 additions & 1 deletion spanner/key_test.go
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package spanner

import (
"errors"
"testing"
"time"

Expand All @@ -25,6 +26,24 @@ import (
sppb "google.golang.org/genproto/googleapis/spanner/v1"
)

type customKeyToString string

func (k customKeyToString) EncodeSpanner() (interface{}, error) {
return string(k), nil
}

type customKeyToInt int

func (k customKeyToInt) EncodeSpanner() (interface{}, error) {
return int(k), nil
}

type customKeyToError struct{}

func (k customKeyToError) EncodeSpanner() (interface{}, error) {
return nil, errors.New("always error")
}

// Test Key.String() and Key.proto().
func TestKey(t *testing.T) {
tm, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
Expand Down Expand Up @@ -184,12 +203,27 @@ func TestKey(t *testing.T) {
wantProto: listValueProto(stringProto("1"), nullProto(), stringProto("value"), floatProto(1.5), boolProto(true)),
wantStr: `(1,<null>,"value",1.5,true)`,
},
{
k: Key{customKeyToString("value")},
wantProto: listValueProto(stringProto("value")),
wantStr: `("value")`,
},
{
k: Key{customKeyToInt(1)},
wantProto: listValueProto(intProto(1)),
wantStr: `(1)`,
},
{
k: Key{customKeyToError{}},
wantProto: nil,
wantStr: `(error)`,
},
} {
if got := test.k.String(); got != test.wantStr {
t.Errorf("%v.String() = %v, want %v", test.k, got, test.wantStr)
}
gotProto, err := test.k.proto()
if err != nil {
if test.wantProto != nil && err != nil {
t.Errorf("%v.proto() returns error %v; want nil error", test.k, err)
}
if !testEqual(gotProto, test.wantProto) {
Expand Down Expand Up @@ -237,6 +271,14 @@ func TestKeyRange(t *testing.T) {
},
wantStr: "[(1),(10)]",
},
{
kr: KeyRange{Key{customKeyToString("A")}, Key{customKeyToString("D")}, OpenOpen},
wantProto: &sppb.KeyRange{
StartKeyType: &sppb.KeyRange_StartOpen{StartOpen: listValueProto(stringProto("A"))},
EndKeyType: &sppb.KeyRange_EndOpen{EndOpen: listValueProto(stringProto("D"))},
},
wantStr: `(("A"),("D"))`,
},
} {
if got := test.kr.String(); got != test.wantStr {
t.Errorf("%v.String() = %v, want %v", test.kr, got, test.wantStr)
Expand Down Expand Up @@ -360,6 +402,18 @@ func TestKeySets(t *testing.T) {
Key{6}),
&sppb.KeySet{All: true},
},
{
KeySets(
Key{customKeyToInt(1), customKeyToInt(2)},
Key{customKeyToInt(3), customKeyToInt(4)},
),
&sppb.KeySet{
Keys: []*proto3.ListValue{
listValueProto(int1, int2),
listValueProto(int3, int4),
},
},
},
} {
gotProto, err := test.ks.keySetProto()
if err != nil {
Expand Down

0 comments on commit f799758

Please sign in to comment.