Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: avoid adding dummy WHERE clause into UPDATE and DELETE queires #516

Merged
merged 12 commits into from Nov 18, 2020
3 changes: 3 additions & 0 deletions google/cloud/spanner_dbapi/__init__.py
Expand Up @@ -71,6 +71,9 @@ def connect(
If none are specified, the client will attempt to ascertain
the credentials from the environment.

:type user_agent: :class:`str`
:param user_agent: (Optional) The user agent to be used with API requests.

:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Cloud Spanner resource.

Expand Down
204 changes: 105 additions & 99 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -8,6 +8,7 @@

import datetime
import decimal
import os
import re
from functools import reduce

Expand Down Expand Up @@ -53,6 +54,105 @@

RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)

SPANNER_RESERVED_KEYWORDS = {
"ALL",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASSERT_ROWS_MODIFIED",
"AT",
"BETWEEN",
"BY",
"CASE",
"CAST",
"COLLATE",
"CONTAINS",
"CREATE",
"CROSS",
"CUBE",
"CURRENT",
"DEFAULT",
"DEFINE",
"DESC",
"DISTINCT",
"DROP",
"ELSE",
"END",
"ENUM",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXISTS",
"EXTRACT",
"FALSE",
"FETCH",
"FOLLOWING",
"FOR",
"FROM",
"FULL",
"GROUP",
"GROUPING",
"GROUPS",
"HASH",
"HAVING",
"IF",
"IGNORE",
"IN",
"INNER",
"INTERSECT",
"INTERVAL",
"INTO",
"IS",
"JOIN",
"LATERAL",
"LEFT",
"LIKE",
"LIMIT",
"LOOKUP",
"MERGE",
"NATURAL",
"NEW",
"NO",
"NOT",
"NULL",
"NULLS",
"OF",
"ON",
"OR",
"ORDER",
"OUTER",
"OVER",
"PARTITION",
"PRECEDING",
"PROTO",
"RANGE",
"RECURSIVE",
"RESPECT",
"RIGHT",
"ROLLUP",
"ROWS",
"SELECT",
"SET",
"SOME",
"STRUCT",
"TABLESAMPLE",
"THEN",
"TO",
"TREAT",
"TRUE",
"UNBOUNDED",
"UNION",
"UNNEST",
"USING",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHIN",
}

IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def classify_stmt(query):
"""Determine SQL query type.
Expand Down Expand Up @@ -402,107 +502,13 @@ def ensure_where_clause(sql):
for token in sqlparse.parse(sql)[0]
):
return sql
return sql + " WHERE 1=1"

if os.environ.get("RUNNING_SPANNER_BACKEND_TESTS") == "1":
return sql + " WHERE 1=1"

SPANNER_RESERVED_KEYWORDS = {
"ALL",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASSERT_ROWS_MODIFIED",
"AT",
"BETWEEN",
"BY",
"CASE",
"CAST",
"COLLATE",
"CONTAINS",
"CREATE",
"CROSS",
"CUBE",
"CURRENT",
"DEFAULT",
"DEFINE",
"DESC",
"DISTINCT",
"DROP",
"ELSE",
"END",
"ENUM",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXISTS",
"EXTRACT",
"FALSE",
"FETCH",
"FOLLOWING",
"FOR",
"FROM",
"FULL",
"GROUP",
"GROUPING",
"GROUPS",
"HASH",
"HAVING",
"IF",
"IGNORE",
"IN",
"INNER",
"INTERSECT",
"INTERVAL",
"INTO",
"IS",
"JOIN",
"LATERAL",
"LEFT",
"LIKE",
"LIMIT",
"LOOKUP",
"MERGE",
"NATURAL",
"NEW",
"NO",
"NOT",
"NULL",
"NULLS",
"OF",
"ON",
"OR",
"ORDER",
"OUTER",
"OVER",
"PARTITION",
"PRECEDING",
"PROTO",
"RANGE",
"RECURSIVE",
"RESPECT",
"RIGHT",
"ROLLUP",
"ROWS",
"SELECT",
"SET",
"SOME",
"STRUCT",
"TABLESAMPLE",
"THEN",
"TO",
"TREAT",
"TRUE",
"UNBOUNDED",
"UNION",
"UNNEST",
"USING",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHIN",
}
raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if I misunderstood something here.

I believe this clause is added by default because Django will construct DELETE and UPDATE queries that don't have the WHERE clause as this is not a requirement in other databases. It ensure that customers can smoothly transition from their existing applications without having to write raw queries with the WHERE clause.

The clause wasn't added just to get the tests to pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skuruppu, hiding this useful restriction doesn't look right. But I see the point in smooth transition as well. Do you mean that only for Django? Or we should consider that any user can have queries without WHERE, so this dummy should be kept for the whole connection API? (we have spanner_dbapi directory, which seems to be intended to be default connection API - it's from where I'm proposing to erase the dummy, and django_spanner directory, where Django-specific parts are held).

The main thing that makes me nervous is that queries can be big and branched (for example, UNIONS are little bit hard to interpret, as WHERE clause can be relative to the last of the union parts, but also can be relative to the result of the UNION on the whole, and they both will be at the about same place). I'm not sure if there is a way to check 100% of variants with several regexes, and insert this dummy correctly. So, I think it'll be exploding things from time to time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it may be good to notify the user about this restriction from the connection API. Is it possible to instead modify the Django implementation to add the WHERE clause in the generated query? I think this is where it matters the most in order to support users migrating from other databases.



def escape_name(name):
Expand Down
41 changes: 15 additions & 26 deletions tests/spanner_dbapi/test_parse_utils.py
Expand Up @@ -445,34 +445,23 @@ def test_get_param_types(self):
self.assertEqual(got_param_types, want_param_types)

def test_ensure_where_clause(self):
cases = [
(
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
),
(
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1",
),
(
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
(
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
),
("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"),
]
cases = (
"UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1",
"UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2",
"UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)",
)
err_cases = (
"UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5",
"DELETE * FROM TABLE",
)
for sql in cases:
with self.subTest(sql=sql):
ensure_where_clause(sql)

for sql, want in cases:
for sql in err_cases:
with self.subTest(sql=sql):
got = ensure_where_clause(sql)
self.assertEqual(got, want)
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)

def test_escape_name(self):
cases = [
Expand Down