Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix black, isort compatibility #469

Merged
merged 11 commits into from Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 11 additions & 5 deletions django_spanner/__init__.py
Expand Up @@ -5,13 +5,16 @@
# https://developers.google.com/open-source/licenses/bsd

import datetime

# Monkey-patch AutoField to generate a random value since Cloud Spanner can't
# do that.
from uuid import uuid4

import pkg_resources
from django.db.models.fields import AutoField, Field
# Monkey-patch google.DatetimeWithNanoseconds's __eq__ compare against datetime.datetime.

# Monkey-patch google.DatetimeWithNanoseconds's __eq__ compare against
# datetime.datetime.
from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from .expressions import register_expressions
Expand All @@ -33,14 +36,16 @@ def gen_rand_int64():


def autofield_init(self, *args, **kwargs):
kwargs['blank'] = True
kwargs["blank"] = True
Field.__init__(self, *args, **kwargs)
self.default = gen_rand_int64


AutoField.__init__ = autofield_init

old_datetimewithnanoseconds_eq = getattr(DatetimeWithNanoseconds, '__eq__', None)
old_datetimewithnanoseconds_eq = getattr(
DatetimeWithNanoseconds, "__eq__", None
)


def datetimewithnanoseconds_eq(self, other):
Expand All @@ -62,12 +67,13 @@ def datetimewithnanoseconds_eq(self, other):
DatetimeWithNanoseconds.__eq__ = datetimewithnanoseconds_eq

# Sanity check here since tests can't easily be run for this file:
if __name__ == '__main__':
if __name__ == "__main__":
from django.utils import timezone

UTC = timezone.utc

dt = datetime.datetime(2020, 1, 10, 2, 44, 57, 999, UTC)
dtns = DatetimeWithNanoseconds(2020, 1, 10, 2, 44, 57, 999, UTC)
equal = dtns == dt
if not equal:
raise Exception('%s\n!=\n%s' % (dtns, dt))
raise Exception("%s\n!=\n%s" % (dtns, dt))
4 changes: 3 additions & 1 deletion django_spanner/base.py
Expand Up @@ -110,7 +110,9 @@ def instance(self):

@property
def _nodb_connection(self):
raise NotImplementedError('Spanner does not have a "no db" connection.')
raise NotImplementedError(
'Spanner does not have a "no db" connection.'
)

