From df919bedb48b6ee4c1942384d614abe9ea19e4df Mon Sep 17 00:00:00 2001 From: Vikash Singh <3116482+vi3k6i5@users.noreply.github.com> Date: Tue, 6 Apr 2021 15:41:48 +0530 Subject: [PATCH 1/3] fix: lint_setup_py was failing in Kokoro is not fixed --- README.rst | 3 +- code-of-conduct.md | 63 --------------------------------- django_spanner/functions.py | 1 + django_spanner/introspection.py | 1 + django_spanner/operations.py | 5 ++- django_spanner/schema.py | 1 + docs/conf.py | 3 -- noxfile.py | 20 ++++++----- 8 files changed, 20 insertions(+), 77 deletions(-) delete mode 100644 code-of-conduct.md diff --git a/README.rst b/README.rst index 91be439d9e..8f1ef0440d 100644 --- a/README.rst +++ b/README.rst @@ -134,8 +134,7 @@ Contributing Contributions to this library are always welcome and highly encouraged. -See `CONTRIBUTING `_ for more information on how to get -started. +See [CONTRIBUTING][contributing] for more information on how to get started. Please note that this project is released with a Contributor Code of Conduct. By participating in this project you agree to abide by its terms. See the `Code diff --git a/code-of-conduct.md b/code-of-conduct.md deleted file mode 100644 index b24eed38ad..0000000000 --- a/code-of-conduct.md +++ /dev/null @@ -1,63 +0,0 @@ -# Google Open Source Community Guidelines - -At Google, we recognize and celebrate the creativity and collaboration of open -source contributors and the diversity of skills, experiences, cultures, and -opinions they bring to the projects and communities they participate in. - -Every one of Google's open source projects and communities are inclusive -environments, based on treating all individuals respectfully, regardless of -gender identity and expression, sexual orientation, disabilities, -neurodiversity, physical appearance, body size, ethnicity, nationality, race, -age, religion, or similar personal characteristic. - -We value diverse opinions, but we value respectful behavior more. - -Respectful behavior includes: - -* Being considerate, kind, constructive, and helpful. -* Not engaging in demeaning, discriminatory, harassing, hateful, sexualized, or - physically threatening behavior, speech, and imagery. -* Not engaging in unwanted physical contact. - -Some Google open source projects [may adopt][] an explicit project code of -conduct, which may have additional detailed expectations for participants. Most -of those projects will use our [modified Contributor Covenant][]. - -[may adopt]: https://opensource.google/docs/releasing/preparing/#conduct -[modified Contributor Covenant]: https://opensource.google/docs/releasing/template/CODE_OF_CONDUCT/ - -## Resolve peacefully - -We do not believe that all conflict is necessarily bad; healthy debate and -disagreement often yields positive results. However, it is never okay to be -disrespectful. - -If you see someone behaving disrespectfully, you are encouraged to address the -behavior directly with those involved. Many issues can be resolved quickly and -easily, and this gives people more control over the outcome of their dispute. -If you are unable to resolve the matter for any reason, or if the behavior is -threatening or harassing, report it. We are dedicated to providing an -environment where participants feel welcome and safe. - -## Reporting problems - -Some Google open source projects may adopt a project-specific code of conduct. -In those cases, a Google employee will be identified as the Project Steward, -who will receive and handle reports of code of conduct violations. In the event -that a project hasn’t identified a Project Steward, you can report problems by -emailing opensource@google.com. - -We will investigate every complaint, but you may not receive a direct response. -We will use our discretion in determining when and how to follow up on reported -incidents, which may range from not taking action to permanent expulsion from -the project and project-sponsored spaces. We will notify the accused of the -report and provide them an opportunity to discuss it before any action is -taken. The identity of the reporter will be omitted from the details of the -report supplied to the accused. In potentially harmful situations, such as -ongoing harassment or threats to anyone's safety, we may take action without -notice. - -*This document was adapted from the [IndieWeb Code of Conduct][] and can also -be found at .* - -[IndieWeb Code of Conduct]: https://indieweb.org/code-of-conduct \ No newline at end of file diff --git a/django_spanner/functions.py b/django_spanner/functions.py index bc02d0b5d8..3cf3ec73b9 100644 --- a/django_spanner/functions.py +++ b/django_spanner/functions.py @@ -28,6 +28,7 @@ class IfNull(Func): """Represent SQL `IFNULL` function.""" + function = "IFNULL" arity = 2 diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index 2dd7341972..9cefd0687f 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -15,6 +15,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): """A Spanner-specific version of Django introspection utilities.""" + data_types_reverse = { TypeCode.BOOL: "BooleanField", TypeCode.BYTES: "BinaryField", diff --git a/django_spanner/operations.py b/django_spanner/operations.py index 6ce0260c81..e3ff7471ec 100644 --- a/django_spanner/operations.py +++ b/django_spanner/operations.py @@ -25,6 +25,7 @@ class DatabaseOperations(BaseDatabaseOperations): """A Spanner-specific version of Django database operations.""" + cast_data_types = {"CharField": "STRING", "TextField": "STRING"} cast_char_field_without_max_length = "STRING" compiler_module = "django_spanner.compiler" @@ -108,7 +109,9 @@ def bulk_insert_sql(self, fields, placeholder_rows): values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql) return "VALUES " + values_sql - def sql_flush(self, style, tables, reset_sequences=False, allow_cascade=False): + def sql_flush( + self, style, tables, reset_sequences=False, allow_cascade=False + ): """ Override the base class method. Returns a list of SQL statements required to remove all data from the given database tables (without diff --git a/django_spanner/schema.py b/django_spanner/schema.py index b6c859c466..6d71f31673 100644 --- a/django_spanner/schema.py +++ b/django_spanner/schema.py @@ -13,6 +13,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): The database abstraction layer that turns things like “create a model” or “delete a field” into SQL. """ + sql_create_table = ( "CREATE TABLE %(table)s (%(definition)s) PRIMARY KEY(%(primary_key)s)" ) diff --git a/docs/conf.py b/docs/conf.py index d26c0698e6..1cffc0625d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,9 +100,6 @@ # directories to ignore when looking for source files. exclude_patterns = [ "_build", - "samples/AUTHORING_GUIDE.md", - "samples/CONTRIBUTING.md", - "samples/snippets/README.rst", ] # The reST default role (used for this markup: `text`) to use for all diff --git a/noxfile.py b/noxfile.py index 2c1edbe573..7bea0b8dda 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,13 +17,18 @@ BLACK_VERSION = "black==19.10b0" BLACK_PATHS = [ "docs", + "django_spanner", "tests", "noxfile.py", "setup.py", ] +DEFAULT_PYTHON_VERSION = "3.8" +SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] +UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"] -@nox.session(python="3.8") + +@nox.session(python=DEFAULT_PYTHON_VERSION) def lint(session): """Run linters. @@ -35,7 +40,7 @@ def lint(session): session.run("flake8", "django_spanner", "tests") -@nox.session(python="3.8") +@nox.session(python="3.6") def blacken(session): """Run black. @@ -49,7 +54,7 @@ def blacken(session): session.run("black", *BLACK_PATHS) -@nox.session(python="3.8") +@nox.session(python=DEFAULT_PYTHON_VERSION) def lint_setup_py(session): """Verify that setup.py is valid (including RST check).""" session.install("docutils", "pygments") @@ -70,23 +75,22 @@ def default(session): "py.test", "--quiet", "--cov=django_spanner", - "--cov=google.cloud", "--cov=tests.unit", "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=60", + "--cov-fail-under=20", os.path.join("tests", "unit"), *session.posargs ) -@nox.session(python="3.8") +@nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" - session.install("-e", ".") - session.install("sphinx<3.0.0", "alabaster", "recommonmark") + session.install("-e", ".[tracing]") + session.install("sphinx", "alabaster", "recommonmark") shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( From 57e4be458ebd1b125218b3603822d2544b02813c Mon Sep 17 00:00:00 2001 From: Vikash Singh <3116482+vi3k6i5@users.noreply.github.com> Date: Thu, 22 Apr 2021 22:06:59 +0530 Subject: [PATCH 2/3] feat: added test cases --- noxfile.py | 9 +- tests/settings.py | 46 +++ tests/unit/django_spanner/__init__.py | 0 tests/unit/django_spanner/models.py | 67 ++++ tests/unit/django_spanner/test_base.py | 10 +- tests/unit/django_spanner/test_client.py | 7 +- tests/unit/django_spanner/test_compiler.py | 193 +++++++++++ tests/unit/django_spanner/test_expressions.py | 62 ++++ tests/unit/django_spanner/test_lookups.py | 312 ++++++++++++++++++ tests/unit/django_spanner/test_operations.py | 305 +++++++++++++++++ tests/unit/django_spanner/test_utils.py | 50 +++ tests/unit/django_spanner/test_validation.py | 41 +++ 12 files changed, 1087 insertions(+), 15 deletions(-) create mode 100644 tests/settings.py create mode 100644 tests/unit/django_spanner/__init__.py create mode 100644 tests/unit/django_spanner/models.py create mode 100644 tests/unit/django_spanner/test_compiler.py create mode 100644 tests/unit/django_spanner/test_expressions.py create mode 100644 tests/unit/django_spanner/test_lookups.py create mode 100644 tests/unit/django_spanner/test_operations.py create mode 100644 tests/unit/django_spanner/test_utils.py create mode 100644 tests/unit/django_spanner/test_validation.py diff --git a/noxfile.py b/noxfile.py index a19bbc4360..36b68cfc81 100644 --- a/noxfile.py +++ b/noxfile.py @@ -66,7 +66,12 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. session.install( - "django~=2.2", "mock", "mock-import", "pytest", "pytest-cov" + "django~=2.2", + "mock", + "mock-import", + "pytest", + "pytest-cov", + "coverage", ) session.install("-e", ".") @@ -79,7 +84,7 @@ def default(session): "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=20", + "--cov-fail-under=70", os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000000..ecf6567d46 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,46 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +DEBUG = True +USE_TZ = True + +INSTALLED_APPS = [ + "django_spanner", # Must be the first entry + "django.contrib.contenttypes", + "django.contrib.auth", + "django.contrib.sites", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "tests", +] + +TIME_ZONE = "UTC" + +DATABASES = { + "default": { + "ENGINE": "django_spanner", + "PROJECT": "emulator-local", + "INSTANCE": "django-test-instance", + "NAME": "django-test-db", + } +} +SECRET_KEY = "spanner emulator secret key" + +PASSWORD_HASHERS = [ + "django.contrib.auth.hashers.MD5PasswordHasher", +] + +SITE_ID = 1 + +CONN_MAX_AGE = 60 + +ENGINE = "django_spanner" +PROJECT = "emulator-local" +INSTANCE = "django-test-instance" +NAME = "django-test-db" +OPTIONS = {} +AUTOCOMMIT = True diff --git a/tests/unit/django_spanner/__init__.py b/tests/unit/django_spanner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/django_spanner/models.py b/tests/unit/django_spanner/models.py new file mode 100644 index 0000000000..d856051937 --- /dev/null +++ b/tests/unit/django_spanner/models.py @@ -0,0 +1,67 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd +""" +Different models used for testing django-spanner code. +""" +import os +from django.db import models +import django + +# Load django settings before loading django models. +os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" +django.setup() + + +# Register transformations for model fields. +class UpperCase(models.Transform): + lookup_name = "upper" + function = "UPPER" + bilateral = True + + +models.CharField.register_lookup(UpperCase) +models.TextField.register_lookup(UpperCase) + + +# Models +class ModelDecimalField(models.Model): + field = models.DecimalField() + + +class ModelCharField(models.Model): + field = models.CharField() + + +class Item(models.Model): + item_id = models.IntegerField() + name = models.CharField(max_length=10) + created = models.DateTimeField() + modified = models.DateTimeField(blank=True, null=True) + + class Meta: + ordering = ["name"] + + +class Number(models.Model): + num = models.IntegerField() + decimal_num = models.DecimalField(max_digits=5, decimal_places=2) + item = models.ForeignKey(Item, models.CASCADE) + + +class Author(models.Model): + name = models.CharField(max_length=40) + last_name = models.CharField(max_length=40) + num = models.IntegerField(unique=True) + created = models.DateTimeField() + modified = models.DateTimeField(blank=True, null=True) + + +class Report(models.Model): + name = models.CharField(max_length=10) + creator = models.ForeignKey(Author, models.CASCADE, null=True) + + class Meta: + ordering = ["name"] diff --git a/tests/unit/django_spanner/test_base.py b/tests/unit/django_spanner/test_base.py index 32d965b9d1..9df06afd44 100644 --- a/tests/unit/django_spanner/test_base.py +++ b/tests/unit/django_spanner/test_base.py @@ -4,7 +4,6 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys import unittest import os @@ -13,9 +12,6 @@ @mock_import() -@unittest.skipIf( - sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5" -) class TestBase(unittest.TestCase): PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] INSTANCE_ID = "instance_id" @@ -49,7 +45,7 @@ def test_property_instance(self): _ = db_wrapper.instance mock_instance.assert_called_once_with(settings_dict["INSTANCE"]) - def test_property__nodb_connection(self): + def test_property_nodb_connection(self): db_wrapper = self._make_one(None) with self.assertRaises(NotImplementedError): db_wrapper._nodb_connection() @@ -86,7 +82,7 @@ def test_create_cursor(self): db_wrapper.create_cursor() mock_cursor.assert_called_once_with() - def test__set_autocommit(self): + def test_set_autocommit(self): db_wrapper = self._make_one(self.settings_dict) db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.autocommit = False @@ -110,7 +106,7 @@ def test_is_usable(self): mock_connection.cursor = mock.MagicMock(side_effect=Error) self.assertFalse(db_wrapper.is_usable()) - def test__start_transaction_under_autocommit(self): + def test_start_transaction_under_autocommit(self): db_wrapper = self._make_one(self.settings_dict) db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock_cursor = mock.MagicMock() diff --git a/tests/unit/django_spanner/test_client.py b/tests/unit/django_spanner/test_client.py index fd02434b04..c892bd2993 100644 --- a/tests/unit/django_spanner/test_client.py +++ b/tests/unit/django_spanner/test_client.py @@ -4,14 +4,11 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys import unittest import os +from google.cloud.spanner_dbapi.exceptions import NotSupportedError -@unittest.skipIf( - sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5" -) class TestClient(unittest.TestCase): PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] INSTANCE_ID = "instance_id" @@ -36,8 +33,6 @@ def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) def test_runshell(self): - from google.cloud.spanner_dbapi.exceptions import NotSupportedError - db_wrapper = self._make_one(self.settings_dict) with self.assertRaises(NotSupportedError): diff --git a/tests/unit/django_spanner/test_compiler.py b/tests/unit/django_spanner/test_compiler.py new file mode 100644 index 0000000000..7b18124976 --- /dev/null +++ b/tests/unit/django_spanner/test_compiler.py @@ -0,0 +1,193 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.test import SimpleTestCase +from django.core.exceptions import EmptyResultSet +from django.db.utils import DatabaseError +from django_spanner.compiler import SQLCompiler +from django.db.models.query import QuerySet +from .models import Number + + +class TestCompiler(SimpleTestCase): + settings_dict = {"dummy_param": "dummy"} + + def _get_target_class(self): + from django_spanner.base import DatabaseWrapper + + return DatabaseWrapper + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_unsupported_ordering_slicing_raises_db_error(self): + """ + Tries limit/offset and order by in subqueries which are not supported + by spanner. + """ + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = "LIMIT/OFFSET not allowed in subqueries of compound statements" + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2[:10])) + msg = "ORDER BY not allowed in subqueries of compound statements" + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.order_by("id").union(qs2)) + + def test_get_combinator_sql_all_union_sql_generated(self): + """ + Tries union sql generator. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.union(qs2) + + compiler = SQLCompiler(qs4.query, connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", True) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION ALL SELECT tests_number.num " + + "FROM tests_number WHERE tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_distinct_union_sql_generated(self): + """ + Tries union sql generator with distinct. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.union(qs2) + + compiler = SQLCompiler(qs4.query, connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT SELECT " + + "tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_difference_all_sql_generated(self): + """ + Tries difference sql generator. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.difference(qs2) + + compiler = SQLCompiler(qs4.query, connection, "default") + sql_compiled, params = compiler.get_combinator_sql("difference", True) + + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num " + + "FROM tests_number WHERE tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_difference_distinct_sql_generated(self): + """ + Tries difference sql generator with distinct. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs4 = qs1.difference(qs2) + + compiler = SQLCompiler(qs4.query, connection, "default") + sql_compiled, params = compiler.get_combinator_sql("difference", False) + + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s EXCEPT DISTINCT SELECT " + + "tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s" + ], + ) + self.assertEqual(params, [1, 8]) + + def test_get_combinator_sql_union_and_difference_query_together(self): + """ + Tries sql generator with union of queryset with queryset of difference. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs3 = Number.objects.filter(num__exact=10).values("num") + qs4 = qs1.union(qs2.difference(qs3)) + + compiler = SQLCompiler(qs4.query, connection, "default") + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT (" + + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s EXCEPT DISTINCT " + + "SELECT tests_number.num FROM tests_number " + + "WHERE tests_number.num = %s)" + ], + ) + self.assertEqual(params, [1, 8, 10]) + + def test_get_combinator_sql_parentheses_in_compound_not_supported(self): + """ + Tries sql generator with union of queryset with queryset of difference, + adding support for parentheses in compound sql statement. + """ + connection = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1).values("num") + qs2 = Number.objects.filter(num__gte=8).values("num") + qs3 = Number.objects.filter(num__exact=10).values("num") + qs4 = qs1.union(qs2.difference(qs3)) + + compiler = SQLCompiler(qs4.query, connection, "default") + compiler.connection.features.supports_parentheses_in_compound = False + sql_compiled, params = compiler.get_combinator_sql("union", False) + self.assertEqual( + sql_compiled, + [ + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s UNION DISTINCT SELECT * FROM (" + + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num >= %s EXCEPT DISTINCT " + + "SELECT tests_number.num FROM tests_number " + + "WHERE tests_number.num = %s)" + ], + ) + self.assertEqual(params, [1, 8, 10]) + + def test_get_combinator_sql_empty_queryset_raises_exception(self): + """ + Tries sql generator with empty queryset. + """ + connection = self._make_one(self.settings_dict) + compiler = SQLCompiler(QuerySet().query, connection, "default") + with self.assertRaises(EmptyResultSet): + compiler.get_combinator_sql("union", False) diff --git a/tests/unit/django_spanner/test_expressions.py b/tests/unit/django_spanner/test_expressions.py new file mode 100644 index 0000000000..bf1bbd59a1 --- /dev/null +++ b/tests/unit/django_spanner/test_expressions.py @@ -0,0 +1,62 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.test import SimpleTestCase +from django_spanner.compiler import SQLCompiler +from django.db.models import F +from .models import Report + + +class TestExpressions(SimpleTestCase): + settings_dict = {"dummy_param": "dummy"} + + def _get_target_class(self): + from django_spanner.base import DatabaseWrapper + + return DatabaseWrapper + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_order_by_sql_query_with_order_by_null_last(self): + connection = self._make_one(self.settings_dict) + + qs1 = Report.objects.values("name").order_by( + F("name").desc(nulls_last=True) + ) + compiler = SQLCompiler(qs1.query, connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name IS NULL, tests_report.name DESC", + ) + + def test_order_by_sql_query_with_order_by_null_first(self): + connection = self._make_one(self.settings_dict) + + qs1 = Report.objects.values("name").order_by( + F("name").desc(nulls_first=True) + ) + compiler = SQLCompiler(qs1.query, connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name IS NOT NULL, tests_report.name DESC", + ) + + def test_order_by_sql_query_with_order_by_name(self): + connection = self._make_one(self.settings_dict) + + qs1 = Report.objects.values("name") + compiler = SQLCompiler(qs1.query, connection, "default") + sql_compiled, _ = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_report.name FROM tests_report ORDER BY " + + "tests_report.name ASC", + ) diff --git a/tests/unit/django_spanner/test_lookups.py b/tests/unit/django_spanner/test_lookups.py new file mode 100644 index 0000000000..5802a8cde3 --- /dev/null +++ b/tests/unit/django_spanner/test_lookups.py @@ -0,0 +1,312 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.test import SimpleTestCase +from django_spanner.compiler import SQLCompiler +from django.db.models import F +from .models import Number, Author + + +class TestLookups(SimpleTestCase): + settings_dict = {"dummy_instance": "instance"} + + def _get_target_class(self): + from django_spanner.base import DatabaseWrapper + + return DatabaseWrapper + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_cast_param_to_float_lte_sql_query(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(decimal_num__lte=1.1).values("decimal_num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.decimal_num FROM tests_number WHERE " + + "tests_number.decimal_num <= %s", + ) + self.assertEqual(params, (1.1,)) + + def test_cast_param_to_float_for_int_field_query(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(num__lte=1.1).values("num") + + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.num <= %s", + ) + self.assertEqual(params, (1,)) + + def test_cast_param_to_float_for_foreign_key_field_query(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(item_id__exact="10").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.item_id = %s", + ) + self.assertEqual(params, (10,)) + + def test_cast_param_to_float_with_no_params_query(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Number.objects.filter(item_id__exact=F("num")).values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_number.num FROM tests_number WHERE " + + "tests_number.item_id = (tests_number.num)", + ) + self.assertEqual(params, ()) + + def test_startswith_endswith_sql_query_with_startswith(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__startswith="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("^abc",)) + + def test_startswith_endswith_sql_query_with_endswith(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__endswith="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc$",)) + + def test_startswith_endswith_sql_query_case_insensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__istartswith="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)^abc",)) + + def test_startswith_endswith_sql_query_with_bileteral_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__startswith="abc").values( + "name" + ) + + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('^', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_startswith_endswith_case_insensitive_transform_sql_query(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__istartswith="abc").values( + "name" + ) + + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('^(?i)', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_startswith_endswith_endswith_sql_query_with_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__endswith="abc").values("name") + + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('', (UPPER(%s)), '$'), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_sensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__regex="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_insensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__iregex="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)abc",)) + + def test_regex_sql_query_case_sensitive_with_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__regex="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "(UPPER(%s)))", + ) + self.assertEqual(params, ("abc",)) + + def test_regex_sql_query_case_insensitive_with_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__iregex="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "CONCAT('(?i)', (UPPER(%s))))", + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_insensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__icontains="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("(?i)abc",)) + + def test_contains_sql_query_case_sensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__contains="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_insensitive_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__icontains="abc").values( + "name" + ) + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + "REPLACE(REPLACE(REPLACE(CONCAT('(?i)', (UPPER(%s))), " + + '"\\\\", "\\\\\\\\"), "%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_contains_sql_query_case_sensitive_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__contains="abc").values("name") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(UPPER(tests_author.name) AS STRING), " + + 'REPLACE(REPLACE(REPLACE((UPPER(%s)), "\\\\", "\\\\\\\\"), ' + + '"%%", r"\\%%"), "_", r"\\_"))', + ) + self.assertEqual(params, ("abc",)) + + def test_iexact_sql_query_case_insensitive(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__iexact="abc").values("num") + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.num FROM tests_author WHERE " + + "REGEXP_CONTAINS(CAST(tests_author.name AS STRING), %s)", + ) + self.assertEqual(params, ("^(?i)abc$",)) + + def test_iexact_sql_query_case_insensitive_function_transform(self): + db_wrapper = self._make_one(self.settings_dict) + + qs1 = Author.objects.filter(name__upper__iexact=F("last_name")).values( + "name" + ) + compiler = SQLCompiler(qs1.query, db_wrapper, "default") + sql_compiled, params = compiler.as_sql() + + self.assertEqual( + sql_compiled, + "SELECT tests_author.name FROM tests_author WHERE " + + "REGEXP_CONTAINS((UPPER(tests_author.last_name)), " + + "CONCAT('^(?i)', CAST(UPPER(tests_author.name) AS STRING), '$'))", + ) + self.assertEqual(params, ()) diff --git a/tests/unit/django_spanner/test_operations.py b/tests/unit/django_spanner/test_operations.py new file mode 100644 index 0000000000..d359bb084b --- /dev/null +++ b/tests/unit/django_spanner/test_operations.py @@ -0,0 +1,305 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.test import SimpleTestCase +from django.db.utils import DatabaseError +from datetime import timedelta +from django_spanner.operations import DatabaseOperations + + +class TestOperations(SimpleTestCase): + def _get_target_class(self): + from django_spanner.base import DatabaseWrapper + + return DatabaseWrapper + + def _make_one(self, *args, **kwargs): + dummy_settings = {"dummy_param": "dummy"} + conn = self._get_target_class()(settings_dict=dummy_settings) + return DatabaseOperations(conn) + + def test_max_name_length(self): + db_op = self._make_one() + self.assertEqual(db_op.max_name_length(), 128) + + def test_quote_name(self): + db_op = self._make_one() + quoted_name = db_op.quote_name("abc") + self.assertEqual(quoted_name, "abc") + + def test_quote_name_spanner_reserved_keyword_escaped(self): + db_op = self._make_one() + quoted_name = db_op.quote_name("ALL") + self.assertEqual(quoted_name, "`ALL`") + + def test_bulk_batch_size(self): + db_op = self._make_one() + self.assertEqual( + db_op.bulk_batch_size(fields=None, objs=None), + db_op.connection.features.max_query_params, + ) + + def test_sql_flush(self): + from django.core.management.color import no_style + + db_op = self._make_one() + self.assertEqual( + db_op.sql_flush(style=no_style(), tables=["Table1, Table2"]), + ["DELETE FROM `Table1, Table2`"], + ) + + def test_sql_flush_empty_table_list(self): + from django.core.management.color import no_style + + db_op = self._make_one() + self.assertEqual( + db_op.sql_flush(style=no_style(), tables=[]), [], + ) + + def test_adapt_datefield_value(self): + from google.cloud.spanner_dbapi.types import DateStr + + db_op = self._make_one() + self.assertIsInstance( + db_op.adapt_datefield_value("dummy_date"), DateStr, + ) + + def test_adapt_datefield_value_none(self): + db_op = self._make_one() + self.assertIsNone(db_op.adapt_datefield_value(value=None),) + + def test_adapt_decimalfield_value(self): + db_op = self._make_one() + self.assertIsInstance( + db_op.adapt_decimalfield_value(value=1), float, + ) + + def test_adapt_decimalfield_value_none(self): + db_op = self._make_one() + self.assertIsNone(db_op.adapt_decimalfield_value(value=None),) + + def test_convert_binaryfield_value(self): + from base64 import b64encode + + db_op = self._make_one() + self.assertEqual( + db_op.convert_binaryfield_value( + value=b64encode(b"abc"), expression=None, connection=None + ), + b"abc", + ) + + def test_convert_binaryfield_value_none(self): + db_op = self._make_one() + self.assertIsNone( + db_op.convert_binaryfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_adapt_datetimefield_value_none(self): + db_op = self._make_one() + self.assertIsNone(db_op.adapt_datetimefield_value(value=None),) + + def test_adapt_timefield_value_none(self): + db_op = self._make_one() + self.assertIsNone(db_op.adapt_timefield_value(value=None),) + + def test_convert_decimalfield_value(self): + from decimal import Decimal + + db_op = self._make_one() + self.assertIsInstance( + db_op.convert_decimalfield_value( + value=1.0, expression=None, connection=None + ), + Decimal, + ) + + def test_convert_decimalfield_value_none(self): + db_op = self._make_one() + self.assertIsNone( + db_op.convert_decimalfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_convert_uuidfield_value(self): + import uuid + + db_op = self._make_one() + uuid_obj = uuid.uuid4() + self.assertEqual( + db_op.convert_uuidfield_value( + str(uuid_obj), expression=None, connection=None + ), + uuid_obj, + ) + + def test_convert_uuidfield_value_none(self): + db_op = self._make_one() + self.assertIsNone( + db_op.convert_uuidfield_value( + value=None, expression=None, connection=None + ), + ) + + def test_date_extract_sql(self): + db_op = self._make_one() + self.assertEqual( + db_op.date_extract_sql("week", "dummy_field"), + "EXTRACT(isoweek FROM dummy_field)", + ) + + def test_date_extract_sql_lookup_type_dayofweek(self): + db_op = self._make_one() + self.assertEqual( + db_op.date_extract_sql("dayofweek", "dummy_field"), + "EXTRACT(dayofweek FROM dummy_field)", + ) + + def test_datetime_extract_sql(self): + from django.conf import settings + + settings.USE_TZ = True + db_op = self._make_one() + self.assertEqual( + db_op.datetime_extract_sql("dayofweek", "dummy_field", "IST"), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "IST")', + ) + + def test_datetime_extract_sql_use_tz_false(self): + from django.conf import settings + + settings.USE_TZ = False + db_op = self._make_one() + self.assertEqual( + db_op.datetime_extract_sql("dayofweek", "dummy_field", "IST"), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', + ) + settings.USE_TZ = True # reset changes. + + def test_time_extract_sql(self): + db_op = self._make_one() + self.assertEqual( + db_op.time_extract_sql("dayofweek", "dummy_field"), + 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', + ) + + def test_time_trunc_sql(self): + db_op = self._make_one() + self.assertEqual( + db_op.time_trunc_sql("dayofweek", "dummy_field"), + 'TIMESTAMP_TRUNC(dummy_field, dayofweek, "UTC")', + ) + + def test_datetime_cast_date_sql(self): + db_op = self._make_one() + self.assertEqual( + db_op.datetime_cast_date_sql("dummy_field", "IST"), + 'DATE(dummy_field, "IST")', + ) + + def test_datetime_cast_time_sql(self): + from django.conf import settings + + settings.USE_TZ = True + db_op = self._make_one() + self.assertEqual( + db_op.datetime_cast_time_sql("dummy_field", "IST"), + "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'IST'))", + ) + + def test_datetime_cast_time_sql_use_tz_false(self): + from django.conf import settings + + settings.USE_TZ = False + db_op = self._make_one() + self.assertEqual( + db_op.datetime_cast_time_sql("dummy_field", "IST"), + "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'UTC'))", + ) + settings.USE_TZ = True # reset changes. + + def test_date_interval_sql(self): + db_op = self._make_one() + self.assertEqual( + db_op.date_interval_sql(timedelta(days=1)), + "INTERVAL 86400000000 MICROSECOND", + ) + + def test_format_for_duration_arithmetic(self): + db_op = self._make_one() + self.assertEqual( + db_op.format_for_duration_arithmetic(1200), + "INTERVAL 1200 MICROSECOND", + ) + + def test_combine_expression_mod(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_expression("%%", ["10", "2"]), "MOD(10, 2)", + ) + + def test_combine_expression_power(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_expression("^", ["10", "2"]), "POWER(10, 2)", + ) + + def test_combine_expression_bit_extention(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_expression(">>", ["10", "2"]), + "CAST(FLOOR(10 / POW(2, 2)) AS INT64)", + ) + + def test_combine_expression_multiply(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_expression("*", ["10", "2"]), "10 * 2", + ) + + def test_combine_duration_expression_add(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_duration_expression( + "+", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ), + 'TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00, INTERVAL 10 MINUTE)', + ) + + def test_combine_duration_expression_subtract(self): + db_op = self._make_one() + self.assertEqual( + db_op.combine_duration_expression( + "-", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ), + 'TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00, INTERVAL 10 MINUTE)', + ) + + def test_combine_duration_expression_database_error(self): + db_op = self._make_one() + msg = "Invalid connector for timedelta:" + with self.assertRaisesMessage(DatabaseError, msg): + db_op.combine_duration_expression( + "*", + ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], + ) + + def test_lookup_cast_match_lookup_type(self): + db_op = self._make_one() + self.assertEqual( + db_op.lookup_cast("contains",), "CAST(%s AS STRING)", + ) + + def test_lookup_cast_unmatched_lookup_type(self): + db_op = self._make_one() + self.assertEqual( + db_op.lookup_cast("dummy",), "%s", + ) diff --git a/tests/unit/django_spanner/test_utils.py b/tests/unit/django_spanner/test_utils.py new file mode 100644 index 0000000000..0ebb0a4212 --- /dev/null +++ b/tests/unit/django_spanner/test_utils.py @@ -0,0 +1,50 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest +from django_spanner.utils import check_django_compatability +from django.core.exceptions import ImproperlyConfigured +from django_spanner.utils import add_dummy_where +import django +import django_spanner + + +class TestUtils(unittest.TestCase): + SQL_WITH_WHERE = "Select 1 from Table WHERE 1=1" + SQL_WITHOUT_WHERE = "Select 1 from Table" + + def test_check_django_compatability_match(self): + """ + Checks django compatibility match. + """ + django_spanner.__version__ = "2.2" + django.VERSION = (2, 2, 19, "alpha", 0) + check_django_compatability() + + def test_check_django_compatability_mismatch(self): + """ + Checks django compatibility mismatch. + """ + django_spanner.__version__ = "2.2" + django.VERSION = (3, 2, 19, "alpha", 0) + with self.assertRaises(ImproperlyConfigured): + check_django_compatability() + + def test_add_dummy_where_with_where_present_and_not_added(self): + """ + Checks if dummy where clause is not added when present in select + statement. + """ + updated_sql = add_dummy_where(self.SQL_WITH_WHERE) + self.assertEqual(updated_sql, self.SQL_WITH_WHERE) + + def test_add_dummy_where_with_where_not_present_and_added(self): + """ + Checks if dummy where clause is added when not present in select + statement. + """ + updated_sql = add_dummy_where(self.SQL_WITHOUT_WHERE) + self.assertEqual(updated_sql, self.SQL_WITH_WHERE) diff --git a/tests/unit/django_spanner/test_validation.py b/tests/unit/django_spanner/test_validation.py new file mode 100644 index 0000000000..88f62c3c54 --- /dev/null +++ b/tests/unit/django_spanner/test_validation.py @@ -0,0 +1,41 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django.test import SimpleTestCase +from django_spanner.validation import DatabaseValidation +from django.db import connection +from django.core.checks import Error as DjangoError +from .models import ModelDecimalField, ModelCharField + + +class TestValidation(SimpleTestCase): + def test_check_field_type_with_decimal_field_not_support_error(self): + """ + Checks if decimal field fails database validation as it's not + supported in spanner. + """ + field = ModelDecimalField._meta.get_field("field") + validator = DatabaseValidation(connection=connection) + self.assertEqual( + validator.check_field(field), + [ + DjangoError( + "DecimalField is not yet supported by Spanner.", + obj=field, + id="spanner.E001", + ) + ], + ) + + def test_check_field_type_with_char_field_no_error(self): + """ + Checks if string field passes database validation. + """ + field = ModelCharField._meta.get_field("field") + validator = DatabaseValidation(connection=connection) + self.assertEqual( + validator.check_field(field), [], + ) From c418e161ef480d4f5abd74bc6a6145874226e308 Mon Sep 17 00:00:00 2001 From: Vikash Singh <3116482+vi3k6i5@users.noreply.github.com> Date: Thu, 6 May 2021 21:32:03 +0530 Subject: [PATCH 3/3] test: changed test settings as per review comments --- noxfile.py | 2 +- tests/conftest.py | 13 ++ tests/unit/django_spanner/models.py | 6 - tests/unit/django_spanner/simple_test.py | 33 ++++ tests/unit/django_spanner/test_base.py | 87 +++------- tests/unit/django_spanner/test_client.py | 32 +--- tests/unit/django_spanner/test_compiler.py | 44 ++---- tests/unit/django_spanner/test_expressions.py | 26 +-- tests/unit/django_spanner/test_lookups.py | 74 +++------ tests/unit/django_spanner/test_operations.py | 149 +++++++----------- tests/unit/django_spanner/test_utils.py | 4 +- tests/unit/django_spanner/test_validation.py | 4 +- 12 files changed, 178 insertions(+), 296 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/unit/django_spanner/simple_test.py diff --git a/noxfile.py b/noxfile.py index 36b68cfc81..89f7b2904e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -84,7 +84,7 @@ def default(session): "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=70", + "--cov-fail-under=68", os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..308a903387 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import os +import django +from django.conf import settings + +# We manually designate which settings we will be using in an environment +# variable. This is similar to what occurs in the `manage.py` file. +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings") + + +# `pytest` automatically calls this function once when tests are run. +def pytest_configure(): + settings.DEBUG = False + django.setup() diff --git a/tests/unit/django_spanner/models.py b/tests/unit/django_spanner/models.py index d856051937..8dfb9d8e48 100644 --- a/tests/unit/django_spanner/models.py +++ b/tests/unit/django_spanner/models.py @@ -6,13 +6,7 @@ """ Different models used for testing django-spanner code. """ -import os from django.db import models -import django - -# Load django settings before loading django models. -os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" -django.setup() # Register transformations for model fields. diff --git a/tests/unit/django_spanner/simple_test.py b/tests/unit/django_spanner/simple_test.py new file mode 100644 index 0000000000..1fcb92bd29 --- /dev/null +++ b/tests/unit/django_spanner/simple_test.py @@ -0,0 +1,33 @@ +# Copyright 2021 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from django_spanner.client import DatabaseClient +from django_spanner.base import DatabaseWrapper +from django_spanner.operations import DatabaseOperations +from unittest import TestCase +import os + + +class SpannerSimpleTestClass(TestCase): + @classmethod + def setUpClass(cls): + cls.PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] + + cls.INSTANCE_ID = "instance_id" + cls.DATABASE_ID = "database_id" + cls.USER_AGENT = "django_spanner/2.2.0a1" + cls.OPTIONS = {"option": "dummy"} + + cls.settings_dict = { + "PROJECT": cls.PROJECT, + "INSTANCE": cls.INSTANCE_ID, + "NAME": cls.DATABASE_ID, + "user_agent": cls.USER_AGENT, + "OPTIONS": cls.OPTIONS, + } + cls.db_client = DatabaseClient(cls.settings_dict) + cls.db_wrapper = cls.connection = DatabaseWrapper(cls.settings_dict) + cls.db_operations = DatabaseOperations(cls.connection) diff --git a/tests/unit/django_spanner/test_base.py b/tests/unit/django_spanner/test_base.py index 9df06afd44..9b2d60c1c4 100644 --- a/tests/unit/django_spanner/test_base.py +++ b/tests/unit/django_spanner/test_base.py @@ -4,55 +4,24 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import unittest -import os - -from mock_import import mock_import from unittest import mock +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -@mock_import() -class TestBase(unittest.TestCase): - PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] - INSTANCE_ID = "instance_id" - DATABASE_ID = "database_id" - USER_AGENT = "django_spanner/2.2.0a1" - OPTIONS = {"option": "dummy"} - - settings_dict = { - "PROJECT": PROJECT, - "INSTANCE": INSTANCE_ID, - "NAME": DATABASE_ID, - "user_agent": USER_AGENT, - "OPTIONS": OPTIONS, - } - - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - +class TestBase(SpannerSimpleTestClass): def test_property_instance(self): - settings_dict = {"INSTANCE": "instance"} - db_wrapper = self._make_one(settings_dict=settings_dict) - with mock.patch("django_spanner.base.spanner") as mock_spanner: mock_spanner.Client = mock_client = mock.MagicMock() mock_client().instance = mock_instance = mock.MagicMock() - _ = db_wrapper.instance - mock_instance.assert_called_once_with(settings_dict["INSTANCE"]) + _ = self.db_wrapper.instance + mock_instance.assert_called_once_with(self.INSTANCE_ID) def test_property_nodb_connection(self): - db_wrapper = self._make_one(None) with self.assertRaises(NotImplementedError): - db_wrapper._nodb_connection() + self.db_wrapper._nodb_connection() def test_get_connection_params(self): - db_wrapper = self._make_one(self.settings_dict) - params = db_wrapper.get_connection_params() + params = self.db_wrapper.get_connection_params() self.assertEqual(params["project"], self.PROJECT) self.assertEqual(params["instance_id"], self.INSTANCE_ID) @@ -61,54 +30,50 @@ def test_get_connection_params(self): self.assertEqual(params["option"], self.OPTIONS["option"]) def test_get_new_connection(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.Database = mock_database = mock.MagicMock() + self.db_wrapper.Database = mock_database = mock.MagicMock() mock_database.connect = mock_connection = mock.MagicMock() conn_params = {"test_param": "dummy"} - db_wrapper.get_new_connection(conn_params) + self.db_wrapper.get_new_connection(conn_params) mock_connection.assert_called_once_with(**conn_params) def test_init_connection_state(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.close = mock_close = mock.MagicMock() - db_wrapper.init_connection_state() + self.db_wrapper.init_connection_state() mock_close.assert_called_once_with() def test_create_cursor(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock_cursor = mock.MagicMock() - db_wrapper.create_cursor() + self.db_wrapper.create_cursor() mock_cursor.assert_called_once_with() def test_set_autocommit(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.autocommit = False - db_wrapper._set_autocommit(True) + self.db_wrapper._set_autocommit(True) self.assertEqual(mock_connection.autocommit, True) def test_is_usable(self): - from google.cloud.spanner_dbapi.exceptions import Error - - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = None - self.assertFalse(db_wrapper.is_usable()) + self.db_wrapper.connection = None + self.assertFalse(self.db_wrapper.is_usable()) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.is_closed = True - self.assertFalse(db_wrapper.is_usable()) + self.assertFalse(self.db_wrapper.is_usable()) mock_connection.is_closed = False - self.assertTrue(db_wrapper.is_usable()) + self.assertTrue(self.db_wrapper.is_usable()) + + def test_is_usable_with_error(self): + from google.cloud.spanner_dbapi.exceptions import Error + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock.MagicMock(side_effect=Error) - self.assertFalse(db_wrapper.is_usable()) + self.assertFalse(self.db_wrapper.is_usable()) def test_start_transaction_under_autocommit(self): - db_wrapper = self._make_one(self.settings_dict) - db_wrapper.connection = mock_connection = mock.MagicMock() + self.db_wrapper.connection = mock_connection = mock.MagicMock() mock_connection.cursor = mock_cursor = mock.MagicMock() - db_wrapper._start_transaction_under_autocommit() + self.db_wrapper._start_transaction_under_autocommit() mock_cursor.assert_called_once_with() diff --git a/tests/unit/django_spanner/test_client.py b/tests/unit/django_spanner/test_client.py index c892bd2993..10b38fb2f9 100644 --- a/tests/unit/django_spanner/test_client.py +++ b/tests/unit/django_spanner/test_client.py @@ -4,36 +4,12 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import unittest -import os -from google.cloud.spanner_dbapi.exceptions import NotSupportedError - - -class TestClient(unittest.TestCase): - PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] - INSTANCE_ID = "instance_id" - DATABASE_ID = "database_id" - USER_AGENT = "django_spanner/2.2.0a1" - OPTIONS = {"option": "dummy"} - - settings_dict = { - "PROJECT": PROJECT, - "INSTANCE": INSTANCE_ID, - "NAME": DATABASE_ID, - "user_agent": USER_AGENT, - "OPTIONS": OPTIONS, - } - def _get_target_class(self): - from django_spanner.client import DatabaseClient - - return DatabaseClient +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) +class TestClient(SpannerSimpleTestClass): def test_runshell(self): - db_wrapper = self._make_one(self.settings_dict) - with self.assertRaises(NotSupportedError): - db_wrapper.runshell(parameters=self.settings_dict) + self.db_client.runshell(parameters=self.settings_dict) diff --git a/tests/unit/django_spanner/test_compiler.py b/tests/unit/django_spanner/test_compiler.py index 7b18124976..11fd1222a0 100644 --- a/tests/unit/django_spanner/test_compiler.py +++ b/tests/unit/django_spanner/test_compiler.py @@ -4,25 +4,15 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.test import SimpleTestCase from django.core.exceptions import EmptyResultSet from django.db.utils import DatabaseError from django_spanner.compiler import SQLCompiler from django.db.models.query import QuerySet +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass from .models import Number -class TestCompiler(SimpleTestCase): - settings_dict = {"dummy_param": "dummy"} - - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - +class TestCompiler(SpannerSimpleTestClass): def test_unsupported_ordering_slicing_raises_db_error(self): """ Tries limit/offset and order by in subqueries which are not supported @@ -31,23 +21,22 @@ def test_unsupported_ordering_slicing_raises_db_error(self): qs1 = Number.objects.all() qs2 = Number.objects.all() msg = "LIMIT/OFFSET not allowed in subqueries of compound statements" - with self.assertRaisesMessage(DatabaseError, msg): + with self.assertRaisesRegex(DatabaseError, msg): list(qs1.union(qs2[:10])) msg = "ORDER BY not allowed in subqueries of compound statements" - with self.assertRaisesMessage(DatabaseError, msg): + with self.assertRaisesRegex(DatabaseError, msg): list(qs1.order_by("id").union(qs2)) def test_get_combinator_sql_all_union_sql_generated(self): """ Tries union sql generator. """ - connection = self._make_one(self.settings_dict) qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs4 = qs1.union(qs2) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") sql_compiled, params = compiler.get_combinator_sql("union", True) self.assertEqual( sql_compiled, @@ -63,13 +52,12 @@ def test_get_combinator_sql_distinct_union_sql_generated(self): """ Tries union sql generator with distinct. """ - connection = self._make_one(self.settings_dict) qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs4 = qs1.union(qs2) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") sql_compiled, params = compiler.get_combinator_sql("union", False) self.assertEqual( sql_compiled, @@ -86,13 +74,11 @@ def test_get_combinator_sql_difference_all_sql_generated(self): """ Tries difference sql generator. """ - connection = self._make_one(self.settings_dict) - qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs4 = qs1.difference(qs2) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") sql_compiled, params = compiler.get_combinator_sql("difference", True) self.assertEqual( @@ -109,13 +95,11 @@ def test_get_combinator_sql_difference_distinct_sql_generated(self): """ Tries difference sql generator with distinct. """ - connection = self._make_one(self.settings_dict) - qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs4 = qs1.difference(qs2) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") sql_compiled, params = compiler.get_combinator_sql("difference", False) self.assertEqual( @@ -133,20 +117,18 @@ def test_get_combinator_sql_union_and_difference_query_together(self): """ Tries sql generator with union of queryset with queryset of difference. """ - connection = self._make_one(self.settings_dict) - qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs3 = Number.objects.filter(num__exact=10).values("num") qs4 = qs1.union(qs2.difference(qs3)) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") sql_compiled, params = compiler.get_combinator_sql("union", False) self.assertEqual( sql_compiled, [ "SELECT tests_number.num FROM tests_number WHERE " - + "tests_number.num <= %s UNION DISTINCT (" + + "tests_number.num <= %s UNION DISTINCT SELECT * FROM (" + "SELECT tests_number.num FROM tests_number WHERE " + "tests_number.num >= %s EXCEPT DISTINCT " + "SELECT tests_number.num FROM tests_number " @@ -160,14 +142,13 @@ def test_get_combinator_sql_parentheses_in_compound_not_supported(self): Tries sql generator with union of queryset with queryset of difference, adding support for parentheses in compound sql statement. """ - connection = self._make_one(self.settings_dict) qs1 = Number.objects.filter(num__lte=1).values("num") qs2 = Number.objects.filter(num__gte=8).values("num") qs3 = Number.objects.filter(num__exact=10).values("num") qs4 = qs1.union(qs2.difference(qs3)) - compiler = SQLCompiler(qs4.query, connection, "default") + compiler = SQLCompiler(qs4.query, self.connection, "default") compiler.connection.features.supports_parentheses_in_compound = False sql_compiled, params = compiler.get_combinator_sql("union", False) self.assertEqual( @@ -187,7 +168,6 @@ def test_get_combinator_sql_empty_queryset_raises_exception(self): """ Tries sql generator with empty queryset. """ - connection = self._make_one(self.settings_dict) - compiler = SQLCompiler(QuerySet().query, connection, "default") + compiler = SQLCompiler(QuerySet().query, self.connection, "default") with self.assertRaises(EmptyResultSet): compiler.get_combinator_sql("union", False) diff --git a/tests/unit/django_spanner/test_expressions.py b/tests/unit/django_spanner/test_expressions.py index bf1bbd59a1..0efc99ce08 100644 --- a/tests/unit/django_spanner/test_expressions.py +++ b/tests/unit/django_spanner/test_expressions.py @@ -4,30 +4,18 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.test import SimpleTestCase from django_spanner.compiler import SQLCompiler from django.db.models import F +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass from .models import Report -class TestExpressions(SimpleTestCase): - settings_dict = {"dummy_param": "dummy"} - - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - +class TestExpressions(SpannerSimpleTestClass): def test_order_by_sql_query_with_order_by_null_last(self): - connection = self._make_one(self.settings_dict) - qs1 = Report.objects.values("name").order_by( F("name").desc(nulls_last=True) ) - compiler = SQLCompiler(qs1.query, connection, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, _ = compiler.as_sql() self.assertEqual( sql_compiled, @@ -36,12 +24,10 @@ def test_order_by_sql_query_with_order_by_null_last(self): ) def test_order_by_sql_query_with_order_by_null_first(self): - connection = self._make_one(self.settings_dict) - qs1 = Report.objects.values("name").order_by( F("name").desc(nulls_first=True) ) - compiler = SQLCompiler(qs1.query, connection, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, _ = compiler.as_sql() self.assertEqual( sql_compiled, @@ -50,10 +36,8 @@ def test_order_by_sql_query_with_order_by_null_first(self): ) def test_order_by_sql_query_with_order_by_name(self): - connection = self._make_one(self.settings_dict) - qs1 = Report.objects.values("name") - compiler = SQLCompiler(qs1.query, connection, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, _ = compiler.as_sql() self.assertEqual( sql_compiled, diff --git a/tests/unit/django_spanner/test_lookups.py b/tests/unit/django_spanner/test_lookups.py index 5802a8cde3..90b17e5515 100644 --- a/tests/unit/django_spanner/test_lookups.py +++ b/tests/unit/django_spanner/test_lookups.py @@ -4,28 +4,17 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.test import SimpleTestCase from django_spanner.compiler import SQLCompiler from django.db.models import F +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass from .models import Number, Author -class TestLookups(SimpleTestCase): - settings_dict = {"dummy_instance": "instance"} - - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - +class TestLookups(SpannerSimpleTestClass): def test_cast_param_to_float_lte_sql_query(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Number.objects.filter(decimal_num__lte=1.1).values("decimal_num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -35,11 +24,10 @@ def test_cast_param_to_float_lte_sql_query(self): self.assertEqual(params, (1.1,)) def test_cast_param_to_float_for_int_field_query(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Number.objects.filter(num__lte=1.1).values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -49,10 +37,9 @@ def test_cast_param_to_float_for_int_field_query(self): self.assertEqual(params, (1,)) def test_cast_param_to_float_for_foreign_key_field_query(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Number.objects.filter(item_id__exact="10").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -62,10 +49,9 @@ def test_cast_param_to_float_for_foreign_key_field_query(self): self.assertEqual(params, (10,)) def test_cast_param_to_float_with_no_params_query(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Number.objects.filter(item_id__exact=F("num")).values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -75,10 +61,9 @@ def test_cast_param_to_float_with_no_params_query(self): self.assertEqual(params, ()) def test_startswith_endswith_sql_query_with_startswith(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__startswith="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -88,10 +73,9 @@ def test_startswith_endswith_sql_query_with_startswith(self): self.assertEqual(params, ("^abc",)) def test_startswith_endswith_sql_query_with_endswith(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__endswith="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -101,10 +85,9 @@ def test_startswith_endswith_sql_query_with_endswith(self): self.assertEqual(params, ("abc$",)) def test_startswith_endswith_sql_query_case_insensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__istartswith="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -114,13 +97,12 @@ def test_startswith_endswith_sql_query_case_insensitive(self): self.assertEqual(params, ("(?i)^abc",)) def test_startswith_endswith_sql_query_with_bileteral_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__startswith="abc").values( "name" ) - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -132,13 +114,12 @@ def test_startswith_endswith_sql_query_with_bileteral_transform(self): self.assertEqual(params, ("abc",)) def test_startswith_endswith_case_insensitive_transform_sql_query(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__istartswith="abc").values( "name" ) - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -150,11 +131,10 @@ def test_startswith_endswith_case_insensitive_transform_sql_query(self): self.assertEqual(params, ("abc",)) def test_startswith_endswith_endswith_sql_query_with_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__endswith="abc").values("name") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( @@ -167,10 +147,9 @@ def test_startswith_endswith_endswith_sql_query_with_transform(self): self.assertEqual(params, ("abc",)) def test_regex_sql_query_case_sensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__regex="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -180,10 +159,9 @@ def test_regex_sql_query_case_sensitive(self): self.assertEqual(params, ("abc",)) def test_regex_sql_query_case_insensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__iregex="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -193,10 +171,9 @@ def test_regex_sql_query_case_insensitive(self): self.assertEqual(params, ("(?i)abc",)) def test_regex_sql_query_case_sensitive_with_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__regex="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( @@ -208,10 +185,9 @@ def test_regex_sql_query_case_sensitive_with_transform(self): self.assertEqual(params, ("abc",)) def test_regex_sql_query_case_insensitive_with_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__iregex="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( @@ -223,10 +199,9 @@ def test_regex_sql_query_case_insensitive_with_transform(self): self.assertEqual(params, ("abc",)) def test_contains_sql_query_case_insensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__icontains="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -236,10 +211,9 @@ def test_contains_sql_query_case_insensitive(self): self.assertEqual(params, ("(?i)abc",)) def test_contains_sql_query_case_sensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__contains="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -249,12 +223,11 @@ def test_contains_sql_query_case_sensitive(self): self.assertEqual(params, ("abc",)) def test_contains_sql_query_case_insensitive_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__icontains="abc").values( "name" ) - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -266,10 +239,9 @@ def test_contains_sql_query_case_insensitive_transform(self): self.assertEqual(params, ("abc",)) def test_contains_sql_query_case_sensitive_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__contains="abc").values("name") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( sql_compiled, @@ -281,10 +253,9 @@ def test_contains_sql_query_case_sensitive_transform(self): self.assertEqual(params, ("abc",)) def test_iexact_sql_query_case_insensitive(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__iexact="abc").values("num") - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( @@ -295,12 +266,11 @@ def test_iexact_sql_query_case_insensitive(self): self.assertEqual(params, ("^(?i)abc$",)) def test_iexact_sql_query_case_insensitive_function_transform(self): - db_wrapper = self._make_one(self.settings_dict) qs1 = Author.objects.filter(name__upper__iexact=F("last_name")).values( "name" ) - compiler = SQLCompiler(qs1.query, db_wrapper, "default") + compiler = SQLCompiler(qs1.query, self.connection, "default") sql_compiled, params = compiler.as_sql() self.assertEqual( diff --git a/tests/unit/django_spanner/test_operations.py b/tests/unit/django_spanner/test_operations.py index d359bb084b..ae6384233a 100644 --- a/tests/unit/django_spanner/test_operations.py +++ b/tests/unit/django_spanner/test_operations.py @@ -4,125 +4,108 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.test import SimpleTestCase from django.db.utils import DatabaseError from datetime import timedelta -from django_spanner.operations import DatabaseOperations +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -class TestOperations(SimpleTestCase): - def _get_target_class(self): - from django_spanner.base import DatabaseWrapper - - return DatabaseWrapper - - def _make_one(self, *args, **kwargs): - dummy_settings = {"dummy_param": "dummy"} - conn = self._get_target_class()(settings_dict=dummy_settings) - return DatabaseOperations(conn) - +class TestOperations(SpannerSimpleTestClass): def test_max_name_length(self): - db_op = self._make_one() - self.assertEqual(db_op.max_name_length(), 128) + self.assertEqual(self.db_operations.max_name_length(), 128) def test_quote_name(self): - db_op = self._make_one() - quoted_name = db_op.quote_name("abc") + quoted_name = self.db_operations.quote_name("abc") self.assertEqual(quoted_name, "abc") def test_quote_name_spanner_reserved_keyword_escaped(self): - db_op = self._make_one() - quoted_name = db_op.quote_name("ALL") + quoted_name = self.db_operations.quote_name("ALL") self.assertEqual(quoted_name, "`ALL`") def test_bulk_batch_size(self): - db_op = self._make_one() self.assertEqual( - db_op.bulk_batch_size(fields=None, objs=None), - db_op.connection.features.max_query_params, + self.db_operations.bulk_batch_size(fields=None, objs=None), + self.db_operations.connection.features.max_query_params, ) def test_sql_flush(self): from django.core.management.color import no_style - db_op = self._make_one() self.assertEqual( - db_op.sql_flush(style=no_style(), tables=["Table1, Table2"]), + self.db_operations.sql_flush( + style=no_style(), tables=["Table1, Table2"] + ), ["DELETE FROM `Table1, Table2`"], ) def test_sql_flush_empty_table_list(self): from django.core.management.color import no_style - db_op = self._make_one() self.assertEqual( - db_op.sql_flush(style=no_style(), tables=[]), [], + self.db_operations.sql_flush(style=no_style(), tables=[]), [], ) def test_adapt_datefield_value(self): from google.cloud.spanner_dbapi.types import DateStr - db_op = self._make_one() self.assertIsInstance( - db_op.adapt_datefield_value("dummy_date"), DateStr, + self.db_operations.adapt_datefield_value("dummy_date"), DateStr, ) def test_adapt_datefield_value_none(self): - db_op = self._make_one() - self.assertIsNone(db_op.adapt_datefield_value(value=None),) + self.assertIsNone( + self.db_operations.adapt_datefield_value(value=None), + ) def test_adapt_decimalfield_value(self): - db_op = self._make_one() self.assertIsInstance( - db_op.adapt_decimalfield_value(value=1), float, + self.db_operations.adapt_decimalfield_value(value=1), float, ) def test_adapt_decimalfield_value_none(self): - db_op = self._make_one() - self.assertIsNone(db_op.adapt_decimalfield_value(value=None),) + self.assertIsNone( + self.db_operations.adapt_decimalfield_value(value=None), + ) def test_convert_binaryfield_value(self): from base64 import b64encode - db_op = self._make_one() self.assertEqual( - db_op.convert_binaryfield_value( + self.db_operations.convert_binaryfield_value( value=b64encode(b"abc"), expression=None, connection=None ), b"abc", ) def test_convert_binaryfield_value_none(self): - db_op = self._make_one() self.assertIsNone( - db_op.convert_binaryfield_value( + self.db_operations.convert_binaryfield_value( value=None, expression=None, connection=None ), ) def test_adapt_datetimefield_value_none(self): - db_op = self._make_one() - self.assertIsNone(db_op.adapt_datetimefield_value(value=None),) + self.assertIsNone( + self.db_operations.adapt_datetimefield_value(value=None), + ) def test_adapt_timefield_value_none(self): - db_op = self._make_one() - self.assertIsNone(db_op.adapt_timefield_value(value=None),) + self.assertIsNone( + self.db_operations.adapt_timefield_value(value=None), + ) def test_convert_decimalfield_value(self): from decimal import Decimal - db_op = self._make_one() self.assertIsInstance( - db_op.convert_decimalfield_value( + self.db_operations.convert_decimalfield_value( value=1.0, expression=None, connection=None ), Decimal, ) def test_convert_decimalfield_value_none(self): - db_op = self._make_one() self.assertIsNone( - db_op.convert_decimalfield_value( + self.db_operations.convert_decimalfield_value( value=None, expression=None, connection=None ), ) @@ -130,34 +113,30 @@ def test_convert_decimalfield_value_none(self): def test_convert_uuidfield_value(self): import uuid - db_op = self._make_one() uuid_obj = uuid.uuid4() self.assertEqual( - db_op.convert_uuidfield_value( + self.db_operations.convert_uuidfield_value( str(uuid_obj), expression=None, connection=None ), uuid_obj, ) def test_convert_uuidfield_value_none(self): - db_op = self._make_one() self.assertIsNone( - db_op.convert_uuidfield_value( + self.db_operations.convert_uuidfield_value( value=None, expression=None, connection=None ), ) def test_date_extract_sql(self): - db_op = self._make_one() self.assertEqual( - db_op.date_extract_sql("week", "dummy_field"), + self.db_operations.date_extract_sql("week", "dummy_field"), "EXTRACT(isoweek FROM dummy_field)", ) def test_date_extract_sql_lookup_type_dayofweek(self): - db_op = self._make_one() self.assertEqual( - db_op.date_extract_sql("dayofweek", "dummy_field"), + self.db_operations.date_extract_sql("dayofweek", "dummy_field"), "EXTRACT(dayofweek FROM dummy_field)", ) @@ -165,9 +144,10 @@ def test_datetime_extract_sql(self): from django.conf import settings settings.USE_TZ = True - db_op = self._make_one() self.assertEqual( - db_op.datetime_extract_sql("dayofweek", "dummy_field", "IST"), + self.db_operations.datetime_extract_sql( + "dayofweek", "dummy_field", "IST" + ), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "IST")', ) @@ -175,31 +155,29 @@ def test_datetime_extract_sql_use_tz_false(self): from django.conf import settings settings.USE_TZ = False - db_op = self._make_one() self.assertEqual( - db_op.datetime_extract_sql("dayofweek", "dummy_field", "IST"), + self.db_operations.datetime_extract_sql( + "dayofweek", "dummy_field", "IST" + ), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', ) settings.USE_TZ = True # reset changes. def test_time_extract_sql(self): - db_op = self._make_one() self.assertEqual( - db_op.time_extract_sql("dayofweek", "dummy_field"), + self.db_operations.time_extract_sql("dayofweek", "dummy_field"), 'EXTRACT(dayofweek FROM dummy_field AT TIME ZONE "UTC")', ) def test_time_trunc_sql(self): - db_op = self._make_one() self.assertEqual( - db_op.time_trunc_sql("dayofweek", "dummy_field"), + self.db_operations.time_trunc_sql("dayofweek", "dummy_field"), 'TIMESTAMP_TRUNC(dummy_field, dayofweek, "UTC")', ) def test_datetime_cast_date_sql(self): - db_op = self._make_one() self.assertEqual( - db_op.datetime_cast_date_sql("dummy_field", "IST"), + self.db_operations.datetime_cast_date_sql("dummy_field", "IST"), 'DATE(dummy_field, "IST")', ) @@ -207,9 +185,8 @@ def test_datetime_cast_time_sql(self): from django.conf import settings settings.USE_TZ = True - db_op = self._make_one() self.assertEqual( - db_op.datetime_cast_time_sql("dummy_field", "IST"), + self.db_operations.datetime_cast_time_sql("dummy_field", "IST"), "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'IST'))", ) @@ -217,56 +194,50 @@ def test_datetime_cast_time_sql_use_tz_false(self): from django.conf import settings settings.USE_TZ = False - db_op = self._make_one() self.assertEqual( - db_op.datetime_cast_time_sql("dummy_field", "IST"), + self.db_operations.datetime_cast_time_sql("dummy_field", "IST"), "TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', dummy_field, 'UTC'))", ) settings.USE_TZ = True # reset changes. def test_date_interval_sql(self): - db_op = self._make_one() self.assertEqual( - db_op.date_interval_sql(timedelta(days=1)), + self.db_operations.date_interval_sql(timedelta(days=1)), "INTERVAL 86400000000 MICROSECOND", ) def test_format_for_duration_arithmetic(self): - db_op = self._make_one() self.assertEqual( - db_op.format_for_duration_arithmetic(1200), + self.db_operations.format_for_duration_arithmetic(1200), "INTERVAL 1200 MICROSECOND", ) def test_combine_expression_mod(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_expression("%%", ["10", "2"]), "MOD(10, 2)", + self.db_operations.combine_expression("%%", ["10", "2"]), + "MOD(10, 2)", ) def test_combine_expression_power(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_expression("^", ["10", "2"]), "POWER(10, 2)", + self.db_operations.combine_expression("^", ["10", "2"]), + "POWER(10, 2)", ) def test_combine_expression_bit_extention(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_expression(">>", ["10", "2"]), + self.db_operations.combine_expression(">>", ["10", "2"]), "CAST(FLOOR(10 / POW(2, 2)) AS INT64)", ) def test_combine_expression_multiply(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_expression("*", ["10", "2"]), "10 * 2", + self.db_operations.combine_expression("*", ["10", "2"]), "10 * 2", ) def test_combine_duration_expression_add(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_duration_expression( + self.db_operations.combine_duration_expression( "+", ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], ), @@ -274,9 +245,8 @@ def test_combine_duration_expression_add(self): ) def test_combine_duration_expression_subtract(self): - db_op = self._make_one() self.assertEqual( - db_op.combine_duration_expression( + self.db_operations.combine_duration_expression( "-", ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], ), @@ -284,22 +254,19 @@ def test_combine_duration_expression_subtract(self): ) def test_combine_duration_expression_database_error(self): - db_op = self._make_one() msg = "Invalid connector for timedelta:" - with self.assertRaisesMessage(DatabaseError, msg): - db_op.combine_duration_expression( + with self.assertRaisesRegex(DatabaseError, msg): + self.db_operations.combine_duration_expression( "*", ['TIMESTAMP "2008-12-25 15:30:00+00', "INTERVAL 10 MINUTE"], ) def test_lookup_cast_match_lookup_type(self): - db_op = self._make_one() self.assertEqual( - db_op.lookup_cast("contains",), "CAST(%s AS STRING)", + self.db_operations.lookup_cast("contains",), "CAST(%s AS STRING)", ) def test_lookup_cast_unmatched_lookup_type(self): - db_op = self._make_one() self.assertEqual( - db_op.lookup_cast("dummy",), "%s", + self.db_operations.lookup_cast("dummy",), "%s", ) diff --git a/tests/unit/django_spanner/test_utils.py b/tests/unit/django_spanner/test_utils.py index 0ebb0a4212..e4d50861d0 100644 --- a/tests/unit/django_spanner/test_utils.py +++ b/tests/unit/django_spanner/test_utils.py @@ -4,15 +4,15 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import unittest from django_spanner.utils import check_django_compatability from django.core.exceptions import ImproperlyConfigured from django_spanner.utils import add_dummy_where import django import django_spanner +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass -class TestUtils(unittest.TestCase): +class TestUtils(SpannerSimpleTestClass): SQL_WITH_WHERE = "Select 1 from Table WHERE 1=1" SQL_WITHOUT_WHERE = "Select 1 from Table" diff --git a/tests/unit/django_spanner/test_validation.py b/tests/unit/django_spanner/test_validation.py index 88f62c3c54..5a8946aef1 100644 --- a/tests/unit/django_spanner/test_validation.py +++ b/tests/unit/django_spanner/test_validation.py @@ -4,14 +4,14 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from django.test import SimpleTestCase from django_spanner.validation import DatabaseValidation from django.db import connection from django.core.checks import Error as DjangoError +from tests.unit.django_spanner.simple_test import SpannerSimpleTestClass from .models import ModelDecimalField, ModelCharField -class TestValidation(SimpleTestCase): +class TestValidation(SpannerSimpleTestClass): def test_check_field_type_with_decimal_field_not_support_error(self): """ Checks if decimal field fails database validation as it's not