diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index d88dcafb0d..8848233d45 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -523,11 +523,19 @@ def get_param_types(params): def ensure_where_clause(sql): """ Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. - Add a dummy WHERE clause if necessary. + Raise an error, if the given sql doesn't include it. + + :type sql: `str` + :param sql: SQL code to check. + + :raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause. """ if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): return sql - return sql + " WHERE 1=1" + + raise ProgrammingError( + "Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query" + ) def escape_name(name): diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 23ed5010d1..871214a360 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -100,7 +100,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None + transaction=transaction, sql="SELECT * WHERE true", params={}, ) return res diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index a79ad8dc51..6d89a8a46a 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -391,36 +391,26 @@ def test_get_param_types_none(self): @unittest.skipIf(skip_condition, skip_message) def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause - cases = [ - ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - ), - ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", - ), - ( - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ( - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ), - ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), - ] + cases = ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ) + err_cases = ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "DELETE * FROM TABLE", + ) + for sql in cases: + with self.subTest(sql=sql): + ensure_where_clause(sql) - for sql, want in cases: + for sql in err_cases: with self.subTest(sql=sql): - got = ensure_where_clause(sql) - self.assertEqual(got, want) + with self.assertRaises(ProgrammingError): + ensure_where_clause(sql) @unittest.skipIf(skip_condition, skip_message) def test_escape_name(self):