Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,081 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
DEBUG = True | ||
USE_TZ = True | ||
|
||
INSTALLED_APPS = [ | ||
"django_spanner", # Must be the first entry | ||
"django.contrib.contenttypes", | ||
"django.contrib.auth", | ||
"django.contrib.sites", | ||
"django.contrib.sessions", | ||
"django.contrib.messages", | ||
"django.contrib.staticfiles", | ||
"tests", | ||
] | ||
|
||
TIME_ZONE = "UTC" | ||
|
||
DATABASES = { | ||
"default": { | ||
"ENGINE": "django_spanner", | ||
"PROJECT": "emulator-local", | ||
"INSTANCE": "django-test-instance", | ||
"NAME": "django-test-db", | ||
} | ||
} | ||
SECRET_KEY = "spanner emulator secret key" | ||
|
||
PASSWORD_HASHERS = [ | ||
"django.contrib.auth.hashers.MD5PasswordHasher", | ||
] | ||
|
||
SITE_ID = 1 | ||
|
||
CONN_MAX_AGE = 60 | ||
|
||
ENGINE = "django_spanner" | ||
PROJECT = "emulator-local" | ||
INSTANCE = "django-test-instance" | ||
NAME = "django-test-db" | ||
OPTIONS = {} | ||
AUTOCOMMIT = True |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Use of this source code is governed by a BSD-style | ||
# license that can be found in the LICENSE file or at | ||
# https://developers.google.com/open-source/licenses/bsd | ||
""" | ||
Different models used for testing django-spanner code. | ||
""" | ||
import os | ||
from django.db import models | ||
import django | ||
|
||
# Load django settings before loading dhango models. | ||
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" | ||
django.setup() | ||
|
||
|
||
# Register transformations for model fields. | ||
class UpperCase(models.Transform): | ||
lookup_name = "upper" | ||
function = "UPPER" | ||
bilateral = True | ||
|
||
|
||
models.CharField.register_lookup(UpperCase) | ||
models.TextField.register_lookup(UpperCase) | ||
|
||
|
||
# Models | ||
class ModelDecimalField(models.Model): | ||
field = models.DecimalField() | ||
|
||
|
||
class ModelCharField(models.Model): | ||
field = models.CharField() | ||
|
||
|
||
class Item(models.Model): | ||
item_id = models.IntegerField() | ||
name = models.CharField(max_length=10) | ||
created = models.DateTimeField() | ||
modified = models.DateTimeField(blank=True, null=True) | ||
|
||
class Meta: | ||
ordering = ["name"] | ||
|
||
|
||
class Number(models.Model): | ||
num = models.IntegerField() | ||
decimal_num = models.DecimalField(max_digits=5, decimal_places=2) | ||
item = models.ForeignKey(Item, models.CASCADE) | ||
|
||
|
||
class Author(models.Model): | ||
name = models.CharField(max_length=40) | ||
last_name = models.CharField(max_length=40) | ||
num = models.IntegerField(unique=True) | ||
created = models.DateTimeField() | ||
modified = models.DateTimeField(blank=True, null=True) | ||
|
||
|
||
class Report(models.Model): | ||
name = models.CharField(max_length=10) | ||
creator = models.ForeignKey(Author, models.CASCADE, null=True) | ||
|
||
class Meta: | ||
ordering = ["name"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Use of this source code is governed by a BSD-style | ||
# license that can be found in the LICENSE file or at | ||
# https://developers.google.com/open-source/licenses/bsd | ||
|
||
from django.test import SimpleTestCase | ||
from django.core.exceptions import EmptyResultSet | ||
from django.db.utils import DatabaseError | ||
from django_spanner.compiler import SQLCompiler | ||
from django.db.models.query import QuerySet | ||
from .models import Number | ||
|
||
|
||
class TestUtils(SimpleTestCase): | ||
settings_dict = {"dummy_param": "dummy"} | ||
|
||
def _get_target_class(self): | ||
from django_spanner.base import DatabaseWrapper | ||
|
||
return DatabaseWrapper | ||
|
||
def _make_one(self, *args, **kwargs): | ||
return self._get_target_class()(*args, **kwargs) | ||
|
||
def test_unsupported_ordering_slicing_raises_db_error(self): | ||
""" | ||
Tries limit/offset and order by in subqueries which are not supported | ||
by spanner. | ||
""" | ||
qs1 = Number.objects.all() | ||
qs2 = Number.objects.all() | ||
msg = "LIMIT/OFFSET not allowed in subqueries of compound statements" | ||
with self.assertRaisesMessage(DatabaseError, msg): | ||
list(qs1.union(qs2[:10])) | ||
msg = "ORDER BY not allowed in subqueries of compound statements" | ||
with self.assertRaisesMessage(DatabaseError, msg): | ||
list(qs1.order_by("id").union(qs2)) | ||
|
||
def test_get_combinator_sql_all_union_sql_generated(self): | ||
""" | ||
Tries union sql generator. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs4 = qs1.union(qs2) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
sql_compiled, params = compiler.get_combinator_sql("union", True) | ||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num " | ||
+ "FROM tests_number WHERE tests_number.num >= %s" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8]) | ||
|
||
def test_get_combinator_sql_distinct_union_sql_generated(self): | ||
""" | ||
Tries union sql generator with distinct. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs4 = qs1.union(qs2) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
sql_compiled, params = compiler.get_combinator_sql("union", False) | ||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s UNION DISTINCT SELECT " | ||
+ "tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num >= %s" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8]) | ||
|
||
def test_get_combinator_sql_difference_all_sql_generated(self): | ||
""" | ||
Tries difference sql generator. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs4 = qs1.difference(qs2) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
sql_compiled, params = compiler.get_combinator_sql("difference", True) | ||
|
||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num " | ||
+ "FROM tests_number WHERE tests_number.num >= %s" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8]) | ||
|
||
def test_get_combinator_sql_difference_distinct_sql_generated(self): | ||
""" | ||
Tries difference sql generator with distinct. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs4 = qs1.difference(qs2) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
sql_compiled, params = compiler.get_combinator_sql("difference", False) | ||
|
||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT " | ||
+ "tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num >= %s" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8]) | ||
|
||
def test_get_combinator_sql_union_and_difference_query_together(self): | ||
""" | ||
Tries sql generator with union of queryset with queryset of difference. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs3 = Number.objects.filter(num__exact=10).values("num") | ||
qs4 = qs1.union(qs2.difference(qs3)) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
sql_compiled, params = compiler.get_combinator_sql("union", False) | ||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s UNION DISTINCT (" | ||
+ "SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num >= %s EXCEPT DISTINCT " | ||
+ "SELECT tests_number.num FROM tests_number " | ||
+ "WHERE tests_number.num = %s)" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8, 10]) | ||
|
||
def test_get_combinator_sql_parentheses_in_compound_not_supported(self): | ||
""" | ||
Tries sql generator with union of queryset with queryset of difference, | ||
adding support for parentheses in compound sql statement. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
|
||
qs1 = Number.objects.filter(num__lte=1).values("num") | ||
qs2 = Number.objects.filter(num__gte=8).values("num") | ||
qs3 = Number.objects.filter(num__exact=10).values("num") | ||
qs4 = qs1.union(qs2.difference(qs3)) | ||
|
||
compiler = SQLCompiler(qs4.query, connection, "default") | ||
compiler.connection.features.supports_parentheses_in_compound = False | ||
sql_compiled, params = compiler.get_combinator_sql("union", False) | ||
self.assertEqual( | ||
sql_compiled, | ||
[ | ||
"SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM (" | ||
+ "SELECT tests_number.num FROM tests_number WHERE " | ||
+ "tests_number.num >= %s EXCEPT DISTINCT " | ||
+ "SELECT tests_number.num FROM tests_number " | ||
+ "WHERE tests_number.num = %s)" | ||
], | ||
) | ||
self.assertEqual(params, [1, 8, 10]) | ||
|
||
def test_get_combinator_sql_empty_queryset_raises_exception(self): | ||
""" | ||
Tries sql generator with empty queryset. | ||
""" | ||
connection = self._make_one(self.settings_dict) | ||
compiler = SQLCompiler(QuerySet().query, connection, "default") | ||
with self.assertRaises(EmptyResultSet): | ||
compiler.get_combinator_sql("union", False) |
Oops, something went wrong.