diff --git a/django_spanner/__init__.py b/django_spanner/__init__.py index 861e3abb94..a26703d5a5 100644 --- a/django_spanner/__init__.py +++ b/django_spanner/__init__.py @@ -5,6 +5,7 @@ # https://developers.google.com/open-source/licenses/bsd import datetime +import os # Monkey-patch AutoField to generate a random value since Cloud Spanner can't # do that. @@ -24,6 +25,8 @@ __version__ = pkg_resources.get_distribution("django-google-spanner").version +USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None + check_django_compatability() register_expressions() register_functions() diff --git a/django_spanner/features.py b/django_spanner/features.py index af7e4c1131..050ba9c7b9 100644 --- a/django_spanner/features.py +++ b/django_spanner/features.py @@ -184,7 +184,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.comparison.test_cast.CastTests.test_cast_to_decimal_field", "model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding", "model_fields.test_decimalfield.DecimalFieldTests.test_roundtrip_with_trailing_zeros", - # No CHECK constraints in Spanner. + # Spanner does not support unsigned integer field. "model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values", # Spanner doesn't support the variance the standard deviation database # functions: diff --git a/django_spanner/schema.py b/django_spanner/schema.py index d28dcc4f6e..247358857a 100644 --- a/django_spanner/schema.py +++ b/django_spanner/schema.py @@ -7,6 +7,7 @@ from django.db import NotSupportedError from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django_spanner._opentelemetry_tracing import trace_call +from django_spanner import USE_EMULATOR class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @@ -472,8 +473,13 @@ def _alter_column_type_sql(self, model, old_field, new_field, new_type): ) def _check_sql(self, name, check): - # Spanner doesn't support CHECK constraints. - return None + # Emulator does not support check constraints yet. + if USE_EMULATOR: + return None + return self.sql_constraint % { + "name": self.quote_name(name), + "constraint": self.sql_check_constraint % {"check": check}, + } def _unique_sql(self, model, fields, name, condition=None): # Inline constraints aren't supported, so create the index separately. diff --git a/noxfile.py b/noxfile.py index a5c05e7a02..3b51d73841 100644 --- a/noxfile.py +++ b/noxfile.py @@ -43,7 +43,7 @@ def lint(session): session.run("flake8", "django_spanner", "tests") -@nox.session(python="3.6") +@nox.session(python=DEFAULT_PYTHON_VERSION) def blacken(session): """Run black. diff --git a/tests/system/django_spanner/models.py b/tests/system/django_spanner/models.py index 5524ad8ec9..f7153ba994 100644 --- a/tests/system/django_spanner/models.py +++ b/tests/system/django_spanner/models.py @@ -21,3 +21,16 @@ class Number(models.Model): def __str__(self): return str(self.num) + + +class Event(models.Model): + start_date = models.DateTimeField() + end_date = models.DateTimeField() + + class Meta: + constraints = [ + models.CheckConstraint( + check=models.Q(end_date__gt=models.F("start_date")), + name="check_start_date", + ), + ] diff --git a/tests/system/django_spanner/test_check_constraint.py b/tests/system/django_spanner/test_check_constraint.py new file mode 100644 index 0000000000..9177166ce9 --- /dev/null +++ b/tests/system/django_spanner/test_check_constraint.py @@ -0,0 +1,64 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from .models import Event +from django.test import TransactionTestCase +import datetime +import unittest +from django.utils import timezone +from google.api_core.exceptions import OutOfRange +from django.db import connection +from django_spanner import USE_EMULATOR +from tests.system.django_spanner.utils import ( + setup_instance, + teardown_instance, + setup_database, + teardown_database, +) + + +@unittest.skipIf( + USE_EMULATOR, "Check Constraint is not implemented in emulator." +) +class TestCheckConstraint(TransactionTestCase): + @classmethod + def setUpClass(cls): + setup_instance() + setup_database() + with connection.schema_editor() as editor: + # Create the table + editor.create_model(Event) + + @classmethod + def tearDownClass(cls): + with connection.schema_editor() as editor: + # delete the table + editor.delete_model(Event) + teardown_database() + teardown_instance() + + def test_insert_valid_value(self): + """ + Tests model object creation with Event model. + """ + now = timezone.now() + now_plus_10 = now + datetime.timedelta(minutes=10) + event_valid = Event(start_date=now, end_date=now_plus_10) + event_valid.save() + qs1 = Event.objects.filter().values("start_date") + self.assertEqual(qs1[0]["start_date"], now) + # Delete data from Event table. + Event.objects.all().delete() + + def test_insert_invalid_value(self): + """ + Tests model object creation with invalid data in Event model. + """ + now = timezone.now() + now_minus_1_day = now - timezone.timedelta(days=1) + event_invalid = Event(start_date=now, end_date=now_minus_1_day) + with self.assertRaises(OutOfRange): + event_invalid.save() diff --git a/tests/system/django_spanner/test_decimal.py b/tests/system/django_spanner/test_decimal.py index 73df7e796b..4155599af1 100644 --- a/tests/system/django_spanner/test_decimal.py +++ b/tests/system/django_spanner/test_decimal.py @@ -6,14 +6,13 @@ from .models import Author, Number from django.test import TransactionTestCase -from django.db import connection, ProgrammingError +from django.db import connection from decimal import Decimal from tests.system.django_spanner.utils import ( setup_instance, teardown_instance, setup_database, teardown_database, - USE_EMULATOR, ) @@ -87,12 +86,8 @@ def test_decimal_precision_limit(self): Tests decimal object precission limit. """ num_val = Number(num=Decimal(1) / Decimal(3)) - if USE_EMULATOR: - with self.assertRaises(ValueError): - num_val.save() - else: - with self.assertRaises(ProgrammingError): - num_val.save() + with self.assertRaises(ValueError): + num_val.save() def test_decimal_update(self): """ diff --git a/tests/system/django_spanner/utils.py b/tests/system/django_spanner/utils.py index 7fac5166e0..3dca9db9b8 100644 --- a/tests/system/django_spanner/utils.py +++ b/tests/system/django_spanner/utils.py @@ -15,11 +15,12 @@ from test_utils.retry import RetryErrors from django_spanner.creation import DatabaseCreation +from django_spanner import USE_EMULATOR CREATE_INSTANCE = ( os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None ) -USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None + SPANNER_OPERATION_TIMEOUT_IN_SECONDS = int( os.getenv("SPANNER_OPERATION_TIMEOUT_IN_SECONDS", 60) )