Skip to content

Commit

Permalink
feat: add decimal/numeric support
Browse files Browse the repository at this point in the history
  • Loading branch information
vi3k6i5 committed May 14, 2021
1 parent ad8e43e commit 40a8fce
Show file tree
Hide file tree
Showing 12 changed files with 423 additions and 45 deletions.
2 changes: 1 addition & 1 deletion django_spanner/base.py
Expand Up @@ -34,7 +34,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"CharField": "STRING(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "FLOAT64",
"DecimalField": "NUMERIC",
"DurationField": "INT64",
"EmailField": "STRING(%(max_length)s)",
"FileField": "STRING(%(max_length)s)",
Expand Down
9 changes: 5 additions & 4 deletions django_spanner/features.py
Expand Up @@ -233,10 +233,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"queries.test_bulk_update.BulkUpdateTests.test_large_batch",
# Spanner doesn't support random ordering.
"ordering.tests.OrderingTests.test_random_ordering",
# No matching signature for function MOD for argument types: FLOAT64,
# FLOAT64. Supported signatures: MOD(INT64, INT64)
"db_functions.math.test_mod.ModTests.test_decimal",
"db_functions.math.test_mod.ModTests.test_float",
# casting DateField to DateTimeField adds an unexpected hour:
# https://github.com/orijtech/spanner-orm/issues/260
"db_functions.comparison.test_cast.CastTests.test_cast_from_db_date_to_datetime",
Expand Down Expand Up @@ -364,6 +360,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"model_formsets.tests.ModelFormsetTest.test_prevent_change_outer_model_and_create_invalid_data",
"model_formsets_regress.tests.FormfieldShouldDeleteFormTests.test_no_delete",
"model_formsets_regress.tests.FormsetTests.test_extraneous_query_is_not_run",
# Numeric field is not supported in primary key/unique key.
"model_formsets.tests.ModelFormsetTest.test_inline_formsets_with_custom_pk",
"model_forms.tests.ModelFormBaseTest.test_exclude_and_validation",
"model_forms.tests.UniqueTest.test_unique_together",
"model_forms.tests.UniqueTest.test_override_unique_together_message",
# os.chmod() doesn't work on Kokoro?
"file_uploads.tests.DirectoryCreationTests.test_readonly_root",
# Tests that sometimes fail on Kokoro for unknown reasons.
Expand Down
1 change: 1 addition & 0 deletions django_spanner/introspection.py
Expand Up @@ -24,6 +24,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
TypeCode.INT64: "IntegerField",
TypeCode.STRING: "CharField",
TypeCode.TIMESTAMP: "DateTimeField",
TypeCode.NUMERIC: "DecimalField",
}

