diff --git a/django_spanner/base.py b/django_spanner/base.py index 9b0824a25c..4a4b86ff7d 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -17,7 +17,6 @@ from .introspection import DatabaseIntrospection from .operations import DatabaseOperations from .schema import DatabaseSchemaEditor -from .validation import DatabaseValidation class DatabaseWrapper(BaseDatabaseWrapper): @@ -34,7 +33,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "CharField": "STRING(%(max_length)s)", "DateField": "DATE", "DateTimeField": "TIMESTAMP", - "DecimalField": "FLOAT64", + "DecimalField": "NUMERIC", "DurationField": "INT64", "EmailField": "STRING(%(max_length)s)", "FileField": "STRING(%(max_length)s)", @@ -104,7 +103,6 @@ class DatabaseWrapper(BaseDatabaseWrapper): introspection_class = DatabaseIntrospection ops_class = DatabaseOperations client_class = DatabaseClient - validation_class = DatabaseValidation @property def instance(self): diff --git a/django_spanner/features.py b/django_spanner/features.py index 34fc258159..06d4bb03f2 100644 --- a/django_spanner/features.py +++ b/django_spanner/features.py @@ -233,10 +233,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "queries.test_bulk_update.BulkUpdateTests.test_large_batch", # Spanner doesn't support random ordering. "ordering.tests.OrderingTests.test_random_ordering", - # No matching signature for function MOD for argument types: FLOAT64, - # FLOAT64. Supported signatures: MOD(INT64, INT64) - "db_functions.math.test_mod.ModTests.test_decimal", - "db_functions.math.test_mod.ModTests.test_float", # casting DateField to DateTimeField adds an unexpected hour: # https://github.com/googleapis/python-spanner-django/issues/260 "db_functions.comparison.test_cast.CastTests.test_cast_from_db_date_to_datetime", @@ -364,6 +360,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): "model_formsets.tests.ModelFormsetTest.test_prevent_change_outer_model_and_create_invalid_data", "model_formsets_regress.tests.FormfieldShouldDeleteFormTests.test_no_delete", "model_formsets_regress.tests.FormsetTests.test_extraneous_query_is_not_run", + # Numeric field is not supported in primary key/unique key. + "model_formsets.tests.ModelFormsetTest.test_inline_formsets_with_custom_pk", + "model_forms.tests.ModelFormBaseTest.test_exclude_and_validation", + "model_forms.tests.UniqueTest.test_unique_together", + "model_forms.tests.UniqueTest.test_override_unique_together_message", # os.chmod() doesn't work on Kokoro? "file_uploads.tests.DirectoryCreationTests.test_readonly_root", # Tests that sometimes fail on Kokoro for unknown reasons. @@ -1026,12 +1027,20 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.math.test_ceil.CeilTests.test_null", # noqa "db_functions.math.test_ceil.CeilTests.test_transform", # noqa "db_functions.math.test_cos.CosTests.test_null", # noqa + "db_functions.math.test_cos.CosTests.test_transform", # noqa "db_functions.math.test_cot.CotTests.test_null", # noqa + "db_functions.math.test_degrees.DegreesTests.test_decimal", # noqa "db_functions.math.test_degrees.DegreesTests.test_null", # noqa + "db_functions.math.test_exp.ExpTests.test_decimal", # noqa "db_functions.math.test_exp.ExpTests.test_null", # noqa + "db_functions.math.test_exp.ExpTests.test_transform", # noqa "db_functions.math.test_floor.FloorTests.test_null", # noqa + "db_functions.math.test_ln.LnTests.test_decimal", # noqa "db_functions.math.test_ln.LnTests.test_null", # noqa + "db_functions.math.test_ln.LnTests.test_transform", # noqa + "db_functions.math.test_log.LogTests.test_decimal", # noqa "db_functions.math.test_log.LogTests.test_null", # noqa + "db_functions.math.test_mod.ModTests.test_float", # noqa "db_functions.math.test_mod.ModTests.test_null", # noqa "db_functions.math.test_power.PowerTests.test_decimal", # noqa "db_functions.math.test_power.PowerTests.test_float", # noqa @@ -1040,7 +1049,9 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.math.test_radians.RadiansTests.test_null", # noqa "db_functions.math.test_round.RoundTests.test_null", # noqa "db_functions.math.test_sin.SinTests.test_null", # noqa + "db_functions.math.test_sqrt.SqrtTests.test_decimal", # noqa "db_functions.math.test_sqrt.SqrtTests.test_null", # noqa + "db_functions.math.test_sqrt.SqrtTests.test_transform", # noqa "db_functions.math.test_tan.TanTests.test_null", # noqa "db_functions.tests.FunctionTests.test_func_transform_bilateral", # noqa "db_functions.tests.FunctionTests.test_func_transform_bilateral_multivalue", # noqa diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index e84996bd5d..b95ea3e629 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): TypeCode.INT64: "IntegerField", TypeCode.STRING: "CharField", TypeCode.TIMESTAMP: "DateTimeField", + TypeCode.NUMERIC: "DecimalField", } def get_field_type(self, data_type, description): diff --git a/django_spanner/lookups.py b/django_spanner/lookups.py index cad536c914..c2e642d26a 100644 --- a/django_spanner/lookups.py +++ b/django_spanner/lookups.py @@ -4,7 +4,6 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.db.models import DecimalField from django.db.models.lookups import ( Contains, EndsWith, @@ -233,13 +232,8 @@ def cast_param_to_float(self, compiler, connection): """ sql, params = self.as_sql(compiler, connection) if params: - # Cast to DecimaField lookup values to float because - # google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize - # decimal.Decimal. - if isinstance(self.lhs.output_field, DecimalField): - params[0] = float(params[0]) # Cast remote field lookups that must be integer but come in as string. - elif hasattr(self.lhs.output_field, "get_path_info"): + if hasattr(self.lhs.output_field, "get_path_info"): for i, field in enumerate( self.lhs.output_field.get_path_info()[-1].target_fields ): diff --git a/django_spanner/operations.py b/django_spanner/operations.py index e3ff7471ec..48a3e3cef3 100644 --- a/django_spanner/operations.py +++ b/django_spanner/operations.py @@ -8,7 +8,6 @@ import re from base64 import b64decode from datetime import datetime, time -from decimal import Decimal from uuid import UUID from django.conf import settings @@ -190,10 +189,11 @@ def adapt_decimalfield_value( self, value, max_digits=None, decimal_places=None ): """ - Convert value from decimal.Decimal into float, for a direct mapping - and correct serialization with RPCs to Cloud Spanner. + Convert value from decimal.Decimal to spanner compatible value. + Since spanner supports Numeric storage of decimal and python spanner + takes care of the conversion so this is a no-op method call. - :type value: :class:`~google.cloud.spanner_v1.types.Numeric` + :type value: :class:`decimal.Decimal` :param value: A decimal field value. :type max_digits: int @@ -203,12 +203,10 @@ def adapt_decimalfield_value( :param decimal_places: (Optional) The number of decimal places to store with the number. - :rtype: float - :returns: Formatted value. + :rtype: decimal.Decimal + :returns: decimal value. """ - if value is None: - return None - return float(value) + return value def adapt_timefield_value(self, value): """ @@ -244,8 +242,6 @@ def get_db_converters(self, expression): internal_type = expression.output_field.get_internal_type() if internal_type == "DateTimeField": converters.append(self.convert_datetimefield_value) - elif internal_type == "DecimalField": - converters.append(self.convert_decimalfield_value) elif internal_type == "TimeField": converters.append(self.convert_timefield_value) elif internal_type == "BinaryField": @@ -311,26 +307,6 @@ def convert_datetimefield_value(self, value, expression, connection): else dt ) - def convert_decimalfield_value(self, value, expression, connection): - """Convert Spanner DecimalField value for Django. - - :type value: float - :param value: A decimal field. - - :type expression: :class:`django.db.models.expressions.BaseExpression` - :param expression: A query expression. - - :type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection` - :param connection: Reference to a Spanner database connection. - - :rtype: :class:`Decimal` - :returns: A converted decimal field. - """ - if value is None: - return value - # Cloud Spanner returns a float. - return Decimal(str(value)) - def convert_timefield_value(self, value, expression, connection): """Convert Spanner TimeField value for Django. diff --git a/django_spanner/validation.py b/django_spanner/validation.py deleted file mode 100644 index 99f270d3d6..0000000000 --- a/django_spanner/validation.py +++ /dev/null @@ -1,33 +0,0 @@ -import os - -from django.core import checks -from django.db.backends.base.validation import BaseDatabaseValidation -from django.db.models import DecimalField - - -class DatabaseValidation(BaseDatabaseValidation): - def check_field_type(self, field, field_type): - """Check field type and collect errors. - - :type field: :class:`~django.db.migrations.operations.models.fields.FieldOperation` - :param field: The field of the table. - - :type field_type: str - :param field_type: The type of the field. - - :rtype: list - :return: A list of errors. - """ - errors = [] - # Disable the error when running the Django test suite. - if os.environ.get( - "RUNNING_SPANNER_BACKEND_TESTS" - ) != "1" and isinstance(field, DecimalField): - errors.append( - checks.Error( - "DecimalField is not yet supported by Spanner.", - obj=field, - id="spanner.E001", - ) - ) - return errors diff --git a/noxfile.py b/noxfile.py index 89f7b2904e..1f70054f5f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -84,7 +84,7 @@ def default(session): "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=68", + "--cov-fail-under=65", os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/unit/django_spanner/test_lookups.py b/tests/unit/django_spanner/test_lookups.py index 90b17e5515..1931d255aa 100644 --- a/tests/unit/django_spanner/test_lookups.py +++ b/tests/unit/django_spanner/test_lookups.py @@ -7,13 +7,16 @@ from django_spanner.compiler import SQLCompiler from django.db.models import F from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from decimal import Decimal from .models import Number, Author class TestLookups(SpannerSimpleTestClass): def test_cast_param_to_float_lte_sql_query(self): - qs1 = Number.objects.filter(decimal_num__lte=1.1).values("decimal_num") + qs1 = Number.objects.filter(decimal_num__lte=Decimal("1.1")).values( + "decimal_num" + ) compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( @@ -21,7 +24,7 @@ def test_cast_param_to_float_lte_sql_query(self): "SELECT tests_number.decimal_num FROM tests_number WHERE " + "tests_number.decimal_num <= %s", ) - self.assertEqual(params, (1.1,)) + self.assertEqual(params, (Decimal("1.1"),)) def test_cast_param_to_float_for_int_field_query(self): diff --git a/tests/unit/django_spanner/test_operations.py b/tests/unit/django_spanner/test_operations.py index ae6384233a..e2a77148b2 100644 --- a/tests/unit/django_spanner/test_operations.py +++ b/tests/unit/django_spanner/test_operations.py @@ -7,6 +7,7 @@ from django.db.utils import DatabaseError from datetime import timedelta from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from decimal import Decimal class TestOperations(SpannerSimpleTestClass): @@ -58,7 +59,8 @@ def test_adapt_datefield_value_none(self): def test_adapt_decimalfield_value(self): self.assertIsInstance( - self.db_operations.adapt_decimalfield_value(value=1), float, + self.db_operations.adapt_decimalfield_value(value=Decimal("1")), + Decimal, ) def test_adapt_decimalfield_value_none(self): @@ -93,23 +95,6 @@ def test_adapt_timefield_value_none(self): self.db_operations.adapt_timefield_value(value=None), ) - def test_convert_decimalfield_value(self): - from decimal import Decimal - - self.assertIsInstance( - self.db_operations.convert_decimalfield_value( - value=1.0, expression=None, connection=None - ), - Decimal, - ) - - def test_convert_decimalfield_value_none(self): - self.assertIsNone( - self.db_operations.convert_decimalfield_value( - value=None, expression=None, connection=None - ), - ) - def test_convert_uuidfield_value(self): import uuid diff --git a/tests/unit/django_spanner/test_validation.py b/tests/unit/django_spanner/test_validation.py deleted file mode 100644 index 5a8946aef1..0000000000 --- a/tests/unit/django_spanner/test_validation.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2021 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -from django_spanner.validation import DatabaseValidation -from django.db import connection -from django.core.checks import Error as DjangoError -from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -from .models import ModelDecimalField, ModelCharField - - -class TestValidation(SpannerSimpleTestClass): - def test_check_field_type_with_decimal_field_not_support_error(self): - """ - Checks if decimal field fails database validation as it's not - supported in spanner. - """ - field = ModelDecimalField._meta.get_field("field") - validator = DatabaseValidation(connection=connection) - self.assertEqual( - validator.check_field(field), - [ - DjangoError( - "DecimalField is not yet supported by Spanner.", - obj=field, - id="spanner.E001", - ) - ], - ) - - def test_check_field_type_with_char_field_no_error(self): - """ - Checks if string field passes database validation. - """ - field = ModelCharField._meta.get_field("field") - validator = DatabaseValidation(connection=connection) - self.assertEqual( - validator.check_field(field), [], - )