Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add decimal validation for numeric precision and scale supported by Spanner #340

Merged
merged 13 commits into from May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 36 additions & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -22,7 +22,7 @@
import sqlparse
from google.cloud import spanner_v1 as spanner

from .exceptions import Error, ProgrammingError
from .exceptions import Error, ProgrammingError, NotSupportedError
from .parser import parse_values
from .types import DateStr, TimestampStr
from .utils import sanitize_literals_for_upload
Expand Down Expand Up @@ -144,6 +144,15 @@
STMT_UPDATING = "UPDATING"
STMT_INSERT = "INSERT"

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
"Max scale for a numeric is 9. The requested numeric has scale {}"
)
NUMERIC_MAX_PRECISION_ERR_MSG = (
"Max precision for the whole component of a numeric is 29. The requested "
+ "numeric has a whole component with precision {}"
)

# Heuristic for identifying statements that don't need to be run as updates.
RE_NON_UPDATE = re.compile(r"^\s*(SELECT)", re.IGNORECASE)

Expand Down Expand Up @@ -509,11 +518,37 @@ def sql_pyformat_args_to_spanner(sql, params):
resolved_value = pyfmt % params
named_args[key] = resolved_value
else:
assert_numeric_precision_and_scale(params[i])
named_args[key] = params[i]

return sanitize_literals_for_upload(sql), named_args


def assert_numeric_precision_and_scale(value):
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved
"""
Spanner supports fixed 38 digits of precision and 9 digits of scale.
This number can be optionally prefixed with a plus or minus sign.
Read more here: https://cloud.google.com/spanner/docs/data-types#numeric_type

Asserts that input numeric field is within spanner supported range.
:type value: Any
:param value: The value to check for Cloud Spanner compatibility.

:raises NotSupportedError: if value is not within supporteed precision or scale of spanner.
"""
if isinstance(value, decimal.Decimal):

scale = value.as_tuple().exponent
precision = len(value.as_tuple().digits)

if scale < -9:
raise NotSupportedError(NUMERIC_MAX_SCALE_ERR_MSG.format(abs(scale)))
if precision + scale > 29:
raise NotSupportedError(
NUMERIC_MAX_PRECISION_ERR_MSG.format(precision + scale)
)


def get_param_types(params):
"""Determine Cloud Spanner types for the given parameters.

Expand Down
83 changes: 71 additions & 12 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -254,8 +254,6 @@ def test_rows_for_insert_or_update(self):

@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner(self):
import decimal

from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner

cases = [
Expand Down Expand Up @@ -300,16 +298,6 @@ def test_sql_pyformat_args_to_spanner(self):
("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}),
("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}),
),
(
(
"SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s",
(1, 1.0, decimal.Decimal("31")),
),
(
"SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2",
{"a0": 1, "a1": 1.0, "a2": decimal.Decimal("31")},
),
),
]
for ((sql_in, params), sql_want) in cases:
with self.subTest(sql=sql_in):
Expand Down Expand Up @@ -339,6 +327,77 @@ def test_sql_pyformat_args_to_spanner_invalid(self):
lambda: sql_pyformat_args_to_spanner(sql, params),
)

@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner_for_valid_decimal(self):
import decimal

from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner

sql_in = "SELECT an.p AS np FROM an WHERE an.p >= %s"
sql_want = "SELECT an.p AS np FROM an WHERE an.p >= @a0"

cases = [
decimal.Decimal("42"),
decimal.Decimal("9.9999999999999999999999999999999999999E+28"),
decimal.Decimal("-9.9999999999999999999999999999999999999E+28"),
decimal.Decimal("99999999999999999999999999999.999999999"),
decimal.Decimal("1E+28"),
decimal.Decimal("1E-9"),
]
for param in cases:
with self.subTest(sql=sql_in, sql_want=sql_want, param=param):
got_sql, got_named_args = sql_pyformat_args_to_spanner(sql_in, [param])
want_named_args = {"a0": param}
self.assertEqual(got_sql, sql_want, "SQL does not match")
self.assertEqual(
got_named_args, want_named_args, "Named args do not match"
)

@unittest.skipIf(skip_condition, skip_message)
def test_assert_numeric_precision_and_scale_invalid(self):
import decimal
from google.cloud.spanner_dbapi import exceptions
from google.cloud.spanner_dbapi.parse_utils import (
assert_numeric_precision_and_scale,
NUMERIC_MAX_SCALE_ERR_MSG,
NUMERIC_MAX_PRECISION_ERR_MSG,
)

max_precision_error_msg = NUMERIC_MAX_PRECISION_ERR_MSG.format("30")
max_scale_error_msg = NUMERIC_MAX_SCALE_ERR_MSG.format("10")

cases = [
(
decimal.Decimal("9.9999999999999999999999999999999999999E+29"),
max_precision_error_msg,
),
(
decimal.Decimal("-9.9999999999999999999999999999999999999E+29"),
max_precision_error_msg,
),
(
decimal.Decimal("999999999999999999999999999999.99999999"),
max_precision_error_msg,
),
(
decimal.Decimal("-999999999999999999999999999999.99999999"),
max_precision_error_msg,
),
(
decimal.Decimal("999999999999999999999999999999"),
max_precision_error_msg,
),
(decimal.Decimal("1E+29"), max_precision_error_msg),
(decimal.Decimal("1E-10"), max_scale_error_msg),
]
for param, err_msg in cases:
with self.subTest(param=param, err_msg=err_msg):
self.assertRaisesRegex(
exceptions.NotSupportedError,
err_msg,
lambda: assert_numeric_precision_and_scale(param),
)

@unittest.skipIf(skip_condition, skip_message)
def test_get_param_types(self):
import datetime
Expand Down