Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dbapi): remove string conversion for numeric fields #317

Merged
merged 9 commits into from Apr 26, 2021
16 changes: 1 addition & 15 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -509,25 +509,11 @@ def sql_pyformat_args_to_spanner(sql, params):
resolved_value = pyfmt % params
named_args[key] = resolved_value
else:
named_args[key] = cast_for_spanner(params[i])
named_args[key] = params[i]

return sanitize_literals_for_upload(sql), named_args


def cast_for_spanner(value):
"""Convert the param to its Cloud Spanner equivalent type.

:type value: Any
:param value: The value to convert to a Cloud Spanner type.

:rtype: Any
:returns: The value converted to a Cloud Spanner type.
"""
if isinstance(value, decimal.Decimal):
return str(value)
return value


def get_param_types(params):
"""Determine Cloud Spanner types for the given parameters.

Expand Down
13 changes: 1 addition & 12 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -307,7 +307,7 @@ def test_sql_pyformat_args_to_spanner(self):
),
(
"SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2",
{"a0": 1, "a1": 1.0, "a2": str(31)},
{"a0": 1, "a1": 1.0, "a2": decimal.Decimal("31")},
),
),
]
Expand Down Expand Up @@ -339,17 +339,6 @@ def test_sql_pyformat_args_to_spanner_invalid(self):
lambda: sql_pyformat_args_to_spanner(sql, params),
)

def test_cast_for_spanner(self):
import decimal

from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner

dec = 3
value = decimal.Decimal(dec)
self.assertEqual(cast_for_spanner(value), str(dec))
self.assertEqual(cast_for_spanner(5), 5)
self.assertEqual(cast_for_spanner("string"), "string")

@unittest.skipIf(skip_condition, skip_message)
def test_get_param_types(self):
import datetime
Expand Down