Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: Fix black, isort compatibility (#469)
  • Loading branch information
c24t committed Sep 15, 2020
1 parent 45d6b97 commit dd005d5
Show file tree
Hide file tree
Showing 31 changed files with 2,973 additions and 2,472 deletions.
16 changes: 11 additions & 5 deletions django_spanner/__init__.py
Expand Up @@ -5,13 +5,16 @@
# https://developers.google.com/open-source/licenses/bsd

import datetime

# Monkey-patch AutoField to generate a random value since Cloud Spanner can't
# do that.
from uuid import uuid4

import pkg_resources
from django.db.models.fields import AutoField, Field
# Monkey-patch google.DatetimeWithNanoseconds's __eq__ compare against datetime.datetime.

# Monkey-patch google.DatetimeWithNanoseconds's __eq__ compare against
# datetime.datetime.
from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from .expressions import register_expressions
Expand All @@ -33,14 +36,16 @@ def gen_rand_int64():


def autofield_init(self, *args, **kwargs):
kwargs['blank'] = True
kwargs["blank"] = True
Field.__init__(self, *args, **kwargs)
self.default = gen_rand_int64


AutoField.__init__ = autofield_init

old_datetimewithnanoseconds_eq = getattr(DatetimeWithNanoseconds, '__eq__', None)
old_datetimewithnanoseconds_eq = getattr(
DatetimeWithNanoseconds, "__eq__", None
)


def datetimewithnanoseconds_eq(self, other):
Expand All @@ -62,12 +67,13 @@ def datetimewithnanoseconds_eq(self, other):
DatetimeWithNanoseconds.__eq__ = datetimewithnanoseconds_eq

# Sanity check here since tests can't easily be run for this file:
if __name__ == '__main__':
if __name__ == "__main__":
from django.utils import timezone

UTC = timezone.utc

dt = datetime.datetime(2020, 1, 10, 2, 44, 57, 999, UTC)
dtns = DatetimeWithNanoseconds(2020, 1, 10, 2, 44, 57, 999, UTC)
equal = dtns == dt
if not equal:
raise Exception('%s\n!=\n%s' % (dtns, dt))
raise Exception("%s\n!=\n%s" % (dtns, dt))
4 changes: 3 additions & 1 deletion django_spanner/base.py
Expand Up @@ -110,7 +110,9 @@ def instance(self):

@property
def _nodb_connection(self):
raise NotImplementedError('Spanner does not have a "no db" connection.')
raise NotImplementedError(
'Spanner does not have a "no db" connection.'
)

def get_connection_params(self):
return {
Expand Down
2 changes: 1 addition & 1 deletion django_spanner/client.py
Expand Up @@ -9,4 +9,4 @@

class DatabaseClient(BaseDatabaseClient):
def runshell(self):
raise NotImplementedError('dbshell is not implemented.')
raise NotImplementedError("dbshell is not implemented.")
59 changes: 42 additions & 17 deletions django_spanner/compiler.py
Expand Up @@ -7,8 +7,15 @@
from django.core.exceptions import EmptyResultSet
from django.db.models.sql.compiler import (
SQLAggregateCompiler as BaseSQLAggregateCompiler,
SQLCompiler as BaseSQLCompiler, SQLDeleteCompiler as BaseSQLDeleteCompiler,
)
from django.db.models.sql.compiler import SQLCompiler as BaseSQLCompiler
from django.db.models.sql.compiler import (
SQLDeleteCompiler as BaseSQLDeleteCompiler,
)
from django.db.models.sql.compiler import (
SQLInsertCompiler as BaseSQLInsertCompiler,
)
from django.db.models.sql.compiler import (
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)
from django.db.utils import DatabaseError
Expand All @@ -24,50 +31,68 @@ def get_combinator_sql(self, combinator, all):
features = self.connection.features
compilers = [
query.get_compiler(self.using, self.connection)
for query in self.query.combined_queries if not query.is_empty()
for query in self.query.combined_queries
if not query.is_empty()
]
if not features.supports_slicing_ordering_in_compound:
for query, compiler in zip(self.query.combined_queries, compilers):
if query.low_mark or query.high_mark:
raise DatabaseError('LIMIT/OFFSET not allowed in subqueries of compound statements.')
raise DatabaseError(
"LIMIT/OFFSET not allowed in subqueries of compound "
"statements."
)
if compiler.get_order_by():
raise DatabaseError('ORDER BY not allowed in subqueries of compound statements.')
raise DatabaseError(
"ORDER BY not allowed in subqueries of compound "
"statements."
)
parts = ()
for compiler in compilers:
try:
# If the columns list is limited, then all combined queries
# must have the same columns list. Set the selects defined on
# the query on all combined queries, if not already set.
if not compiler.query.values_select and self.query.values_select:
compiler.query.set_values((
*self.query.extra_select,
*self.query.values_select,
*self.query.annotation_select,
))
if (
not compiler.query.values_select
and self.query.values_select
):
compiler.query.set_values(
(
*self.query.extra_select,
*self.query.values_select,
*self.query.annotation_select,
)
)
part_sql, part_args = compiler.as_sql()
if compiler.query.combinator:
# Wrap in a subquery if wrapping in parentheses isn't
# supported.
if not features.supports_parentheses_in_compound:
part_sql = 'SELECT * FROM ({})'.format(part_sql)
part_sql = "SELECT * FROM ({})".format(part_sql)
# Add parentheses when combining with compound query if not
# already added for all compound queries.
elif not features.supports_slicing_ordering_in_compound:
part_sql = '({})'.format(part_sql)
part_sql = "({})".format(part_sql)
parts += ((part_sql, part_args),)
except EmptyResultSet:
# Omit the empty queryset with UNION and with DIFFERENCE if the
# first queryset is nonempty.
if combinator == 'union' or (combinator == 'difference' and parts):
if combinator == "union" or (
combinator == "difference" and parts
):
continue
raise
if not parts:
raise EmptyResultSet
combinator_sql = self.connection.ops.set_operators[combinator]
combinator_sql += ' ALL' if all else ' DISTINCT'
braces = '({})' if features.supports_slicing_ordering_in_compound else '{}'
sql_parts, args_parts = zip(*((braces.format(sql), args) for sql, args in parts))
result = [' {} '.format(combinator_sql).join(sql_parts)]
combinator_sql += " ALL" if all else " DISTINCT"
braces = (
"({})" if features.supports_slicing_ordering_in_compound else "{}"
)
sql_parts, args_parts = zip(
*((braces.format(sql), args) for sql, args in parts)
)
result = [" {} ".format(combinator_sql).join(sql_parts)]
params = []
for part in args_parts:
params.extend(part)
Expand Down
40 changes: 26 additions & 14 deletions django_spanner/creation.py
Expand Up @@ -17,18 +17,22 @@ class DatabaseCreation(BaseDatabaseCreation):
def mark_skips(self):
"""Skip tests that don't work on Spanner."""
for test_name in self.connection.features.skip_tests:
test_case_name, _, method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
test_case_name, _, method_name = test_name.rpartition(".")
test_app = test_name.split(".")[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
method = getattr(test_case, method_name)
setattr(test_case, method_name, skip('unsupported by Spanner')(method))
setattr(
test_case,
method_name,
skip("unsupported by Spanner")(method),
)

def create_test_db(self, *args, **kwargs):
# This environment variable is set by the Travis build script or
# by a developer running the tests locally.
if os.environ.get('RUNNING_SPANNER_BACKEND_TESTS') == '1':
if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1":
self.mark_skips()
super().create_test_db(*args, **kwargs)

Expand All @@ -38,7 +42,7 @@ def _create_test_db(self, verbosity, autoclobber, keepdb=False):
test_database_name = self._get_test_db_name()
# Don't quote the test database name because google.cloud.spanner_v1
# does it.
test_db_params = {'dbname': test_database_name}
test_db_params = {"dbname": test_database_name}
# Create the test database.
try:
self._execute_create_test_db(None, test_db_params, keepdb)
Expand All @@ -47,29 +51,37 @@ def _create_test_db(self, verbosity, autoclobber, keepdb=False):
# just return and skip it all.
if keepdb:
return test_database_name
self.log('Got an error creating the test database: %s' % e)
self.log("Got an error creating the test database: %s" % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
try:
if verbosity >= 1:
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, test_database_name
),
)
)
self._destroy_test_db(test_database_name, verbosity)
self._execute_create_test_db(None, test_db_params, keepdb)
except Exception as e:
self.log('Got an error recreating the test database: %s' % e)
self.log(
"Got an error recreating the test database: %s" % e
)
sys.exit(2)
else:
self.log('Tests cancelled.')
self.log("Tests cancelled.")
sys.exit(1)
return test_database_name

def _execute_create_test_db(self, cursor, parameters, keepdb=False):
self.connection.instance.database(parameters['dbname']).create()
self.connection.instance.database(parameters["dbname"]).create()

def _destroy_test_db(self, test_database_name, verbosity):
self.connection.instance.database(test_database_name).drop()
8 changes: 5 additions & 3 deletions django_spanner/expressions.py
Expand Up @@ -12,10 +12,12 @@ def order_by(self, compiler, connection, **extra_context):
# DatabaseFeatures.supports_order_by_nulls_modifier = False.
template = None
if self.nulls_last:
template = '%(expression)s IS NULL, %(expression)s %(ordering)s'
template = "%(expression)s IS NULL, %(expression)s %(ordering)s"
elif self.nulls_first:
template = '%(expression)s IS NOT NULL, %(expression)s %(ordering)s'
return self.as_sql(compiler, connection, template=template, **extra_context)
template = "%(expression)s IS NOT NULL, %(expression)s %(ordering)s"
return self.as_sql(
compiler, connection, template=template, **extra_context
)


def register_expressions():
Expand Down

0 comments on commit dd005d5

Please sign in to comment.