Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for NUMERIC type #86

Merged
merged 9 commits into from Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Expand Up @@ -15,6 +15,7 @@
"""Helper functions for Cloud Spanner."""

import datetime
import decimal
import math

import six
Expand Down Expand Up @@ -127,6 +128,8 @@ def _make_value_pb(value):
return Value(string_value=value)
if isinstance(value, ListValue):
return Value(list_value=value)
if isinstance(value, decimal.Decimal):
return Value(string_value=str(value))
raise ValueError("Unknown type: %s" % (value,))


Expand Down Expand Up @@ -201,6 +204,8 @@ def _parse_value_pb(value_pb, field_type):
_parse_value_pb(item_pb, field_type.struct_type.fields[i].type)
for (i, item_pb) in enumerate(value_pb.list_value.values)
]
elif field_type.code == type_pb2.NUMERIC:
result = decimal.Decimal(value_pb.string_value)
else:
raise ValueError("Unknown type: %s" % (field_type,))
return result
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/param_types.py
Expand Up @@ -25,6 +25,7 @@
FLOAT64 = type_pb2.Type(code=type_pb2.FLOAT64)
DATE = type_pb2.Type(code=type_pb2.DATE)
TIMESTAMP = type_pb2.Type(code=type_pb2.TIMESTAMP)
NUMERIC = type_pb2.Type(code=type_pb2.NUMERIC)


def Array(element_type): # pylint: disable=invalid-name
Expand Down
53 changes: 53 additions & 0 deletions tests/_fixtures.py
Expand Up @@ -16,6 +16,58 @@


DDL = """\
CREATE TABLE contacts (
contact_id INT64,
first_name STRING(1024),
last_name STRING(1024),
email STRING(1024) )
PRIMARY KEY (contact_id);
CREATE TABLE contact_phones (
contact_id INT64,
phone_type STRING(1024),
phone_number STRING(1024) )
PRIMARY KEY (contact_id, phone_type),
INTERLEAVE IN PARENT contacts ON DELETE CASCADE;
CREATE TABLE all_types (
pkey INT64 NOT NULL,
int_value INT64,
int_array ARRAY<INT64>,
bool_value BOOL,
bool_array ARRAY<BOOL>,
bytes_value BYTES(16),
bytes_array ARRAY<BYTES(16)>,
date_value DATE,
date_array ARRAY<DATE>,
float_value FLOAT64,
float_array ARRAY<FLOAT64>,
string_value STRING(16),
string_array ARRAY<STRING(16)>,
timestamp_value TIMESTAMP,
timestamp_array ARRAY<TIMESTAMP>,
numeric_value NUMERIC,
numeric_array ARRAY<NUMERIC>)
PRIMARY KEY (pkey);
CREATE TABLE counters (
name STRING(1024),
value INT64 )
PRIMARY KEY (name);
CREATE TABLE string_plus_array_of_string (
id INT64,
name STRING(16),
tags ARRAY<STRING(16)> )
PRIMARY KEY (id);
CREATE INDEX name ON contacts(first_name, last_name);
CREATE TABLE users_history (
id INT64 NOT NULL,
commit_ts TIMESTAMP NOT NULL OPTIONS
(allow_commit_timestamp=true),
name STRING(MAX) NOT NULL,
email STRING(MAX),
deleted BOOL NOT NULL )
PRIMARY KEY(id, commit_ts DESC);
"""

EMULATOR_DDL = """\
larkee marked this conversation as resolved.
Show resolved Hide resolved
CREATE TABLE contacts (
contact_id INT64,
first_name STRING(1024),
Expand Down Expand Up @@ -66,3 +118,4 @@
"""

DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()]
EMULATOR_DDL_STATEMENTS = [stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip()]
80 changes: 70 additions & 10 deletions tests/system/test_system.py
Expand Up @@ -14,6 +14,7 @@