def get_field_type(self, data_type, description):
Expand Down
8 changes: 1 addition & 7 deletions django_spanner/lookups.py
Expand Up @@ -4,7 +4,6 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from django.db.models import DecimalField
from django.db.models.lookups import (
Contains,
EndsWith,
Expand Down Expand Up @@ -233,13 +232,8 @@ def cast_param_to_float(self, compiler, connection):
"""
sql, params = self.as_sql(compiler, connection)
if params:
# Cast to DecimaField lookup values to float because
# google.cloud.spanner_v1._helpers._make_value_pb() doesn't serialize
# decimal.Decimal.
if isinstance(self.lhs.output_field, DecimalField):
params[0] = float(params[0])
# Cast remote field lookups that must be integer but come in as string.
elif hasattr(self.lhs.output_field, "get_path_info"):
if hasattr(self.lhs.output_field, "get_path_info"):
for i, field in enumerate(
self.lhs.output_field.get_path_info()[-1].target_fields
):
Expand Down
38 changes: 7 additions & 31 deletions django_spanner/operations.py
Expand Up @@ -8,7 +8,6 @@
import re
from base64 import b64decode
from datetime import datetime, time
from decimal import Decimal
from uuid import UUID

from django.conf import settings
Expand Down Expand Up @@ -190,10 +189,11 @@ def adapt_decimalfield_value(
self, value, max_digits=None, decimal_places=None
):
"""
Convert value from decimal.Decimal into float, for a direct mapping
and correct serialization with RPCs to Cloud Spanner.
Convert value from decimal.Decimal to spanner compatible value.
Since spanner supports Numeric storage of decimal and python spanner
takes care of the conversion so this is a no-op method call.
:type value: :class:`~google.cloud.spanner_v1.types.Numeric`
:type value: :class:`decimal.Decimal`
:param value: A decimal field value.
:type max_digits: int
Expand All @@ -203,12 +203,10 @@ def adapt_decimalfield_value(
:param decimal_places: (Optional) The number of decimal places to store
with the number.
:rtype: float
:returns: Formatted value.
:rtype: decimal.Decimal
:returns: decimal value.
"""
if value is None:
return None
return float(value)
return value

def adapt_timefield_value(self, value):
"""
Expand Down Expand Up @@ -244,8 +242,6 @@ def get_db_converters(self, expression):
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
converters.append(self.convert_datetimefield_value)
elif internal_type == "DecimalField":
converters.append(self.convert_decimalfield_value)
elif internal_type == "TimeField":
converters.append(self.convert_timefield_value)
elif internal_type == "BinaryField":
Expand Down Expand Up @@ -311,26 +307,6 @@ def convert_datetimefield_value(self, value, expression, connection):
else dt
)

def convert_decimalfield_value(self, value, expression, connection):
"""Convert Spanner DecimalField value for Django.
:type value: float
:param value: A decimal field.
:type expression: :class:`django.db.models.expressions.BaseExpression`
:param expression: A query expression.
:type connection: :class:`~google.cloud.cpanner_dbapi.connection.Connection`
:param connection: Reference to a Spanner database connection.
:rtype: :class:`Decimal`
:returns: A converted decimal field.
"""
if value is None:
return value
# Cloud Spanner returns a float.
return Decimal(str(value))

def convert_timefield_value(self, value, expression, connection):
"""Convert Spanner TimeField value for Django.
Expand Down
57 changes: 55 additions & 2 deletions noxfile.py
Expand Up @@ -10,6 +10,7 @@
from __future__ import absolute_import

import os
import pathlib
import shutil

import nox
Expand All @@ -25,7 +26,9 @@

DEFAULT_PYTHON_VERSION = "3.8"
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]

CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()


@nox.session(python=DEFAULT_PYTHON_VERSION)
Expand Down Expand Up @@ -81,7 +84,7 @@ def default(session):
"--cov-report=",
"--cov-fail-under=20",
os.path.join("tests", "unit"),
*session.posargs
*session.posargs,
)


Expand All @@ -91,6 +94,56 @@ def unit(session):
default(session)


@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS)
def system(session):
"""Run the system test suite."""
constraints_path = str(
CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt"
)
system_test_path = os.path.join("tests", "system.py")
system_test_folder_path = os.path.join("tests", "system")

# Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true.
if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false":
session.skip("RUN_SYSTEM_TESTS is set to false, skipping")
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS", ""
) and not os.environ.get("SPANNER_EMULATOR_HOST", ""):
session.skip(
"Credentials or emulator host must be set via environment variable"
)

system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
# Sanity check: only run tests if found.
if not system_test_exists and not system_test_folder_exists:
session.skip("System tests were not found")

# Use pre-release gRPC for system tests.
session.install("--pre", "grpcio")

# Install all test dependencies, then install this package into the
# virtualenv's dist-packages.
session.install(
"django~=2.2",
"mock",
"pytest",
"google-cloud-testutils",
"-c",
constraints_path,
)
session.install("-e", ".[tracing]", "-c", constraints_path)

# Run py.test against the system tests.
if system_test_exists:
session.run("py.test", "--quiet", system_test_path, *session.posargs)
if system_test_folder_exists:
session.run(
"py.test", "--quiet", system_test_folder_path, *session.posargs
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def cover(session):
"""Run the final coverage report.
Expand Down
19 changes: 19 additions & 0 deletions tests/system/conftest.py
@@ -0,0 +1,19 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import os
import django
from django.conf import settings

# We manually designate which settings we will be using in an environment
# variable. This is similar to what occurs in the `manage.py` file.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.system.settings")


# `pytest` automatically calls this function once when tests are run.
def pytest_configure():
settings.DEBUG = False
django.setup()
Empty file.
23 changes: 23 additions & 0 deletions tests/system/django_spanner/models.py
@@ -0,0 +1,23 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

"""
Different models used by system tests in django-spanner code.
"""
from django.db import models


class Author(models.Model):
first_name = models.CharField(max_length=20)
last_name = models.CharField(max_length=20)
rating = models.DecimalField()


class Number(models.Model):
num = models.DecimalField()

def __str__(self):
return str(self.num)
117 changes: 117 additions & 0 deletions tests/system/django_spanner/test_decimal.py
@@ -0,0 +1,117 @@
# Copyright 2021 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from .models import Author, Number
from django.test import TransactionTestCase
from django.db import connection, ProgrammingError
from django.db.utils import IntegrityError
from decimal import Decimal
from tests.system.django_spanner.utils import (
setup_instance,
teardown_instance,
setup_database,
teardown_database,
USE_EMULATOR,
)


class TestDecimal(TransactionTestCase):
@classmethod
def setUpClass(cls):
setup_instance()
setup_database()
with connection.schema_editor() as editor:
# Create the tables
editor.create_model(Author)
editor.create_model(Number)

@classmethod
def tearDownClass(cls):
with connection.schema_editor() as editor:
# delete the table
editor.delete_model(Author)
editor.delete_model(Number)
teardown_database()
teardown_instance()

def rating_transform(self, value):
return value["rating"]

def values_transform(self, value):
return value.num

def assertValuesEqual(
self, queryset, expected_values, transformer, ordered=True
):
self.assertQuerysetEqual(
queryset, expected_values, transformer, ordered
)

def test_insert_and_search_decimal_value(self):
"""
Tests model object creation with Author model.
"""
author_kent = Author(
first_name="Arthur", last_name="Kent", rating=Decimal("4.1"),
)
author_kent.save()
qs1 = Author.objects.filter(rating__gte=3).values("rating")
self.assertValuesEqual(
qs1, [Decimal("4.1")], self.rating_transform,
)
# Delete data from Author table.
Author.objects.all().delete()

def test_decimal_filter(self):
"""
Tests decimal filter query.
"""
# Insert data into Number table.
Number.objects.bulk_create(
Number(num=Decimal(i) / Decimal(10)) for i in range(10)
)
qs1 = Number.objects.filter(num__lte=Decimal(2) / Decimal(10))
self.assertValuesEqual(
qs1,
[Decimal(i) / Decimal(10) for i in range(3)],
self.values_transform,
ordered=False,
)
# Delete data from Number table.
Number.objects.all().delete()

def test_decimal_precision_limit(self):
"""
Tests decimal object precission limit.
"""
num_val = Number(num=Decimal(1) / Decimal(3))
if USE_EMULATOR:
msg = "The NUMERIC type supports 38 digits of precision and 9 digits of scale."
with self.assertRaisesRegex(IntegrityError, msg):
num_val.save()
else:
msg = "400 Invalid value for bind parameter a0: Expected NUMERIC."
with self.assertRaisesRegex(ProgrammingError, msg):
num_val.save()

def test_decimal_update(self):
"""
Tests decimal object update.
"""
author_kent = Author(
first_name="Arthur", last_name="Kent", rating=Decimal("4.1"),
)
author_kent.save()
author_kent.rating = Decimal("4.2")
author_kent.save()
qs1 = Author.objects.filter(rating__gte=Decimal("4.2")).values(
"rating"
)
self.assertValuesEqual(
qs1, [Decimal("4.2")], self.rating_transform,
)
# Delete data from Author table.
Author.objects.all().delete()

0 comments on commit 40a8fce

Please sign in to comment.