Skip to content

Commit

Permalink
feat: add decimal/numeric support (#620)
Browse files Browse the repository at this point in the history
* fix: lint_setup_py was failing in Kokoro is not fixed

* feat: add decimal/numeric support

* fix: remove validation for decimal field not supported

* feat: updated decimal support error message in spanner to match error thrown by python spanner decimal/numeric validation

* fix: removed test_validation as decimal support is now added so validation is not required

* fix: Remove system tests. They will be added separately.

* fix: fixed tests related to decimal conversion in db operations

* fix: fixed tests related to decimal conversion in db operations

* refactor: lint corrections in test_operations file

* fix: corrected coverage number, lowered it t 65

* refactor: lint issues fixed in noxfile and import moved up to module level in test_lookups
  • Loading branch information
vi3k6i5 committed May 19, 2021
1 parent 92ad508 commit d09ad61
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 140 deletions.
4 changes: 1 addition & 3 deletions django_spanner/base.py
Expand Up @@ -17,7 +17,6 @@
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .schema import DatabaseSchemaEditor
from .validation import DatabaseValidation


class DatabaseWrapper(BaseDatabaseWrapper):
Expand All @@ -34,7 +33,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"CharField": "STRING(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "FLOAT64",
"DecimalField": "NUMERIC",
"DurationField": "INT64",
"EmailField": "STRING(%(max_length)s)",
"FileField": "STRING(%(max_length)s)",
Expand Down Expand Up @@ -104,7 +103,6 @@ class DatabaseWrapper(BaseDatabaseWrapper):
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations
client_class = DatabaseClient
validation_class = DatabaseValidation

@property
def instance(self):
Expand Down
19 changes: 15 additions & 4 deletions django_spanner/features.py
Expand Up @@ -233,10 +233,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"queries.test_bulk_update.BulkUpdateTests.test_large_batch",
# Spanner doesn't support random ordering.
"ordering.tests.OrderingTests.test_random_ordering",
# No matching signature for function MOD for argument types: FLOAT64,
# FLOAT64. Supported signatures: MOD(INT64, INT64)
"db_functions.math.test_mod.ModTests.test_decimal",
"db_functions.math.test_mod.ModTests.test_float",
# casting DateField to DateTimeField adds an unexpected hour:
# https://github.com/googleapis/python-spanner-django/issues/260
"db_functions.comparison.test_cast.CastTests.test_cast_from_db_date_to_datetime",
Expand Down Expand Up @@ -364,6 +360,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"model_formsets.tests.ModelFormsetTest.test_prevent_change_outer_model_and_create_invalid_data",
"model_formsets_regress.tests.FormfieldShouldDeleteFormTests.test_no_delete",
"model_formsets_regress.tests.FormsetTests.test_extraneous_query_is_not_run",
# Numeric field is not supported in primary key/unique key.
"model_formsets.tests.ModelFormsetTest.test_inline_formsets_with_custom_pk",
"model_forms.tests.ModelFormBaseTest.test_exclude_and_validation",
"model_forms.tests.UniqueTest.test_unique_together",
"model_forms.tests.UniqueTest.test_override_unique_together_message",
# os.chmod() doesn't work on Kokoro?
"file_uploads.tests.DirectoryCreationTests.test_readonly_root",
# Tests that sometimes fail on Kokoro for unknown reasons.
Expand Down Expand Up @@ -1026,12 +1027,20 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"db_functions.math.test_ceil.CeilTests.test_null", # noqa
"db_functions.math.test_ceil.CeilTests.test_transform", # noqa
"db_functions.math.test_cos.CosTests.test_null", # noqa
"db_functions.math.test_cos.CosTests.test_transform", # noqa
"db_functions.math.test_cot.CotTests.test_null", # noqa
"db_functions.math.test_degrees.DegreesTests.test_decimal", # noqa
"db_functions.math.test_degrees.DegreesTests.test_null", # noqa
"db_functions.math.test_exp.ExpTests.test_decimal", # noqa
"db_functions.math.test_exp.ExpTests.test_null", # noqa
"db_functions.math.test_exp.ExpTests.test_transform", # noqa
"db_functions.math.test_floor.FloorTests.test_null", # noqa
"db_functions.math.test_ln.LnTests.test_decimal", # noqa
"db_functions.math.test_ln.LnTests.test_null", # noqa
"db_functions.math.test_ln.LnTests.test_transform", # noqa
"db_functions.math.test_log.LogTests.test_decimal", # noqa
"db_functions.math.test_log.LogTests.test_null", # noqa
"db_functions.math.test_mod.ModTests.test_float", # noqa
"db_functions.math.test_mod.ModTests.test_null", # noqa
"db_functions.math.test_power.PowerTests.test_decimal", # noqa
"db_functions.math.test_power.PowerTests.test_float", # noqa
Expand All @@ -1040,7 +1049,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"db_functions.math.test_radians.RadiansTests.test_null", # noqa
"db_functions.math.test_round.RoundTests.test_null", # noqa
"db_functions.math.test_sin.SinTests.test_null", # noqa
"db_functions.math.test_sqrt.SqrtTests.test_decimal", # noqa
"db_functions.math.test_sqrt.SqrtTests.test_null", # noqa
"db_functions.math.test_sqrt.SqrtTests.test_transform", # noqa
"db_functions.math.test_tan.TanTests.test_null", # noqa
"db_functions.tests.FunctionTests.test_func_transform_bilateral", # noqa
"db_functions.tests.FunctionTests.test_func_transform_bilateral_multivalue", # noqa
Expand Down
1 change: 1 addition & 0 deletions django_spanner/introspection.py
Expand Up @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
TypeCode.INT64: "IntegerField",
TypeCode.STRING: "CharField",
TypeCode.TIMESTAMP: "DateTimeField",
TypeCode.NUMERIC: "DecimalField",
}

