diff --git a/noxfile.py b/noxfile.py index a19bbc4360..89f7b2904e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -66,7 +66,12 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. session.install( - "django~=2.2", "mock", "mock-import", "pytest", "pytest-cov" + "django~=2.2", + "mock", + "mock-import", + "pytest", + "pytest-cov", + "coverage", ) session.install("-e", ".") @@ -79,7 +84,7 @@ def default(session): "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=20", + "--cov-fail-under=68", os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..308a903387 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import os +import django +from django.conf import settings + +# We manually designate which settings we will be using in an environment +# variable. This is similar to what occurs in the `manage.py` file. +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") + + +# `pytest` automatically calls this function once when tests are run. +def pytest_configure(): + settings.DEBUG = False + django.setup() diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000000..ecf6567d46 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,46 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +DEBUG = True +USE_TZ = True + +INSTALLED_APPS = [ + "django_spanner", # Must be the first entry + "django.contrib.contenttypes", + "django.contrib.auth", + "django.contrib.sites", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "tests", +] + +TIME_ZONE = "UTC" + +DATABASES = { + "default": { + "ENGINE": "django_spanner", + "PROJECT": "emulator-local", + "INSTANCE": "django-test-instance", + "NAME": "django-test-db", + } +} +SECRET_KEY = "spanner emulator secret key" + +PASSWORD_HASHERS = [ + "django.contrib.auth.hashers.MD5PasswordHasher", +] + +SITE_ID = 1 + +CONN_MAX_AGE = 60 + +ENGINE = "django_spanner" +PROJECT = "emulator-local" +INSTANCE = "django-test-instance" +NAME = "django-test-db" +OPTIONS = {} +AUTOCOMMIT = True diff --git a/tests/unit/django_spanner/__init__.py b/tests/unit/django_spanner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/django_spanner/models.py b/tests/unit/django_spanner/models.py new file mode 100644 index 0000000000..8dfb9d8e48 --- /dev/null +++ b/tests/unit/django_spanner/models.py @@ -0,0 +1,61 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd +""" +Different models used for testing django-spanner code. +""" +from django.db import models + + +# Register transformations for model fields. +class UpperCase(models.Transform): + lookup_name = "upper" + function = "UPPER" + bilateral = True + + +models.CharField.register_lookup(UpperCase) +models.TextField.register_lookup(UpperCase) + + +# Models +class ModelDecimalField(models.Model): + field = models.DecimalField() + + +class ModelCharField(models.Model): + field = models.CharField() + + +class Item(models.Model): + item_id = models.IntegerField() + name = models.CharField(max_length=10) + created = models.DateTimeField() + modified = models.DateTimeField(blank=True, null=True) + + class Meta: + ordering = ["name"] + + +class Number(models.Model): + num = models.IntegerField() + decimal_num = models.DecimalField(max_digits=5, decimal_places=2) + item = models.ForeignKey(Item, models.CASCADE) + + +class Author(models.Model): + name = models.CharField(max_length=40) + last_name = models.CharField(max_length=40) + num = models.IntegerField(unique=True) + created = models.DateTimeField() + modified = models.DateTimeField(blank=True, null=True) + + +class Report(models.Model): + name = models.CharField(max_length=10) + creator = models.ForeignKey(Author, models.CASCADE, null=True) + + class Meta: + ordering = ["name"] diff --git a/tests/unit/django_spanner/simple_test.py b/tests/unit/django_spanner/simple_test.py new file mode 100644 index 0000000000..1fcb92bd29 --- /dev/null +++ b/tests/unit/django_spanner/simple_test.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.client import DatabaseClient +from django_spanner.base import DatabaseWrapper +from django_spanner.operations import DatabaseOperations +from unittest import TestCase +import os + + +class SpannerSimpleTestClass(TestCase): + @classmethod + def setUpClass(cls): + cls.PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] + + cls.INSTANCE_ID = "instance_id" + cls.DATABASE_ID = "database_id" + cls.USER_AGENT = "django_spanner/2.2.0a1" + cls.OPTIONS = {"option": "dummy"} + + cls.settings_dict = { + "PROJECT": cls.PROJECT, + "INSTANCE": cls.INSTANCE_ID, + "NAME": cls.DATABASE_ID, + "user_agent": cls.USER_AGENT, + "OPTIONS": cls.OPTIONS, + } + cls.db_client = DatabaseClient(cls.settings_dict) + cls.db_wrapper = cls.connection = DatabaseWrapper(cls.settings_dict) + cls.db_operations = DatabaseOperations(cls.connection) diff --git a/tests/unit/django_spanner/test_base.py b/tests/unit/django_spanner/test_base.py index 32d965b9d1..9b2d60c1c4 100644 --- a/tests/unit/django_spanner/test_base.py +++ b/tests/unit/django_spanner/test_base.py @@ -4,59 +4,24 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys -import unittest -import os - -from mock_import import mock_import from unittest import mock +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -@mock_import() -@unittest.skipIf( - sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5" -) -class TestBase(unittest.TestCase): - PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] - INSTANCE_ID = "instance_id" - DATABASE_ID = "database_id" - USER_AGENT = "django_spanner/2.2.0a1" - OPTIONS = {"option": "dummy"} - - settings_dict = { - "PROJECT": PROJECT, - "INSTANCE": INSTANCE_ID, - "NAME": DATABASE_ID, - "user_agent": USER_AGENT, - "OPTIONS": OPTIONS, - } - - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - +class TestBase(SpannerSimpleTestClass): def test_property_instance(self): - settings_dict = {"INSTANCE": "instance"} - db_wrapper = self._make_one(settings_dict=settings_dict) - with mock.patch("django_spanner.base.spanner") as mock_spanner: mock_spanner.Client = mock_client = mock.MagicMock() mock_client().instance = mock_instance = mock.MagicMock() - _ = db_wrapper.instance - mock_instance.assert_called_once_with(settings_dict["INSTANCE"]) + _ = self.db_wrapper.instance + mock_instance.assert_called_once_with(self.INSTANCE_ID) - def test_property__nodb_connection(self): - db_wrapper = self._make_one(None) + def test_property_nodb_connection(self): with self.assertRaises(NotImplementedError): - db_wrapper._nodb_connection() + self.db_wrapper._nodb_connection() def test_get_connection_params(self): - db_wrapper = self._make_one(self.settings_dict) - params = db_wrapper.get_connection_params() + params = self.db_wrapper.get_connection_params() self.assertEqual(params["project"], self.PROJECT) self.assertEqual(params["instance_id"], self.INSTANCE_ID) @@ -65,54 +30,50 @@ def test_get_connection_params(self): self.assertEqual(params["option"], self.OPTIONS["option"]) def test_get_new_connection(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.Database = mock_database = mock.MagicMock() + self.db_wrapper.Database = mock_database = mock.MagicMock() mock_database.connect = mock_connection = mock.MagicMock() conn_params = {"test_param": "dummy"} - db_wrapper.get_new_connection(conn_params) + self.db_wrapper.get_new_connection(conn_params) mock_connection.assert_called_once_with(**conn_params) def test_init_connection_state(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.close = mock_close = mock.MagicMock() - db_wrapper.init_connection_state() + self.db_wrapper.init_connection_state() mock_close.assert_called_once_with() def test_create_cursor(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock_cursor = mock.MagicMock() - db_wrapper.create_cursor() + self.db_wrapper.create_cursor() mock_cursor.assert_called_once_with() - def test__set_autocommit(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + def test_set_autocommit(self): + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.autocommit = False - db_wrapper._set_autocommit(True) + self.db_wrapper._set_autocommit(True) self.assertEqual(mock_connection.autocommit, True) def test_is_usable(self): - from google.cloud.spanner_dbapi.exceptions import Error - - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = None - self.assertFalse(db_wrapper.is_usable()) + self.db_wrapper.connection = None + self.assertFalse(self.db_wrapper.is_usable()) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.is_closed = True - self.assertFalse(db_wrapper.is_usable()) + self.assertFalse(self.db_wrapper.is_usable()) mock_connection.is_closed = False - self.assertTrue(db_wrapper.is_usable()) + self.assertTrue(self.db_wrapper.is_usable()) + + def test_is_usable_with_error(self): + from google.cloud.spanner_dbapi.exceptions import Error + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock.MagicMock(side_effect=Error) - self.assertFalse(db_wrapper.is_usable()) + self.assertFalse(self.db_wrapper.is_usable()) - def test__start_transaction_under_autocommit(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + def test_start_transaction_under_autocommit(self): + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock_cursor = mock.MagicMock() - db_wrapper._start_transaction_under_autocommit() + self.db_wrapper._start_transaction_under_autocommit() mock_cursor.assert_called_once_with() diff --git a/tests/unit/django_spanner/test_client.py b/tests/unit/django_spanner/test_client.py index fd02434b04..10b38fb2f9 100644 --- a/tests/unit/django_spanner/test_client.py +++ b/tests/unit/django_spanner/test_client.py @@ -4,41 +4,12 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys -import unittest -import os +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -@unittest.skipIf( - sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5" -) -class TestClient(unittest.TestCase): - PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] - INSTANCE_ID = "instance_id" - DATABASE_ID = "database_id" - USER_AGENT = "django_spanner/2.2.0a1" - OPTIONS = {"option": "dummy"} - - settings_dict = { - "PROJECT": PROJECT, - "INSTANCE": INSTANCE_ID, - "NAME": DATABASE_ID, - "user_agent": USER_AGENT, - "OPTIONS": OPTIONS, - } - - def _get_target_class(self): - from django_spanner.client import DatabaseClient - - return DatabaseClient - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) +class TestClient(SpannerSimpleTestClass): def test_runshell(self): - from google.cloud.spanner_dbapi.exceptions import NotSupportedError - - db_wrapper = self._make_one(self.settings_dict) - with self.assertRaises(NotSupportedError): - db_wrapper.runshell(parameters=self.settings_dict) + self.db_client.runshell(parameters=self.settings_dict) diff --git a/tests/unit/django_spanner/test_compiler.py b/tests/unit/django_spanner/test_compiler.py new file mode 100644 index 0000000000..11fd1222a0 --- /dev/null +++ b/tests/unit/django_spanner/test_compiler.py @@ -0,0 +1,173 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.core.exceptions import EmptyResultSet +from django.db.utils import DatabaseError +from django_spanner.compiler import SQLCompiler +from django.db.models.query import QuerySet +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from .models import Number + + +class TestCompiler(SpannerSimpleTestClass): + def test_unsupported_ordering_slicing_raises_db_error(self): + """ + Tries limit/offset and order by in subqueries which are not supported + by spanner. + """ + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = "LIMIT/OFFSET not allowed in subqueries of compound statements" + with self.assertRaisesRegex(DatabaseError, msg): + list(qs1.union(qs2[:10])) + msg = "ORDER BY not allowed in subqueries of compound statements" + with self.assertRaisesRegex(DatabaseError, msg): + list(qs1.order_by("id").union(qs2)) + + def test_get_combinator_sql_all_union_sql_generated(self): + """ + Tries union sql generator. + """ + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.union(qs2) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", True) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION ALL SELECT tests_number.num " + + "FROM tests_number WHERE tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_distinct_union_sql_generated(self): + """ + Tries union sql generator with distinct. + """ + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.union(qs2) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT SELECT " + + "tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_difference_all_sql_generated(self): + """ + Tries difference sql generator. + """ + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.difference(qs2) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + sql_compiled, params = compiler.get_combinator_sql("difference", True) + + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num " + + "FROM tests_number WHERE tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_difference_distinct_sql_generated(self): + """ + Tries difference sql generator with distinct. + """ + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.difference(qs2) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + sql_compiled, params = compiler.get_combinator_sql("difference", False) + + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s EXCEPT DISTINCT SELECT " + + "tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_union_and_difference_query_together(self): + """ + Tries sql generator with union of queryset with queryset of difference. + """ + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs3 = Number.objects.filter(num__exact=10).values("num") + qs4 = qs1.union(qs2.difference(qs3)) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT SELECT * FROM (" + + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s EXCEPT DISTINCT " + + "SELECT tests_number.num FROM tests_number " + + "WHERE tests_number.num = %s)" + ], + ) + self.assertEqual(params, [1, 8, 10]) + + def test_get_combinator_sql_parentheses_in_compound_not_supported(self): + """ + Tries sql generator with union of queryset with queryset of difference, + adding support for parentheses in compound sql statement. + """ + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs3 = Number.objects.filter(num__exact=10).values("num") + qs4 = qs1.union(qs2.difference(qs3)) + + compiler = SQLCompiler(qs4.query, self.connection, "default") + compiler.connection.features.supports_parentheses_in_compound = False + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT SELECT * FROM (" + + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s EXCEPT DISTINCT " + + "SELECT tests_number.num FROM tests_number " + + "WHERE tests_number.num = %s)" + ], + ) + self.assertEqual(params, [1, 8, 10]) + + def test_get_combinator_sql_empty_queryset_raises_exception(self): + """ + Tries sql generator with empty queryset. + """ + compiler = SQLCompiler(QuerySet().query, self.connection, "default") + with self.assertRaises(EmptyResultSet): + compiler.get_combinator_sql("union", False) diff --git a/tests/unit/django_spanner/test_expressions.py b/tests/unit/django_spanner/test_expressions.py new file mode 100644 index 0000000000..0efc99ce08 --- /dev/null +++ b/tests/unit/django_spanner/test_expressions.py @@ -0,0 +1,46 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.compiler import SQLCompiler +from django.db.models import F +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from .models import Report + + +class TestExpressions(SpannerSimpleTestClass): + def test_order_by_sql_query_with_order_by_null_last(self): + qs1 = Report.objects.values("name").order_by( + F("name").desc(nulls_last=True) + ) + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name IS NULL, tests_report.name DESC", + ) + + def test_order_by_sql_query_with_order_by_null_first(self): + qs1 = Report.objects.values("name").order_by( + F("name").desc(nulls_first=True) + ) + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name IS NOT NULL, tests_report.name DESC", + ) + + def test_order_by_sql_query_with_order_by_name(self): + qs1 = Report.objects.values("name") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name ASC", + ) diff --git a/tests/unit/django_spanner/test_lookups.py b/tests/unit/django_spanner/test_lookups.py new file mode 100644 index 0000000000..90b17e5515 --- /dev/null +++ b/tests/unit/django_spanner/test_lookups.py @@ -0,0 +1,282 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.compiler import SQLCompiler +from django.db.models import F +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from .models import Number, Author + + +class TestLookups(SpannerSimpleTestClass): + def test_cast_param_to_float_lte_sql_query(self): + + qs1 = Number.objects.filter(decimal_num__lte=1.1).values("decimal_num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.decimal_num FROM tests_number WHERE " + + "tests_number.decimal_num <= %s", + ) + self.assertEqual(params, (1.1,)) + + def test_cast_param_to_float_for_int_field_query(self): + + qs1 = Number.objects.filter(num__lte=1.1).values("num") + + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s", + ) + self.assertEqual(params, (1,)) + + def test_cast_param_to_float_for_foreign_key_field_query(self): + + qs1 = Number.objects.filter(item_id__exact="10").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.item_id = %s", + ) + self.assertEqual(params, (10,)) + + def test_cast_param_to_float_with_no_params_query(self): + + qs1 = Number.objects.filter(item_id__exact=F("num")).values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.item_id = (tests_number.num)", + ) + self.assertEqual(params, ()) + + def test_startswith_endswith_sql_query_with_startswith(self): + + qs1 = Author.objects.filter(name__startswith="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("^abc",)) + + def test_startswith_endswith_sql_query_with_endswith(self): + + qs1 = Author.objects.filter(name__endswith="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc$",)) + + def test_startswith_endswith_sql_query_case_insensitive(self): + + qs1 = Author.objects.filter(name__istartswith="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)^abc",)) + + def test_startswith_endswith_sql_query_with_bileteral_transform(self): + + qs1 = Author.objects.filter(name__upper__startswith="abc").values( + "name" + ) + + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('^', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_startswith_endswith_case_insensitive_transform_sql_query(self): + + qs1 = Author.objects.filter(name__upper__istartswith="abc").values( + "name" + ) + + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('^(?i)', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_startswith_endswith_endswith_sql_query_with_transform(self): + + qs1 = Author.objects.filter(name__upper__endswith="abc").values("name") + + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('', (UPPER(%s)), '$'), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_sensitive(self): + + qs1 = Author.objects.filter(name__regex="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_insensitive(self): + + qs1 = Author.objects.filter(name__iregex="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)abc",)) + + def test_regex_sql_query_case_sensitive_with_transform(self): + + qs1 = Author.objects.filter(name__upper__regex="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "(UPPER(%s)))", + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_insensitive_with_transform(self): + + qs1 = Author.objects.filter(name__upper__iregex="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "CONCAT('(?i)', (UPPER(%s))))", + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_insensitive(self): + + qs1 = Author.objects.filter(name__icontains="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)abc",)) + + def test_contains_sql_query_case_sensitive(self): + + qs1 = Author.objects.filter(name__contains="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_insensitive_transform(self): + + qs1 = Author.objects.filter(name__upper__icontains="abc").values( + "name" + ) + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('(?i)', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_sensitive_transform(self): + + qs1 = Author.objects.filter(name__upper__contains="abc").values("name") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + 'REPLACE(REPLACE(REPLACE((UPPER(%s)), "\\\\", "\\\\\\\\"), ' + + '"%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_iexact_sql_query_case_insensitive(self): + + qs1 = Author.objects.filter(name__iexact="abc").values("num") + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("^(?i)abc$",)) + + def test_iexact_sql_query_case_insensitive_function_transform(self): + + qs1 = Author.objects.filter(name__upper__iexact=F("last_name")).values( + "name" + ) + compiler = SQLCompiler(qs1.query, self.connection, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS((UPPER(tests_author.last_name)), " + + "CONCAT('^(?i)', CAST(UPPER(tests_author.name) AS STRING), '$'))", + ) + self.assertEqual(params, ()) diff --git a/tests/unit/django_spanner/test_operations.py b/tests/unit/django_spanner/test_operations.py new file mode 100644 index 0000000000..ae6384233a --- /dev/null +++ b/tests/unit/django_spanner/test_operations.py @@ -0,0 +1,272 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.db.utils import DatabaseError +from datetime import timedelta +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass + + +class TestOperations(SpannerSimpleTestClass): + def test_max_name_length(self): + self.assertEqual(self.db_operations.max_name_length(), 128) + + def test_quote_name(self): + quoted_name = self.db_operations.quote_name("abc") + self.assertEqual(quoted_name, "abc") + + def test_quote_name_spanner_reserved_keyword_escaped(self): + quoted_name = self.db_operations.quote_name("ALL") + self.assertEqual(quoted_name, "`ALL`") + + def test_bulk_batch_size(self): + self.assertEqual( + self.db_operations.bulk_batch_size(fields=None, objs=None), + self.db_operations.connection.features.max_query_params, + ) + + def test_sql_flush(self): + from django.core.management.color import no_style + + self.assertEqual( + self.db_operations.sql_flush( + style=no_style(), tables=["Table1, Table2"] + ), + ["DELETE FROM `Table1, Table2`"], + ) + + def test_sql_flush_empty_table_list(self): + from django.core.management.color import no_style + + self.assertEqual( + self.db_operations.sql_flush(style=no_style(), tables=[]), [], + ) + + def test_adapt_datefield_value(self): + from google.cloud.spanner_dbapi.types import DateStr + + self.assertIsInstance( + self.db_operations.adapt_datefield_value("dummy_date"), DateStr, + ) + + def test_adapt_datefield_value_none(self): + self.assertIsNone( + self.db_operations.adapt_datefield_value(value=None), + ) + + def test_adapt_decimalfield_value(self): + self.assertIsInstance( + self.db_operations.adapt_decimalfield_value(value=1), float, + ) + + def test_adapt_decimalfield_value_none(self): + self.assertIsNone( + self.db_operations.adapt_decimalfield_value(value=None), + ) + + def test_convert_binaryfield_value(self): + from base64 import b64encode + + self.assertEqual( + self.db_operations.convert_binaryfield_value( + value=b64encode(b"abc"), expression=None, connection=None + ), + b"abc", + ) + + def test_convert_binaryfield_value_none(self): + self.assertIsNone( + self.db_operations.convert_binaryfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_adapt_datetimefield_value_none(self): + self.assertIsNone( + self.db_operations.adapt_datetimefield_value(value=None), + ) + + def test_adapt_timefield_value_none(self): + self.assertIsNone( + self.db_operations.adapt_timefield_value(value=None), + ) + + def test_convert_decimalfield_value(self): + from decimal import Decimal + + self.assertIsInstance( + self.db_operations.convert_decimalfield_value( + value=1.0, expression=None, connection=None + ), + Decimal, + ) + + def test_convert_decimalfield_value_none(self): + self.assertIsNone( + self.db_operations.convert_decimalfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_convert_uuidfield_value(self): + import uuid + + uuid_obj = uuid.uuid4() + self.assertEqual( + self.db_operations.convert_uuidfield_value( + str(uuid_obj), expression=None, connection=None + ), + uuid_obj, + ) + + def test_convert_uuidfield_value_none(self): + self.assertIsNone( + self.db_operations.convert_uuidfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_date_extract_sql(self): + self.assertEqual( + self.db_operations.date_extract_sql("week", "dummy_field"), + "EXTRACT(isoweek FROM dummy_field)", + ) + + def test_date_extract_sql_lookup_type_dayofweek(self): + self.assertEqual( + self.db_operations.date_extract_sql("dayofweek", "dummy_field"), + "EXTRACT(dayofweek FROM dummy_field)", + ) + + def test_datetime_extract_sql(self): + from django.conf import settings + + settings.USE_TZ = True + self.assertEqual( + self.db_operations.datetime_extract_sql( + "dayofweek", "dummy_field", "IST" + ), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "IST")', + ) + + def test_datetime_extract_sql_use_tz_false(self): + from django.conf import settings + + settings.USE_TZ = False + self.assertEqual( + self.db_operations.datetime_extract_sql( + "dayofweek", "dummy_field", "IST" + ), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', + ) + settings.USE_TZ = True # reset changes. + + def test_time_extract_sql(self): + self.assertEqual( + self.db_operations.time_extract_sql("dayofweek", "dummy_field"), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', + ) + + def test_time_trunc_sql(self): + self.assertEqual( + self.db_operations.time_trunc_sql("dayofweek", "dummy_field"), + 'TIMESTAMP_TRUNC(dummy_field, dayofweek, "UTC")', + ) + + def test_datetime_cast_date_sql(self): + self.assertEqual( + self.db_operations.datetime_cast_date_sql("dummy_field", "IST"), + 'DATE(dummy_field, "IST")', + ) + + def test_datetime_cast_time_sql(self): + from django.conf import settings + + settings.USE_TZ = True + self.assertEqual( + self.db_operations.datetime_cast_time_sql("dummy_field", "IST"), + "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'IST'))", + ) + + def test_datetime_cast_time_sql_use_tz_false(self): + from django.conf import settings + + settings.USE_TZ = False + self.assertEqual( + self.db_operations.datetime_cast_time_sql("dummy_field", "IST"), + "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'UTC'))", + ) + settings.USE_TZ = True # reset changes. + + def test_date_interval_sql(self): + self.assertEqual( + self.db_operations.date_interval_sql(timedelta(days=1)), + "INTERVAL 86400000000 MICROSECOND", + ) + + def test_format_for_duration_arithmetic(self): + self.assertEqual( + self.db_operations.format_for_duration_arithmetic(1200), + "INTERVAL 1200 MICROSECOND", + ) + + def test_combine_expression_mod(self): + self.assertEqual( + self.db_operations.combine_expression("%%", ["10", "2"]), + "MOD(10, 2)", + ) + + def test_combine_expression_power(self): + self.assertEqual( + self.db_operations.combine_expression("^", ["10", "2"]), + "POWER(10, 2)", + ) + + def test_combine_expression_bit_extention(self): + self.assertEqual( + self.db_operations.combine_expression(">>", ["10", "2"]), + "CAST(FLOOR(10 / POW(2, 2)) AS INT64)", + ) + + def test_combine_expression_multiply(self): + self.assertEqual( + self.db_operations.combine_expression("*", ["10", "2"]), "10 * 2", + ) + + def test_combine_duration_expression_add(self): + self.assertEqual( + self.db_operations.combine_duration_expression( + "+", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ), + 'TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00, INTERVAL 10 MINUTE)', + ) + + def test_combine_duration_expression_subtract(self): + self.assertEqual( + self.db_operations.combine_duration_expression( + "-", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ), + 'TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00, INTERVAL 10 MINUTE)', + ) + + def test_combine_duration_expression_database_error(self): + msg = "Invalid connector for timedelta:" + with self.assertRaisesRegex(DatabaseError, msg): + self.db_operations.combine_duration_expression( + "*", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ) + + def test_lookup_cast_match_lookup_type(self): + self.assertEqual( + self.db_operations.lookup_cast("contains",), "CAST(%s AS STRING)", + ) + + def test_lookup_cast_unmatched_lookup_type(self): + self.assertEqual( + self.db_operations.lookup_cast("dummy",), "%s", + ) diff --git a/tests/unit/django_spanner/test_utils.py b/tests/unit/django_spanner/test_utils.py new file mode 100644 index 0000000000..e4d50861d0 --- /dev/null +++ b/tests/unit/django_spanner/test_utils.py @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.utils import check_django_compatability +from django.core.exceptions import ImproperlyConfigured +from django_spanner.utils import add_dummy_where +import django +import django_spanner +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass + + +class TestUtils(SpannerSimpleTestClass): + SQL_WITH_WHERE = "Select 1 from Table WHERE 1=1" + SQL_WITHOUT_WHERE = "Select 1 from Table" + + def test_check_django_compatability_match(self): + """ + Checks django compatibility match. + """ + django_spanner.__version__ = "2.2" + django.VERSION = (2, 2, 19, "alpha", 0) + check_django_compatability() + + def test_check_django_compatability_mismatch(self): + """ + Checks django compatibility mismatch. + """ + django_spanner.__version__ = "2.2" + django.VERSION = (3, 2, 19, "alpha", 0) + with self.assertRaises(ImproperlyConfigured): + check_django_compatability() + + def test_add_dummy_where_with_where_present_and_not_added(self): + """ + Checks if dummy where clause is not added when present in select + statement. + """ + updated_sql = add_dummy_where(self.SQL_WITH_WHERE) + self.assertEqual(updated_sql, self.SQL_WITH_WHERE) + + def test_add_dummy_where_with_where_not_present_and_added(self): + """ + Checks if dummy where clause is added when not present in select + statement. + """ + updated_sql = add_dummy_where(self.SQL_WITHOUT_WHERE) + self.assertEqual(updated_sql, self.SQL_WITH_WHERE) diff --git a/tests/unit/django_spanner/test_validation.py b/tests/unit/django_spanner/test_validation.py new file mode 100644 index 0000000000..5a8946aef1 --- /dev/null +++ b/tests/unit/django_spanner/test_validation.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.validation import DatabaseValidation +from django.db import connection +from django.core.checks import Error as DjangoError +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass +from .models import ModelDecimalField, ModelCharField + + +class TestValidation(SpannerSimpleTestClass): + def test_check_field_type_with_decimal_field_not_support_error(self): + """ + Checks if decimal field fails database validation as it's not + supported in spanner. + """ + field = ModelDecimalField._meta.get_field("field") + validator = DatabaseValidation(connection=connection) + self.assertEqual( + validator.check_field(field), + [ + DjangoError( + "DecimalField is not yet supported by Spanner.", + obj=field, + id="spanner.E001", + ) + ], + ) + + def test_check_field_type_with_char_field_no_error(self): + """ + Checks if string field passes database validation. + """ + field = ModelCharField._meta.get_field("field") + validator = DatabaseValidation(connection=connection) + self.assertEqual( + validator.check_field(field), [], + )