From b8b24e17a74c1296ca5de75798a1a32597691b53 Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Thu, 24 Jun 2021 09:22:54 +1000 Subject: [PATCH] fix: classify batched DDL statements (#360) * fix: classify batched DDL statements * docs: add comment * style: fix lint Co-authored-by: larkee --- google/cloud/spanner_dbapi/cursor.py | 7 ++++++- test.py | 11 +++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 14 +++++++++++++- 3 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 test.py diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 689ba8cf66..c5de13b370 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -176,11 +176,16 @@ def execute(self, sql, args=None): try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: + ddl_statements = [] for ddl in sqlparse.split(sql): if ddl: if ddl[-1] == ";": ddl = ddl[:-1] - self.connection._ddl_statements.append(ddl) + if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL: + raise ValueError("Only DDL statements may be batched.") + ddl_statements.append(ddl) + # Only queue DDL statements if they are all correctly classified. + self.connection._ddl_statements.extend(ddl_statements) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return diff --git a/test.py b/test.py new file mode 100644 index 0000000000..6032524b04 --- /dev/null +++ b/test.py @@ -0,0 +1,11 @@ +from google.cloud import spanner +from gooogle.cloud.spanner_v1 import RequestOptions + +client = spanner.Client() +instance = client.instance('test-instance') +database = instance.database('test-db') + +with database.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * in all_types LIMIT %s", ) + +database.drop() \ No newline at end of file diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 789ca06695..5b1cf12138 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -171,13 +171,25 @@ def test_execute_statement(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=[parse_utils.STMT_DDL, parse_utils.STMT_INSERT], + ) as mock_classify_stmt: + sql = "sql" + with self.assertRaises(ValueError): + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_with(sql) + self.assertEqual(mock_classify_stmt.call_count, 2) + self.assertEqual(cursor.connection._ddl_statements, []) + with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value=parse_utils.STMT_DDL, ) as mock_classify_stmt: sql = "sql" cursor.execute(sql=sql) - mock_classify_stmt.assert_called_once_with(sql) + mock_classify_stmt.assert_called_with(sql) + self.assertEqual(mock_classify_stmt.call_count, 2) self.assertEqual(cursor.connection._ddl_statements, [sql]) with mock.patch(