From 020dc17c823dfb65bfaacace14d2c9f491c97e11 Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Mon, 7 Dec 2020 09:59:35 +0300 Subject: [PATCH] fix(dbapi): executemany() hiding all the results except the last (#181) --- google/cloud/spanner_dbapi/cursor.py | 13 ++++++++ google/cloud/spanner_dbapi/utils.py | 40 ++++++++++++++++++++++++- tests/system/test_system_dbapi.py | 40 +++++++++++++++++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 19 ++++++++++++ 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index e2667f0599..363c2c653c 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -37,6 +37,7 @@ from google.cloud.spanner_dbapi.parse_utils import get_param_types from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_dbapi.utils import PeekIterator +from google.cloud.spanner_dbapi.utils import StreamedManyResultSets _UNSET_COUNT = -1 @@ -210,8 +211,20 @@ def executemany(self, operation, seq_of_params): """ self._raise_if_closed() + classification = parse_utils.classify_stmt(operation) + if classification == parse_utils.STMT_DDL: + raise ProgrammingError( + "Executing DDL statements with executemany() method is not allowed." + ) + + many_result_set = StreamedManyResultSets() + for params in seq_of_params: self.execute(operation, params) + many_result_set.add_iter(self._itr) + + self._result_set = many_result_set + self._itr = many_result_set def fetchone(self): """Fetch the next row of a query result set, returning a single diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py index b0ad3922a5..7cafaaa609 100644 --- a/google/cloud/spanner_dbapi/utils.py +++ b/google/cloud/spanner_dbapi/utils.py @@ -14,6 +14,8 @@ import re +re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") + class PeekIterator: """ @@ -55,7 +57,43 @@ def __iter__(self): return self -re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") +class StreamedManyResultSets: + """Iterator to walk through several `StreamedResultsSet` iterators. + This type of iterator is used by `Cursor.executemany()` + method to iterate through several `StreamedResultsSet` + iterators like they all are merged into single iterator. + """ + + def __init__(self): + self._iterators = [] + self._index = 0 + + def add_iter(self, iterator): + """Add new iterator into this one. + :type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet` + :param iterator: Iterator to merge into this one. + """ + self._iterators.append(iterator) + + def __next__(self): + """Return the next value from the currently streamed iterator. + If the current iterator is streamed to the end, + start to stream the next one. + :rtype: list + :returns: The next result row. + """ + try: + res = next(self._iterators[self._index]) + except StopIteration: + self._index += 1 + res = self.__next__() + except IndexError: + raise StopIteration + + return res + + def __iter__(self): + return self def backtick_unicode(sql): diff --git a/tests/system/test_system_dbapi.py b/tests/system/test_system_dbapi.py index be8e9f2a26..5e331cad8f 100644 --- a/tests/system/test_system_dbapi.py +++ b/tests/system/test_system_dbapi.py @@ -305,6 +305,46 @@ def test_results_checksum(self): self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest()) + def test_execute_many(self): + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com'), + (2, 'first-name2', 'last-name2', 'test.email2@example.com') + """ + ) + conn.commit() + + cursor.executemany( + """ +SELECT * FROM contacts WHERE contact_id = @a1 +""", + ({"a1": 1}, {"a1": 2}), + ) + res = cursor.fetchall() + conn.commit() + + self.assertEqual(len(res), 2) + self.assertEqual(res[0][0], 1) + self.assertEqual(res[1][0], 2) + + # checking that execute() and executemany() + # results are not mixed together + cursor.execute( + """ +SELECT * FROM contacts WHERE contact_id = 1 +""", + ) + res = cursor.fetchone() + conn.commit() + + self.assertEqual(res[0], 1) + conn.close() + def clear_table(transaction): """Clear the test table.""" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 43fc077abe..81b290c4f1 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self): with self.assertRaises(InterfaceError): cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) + def test_executemany_DLL(self): + from google.cloud.spanner_dbapi import connect, ProgrammingError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + + with self.assertRaises(ProgrammingError): + cursor.executemany("""DROP DATABASE database_name""", ()) + def test_executemany(self): from google.cloud.spanner_dbapi import connect @@ -272,6 +288,9 @@ def test_executemany(self): connection = connect("test-instance", "test-database") cursor = connection.cursor() + cursor._result_set = [1, 2, 3] + cursor._itr = iter([1, 2, 3]) + with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.execute" ) as execute_mock: