diff --git a/django_spanner/compiler.py b/django_spanner/compiler.py index 202ef103dc..106686d445 100644 --- a/django_spanner/compiler.py +++ b/django_spanner/compiler.py @@ -12,7 +12,7 @@ SQLInsertCompiler as BaseSQLInsertCompiler, SQLUpdateCompiler as BaseSQLUpdateCompiler, ) -from django.db.utils import DatabaseError +from django.db.utils import DatabaseError, add_dummy_where class SQLCompiler(BaseSQLCompiler): @@ -90,6 +90,8 @@ def get_combinator_sql(self, combinator, all): params = [] for part in args_parts: params.extend(part) + + result = add_dummy_where(result) return result, params diff --git a/django_spanner/utils.py b/django_spanner/utils.py index 1136c33a87..444afe053d 100644 --- a/django_spanner/utils.py +++ b/django_spanner/utils.py @@ -1,4 +1,11 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + import django +import sqlparse from django.core.exceptions import ImproperlyConfigured from django.utils.version import get_version_tuple @@ -18,3 +25,17 @@ def check_django_compatability(): A=django.VERSION[0], B=django.VERSION[1], C=__version__ ) ) + + +def add_dummy_where(sql): + """ + Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. + Add a dummy WHERE clause if necessary. + """ + if any( + isinstance(token, sqlparse.sql.Where) + for token in sqlparse.parse(sql)[0] + ): + return sql + + return sql + " WHERE 1=1" diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..e41f0f381a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -117,7 +117,7 @@ def close(self): self._is_closed = True def _do_execute_update(self, transaction, sql, params, param_types=None): - sql = parse_utils.ensure_where_clause(sql) + parse_utils.ensure_where_clause(sql) sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) result = transaction.execute_update( diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 084eea315e..0e69dbc0ca 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -515,15 +515,18 @@ def get_param_types(params): def ensure_where_clause(sql): """ - Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Add a dummy WHERE clause if necessary. + Raise unless `sql` includes a WHERE clause. + + :type sql: str + :param sql: SQL statement to check. """ - if any( + if not any( isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0] ): - return sql - return sql + " WHERE 1=1" + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause in UPDATE and DELETE statements" + ) def escape_name(name): diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..a73265e932 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -94,7 +94,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None, + transaction=transaction, sql="SELECT * WHERE true", params={}, ) return res diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 1bd38c85eb..d68e4118fd 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -407,36 +407,26 @@ def test_get_param_types_none(self): self.assertEqual(get_param_types(None), None) def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause - cases = [ - ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - ), - ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", - ), - ( - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) - for sql, want in cases: + for sql in err_cases: with self.subTest(sql=sql): - got = ensure_where_clause(sql) - self.assertEqual(got, want) + with self.assertRaises(ProgrammingError): + ensure_where_clause(sql) def test_escape_name(self): from google.cloud.spanner_dbapi.parse_utils import escape_name