def get_connection_params(self):
return {
Expand Down
2 changes: 1 addition & 1 deletion django_spanner/client.py
Expand Up @@ -9,4 +9,4 @@

class DatabaseClient(BaseDatabaseClient):
def runshell(self):
raise NotImplementedError('dbshell is not implemented.')
raise NotImplementedError("dbshell is not implemented.")
59 changes: 42 additions & 17 deletions django_spanner/compiler.py
Expand Up @@ -7,8 +7,15 @@
from django.core.exceptions import EmptyResultSet
from django.db.models.sql.compiler import (
SQLAggregateCompiler as BaseSQLAggregateCompiler,
SQLCompiler as BaseSQLCompiler, SQLDeleteCompiler as BaseSQLDeleteCompiler,
)
from django.db.models.sql.compiler import SQLCompiler as BaseSQLCompiler
from django.db.models.sql.compiler import (
SQLDeleteCompiler as BaseSQLDeleteCompiler,
)
from django.db.models.sql.compiler import (
SQLInsertCompiler as BaseSQLInsertCompiler,
)
from django.db.models.sql.compiler import (
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)
from django.db.utils import DatabaseError
Expand All @@ -24,50 +31,68 @@ def get_combinator_sql(self, combinator, all):
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection)
for query in self.query.combined_queries if not query.is_empty()
for query in self.query.combined_queries
if not query.is_empty()
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
raise DatabaseError(
"LIMIT/OFFSET not allowed in subqueries of compound "
"statements."
)
if compiler.get_order_by():
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
raise DatabaseError(
"ORDER BY not allowed in subqueries of compound "
"statements."
)
parts = ()
for compiler in compilers:
try:
# If the columns list is limited, then all combined queries
# must have the same columns list. Set the selects defined on
# the query on all combined queries, if not already set.
if not compiler.query.values_select and self.query.values_select:
compiler.query.set_values((
*self.query.extra_select,
*self.query.values_select,
*self.query.annotation_select,
))
if (
not compiler.query.values_select
and self.query.values_select
):
compiler.query.set_values(
(
*self.query.extra_select,
*self.query.values_select,
*self.query.annotation_select,
)
)
part_sql, part_args = compiler.as_sql()
if compiler.query.combinator:
# Wrap in a subquery if wrapping in parentheses isn't
# supported.
if not features.supports_parentheses_in_compound:
part_sql = 'SELECT * FROM ({})'.format(part_sql)
part_sql = "SELECT * FROM ({})".format(part_sql)
# Add parentheses when combining with compound query if not
# already added for all compound queries.
elif not features.supports_slicing_ordering_in_compound:
part_sql = '({})'.format(part_sql)
part_sql = "({})".format(part_sql)
parts += ((part_sql, part_args),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
if combinator == 'union' or (combinator == 'difference' and parts):
if combinator == "union" or (
combinator == "difference" and parts
):
continue
raise
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
combinator_sql += ' ALL' if all else ' DISTINCT'
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}'
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
result = [' {} '.format(combinator_sql).join(sql_parts)]
combinator_sql += " ALL" if all else " DISTINCT"
braces = (
"({})" if features.supports_slicing_ordering_in_compound else "{}"
)
sql_parts, args_parts = zip(
*((braces.format(sql), args) for sql, args in parts)
)
result = [" {} ".format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
Expand Down
40 changes: 26 additions & 14 deletions django_spanner/creation.py
Expand Up @@ -17,18 +17,22 @@ class DatabaseCreation(BaseDatabaseCreation):
def mark_skips(self):
"""Skip tests that don't work on Spanner."""
for test_name in self.connection.features.skip_tests:
test_case_name, _, method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
test_case_name, _, method_name = test_name.rpartition(".")
test_app = test_name.split(".")[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
method = getattr(test_case, method_name)
setattr(test_case, method_name, skip('unsupported by Spanner')(method))
setattr(
test_case,
method_name,
skip("unsupported by Spanner")(method),
)

def create_test_db(self, *args, **kwargs):
# This environment variable is set by the Travis build script or
# by a developer running the tests locally.
if os.environ.get('RUNNING_SPANNER_BACKEND_TESTS') == '1':
if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1":
self.mark_skips()
super().create_test_db(*args, **kwargs)

Expand All @@ -38,7 +42,7 @@ def _create_test_db(self, verbosity, autoclobber, keepdb=False):
test_database_name = self._get_test_db_name()
# Don't quote the test database name because google.cloud.spanner_v1
# does it.
test_db_params = {'dbname': test_database_name}
test_db_params = {"dbname": test_database_name}
# Create the test database.
try:
self._execute_create_test_db(None, test_db_params, keepdb)
Expand All @@ -47,29 +51,37 @@ def _create_test_db(self, verbosity, autoclobber, keepdb=False):
# just return and skip it all.
if keepdb:
return test_database_name
self.log('Got an error creating the test database: %s' % e)
self.log("Got an error creating the test database: %s" % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, test_database_name
),
)
)
self._destroy_test_db(test_database_name, verbosity)
self._execute_create_test_db(None, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
self.log(
"Got an error recreating the test database: %s" % e
)
sys.exit(2)
else:
self.log('Tests cancelled.')
self.log("Tests cancelled.")
sys.exit(1)
return test_database_name

def _execute_create_test_db(self, cursor, parameters, keepdb=False):
self.connection.instance.database(parameters['dbname']).create()
self.connection.instance.database(parameters["dbname"]).create()

def _destroy_test_db(self, test_database_name, verbosity):
self.connection.instance.database(test_database_name).drop()
8 changes: 5 additions & 3 deletions django_spanner/expressions.py
Expand Up @@ -12,10 +12,12 @@ def order_by(self, compiler, connection, **extra_context):
# DatabaseFeatures.supports_order_by_nulls_modifier = False.
template = None
if self.nulls_last:
template = '%(expression)s IS NULL, %(expression)s %(ordering)s'
template = "%(expression)s IS NULL, %(expression)s %(ordering)s"
elif self.nulls_first:
template = '%(expression)s IS NOT NULL, %(expression)s %(ordering)s'
return self.as_sql(compiler, connection, template=template, **extra_context)
template = "%(expression)s IS NOT NULL, %(expression)s %(ordering)s"
return self.as_sql(
compiler, connection, template=template, **extra_context
)


def register_expressions():
Expand Down