Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix(dbapi): executemany() hiding all the results except the last (#181)
  • Loading branch information
Ilya Gurov committed Dec 7, 2020
1 parent cbe6ec1 commit 020dc17
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
13 changes: 13 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -37,6 +37,7 @@
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

_UNSET_COUNT = -1

Expand Down Expand Up @@ -210,8 +211,20 @@ def executemany(self, operation, seq_of_params):
"""
self._raise_if_closed()

classification = parse_utils.classify_stmt(operation)
if classification == parse_utils.STMT_DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

many_result_set = StreamedManyResultSets()

for params in seq_of_params:
self.execute(operation, params)
many_result_set.add_iter(self._itr)

self._result_set = many_result_set
self._itr = many_result_set

def fetchone(self):
"""Fetch the next row of a query result set, returning a single
Expand Down
40 changes: 39 additions & 1 deletion google/cloud/spanner_dbapi/utils.py
Expand Up @@ -14,6 +14,8 @@

import re

re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")


class PeekIterator:
"""
Expand Down Expand Up @@ -55,7 +57,43 @@ def __iter__(self):
return self


re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")
class StreamedManyResultSets:
"""Iterator to walk through several `StreamedResultsSet` iterators.
This type of iterator is used by `Cursor.executemany()`
method to iterate through several `StreamedResultsSet`
iterators like they all are merged into single iterator.
"""

def __init__(self):
self._iterators = []
self._index = 0

def add_iter(self, iterator):
"""Add new iterator into this one.
:type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`
:param iterator: Iterator to merge into this one.
"""
self._iterators.append(iterator)

def __next__(self):
"""Return the next value from the currently streamed iterator.
If the current iterator is streamed to the end,
start to stream the next one.
:rtype: list
:returns: The next result row.
"""
try:
res = next(self._iterators[self._index])
except StopIteration:
self._index += 1
res = self.__next__()
except IndexError:
raise StopIteration

return res

def __iter__(self):
return self


def backtick_unicode(sql):
Expand Down
40 changes: 40 additions & 0 deletions tests/system/test_system_dbapi.py
Expand Up @@ -305,6 +305,46 @@ def test_results_checksum(self):

self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest())

def test_execute_many(self):
# connect to the test database
conn = Connection(Config.INSTANCE, self._db)
cursor = conn.cursor()

cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com'),
(2, 'first-name2', 'last-name2', 'test.email2@example.com')
"""
)
conn.commit()

cursor.executemany(
"""
SELECT * FROM contacts WHERE contact_id = @a1
""",
({"a1": 1}, {"a1": 2}),
)
res = cursor.fetchall()
conn.commit()

self.assertEqual(len(res), 2)
self.assertEqual(res[0][0], 1)
self.assertEqual(res[1][0], 2)

# checking that execute() and executemany()
# results are not mixed together
cursor.execute(
"""
SELECT * FROM contacts WHERE contact_id = 1
""",
)
res = cursor.fetchone()
conn.commit()

self.assertEqual(res[0], 1)
conn.close()


def clear_table(transaction):
"""Clear the test table."""
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self):
with self.assertRaises(InterfaceError):
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())

def test_executemany_DLL(self):
from google.cloud.spanner_dbapi import connect, ProgrammingError

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()

with self.assertRaises(ProgrammingError):
cursor.executemany("""DROP DATABASE database_name""", ())

def test_executemany(self):
from google.cloud.spanner_dbapi import connect

Expand All @@ -272,6 +288,9 @@ def test_executemany(self):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()
cursor._result_set = [1, 2, 3]
cursor._itr = iter([1, 2, 3])

with mock.patch(
"google.cloud.spanner_dbapi.cursor.Cursor.execute"
) as execute_mock:
Expand Down

0 comments on commit 020dc17

Please sign in to comment.