import collections
import datetime
import decimal
import math
import operator
import os
Expand All @@ -38,6 +39,7 @@
from google.cloud.spanner_v1.proto.type_pb2 import INT64
from google.cloud.spanner_v1.proto.type_pb2 import STRING
from google.cloud.spanner_v1.proto.type_pb2 import TIMESTAMP
from google.cloud.spanner_v1.proto.type_pb2 import NUMERIC
from google.cloud.spanner_v1.proto.type_pb2 import Type

from google.cloud._helpers import UTC
Expand All @@ -52,10 +54,12 @@
from test_utils.retry import RetryResult
from test_utils.system import unique_resource_id
from tests._fixtures import DDL_STATEMENTS
from tests._fixtures import EMULATOR_DDL_STATEMENTS


CREATE_INSTANCE = os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None
USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None
SKIP_BACKUP_TESTS = os.getenv("SKIP_BACKUP_TESTS") is not None

if CREATE_INSTANCE:
INSTANCE_ID = "google-cloud" + unique_resource_id("-")
Expand Down Expand Up @@ -85,7 +89,8 @@ class Config(object):


def _has_all_ddl(database):
return len(database.ddl_statements) == len(DDL_STATEMENTS)
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
return len(database.ddl_statements) == len(ddl_statements)


def _list_instances():
Expand Down Expand Up @@ -277,8 +282,9 @@ class TestDatabaseAPI(unittest.TestCase, _TestData):
@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "database_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
operation = cls._db.create()
operation.result(30) # raises on failure / timeout.
Expand Down Expand Up @@ -352,12 +358,13 @@ def test_update_database_ddl_with_operation_id(self):
temp_db = Config.INSTANCE.database(temp_db_id, pool=pool)
create_op = temp_db.create()
self.to_delete.append(temp_db)
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS

# We want to make sure the operation completes.
create_op.result(240) # raises on failure / timeout.
# random but shortish always start with letter
operation_id = "a" + str(uuid.uuid4())[:8]
operation = temp_db.update_ddl(DDL_STATEMENTS, operation_id=operation_id)
operation = temp_db.update_ddl(ddl_statements, operation_id=operation_id)

self.assertEqual(operation_id, operation.operation.name.split("/")[-1])

Expand All @@ -366,7 +373,7 @@ def test_update_database_ddl_with_operation_id(self):

temp_db.reload()

self.assertEqual(len(temp_db.ddl_statements), len(DDL_STATEMENTS))
self.assertEqual(len(temp_db.ddl_statements), len(ddl_statements))

def test_db_batch_insert_then_db_snapshot_read(self):
retry = RetryInstanceState(_has_all_ddl)
Expand Down Expand Up @@ -440,15 +447,17 @@ def _unit_of_work(transaction, name):


@unittest.skipIf(USE_EMULATOR, "Skipping backup tests")
@unittest.skipIf(SKIP_BACKUP_TESTS, "Skipping backup tests")
class TestBackupAPI(unittest.TestCase, _TestData):
DATABASE_NAME = "test_database" + unique_resource_id("_")
DATABASE_NAME_2 = "test_database2" + unique_resource_id("_")

@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "database_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
db1 = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
db2 = Config.INSTANCE.database(cls.DATABASE_NAME_2, pool=pool)
cls._db = db1
Expand Down Expand Up @@ -729,6 +738,8 @@ def test_list_backups(self):
OTHER_NAN, = struct.unpack("<d", b"\x01\x00\x01\x00\x00\x00\xf8\xff")
BYTES_1 = b"Ymlu"
BYTES_2 = b"Ym9vdHM="
NUMERIC_1 = decimal.Decimal("0.123456789")
NUMERIC_2 = decimal.Decimal("1234567890")
ALL_TYPES_TABLE = "all_types"
ALL_TYPES_COLUMNS = (
"pkey",
Expand All @@ -746,9 +757,14 @@ def test_list_backups(self):
"string_array",
"timestamp_value",
"timestamp_array",
"numeric_value",
"numeric_array",
)
EMULATOR_ALL_TYPES_COLUMNS = ALL_TYPES_COLUMNS[:-2]
AllTypesRowData = collections.namedtuple("AllTypesRowData", ALL_TYPES_COLUMNS)
AllTypesRowData.__new__.__defaults__ = tuple([None for colum in ALL_TYPES_COLUMNS])
EmulatorAllTypesRowData = collections.namedtuple("EmulatorAllTypesRowData", EMULATOR_ALL_TYPES_COLUMNS)
EmulatorAllTypesRowData.__new__.__defaults__ = tuple([None for colum in EMULATOR_ALL_TYPES_COLUMNS])

