diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 689ba8cf66..c5de13b370 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -176,11 +176,16 @@ def execute(self, sql, args=None): try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: + ddl_statements = [] for ddl in sqlparse.split(sql): if ddl: if ddl[-1] == ";": ddl = ddl[:-1] - self.connection._ddl_statements.append(ddl) + if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL: + raise ValueError("Only DDL statements may be batched.") + ddl_statements.append(ddl) + # Only queue DDL statements if they are all correctly classified. + self.connection._ddl_statements.extend(ddl_statements) if self.connection.autocommit: self.connection.run_prior_DDL_statements() return diff --git a/test.py b/test.py new file mode 100644 index 0000000000..6032524b04 --- /dev/null +++ b/test.py @@ -0,0 +1,11 @@ +from google.cloud import spanner +from gooogle.cloud.spanner_v1 import RequestOptions + +client = spanner.Client() +instance = client.instance('test-instance') +database = instance.database('test-db') + +with database.snapshot() as snapshot: + results = snapshot.execute_sql("SELECT * in all_types LIMIT %s", ) + +database.drop() \ No newline at end of file diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 789ca06695..5b1cf12138 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -171,13 +171,25 @@ def test_execute_statement(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=[parse_utils.STMT_DDL, parse_utils.STMT_INSERT], + ) as mock_classify_stmt: + sql = "sql" + with self.assertRaises(ValueError): + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_with(sql) + self.assertEqual(mock_classify_stmt.call_count, 2) + self.assertEqual(cursor.connection._ddl_statements, []) + with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value=parse_utils.STMT_DDL, ) as mock_classify_stmt: sql = "sql" cursor.execute(sql=sql) - mock_classify_stmt.assert_called_once_with(sql) + mock_classify_stmt.assert_called_with(sql) + self.assertEqual(mock_classify_stmt.call_count, 2) self.assertEqual(cursor.connection._ddl_statements, [sql]) with mock.patch(