def get_field_type(self, data_type, description):
Expand Down
8 changes: 1 addition & 7 deletions django_spanner/lookups.py
Expand Up @@ -4,7 +4,6 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from django.db.models import DecimalField
from django.db.models.lookups import (
Contains,
EndsWith,
Expand Down Expand Up @@ -233,13 +232,8 @@ def cast_param_to_float(self, compiler, connection):
"""
sql, params = self.as_sql(compiler, connection)
if params:
# Cast to DecimaField lookup values to float because
# google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize
# decimal.Decimal.
if isinstance(self.lhs.output_field, DecimalField):
params[0] = float(params[0])
# Cast remote field lookups that must be integer but come in as string.
elif hasattr(self.lhs.output_field, "get_path_info"):
if hasattr(self.lhs.output_field, "get_path_info"):
for i, field in enumerate(
self.lhs.output_field.get_path_info()[-1].target_fields
):
Expand Down
38 changes: 7 additions & 31 deletions django_spanner/operations.py
Expand Up @@ -8,7 +8,6 @@
import re
from base64 import b64decode
from datetime import datetime, time
from decimal import Decimal
from uuid import UUID

from django.conf import settings
Expand Down Expand Up @@ -190,10 +189,11 @@ def adapt_decimalfield_value(
self, value, max_digits=None, decimal_places=None
):
"""
Convert value from decimal.Decimal into float, for a direct mapping
and correct serialization with RPCs to Cloud Spanner.
Convert value from decimal.Decimal to spanner compatible value.
Since spanner supports Numeric storage of decimal and python spanner
takes care of the conversion so this is a no-op method call.
:type value: :class:`~google.cloud.spanner_v1.types.Numeric`
:type value: :class:`decimal.Decimal`
:param value: A decimal field value.
:type max_digits: int
Expand All @@ -203,12 +203,10 @@ def adapt_decimalfield_value(
:param decimal_places: (Optional) The number of decimal places to store
with the number.
:rtype: float
:returns: Formatted value.
:rtype: decimal.Decimal
:returns: decimal value.
"""
if value is None:
return None
return float(value)
return value

def adapt_timefield_value(self, value):
"""
Expand Down Expand Up @@ -244,8 +242,6 @@ def get_db_converters(self, expression):
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
elif internal_type == "DecimalField":
converters.append(self.convert_decimalfield_value)
elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
elif internal_type == "BinaryField":
Expand Down Expand Up @@ -311,26 +307,6 @@ def convert_datetimefield_value(self, value, expression, connection):
else dt
)

def convert_decimalfield_value(self, value, expression, connection):
"""Convert Spanner DecimalField value for Django.
:type value: float
:param value: A decimal field.
:type expression: :class:`django.db.models.expressions.BaseExpression`
:param expression: A query expression.
:type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection`
:param connection: Reference to a Spanner database connection.
:rtype: :class:`Decimal`
:returns: A converted decimal field.
"""
if value is None:
return value
# Cloud Spanner returns a float.
return Decimal(str(value))

def convert_timefield_value(self, value, expression, connection):
"""Convert Spanner TimeField value for Django.
Expand Down
33 changes: 0 additions & 33 deletions django_spanner/validation.py

This file was deleted.

2 changes: 1 addition & 1 deletion noxfile.py
Expand Up @@ -84,7 +84,7 @@ def default(session):
"--cov-append",
"--cov-config=.coveragerc",
"--cov-report=",
"--cov-fail-under=68",
"--cov-fail-under=65",
os.path.join("tests", "unit"),
*session.posargs
)
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/django_spanner/test_lookups.py
Expand Up @@ -7,21 +7,24 @@
from django_spanner.compiler import SQLCompiler
from django.db.models import F
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass
from decimal import Decimal
from .models import Number, Author


class TestLookups(SpannerSimpleTestClass):
def test_cast_param_to_float_lte_sql_query(self):

qs1 = Number.objects.filter(decimal_num__lte=1.1).values("decimal_num")
qs1 = Number.objects.filter(decimal_num__lte=Decimal("1.1")).values(
"decimal_num"
)
compiler = SQLCompiler(qs1.query, self.connection, "default")
sql_compiled, params = compiler.as_sql()
self.assertEqual(
sql_compiled,
"SELECT tests_number.decimal_num FROM tests_number WHERE "
+ "tests_number.decimal_num <= %s",
)
self.assertEqual(params, (1.1,))
self.assertEqual(params, (Decimal("1.1"),))

def test_cast_param_to_float_for_int_field_query(self):

Expand Down
21 changes: 3 additions & 18 deletions tests/unit/django_spanner/test_operations.py
Expand Up @@ -7,6 +7,7 @@
from django.db.utils import DatabaseError
from datetime import timedelta
from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass
from decimal import Decimal


class TestOperations(SpannerSimpleTestClass):
Expand Down Expand Up @@ -58,7 +59,8 @@ def test_adapt_datefield_value_none(self):

def test_adapt_decimalfield_value(self):
self.assertIsInstance(
self.db_operations.adapt_decimalfield_value(value=1), float,
self.db_operations.adapt_decimalfield_value(value=Decimal("1")),
Decimal,
)

def test_adapt_decimalfield_value_none(self):
Expand Down Expand Up @@ -93,23 +95,6 @@ def test_adapt_timefield_value_none(self):
self.db_operations.adapt_timefield_value(value=None),
)

def test_convert_decimalfield_value(self):
from decimal import Decimal

self.assertIsInstance(
self.db_operations.convert_decimalfield_value(
value=1.0, expression=None, connection=None
),
Decimal,
)

def test_convert_decimalfield_value_none(self):
self.assertIsNone(
self.db_operations.convert_decimalfield_value(
value=None, expression=None, connection=None
),
)

def test_convert_uuidfield_value(self):
import uuid

Expand Down
41 changes: 0 additions & 41 deletions tests/unit/django_spanner/test_validation.py

This file was deleted.

0 comments on commit d09ad61

Please sign in to comment.