ALL_TYPES_ROWDATA = (
# all nulls
Expand All @@ -762,6 +778,7 @@ def test_list_backups(self):
AllTypesRowData(pkey=106, string_value=u"VALUE"),
AllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
AllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
AllTypesRowData(pkey=109, numeric_value=NUMERIC_1),
# empty array values
AllTypesRowData(pkey=201, int_array=[]),
AllTypesRowData(pkey=202, bool_array=[]),
Expand All @@ -770,6 +787,7 @@ def test_list_backups(self):
AllTypesRowData(pkey=205, float_array=[]),
AllTypesRowData(pkey=206, string_array=[]),
AllTypesRowData(pkey=207, timestamp_array=[]),
AllTypesRowData(pkey=208, numeric_array=[]),
# non-empty array values, including nulls
AllTypesRowData(pkey=301, int_array=[123, 456, None]),
AllTypesRowData(pkey=302, bool_array=[True, False, None]),
Expand All @@ -778,6 +796,36 @@ def test_list_backups(self):
AllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
AllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]),
)
EMULATOR_ALL_TYPES_ROWDATA = (
# all nulls
EmulatorAllTypesRowData(pkey=0),
# Non-null values
EmulatorAllTypesRowData(pkey=101, int_value=123),
EmulatorAllTypesRowData(pkey=102, bool_value=False),
EmulatorAllTypesRowData(pkey=103, bytes_value=BYTES_1),
EmulatorAllTypesRowData(pkey=104, date_value=SOME_DATE),
EmulatorAllTypesRowData(pkey=105, float_value=1.4142136),
EmulatorAllTypesRowData(pkey=106, string_value=u"VALUE"),
EmulatorAllTypesRowData(pkey=107, timestamp_value=SOME_TIME),
EmulatorAllTypesRowData(pkey=108, timestamp_value=NANO_TIME),
# empty array values
EmulatorAllTypesRowData(pkey=201, int_array=[]),
EmulatorAllTypesRowData(pkey=202, bool_array=[]),
EmulatorAllTypesRowData(pkey=203, bytes_array=[]),
EmulatorAllTypesRowData(pkey=204, date_array=[]),
EmulatorAllTypesRowData(pkey=205, float_array=[]),
EmulatorAllTypesRowData(pkey=206, string_array=[]),
EmulatorAllTypesRowData(pkey=207, timestamp_array=[]),
# non-empty array values, including nulls
EmulatorAllTypesRowData(pkey=301, int_array=[123, 456, None]),
EmulatorAllTypesRowData(pkey=302, bool_array=[True, False, None]),
EmulatorAllTypesRowData(pkey=303, bytes_array=[BYTES_1, BYTES_2, None]),
EmulatorAllTypesRowData(pkey=304, date_array=[SOME_DATE, None]),
EmulatorAllTypesRowData(pkey=305, float_array=[3.1415926, 2.71828, None]),
EmulatorAllTypesRowData(pkey=306, string_array=[u"One", u"Two", None]),
EmulatorAllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]),
)


