Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dbapi): executemany() hiding all the results except the last #181

Merged
merged 2 commits into from Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -37,6 +37,7 @@
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

_UNSET_COUNT = -1

Expand Down Expand Up @@ -210,8 +211,20 @@ def executemany(self, operation, seq_of_params):
"""
self._raise_if_closed()

classification = parse_utils.classify_stmt(operation)
if classification == parse_utils.STMT_DDL:
raise ProgrammingError(
"Executing DDL statements with executemany() method is not allowed."
)

many_result_set = StreamedManyResultSets()

for params in seq_of_params:
self.execute(operation, params)
many_result_set.add_iter(self._itr)

self._result_set = many_result_set
self._itr = many_result_set

def fetchone(self):
"""Fetch the next row of a query result set, returning a single
Expand Down
40 changes: 39 additions & 1 deletion google/cloud/spanner_dbapi/utils.py
Expand Up @@ -14,6 +14,8 @@

import re

re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")


class PeekIterator:
"""
Expand Down Expand Up @@ -55,7 +57,43 @@ def __iter__(self):
return self


re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)")
class StreamedManyResultSets:
"""Iterator to walk through several `StreamedResultsSet` iterators.
This type of iterator is used by `Cursor.executemany()`
method to iterate through several `StreamedResultsSet`
iterators like they all are merged into single iterator.
"""

def __init__(self):
self._iterators = []
self._index = 0

def add_iter(self, iterator):
"""Add new iterator into this one.
:type iterator: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`
:param iterator: Iterator to merge into this one.
"""
self._iterators.append(iterator)

def __next__(self):
"""Return the next value from the currently streamed iterator.
If the current iterator is streamed to the end,
start to stream the next one.
:rtype: list
:returns: The next result row.
"""
try:
res = next(self._iterators[self._index])
except StopIteration:
self._index += 1
res = self.__next__()
except IndexError:
raise StopIteration

return res

def __iter__(self):
return self


def backtick_unicode(sql):
Expand Down
40 changes: 40 additions & 0 deletions tests/system/test_system_dbapi.py
Expand Up @@ -305,6 +305,46 @@ def test_results_checksum(self):

self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest())

def test_execute_many(self):
# connect to the test database
conn = Connection(Config.INSTANCE, self._db)
cursor = conn.cursor()

cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES (1, 'first-name', 'last-name', 'test.email@example.com'),
(2, 'first-name2', 'last-name2', 'test.email2@example.com')
"""
)
conn.commit()

cursor.executemany(
"""
SELECT * FROM contacts WHERE contact_id = @a1
""",
({"a1": 1}, {"a1": 2}),
)
res = cursor.fetchall()
conn.commit()

self.assertEqual(len(res), 2)
self.assertEqual(res[0][0], 1)
self.assertEqual(res[1][0], 2)

# checking that execute() and executemany()
# results are not mixed together
cursor.execute(
"""
SELECT * FROM contacts WHERE contact_id = 1
""",
)
res = cursor.fetchone()
conn.commit()

self.assertEqual(res[0], 1)
conn.close()


def clear_table(transaction):
"""Clear the test table."""
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -257,6 +257,22 @@ def test_executemany_on_closed_cursor(self):
with self.assertRaises(InterfaceError):
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())

def test_executemany_DLL(self):
from google.cloud.spanner_dbapi import connect, ProgrammingError

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()

with self.assertRaises(ProgrammingError):
cursor.executemany("""DROP DATABASE database_name""", ())

def test_executemany(self):
from google.cloud.spanner_dbapi import connect

Expand All @@ -272,6 +288,9 @@ def test_executemany(self):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()
cursor._result_set = [1, 2, 3]
cursor._itr = iter([1, 2, 3])

with mock.patch(
"google.cloud.spanner_dbapi.cursor.Cursor.execute"
) as execute_mock:
Expand Down