Skip to content

Commit

Permalink
fix: add support for row_count in cursor. (#675)
Browse files Browse the repository at this point in the history
* fix: add support for row_count

* docs: update rowcount property doc

* fix: updated tests for cursor to check row_count

* refactor: lint fixes

* test: add test for do_batch_update

* refactor: Empty commit
  • Loading branch information
vi3k6i5 committed Feb 4, 2022
1 parent 39ff137 commit d431339
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
25 changes: 19 additions & 6 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -44,6 +44,8 @@

from google.rpc.code_pb2 import ABORTED, OK

_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")

Expand Down Expand Up @@ -80,6 +82,7 @@ class Cursor(object):
def __init__(self, connection):
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
self.lastrowid = None
self.connection = connection
self._is_closed = False
Expand Down Expand Up @@ -134,13 +137,14 @@ def description(self):

@property
def rowcount(self):
"""The number of rows produced by the last `execute()` call.
"""The number of rows updated by the last UPDATE, DELETE request's `execute()` call.
For SELECT requests the rowcount returns -1.
The property is non-operational and always returns -1. Request
resulting rows are streamed by the `fetch*()` methods and
can't be counted before they are all streamed.
:rtype: int
:returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call.
"""
return -1

return self._row_count

@check_not_closed
def callproc(self, procname, args=None):
Expand Down Expand Up @@ -170,7 +174,11 @@ def _do_execute_update(self, transaction, sql, params):
result = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
)
self._itr = iter([result])
self._itr = None
if type(result) == int:
self._row_count = result

return result

def _do_batch_update(self, transaction, statements, many_result_set):
status, res = transaction.batch_update(statements)
Expand All @@ -181,6 +189,8 @@ def _do_batch_update(self, transaction, statements, many_result_set):
elif status.code != OK:
raise OperationalError(status.message)

self._row_count = sum([max(val, 0) for val in res])

def _batch_DDLs(self, sql):
"""
Check that the given operation contains only DDL
Expand Down Expand Up @@ -414,6 +424,9 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
# Read the first element so that the StreamedResultSet can
# return the metadata after a DQL statement.
self._itr = PeekIterator(self._result_set)
# Unfortunately, Spanner doesn't seem to send back
# information about the number of rows available.
self._row_count = _UNSET_COUNT

def _handle_DQL(self, sql, params):
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
Expand Down
57 changes: 48 additions & 9 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs):

return Connection(*args, **kwargs)

def _transaction_mock(self):
def _transaction_mock(self, mock_response=[]):
from google.rpc.code_pb2 import OK

transaction = mock.Mock(committed=False, rolled_back=False)
transaction.batch_update = mock.Mock(return_value=[mock.Mock(code=OK), []])
transaction.batch_update = mock.Mock(
return_value=[mock.Mock(code=OK), mock_response]
)
return transaction

def test_property_connection(self):
Expand All @@ -62,10 +64,12 @@ def test_property_description(self):
self.assertIsInstance(cursor.description[0], ColumnInfo)

def test_property_rowcount(self):
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)

assert cursor.rowcount == -1
self.assertEqual(cursor.rowcount, _UNSET_COUNT)

def test_callproc(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError
Expand Down Expand Up @@ -93,25 +97,58 @@ def test_close(self, mock_client):
cursor.execute("SELECT * FROM database")

def test_do_execute_update(self):
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
cursor._checksum = ResultsChecksum()
transaction = mock.MagicMock()

def run_helper(ret_value):
transaction.execute_update.return_value = ret_value
cursor._do_execute_update(
res = cursor._do_execute_update(
transaction=transaction, sql="SELECT * WHERE true", params={},
)
return cursor.fetchall()
return res

expected = "good"
self.assertEqual(run_helper(expected), [expected])
self.assertEqual(run_helper(expected), expected)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

expected = 1234
self.assertEqual(run_helper(expected), [expected])
self.assertEqual(run_helper(expected), expected)
self.assertEqual(cursor._row_count, expected)

def test_do_batch_update(self):
from google.cloud.spanner_dbapi import connect
from google.cloud.spanner_v1.param_types import INT64
from google.cloud.spanner_v1.types.spanner import Session

sql = "DELETE FROM table WHERE col1 = %s"

connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock(mock_response=[1, 1, 1])
cursor = connection.cursor()

with mock.patch(
"google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session",
return_value=Session(),
):
with mock.patch(
"google.cloud.spanner_v1.session.Session.transaction",
return_value=transaction,
):
cursor.executemany(sql, [(1,), (2,), (3,)])

transaction.batch_update.assert_called_once_with(
[
("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}),
]
)
self.assertEqual(cursor._row_count, 3)

def test_execute_programming_error(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
Expand Down Expand Up @@ -704,6 +741,7 @@ def test_setoutputsize(self):

def test_handle_dql(self):
from google.cloud.spanner_dbapi import utils
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT

connection = self._make_connection(self.INSTANCE, mock.MagicMock())
connection.database.snapshot.return_value.__enter__.return_value = (
Expand All @@ -715,6 +753,7 @@ def test_handle_dql(self):
cursor._handle_DQL("sql", params=None)
self.assertEqual(cursor._result_set, ["0"])
self.assertIsInstance(cursor._itr, utils.PeekIterator)
self.assertEqual(cursor._row_count, _UNSET_COUNT)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
Expand Down

0 comments on commit d431339

Please sign in to comment.