diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index a9ae36d0d6..1385809162 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -30,6 +30,16 @@ from google.cloud.spanner_v1 import ExecuteSqlRequest +# Validation error messages +NUMERIC_MAX_SCALE_ERR_MSG = ( + "Max scale for a numeric is 9. The requested numeric has scale {}" +) +NUMERIC_MAX_PRECISION_ERR_MSG = ( + "Max precision for the whole component of a numeric is 29. The requested " + + "numeric has a whole component with precision {}" +) + + def _try_to_coerce_bytes(bytestring): """Try to coerce a byte string into the right thing based on Python version and whether or not it is base64 encoded. @@ -87,6 +97,28 @@ def _merge_query_options(base, merge): return combined +def _assert_numeric_precision_and_scale(value): + """ + Asserts that input numeric field is within Spanner supported range. + + Spanner supports fixed 38 digits of precision and 9 digits of scale. + This number can be optionally prefixed with a plus or minus sign. + Read more here: https://cloud.google.com/spanner/docs/data-types#numeric_type + + :type value: decimal.Decimal + :param value: The value to check for Cloud Spanner compatibility. + + :raises NotSupportedError: If value is not within supported precision or scale of Spanner. + """ + scale = value.as_tuple().exponent + precision = len(value.as_tuple().digits) + + if scale < -9: + raise ValueError(NUMERIC_MAX_SCALE_ERR_MSG.format(abs(scale))) + if precision + scale > 29: + raise ValueError(NUMERIC_MAX_PRECISION_ERR_MSG.format(precision + scale)) + + # pylint: disable=too-many-return-statements,too-many-branches def _make_value_pb(value): """Helper for :func:`_make_list_value_pbs`. @@ -129,6 +161,7 @@ def _make_value_pb(value): if isinstance(value, ListValue): return Value(list_value=value) if isinstance(value, decimal.Decimal): + _assert_numeric_precision_and_scale(value) return Value(string_value=str(value)) raise ValueError("Unknown type: %s" % (value,)) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 73277a7de3..11239d730e 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -254,8 +254,6 @@ def test_rows_for_insert_or_update(self): @unittest.skipIf(skip_condition, skip_message) def test_sql_pyformat_args_to_spanner(self): - import decimal - from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner cases = [ @@ -300,16 +298,6 @@ def test_sql_pyformat_args_to_spanner(self): ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), ), - ( - ( - "SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s", - (1, 1.0, decimal.Decimal("31")), - ), - ( - "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", - {"a0": 1, "a1": 1.0, "a2": decimal.Decimal("31")}, - ), - ), ] for ((sql_in, params), sql_want) in cases: with self.subTest(sql=sql_in): diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index fecf2581de..305a6ce7c3 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -233,6 +233,65 @@ def test_w_unknown_type(self): with self.assertRaises(ValueError): self._callFUT(object()) + def test_w_numeric_precision_and_scale_valid(self): + import decimal + from google.protobuf.struct_pb2 import Value + + cases = [ + decimal.Decimal("42"), + decimal.Decimal("9.9999999999999999999999999999999999999E+28"), + decimal.Decimal("-9.9999999999999999999999999999999999999E+28"), + decimal.Decimal("99999999999999999999999999999.999999999"), + decimal.Decimal("1E+28"), + decimal.Decimal("1E-9"), + ] + for value in cases: + with self.subTest(value=value): + value_pb = self._callFUT(value) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, str(value)) + + def test_w_numeric_precision_and_scale_invalid(self): + import decimal + from google.cloud.spanner_v1._helpers import ( + NUMERIC_MAX_SCALE_ERR_MSG, + NUMERIC_MAX_PRECISION_ERR_MSG, + ) + + max_precision_error_msg = NUMERIC_MAX_PRECISION_ERR_MSG.format("30") + max_scale_error_msg = NUMERIC_MAX_SCALE_ERR_MSG.format("10") + + cases = [ + ( + decimal.Decimal("9.9999999999999999999999999999999999999E+29"), + max_precision_error_msg, + ), + ( + decimal.Decimal("-9.9999999999999999999999999999999999999E+29"), + max_precision_error_msg, + ), + ( + decimal.Decimal("999999999999999999999999999999.99999999"), + max_precision_error_msg, + ), + ( + decimal.Decimal("-999999999999999999999999999999.99999999"), + max_precision_error_msg, + ), + ( + decimal.Decimal("999999999999999999999999999999"), + max_precision_error_msg, + ), + (decimal.Decimal("1E+29"), max_precision_error_msg), + (decimal.Decimal("1E-10"), max_scale_error_msg), + ] + + for value, err_msg in cases: + with self.subTest(value=value, err_msg=err_msg): + self.assertRaisesRegex( + ValueError, err_msg, lambda: self._callFUT(value), + ) + class Test_make_list_value_pb(unittest.TestCase): def _callFUT(self, *args, **kw):