Skip to content

Commit

Permalink
feat: added test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
vi3k6i5 committed Apr 28, 2021
1 parent 129e41e commit a0fae75
Show file tree
Hide file tree
Showing 12 changed files with 1,114 additions and 20 deletions.
22 changes: 7 additions & 15 deletions noxfile.py
Expand Up @@ -66,7 +66,12 @@ def lint_setup_py(session):
def default(session):
# Install all test dependencies, then install this package in-place.
session.install(
"django~=2.2", "mock", "mock-import", "pytest", "pytest-cov"
"django~=2.2",
"mock",
"mock-import",
"pytest",
"pytest-cov",
"coverage",
)
session.install("-e", ".")

Expand All @@ -79,7 +84,7 @@ def default(session):
"--cov-append",
"--cov-config=.coveragerc",
"--cov-report=",
"--cov-fail-under=20",
"--cov-fail-under=70",
os.path.join("tests", "unit"),
*session.posargs
)
Expand All @@ -91,19 +96,6 @@ def unit(session):
default(session)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def cover(session):
"""Run the final coverage report.
This outputs the coverage report aggregating coverage from the unit
test runs (not system test runs), and then erases coverage data.
"""
session.install("coverage", "pytest-cov")
session.run("coverage", "report", "--show-missing", "--fail-under=20")

session.run("coverage", "erase")


@nox.session(python=DEFAULT_PYTHON_VERSION)
def docs(session):
"""Build the docs for this library."""
Expand Down
40 changes: 40 additions & 0 deletions tests/settings.py
@@ -0,0 +1,40 @@
DEBUG = True
USE_TZ = True

INSTALLED_APPS = [
"django_spanner", # Must be the first entry
"django.contrib.contenttypes",
"django.contrib.auth",
"django.contrib.sites",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"tests",
]

TIME_ZONE = "UTC"

DATABASES = {
"default": {
"ENGINE": "django_spanner",
"PROJECT": "emulator-local",
"INSTANCE": "django-test-instance",
"NAME": "django-test-db",
}
}
SECRET_KEY = "spanner emulator secret key"

PASSWORD_HASHERS = [
"django.contrib.auth.hashers.MD5PasswordHasher",
]

SITE_ID = 1

CONN_MAX_AGE = 60

ENGINE = "django_spanner"
PROJECT = "emulator-local"
INSTANCE = "django-test-instance"
NAME = "django-test-db"
OPTIONS = {}
AUTOCOMMIT = True
Empty file.
67 changes: 67 additions & 0 deletions tests/unit/django_spanner/models.py
@@ -0,0 +1,67 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""
Different models used for testing django-spanner code.
"""
import os
from django.db import models
import django

# Load django settings before loading dhango models.
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
django.setup()


# Register transformations for model fields.
class UpperCase(models.Transform):
lookup_name = "upper"
function = "UPPER"
bilateral = True


models.CharField.register_lookup(UpperCase)
models.TextField.register_lookup(UpperCase)


# Models
class ModelDecimalField(models.Model):
field = models.DecimalField()


class ModelCharField(models.Model):
field = models.CharField()


class Item(models.Model):
item_id = models.IntegerField()
name = models.CharField(max_length=10)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)

class Meta:
ordering = ["name"]


class Number(models.Model):
num = models.IntegerField()
decimal_num = models.DecimalField(max_digits=5, decimal_places=2)
item = models.ForeignKey(Item, models.CASCADE)


class Author(models.Model):
name = models.CharField(max_length=40)
last_name = models.CharField(max_length=40)
num = models.IntegerField(unique=True)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)


class Report(models.Model):
name = models.CharField(max_length=10)
creator = models.ForeignKey(Author, models.CASCADE, null=True)

class Meta:
ordering = ["name"]
6 changes: 3 additions & 3 deletions tests/unit/django_spanner/test_base.py
Expand Up @@ -49,7 +49,7 @@ def test_property_instance(self):
_ = db_wrapper.instance
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])

def test_property__nodb_connection(self):
def test_property_nodb_connection(self):
db_wrapper = self._make_one(None)
with self.assertRaises(NotImplementedError):
db_wrapper._nodb_connection()
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_create_cursor(self):
db_wrapper.create_cursor()
mock_cursor.assert_called_once_with()

def test__set_autocommit(self):
def test_set_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.autocommit = False
Expand All @@ -110,7 +110,7 @@ def test_is_usable(self):
mock_connection.cursor = mock.MagicMock(side_effect=Error)
self.assertFalse(db_wrapper.is_usable())

def test__start_transaction_under_autocommit(self):
def test_start_transaction_under_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock_cursor = mock.MagicMock()
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/django_spanner/test_client.py
Expand Up @@ -7,6 +7,7 @@
import sys
import unittest
import os
from google.cloud.spanner_dbapi.exceptions import NotSupportedError


@unittest.skipIf(
Expand Down Expand Up @@ -36,8 +37,6 @@ def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

def test_runshell(self):
from google.cloud.spanner_dbapi.exceptions import NotSupportedError

db_wrapper = self._make_one(self.settings_dict)

with self.assertRaises(NotSupportedError):
Expand Down
199 changes: 199 additions & 0 deletions tests/unit/django_spanner/test_compiler.py
@@ -0,0 +1,199 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import sys
import unittest

from django.test import SimpleTestCase
from django.core.exceptions import EmptyResultSet
from django.db.utils import DatabaseError
from django_spanner.compiler import SQLCompiler
from django.db.models.query import QuerySet
from .models import Number


@unittest.skipIf(
sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5"
)
class TestUtils(SimpleTestCase):
settings_dict = {"dummy_param": "dummy"}

def _get_target_class(self):
from django_spanner.base import DatabaseWrapper

return DatabaseWrapper

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

def test_unsupported_ordering_slicing_raises_db_error(self):
"""
Tries limit/offset and order by in subqueries which are not supported
by spanner.
"""
qs1 = Number.objects.all()
qs2 = Number.objects.all()
msg = "LIMIT/OFFSET not allowed in subqueries of compound statements"
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.union(qs2[:10]))
msg = "ORDER BY not allowed in subqueries of compound statements"
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.order_by("id").union(qs2))

def test_get_combinator_sql_all_union_sql_generated(self):
"""
Tries union sql generator.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.union(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", True)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_distinct_union_sql_generated(self):
"""
Tries union sql generator with distinct.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.union(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_all_sql_generated(self):
"""
Tries difference sql generator.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.difference(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", True)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_distinct_sql_generated(self):
"""
Tries difference sql generator with distinct.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.difference(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", False)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_union_and_difference_query_together(self):
"""
Tries sql generator with union of queryset with queryset of difference.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs3 = Number.objects.filter(num__exact=10).values("num")
qs4 = qs1.union(qs2.difference(qs3))

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_parentheses_in_compound_not_supported(self):
"""
Tries sql generator with union of queryset with queryset of difference,
adding support for parentheses in compound sql statement.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs3 = Number.objects.filter(num__exact=10).values("num")
qs4 = qs1.union(qs2.difference(qs3))

compiler = SQLCompiler(qs4.query, connection, "default")
compiler.connection.features.supports_parentheses_in_compound = False
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_empty_queryset_raises_exception(self):
"""
Tries sql generator with empty queryset.
"""
connection = self._make_one(self.settings_dict)
compiler = SQLCompiler(QuerySet().query, connection, "default")
with self.assertRaises(EmptyResultSet):
compiler.get_combinator_sql("union", False)

0 comments on commit a0fae75

Please sign in to comment.