Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove adding a dummy WHERE clause into UPDATE and DELETE statements #169

Merged
merged 2 commits into from Nov 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -524,10 +524,16 @@ def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Add a dummy WHERE clause if necessary.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

:type sql: `str`
:param sql: SQL code to check.
"""
if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]):
return sql
return sql + " WHERE 1=1"

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)


def escape_name(name):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -100,7 +100,7 @@ def test_do_execute_update(self):
def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
res = cursor._do_execute_update(
transaction=transaction, sql="sql", params=None
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return res

Expand Down
42 changes: 16 additions & 26 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -391,36 +391,26 @@ def test_get_param_types_none(self):

@unittest.skipIf(skip_condition, skip_message)
def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

@unittest.skipIf(skip_condition, skip_message)
def test_escape_name(self):
Expand Down