Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for row_count in cursor. #675

Merged
merged 6 commits into from Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 19 additions & 6 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -44,6 +44,8 @@

from google.rpc.code_pb2 import ABORTED, OK

_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")

Expand Down Expand Up @@ -80,6 +82,7 @@ class Cursor(object):
def __init__(self, connection):
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
self.lastrowid = None
self.connection = connection
self._is_closed = False
Expand Down Expand Up @@ -134,13 +137,14 @@ def description(self):

@property
def rowcount(self):
"""The number of rows produced by the last `execute()` call.
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
For SELECT requests the rowcount returns -1.

The property is non-operational and always returns -1. Request
resulting rows are streamed by the `fetch*()` methods and
can't be counted before they are all streamed.
:rtype: int
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
"""
return -1

return self._row_count

@check_not_closed
def callproc(self, procname, args=None):
Expand Down Expand Up @@ -170,7 +174,11 @@ def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
)
self._itr = iter([result])
self._itr = None
if type(result) == int:
self._row_count = result

return result

def _do_batch_update(self, transaction, statements, many_result_set):
status, res = transaction.batch_update(statements)
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -181,6 +189,8 @@ def _do_batch_update(self, transaction, statements, many_result_set):
elif status.code != OK:
raise OperationalError(status.message)

self._row_count = sum([max(val, 0) for val in res])

def _batch_DDLs(self, sql):
"""
Check that the given operation contains only DDL
Expand Down Expand Up @@ -414,6 +424,9 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
self._itr = PeekIterator(self._result_set)
# Unfortunately, Spanner doesn't seem to send back
# information about the number of rows available.
self._row_count = _UNSET_COUNT

def _handle_DQL(self, sql, params):
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
Expand Down
57 changes: 48 additions & 9 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs):

return Connection(*args, **kwargs)

def _transaction_mock(self):
def _transaction_mock(self, mock_response=[]):
from google.rpc.code_pb2 import OK

transaction = mock.Mock(committed=False, rolled_back=False)
transaction.batch_update = mock.Mock(return_value=[mock.Mock(code=OK), []])
transaction.batch_update = mock.Mock(
return_value=[mock.Mock(code=OK), mock_response]
)
return transaction

def test_property_connection(self):
Expand All @@ -62,10 +64,12 @@ def test_property_description(self):
self.assertIsInstance(cursor.description[0], ColumnInfo)

def test_property_rowcount(self):
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)

assert cursor.rowcount == -1
self.assertEqual(cursor.rowcount, _UNSET_COUNT)

def test_callproc(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError
Expand Down Expand Up @@ -93,25 +97,58 @@ def test_close(self, mock_client):
cursor.execute("SELECT * FROM database")

def test_do_execute_update(self):
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
cursor._checksum = ResultsChecksum()
transaction = mock.MagicMock()

def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
cursor._do_execute_update(
res = cursor._do_execute_update(
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return cursor.fetchall()
return res

expected = "good"
self.assertEqual(run_helper(expected), [expected])
self.assertEqual(run_helper(expected), expected)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

expected = 1234
self.assertEqual(run_helper(expected), [expected])
self.assertEqual(run_helper(expected), expected)
self.assertEqual(cursor._row_count, expected)

def test_do_batch_update(self):
from google.cloud.spanner_dbapi import connect
from google.cloud.spanner_v1.param_types import INT64
from google.cloud.spanner_v1.types.spanner import Session

sql = "DELETE FROM table WHERE col1 = %s"

connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock(mock_response=[1, 1, 1])
cursor = connection.cursor()

with mock.patch(
"google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session",
return_value=Session(),
):
with mock.patch(
"google.cloud.spanner_v1.session.Session.transaction",
return_value=transaction,
):
cursor.executemany(sql, [(1,), (2,), (3,)])

transaction.batch_update.assert_called_once_with(
[
("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}),
]
)
self.assertEqual(cursor._row_count, 3)

def test_execute_programming_error(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
Expand Down Expand Up @@ -704,6 +741,7 @@ def test_setoutputsize(self):

def test_handle_dql(self):
from google.cloud.spanner_dbapi import utils
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, mock.MagicMock())
connection.database.snapshot.return_value.__enter__.return_value = (
Expand All @@ -715,6 +753,7 @@ def test_handle_dql(self):
cursor._handle_DQL("sql", params=None)
self.assertEqual(cursor._result_set, ["0"])
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
Expand Down