Skip to content

Commit

Permalink
feat: adding unit tests for django spanner
Browse files Browse the repository at this point in the history
  • Loading branch information
vi3k6i5 committed Apr 21, 2021
1 parent 798e88d commit ec28c1c
Show file tree
Hide file tree
Showing 15 changed files with 1,256 additions and 8 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Expand Up @@ -24,3 +24,7 @@ django_tests_dir

# Built documentation
docs/_build

# mac hidden files.
.DS_Store

5 changes: 4 additions & 1 deletion django_spanner/lookups.py
Expand Up @@ -101,7 +101,10 @@ def iexact(self, compiler, connection):
# lhs_sql is the expression/column to use as the regular expression.
# Use concat to make the value case-insensitive.
lhs_sql = "CONCAT('^(?i)', " + lhs_sql + ", '$')"
rhs_sql = rhs_sql.replace("%%s", "%s")
if not self.rhs_is_direct_value() and not params:
# If rhs is not a direct value and parameter is not present we want
# to have only 1 formatable argument in rhs_sql else we need 2.
rhs_sql = rhs_sql.replace("%%s", "%s")
# rhs_sql is REGEXP_CONTAINS(%s, %%s), and lhs_sql is the column name.
return rhs_sql % lhs_sql, params

Expand Down
9 changes: 7 additions & 2 deletions noxfile.py
Expand Up @@ -66,7 +66,12 @@ def lint_setup_py(session):
def default(session):
# Install all test dependencies, then install this package in-place.
session.install(
"django~=2.2", "mock", "mock-import", "pytest", "pytest-cov"
"django~=2.2",
"mock",
"mock-import",
"pytest",
"pytest-cov",
"coverage",
)
session.install("-e", ".")

Expand All @@ -79,7 +84,7 @@ def default(session):
"--cov-append",
"--cov-config=.coveragerc",
"--cov-report=",
"--cov-fail-under=25",
"--cov-fail-under=80",
os.path.join("tests", "unit"),
*session.posargs
)
Expand Down
40 changes: 40 additions & 0 deletions tests/settings.py
@@ -0,0 +1,40 @@
DEBUG = True
USE_TZ = True

INSTALLED_APPS = [
"django_spanner", # Must be the first entry
"django.contrib.contenttypes",
"django.contrib.auth",
"django.contrib.sites",
"django.contrib.sessions",
"django.contrib.messages",
"django.contrib.staticfiles",
"tests",
]

TIME_ZONE = "UTC"

DATABASES = {
"default": {
"ENGINE": "django_spanner",
"PROJECT": "emulator-local",
"INSTANCE": "django-test-instance",
"NAME": "django-test-db",
}
}
SECRET_KEY = "spanner emulator secret key"

PASSWORD_HASHERS = [
"django.contrib.auth.hashers.MD5PasswordHasher",
]

SITE_ID = 1

CONN_MAX_AGE = 60

ENGINE = "django_spanner"
PROJECT = "emulator-local"
INSTANCE = "django-test-instance"
NAME = "django-test-db"
OPTIONS = {}
AUTOCOMMIT = True
Empty file.
64 changes: 64 additions & 0 deletions tests/unit/django_spanner/models.py
@@ -0,0 +1,64 @@
"""
Different models used for testing django-spanner code.
"""
import os
from django.db import models
import django
from django.db.models import Transform
from django.db.models import CharField, TextField

# Load django settings before loading dhango models.
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
django.setup()


# Register transformations for model fields.
class UpperCase(Transform):
lookup_name = "upper"
function = "UPPER"
bilateral = True


CharField.register_lookup(UpperCase)
TextField.register_lookup(UpperCase)


# Models
class ModelDecimalField(models.Model):
field = models.DecimalField()


class ModelCharField(models.Model):
field = models.CharField()


class Item(models.Model):
item_id = models.IntegerField()
name = models.CharField(max_length=10)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)

class Meta:
ordering = ["name"]


class Number(models.Model):
num = models.IntegerField()
decimal_num = models.DecimalField(max_digits=5, decimal_places=2)
item = models.ForeignKey(Item, models.CASCADE)


class Author(models.Model):
name = models.CharField(max_length=40)
last_name = models.CharField(max_length=40)
num = models.IntegerField(unique=True)
created = models.DateTimeField()
modified = models.DateTimeField(blank=True, null=True)


class Report(models.Model):
name = models.CharField(max_length=10)
creator = models.ForeignKey(Author, models.CASCADE, null=True)

class Meta:
ordering = ["name"]
6 changes: 3 additions & 3 deletions tests/unit/django_spanner/test_base.py
Expand Up @@ -49,7 +49,7 @@ def test_property_instance(self):
_ = db_wrapper.instance
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])

def test_property__nodb_connection(self):
def test_property_nodb_connection(self):
db_wrapper = self._make_one(None)
with self.assertRaises(NotImplementedError):
db_wrapper._nodb_connection()
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_create_cursor(self):
db_wrapper.create_cursor()
mock_cursor.assert_called_once_with()

def test__set_autocommit(self):
def test_set_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.autocommit = False
Expand All @@ -110,7 +110,7 @@ def test_is_usable(self):
mock_connection.cursor = mock.MagicMock(side_effect=Error)
self.assertFalse(db_wrapper.is_usable())

def test__start_transaction_under_autocommit(self):
def test_start_transaction_under_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock_cursor = mock.MagicMock()
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/django_spanner/test_client.py
Expand Up @@ -7,6 +7,7 @@
import sys
import unittest
import os
from google.cloud.spanner_dbapi.exceptions import NotSupportedError


@unittest.skipIf(
Expand Down Expand Up @@ -36,8 +37,6 @@ def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

def test_runshell(self):
from google.cloud.spanner_dbapi.exceptions import NotSupportedError

db_wrapper = self._make_one(self.settings_dict)

with self.assertRaises(NotSupportedError):
Expand Down
199 changes: 199 additions & 0 deletions tests/unit/django_spanner/test_compiler.py
@@ -0,0 +1,199 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import sys
import unittest

from django.test import SimpleTestCase
from django.core.exceptions import EmptyResultSet
from django.db.utils import DatabaseError
from django_spanner.compiler import SQLCompiler
from django.db.models.query import QuerySet
from .models import Number


@unittest.skipIf(
sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5"
)
class TestUtils(SimpleTestCase):
settings_dict = {"dummy_param": "dummy"}

def _get_target_class(self):
from django_spanner.base import DatabaseWrapper

return DatabaseWrapper

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

def test_unsupported_ordering_slicing_raises_db_error(self):
"""
Tries limit/offset and order by in subqueries which are not supported
by spanner.
"""
qs1 = Number.objects.all()
qs2 = Number.objects.all()
msg = "LIMIT/OFFSET not allowed in subqueries of compound statements"
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.union(qs2[:10]))
msg = "ORDER BY not allowed in subqueries of compound statements"
with self.assertRaisesMessage(DatabaseError, msg):
list(qs1.order_by("id").union(qs2))

def test_get_combinator_sql_all_union_sql_generated(self):
"""
Tries union sql generator.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.union(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", True)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_distinct_union_sql_generated(self):
"""
Tries union sql generator with distinct.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.union(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_all_sql_generated(self):
"""
Tries difference sql generator.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.difference(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", True)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
+ "FROM tests_number WHERE tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_difference_distinct_sql_generated(self):
"""
Tries difference sql generator with distinct.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs4 = qs1.difference(qs2)

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("difference", False)

self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
+ "tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s"
],
)
self.assertEqual(params, [1, 8])

def test_get_combinator_sql_union_and_difference_query_together(self):
"""
Tries sql generator with union of queryset with queryset of difference.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs3 = Number.objects.filter(num__exact=10).values("num")
qs4 = qs1.union(qs2.difference(qs3))

compiler = SQLCompiler(qs4.query, connection, "default")
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_parentheses_in_compound_not_supported(self):
"""
Tries sql generator with union of queryset with queryset of difference,
adding support for parentheses in compound sql statement.
"""
connection = self._make_one(self.settings_dict)

qs1 = Number.objects.filter(num__lte=1).values("num")
qs2 = Number.objects.filter(num__gte=8).values("num")
qs3 = Number.objects.filter(num__exact=10).values("num")
qs4 = qs1.union(qs2.difference(qs3))

compiler = SQLCompiler(qs4.query, connection, "default")
compiler.connection.features.supports_parentheses_in_compound = False
sql_compiled, params = compiler.get_combinator_sql("union", False)
self.assertEqual(
sql_compiled,
[
"SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
+ "SELECT tests_number.num FROM tests_number WHERE "
+ "tests_number.num >= %s EXCEPT DISTINCT "
+ "SELECT tests_number.num FROM tests_number "
+ "WHERE tests_number.num = %s)"
],
)
self.assertEqual(params, [1, 8, 10])

def test_get_combinator_sql_empty_queryset_raises_exception(self):
"""
Tries sql generator with empty queryset.
"""
connection = self._make_one(self.settings_dict)
compiler = SQLCompiler(QuerySet().query, connection, "default")
with self.assertRaises(EmptyResultSet):
compiler.get_combinator_sql("union", False)

0 comments on commit ec28c1c

Please sign in to comment.