Skip to content

Commit

Permalink
fix: remove DB-API dependency on pyarrow with decimal query parameters (
Browse files Browse the repository at this point in the history
#551)

* fix: DB API pyarrow dependency with decimal values

DB API should gracefully handle the case when the optional pyarrow
dependency is not installed.

* Blacken DB API helpers tests

* Refine the logic for recognizing NUMERIC Decimals
  • Loading branch information
plamut committed Mar 16, 2021
1 parent a460f93 commit 1b946ba
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 29 deletions.
23 changes: 15 additions & 8 deletions google/cloud/bigquery/dbapi/_helpers.py
Expand Up @@ -19,16 +19,15 @@
import functools
import numbers

try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None

from google.cloud import bigquery
from google.cloud.bigquery import table
from google.cloud.bigquery.dbapi import exceptions


_NUMERIC_SERVER_MIN = decimal.Decimal("-9.9999999999999999999999999999999999999E+28")
_NUMERIC_SERVER_MAX = decimal.Decimal("9.9999999999999999999999999999999999999E+28")


def scalar_to_query_parameter(value, name=None):
"""Convert a scalar value into a query parameter.
Expand Down Expand Up @@ -189,12 +188,20 @@ def bigquery_scalar_type(value):
elif isinstance(value, numbers.Real):
return "FLOAT64"
elif isinstance(value, decimal.Decimal):
# We check for NUMERIC before BIGNUMERIC in order to support pyarrow < 3.0.
scalar_object = pyarrow.scalar(value)
if isinstance(scalar_object, pyarrow.Decimal128Scalar):
vtuple = value.as_tuple()
# NUMERIC values have precision of 38 (number of digits) and scale of 9 (number
# of fractional digits), and their max absolute value must be strictly smaller
# than 1.0E+29.
# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#decimal_types
if (
len(vtuple.digits) <= 38 # max precision: 38
and vtuple.exponent >= -9 # max scale: 9
and _NUMERIC_SERVER_MIN <= value <= _NUMERIC_SERVER_MAX
):
return "NUMERIC"
else:
return "BIGNUMERIC"

elif isinstance(value, str):
return "STRING"
elif isinstance(value, bytes):
Expand Down
66 changes: 45 additions & 21 deletions tests/unit/test_dbapi__helpers.py
Expand Up @@ -25,7 +25,6 @@

import google.cloud._helpers
from google.cloud.bigquery import table
from google.cloud.bigquery._pandas_helpers import _BIGNUMERIC_SUPPORT
from google.cloud.bigquery.dbapi import _helpers
from google.cloud.bigquery.dbapi import exceptions
from tests.unit.helpers import _to_pyarrow
Expand All @@ -39,9 +38,8 @@ def test_scalar_to_query_parameter(self):
(123, "INT64"),
(-123456789, "INT64"),
(1.25, "FLOAT64"),
(decimal.Decimal("1.25"), "NUMERIC"),
(b"I am some bytes", "BYTES"),
(u"I am a string", "STRING"),
("I am a string", "STRING"),
(datetime.date(2017, 4, 1), "DATE"),
(datetime.time(12, 34, 56), "TIME"),
(datetime.datetime(2012, 3, 4, 5, 6, 7), "DATETIME"),
Expand All @@ -51,14 +49,17 @@ def test_scalar_to_query_parameter(self):
),
"TIMESTAMP",
),
(decimal.Decimal("1.25"), "NUMERIC"),
(decimal.Decimal("9.9999999999999999999999999999999999999E+28"), "NUMERIC"),
(decimal.Decimal("1.0E+29"), "BIGNUMERIC"), # more than max NUMERIC value
(decimal.Decimal("1.123456789"), "NUMERIC"),
(decimal.Decimal("1.1234567891"), "BIGNUMERIC"), # scale > 9
(decimal.Decimal("12345678901234567890123456789.012345678"), "NUMERIC"),
(
decimal.Decimal("12345678901234567890123456789012345678"),
"BIGNUMERIC", # larger than max NUMERIC value, despite precision <=38
),
]
if _BIGNUMERIC_SUPPORT:
expected_types.append(
(
decimal.Decimal("1.1234567890123456789012345678901234567890"),
"BIGNUMERIC",
)
)

for value, expected_type in expected_types:
msg = "value: {} expected_type: {}".format(value, expected_type)
Expand All @@ -71,6 +72,33 @@ def test_scalar_to_query_parameter(self):
self.assertEqual(named_parameter.type_, expected_type, msg=msg)
self.assertEqual(named_parameter.value, value, msg=msg)

def test_decimal_to_query_parameter(self): # TODO: merge with previous test

expected_types = [
(decimal.Decimal("9.9999999999999999999999999999999999999E+28"), "NUMERIC"),
(decimal.Decimal("1.0E+29"), "BIGNUMERIC"), # more than max value
(decimal.Decimal("1.123456789"), "NUMERIC"),
(decimal.Decimal("1.1234567891"), "BIGNUMERIC"), # scale > 9
(decimal.Decimal("12345678901234567890123456789.012345678"), "NUMERIC"),
(
decimal.Decimal("12345678901234567890123456789012345678"),
"BIGNUMERIC", # larger than max size, even if precision <=38
),
]

for value, expected_type in expected_types:
msg = f"value: {value} expected_type: {expected_type}"

parameter = _helpers.scalar_to_query_parameter(value)
self.assertIsNone(parameter.name, msg=msg)
self.assertEqual(parameter.type_, expected_type, msg=msg)
self.assertEqual(parameter.value, value, msg=msg)

named_parameter = _helpers.scalar_to_query_parameter(value, name="myvar")
self.assertEqual(named_parameter.name, "myvar", msg=msg)
self.assertEqual(named_parameter.type_, expected_type, msg=msg)
self.assertEqual(named_parameter.value, value, msg=msg)

def test_scalar_to_query_parameter_w_unexpected_type(self):
with self.assertRaises(exceptions.ProgrammingError):
_helpers.scalar_to_query_parameter(value={"a": "dictionary"})
Expand All @@ -89,8 +117,9 @@ def test_array_to_query_parameter_valid_argument(self):
([123, -456, 0], "INT64"),
([1.25, 2.50], "FLOAT64"),
([decimal.Decimal("1.25")], "NUMERIC"),
([decimal.Decimal("{d38}.{d38}".format(d38="9" * 38))], "BIGNUMERIC"),
([b"foo", b"bar"], "BYTES"),
([u"foo", u"bar"], "STRING"),
(["foo", "bar"], "STRING"),
([datetime.date(2017, 4, 1), datetime.date(2018, 4, 1)], "DATE"),
([datetime.time(12, 34, 56), datetime.time(10, 20, 30)], "TIME"),
(
Expand All @@ -113,11 +142,6 @@ def test_array_to_query_parameter_valid_argument(self):
),
]

if _BIGNUMERIC_SUPPORT:
expected_types.append(
([decimal.Decimal("{d38}.{d38}".format(d38="9" * 38))], "BIGNUMERIC")
)

for values, expected_type in expected_types:
msg = "value: {} expected_type: {}".format(values, expected_type)
parameter = _helpers.array_to_query_parameter(values)
Expand All @@ -134,7 +158,7 @@ def test_array_to_query_parameter_empty_argument(self):
_helpers.array_to_query_parameter([])

def test_array_to_query_parameter_unsupported_sequence(self):
unsupported_iterables = [{10, 20, 30}, u"foo", b"bar", bytearray([65, 75, 85])]
unsupported_iterables = [{10, 20, 30}, "foo", b"bar", bytearray([65, 75, 85])]
for iterable in unsupported_iterables:
with self.assertRaises(exceptions.ProgrammingError):
_helpers.array_to_query_parameter(iterable)
Expand All @@ -144,7 +168,7 @@ def test_array_to_query_parameter_sequence_w_invalid_elements(self):
_helpers.array_to_query_parameter([object(), 2, 7])

def test_to_query_parameters_w_dict(self):
parameters = {"somebool": True, "somestring": u"a-string-value"}
parameters = {"somebool": True, "somestring": "a-string-value"}
query_parameters = _helpers.to_query_parameters(parameters)
query_parameter_tuples = []
for param in query_parameters:
Expand All @@ -154,7 +178,7 @@ def test_to_query_parameters_w_dict(self):
sorted(
[
("somebool", "BOOL", True),
("somestring", "STRING", u"a-string-value"),
("somestring", "STRING", "a-string-value"),
]
),
)
Expand All @@ -177,14 +201,14 @@ def test_to_query_parameters_w_dict_dict_param(self):
_helpers.to_query_parameters(parameters)

def test_to_query_parameters_w_list(self):
parameters = [True, u"a-string-value"]
parameters = [True, "a-string-value"]
query_parameters = _helpers.to_query_parameters(parameters)
query_parameter_tuples = []
for param in query_parameters:
query_parameter_tuples.append((param.name, param.type_, param.value))
self.assertSequenceEqual(
sorted(query_parameter_tuples),
sorted([(None, "BOOL", True), (None, "STRING", u"a-string-value")]),
sorted([(None, "BOOL", True), (None, "STRING", "a-string-value")]),
)

def test_to_query_parameters_w_list_array_param(self):
Expand Down

0 comments on commit 1b946ba

Please sign in to comment.