Expand All @@ -787,8 +835,9 @@ class TestSessionAPI(unittest.TestCase, _TestData):
@classmethod
def setUpClass(cls):
pool = BurstyPool(labels={"testcase": "session_api"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
cls._db = Config.INSTANCE.database(
cls.DATABASE_NAME, ddl_statements=DDL_STATEMENTS, pool=pool
cls.DATABASE_NAME, ddl_statements=ddl_statements, pool=pool
)
operation = cls._db.create()
operation.result(30) # raises on failure / timeout.
Expand Down Expand Up @@ -850,13 +899,19 @@ def test_batch_insert_then_read_all_datatypes(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

if USE_EMULATOR:
all_types_columns = EMULATOR_ALL_TYPES_COLUMNS
all_types_rowdata = EMULATOR_ALL_TYPES_ROWDATA
else:
all_types_columns = ALL_TYPES_COLUMNS
all_types_rowdata = ALL_TYPES_ROWDATA
with self._db.batch() as batch:
batch.delete(ALL_TYPES_TABLE, self.ALL)
batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA)
batch.insert(ALL_TYPES_TABLE, all_types_columns, all_types_rowdata)

with self._db.snapshot(read_timestamp=batch.committed) as snapshot:
rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, self.ALL))
self._check_rows_data(rows, expected=ALL_TYPES_ROWDATA)
rows = list(snapshot.read(ALL_TYPES_TABLE, all_types_columns, self.ALL))
self._check_rows_data(rows, expected=all_types_rowdata)

def test_batch_insert_or_update_then_query(self):
retry = RetryInstanceState(_has_all_ddl)
Expand Down Expand Up @@ -1524,9 +1579,10 @@ def test_read_w_index(self):
MY_COLUMNS = self.COLUMNS[0], self.COLUMNS[2]
EXTRA_DDL = ["CREATE INDEX contacts_by_last_name ON contacts(last_name)"]
pool = BurstyPool(labels={"testcase": "read_w_index"})
ddl_statements = EMULATOR_DDL_STATEMENTS if USE_EMULATOR else DDL_STATEMENTS
temp_db = Config.INSTANCE.database(
"test_read" + unique_resource_id("_"),
ddl_statements=DDL_STATEMENTS + EXTRA_DDL,
ddl_statements=ddl_statements + EXTRA_DDL,
pool=pool,
)
operation = temp_db.create()
Expand Down Expand Up @@ -2102,6 +2158,10 @@ def test_execute_sql_w_date_bindings(self):
dates = [SOME_DATE, SOME_DATE + datetime.timedelta(days=1)]
self._bind_test_helper(DATE, SOME_DATE, dates)

@unittest.skipIf(USE_EMULATOR, "Skipping NUMERIC")
def test_execute_sql_w_numeric_bindings(self):
self._bind_test_helper(NUMERIC, NUMERIC_1, [NUMERIC_1, NUMERIC_2])

def test_execute_sql_w_query_param_struct(self):
NAME = "Phred"
COUNT = 123
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test__helpers.py
Expand Up @@ -208,6 +208,15 @@ def test_w_datetime(self):
self.assertIsInstance(value_pb, Value)
self.assertEqual(value_pb.string_value, datetime_helpers.to_rfc3339(now))

def test_w_numeric(self):
import decimal
from google.protobuf.struct_pb2 import Value

value = decimal.Decimal("9999999999999999999999999999.999999999")
value_pb = self._callFUT(value)
self.assertIsInstance(value_pb, Value)
self.assertEqual(value_pb.string_value, str(value))

def test_w_unknown_type(self):
with self.assertRaises(ValueError):
self._callFUT(object())
Expand Down Expand Up @@ -431,6 +440,17 @@ def test_w_struct(self):

self.assertEqual(self._callFUT(value_pb, field_type), VALUES)

def test_w_numeric(self):
import decimal
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1.proto.type_pb2 import Type, NUMERIC

VALUE = decimal.Decimal("99999999999999999999999999999.999999999")
larkee marked this conversation as resolved.
Show resolved Hide resolved
field_type = Type(code=NUMERIC)
value_pb = Value(string_value=str(VALUE))

self.assertEqual(self._callFUT(value_pb, field_type), VALUE)

def test_w_unknown_type(self):
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1.proto.type_pb2 import Type
Expand Down