Skip to content

Commit

Permalink
feat: Add dummy WHERE clause to certain statements (#516)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilya Gurov committed Nov 18, 2020
1 parent 196c449 commit af5d8e3
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 34 deletions.
4 changes: 3 additions & 1 deletion django_spanner/compiler.py
Expand Up @@ -12,7 +12,7 @@
SQLInsertCompiler as BaseSQLInsertCompiler,
SQLUpdateCompiler as BaseSQLUpdateCompiler,
)
from django.db.utils import DatabaseError
from django.db.utils import DatabaseError, add_dummy_where


class SQLCompiler(BaseSQLCompiler):
Expand Down Expand Up @@ -90,6 +90,8 @@ def get_combinator_sql(self, combinator, all):
params = []
for part in args_parts:
params.extend(part)

result = add_dummy_where(result)
return result, params


Expand Down
21 changes: 21 additions & 0 deletions django_spanner/utils.py
@@ -1,4 +1,11 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import django
import sqlparse
from django.core.exceptions import ImproperlyConfigured
from django.utils.version import get_version_tuple

Expand All @@ -18,3 +25,17 @@ def check_django_compatability():
A=django.VERSION[0], B=django.VERSION[1], C=__version__
)
)


def add_dummy_where(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
"""
if any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql

return sql + " WHERE 1=1"
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -117,7 +117,7 @@ def close(self):
self._is_closed = True

def _do_execute_update(self, transaction, sql, params, param_types=None):
sql = parse_utils.ensure_where_clause(sql)
parse_utils.ensure_where_clause(sql)
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)

result = transaction.execute_update(
Expand Down
13 changes: 8 additions & 5 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -515,15 +515,18 @@ def get_param_types(params):

def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
Raise unless `sql` includes a WHERE clause.
:type sql: str
:param sql: SQL statement to check.
"""
if any(
if not any(
isinstance(token, sqlparse.sql.Where)
for token in sqlparse.parse(sql)[0]
):
return sql
return sql + " WHERE 1=1"
raise ProgrammingError(
"Cloud Spanner requires a WHERE clause in UPDATE and DELETE statements"
)


def escape_name(name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -94,7 +94,7 @@ def test_do_execute_update(self):
def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
res = cursor._do_execute_update(
transaction=transaction, sql="sql", params=None,
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return res

Expand Down
42 changes: 16 additions & 26 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -407,36 +407,26 @@ def test_get_param_types_none(self):
self.assertEqual(get_param_types(None), None)

def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

def test_escape_name(self):
from google.cloud.spanner_dbapi.parse_utils import escape_name
Expand Down

0 comments on commit af5d8e3

Please sign in to comment.