Skip to content

Commit

Permalink
feat: remove adding a dummy WHERE clause into UPDATE and DELETE state…
Browse files Browse the repository at this point in the history
…ments (#169)

* feat: don't add dummy WHERE clause into UPDATE and DELETE queries

* fix docstrings
  • Loading branch information
Ilya Gurov committed Nov 20, 2020
1 parent 8cfea48 commit 7f4d478
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
12 changes: 10 additions & 2 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -523,11 +523,19 @@ def get_param_types(params):
def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
Raise an error, if the given sql doesn't include it.
:type sql: `str`
:param sql: SQL code to check.
:raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause.
"""
if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]):
return sql
return sql + " WHERE 1=1"

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)


def escape_name(name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -100,7 +100,7 @@ def test_do_execute_update(self):
def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
res = cursor._do_execute_update(
transaction=transaction, sql="sql", params=None
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return res

Expand Down
42 changes: 16 additions & 26 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -391,36 +391,26 @@ def test_get_param_types_none(self):

@unittest.skipIf(skip_condition, skip_message)
def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

@unittest.skipIf(skip_condition, skip_message)
def test_escape_name(self):
Expand Down

0 comments on commit 7f4d478

Please sign in to comment.