diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 84b35292f0..7c8c5bdbc5 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -44,6 +44,8 @@ from google.rpc.code_pb2 import ABORTED, OK +_UNSET_COUNT = -1 + ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert") @@ -80,6 +82,7 @@ class Cursor(object): def __init__(self, connection): self._itr = None self._result_set = None + self._row_count = _UNSET_COUNT self.lastrowid = None self.connection = connection self._is_closed = False @@ -134,13 +137,14 @@ def description(self): @property def rowcount(self): - """The number of rows produced by the last `execute()` call. + """The number of rows updated by the last UPDATE, DELETE request's `execute()` call. + For SELECT requests the rowcount returns -1. - The property is non-operational and always returns -1. Request - resulting rows are streamed by the `fetch*()` methods and - can't be counted before they are all streamed. + :rtype: int + :returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call. """ - return -1 + + return self._row_count @check_not_closed def callproc(self, procname, args=None): @@ -170,7 +174,11 @@ def _do_execute_update(self, transaction, sql, params): result = transaction.execute_update( sql, params=params, param_types=get_param_types(params) ) - self._itr = iter([result]) + self._itr = None + if type(result) == int: + self._row_count = result + + return result def _do_batch_update(self, transaction, statements, many_result_set): status, res = transaction.batch_update(statements) @@ -181,6 +189,8 @@ def _do_batch_update(self, transaction, statements, many_result_set): elif status.code != OK: raise OperationalError(status.message) + self._row_count = sum([max(val, 0) for val in res]) + def _batch_DDLs(self, sql): """ Check that the given operation contains only DDL @@ -414,6 +424,9 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. self._itr = PeekIterator(self._result_set) + # Unfortunately, Spanner doesn't seem to send back + # information about the number of rows available. + self._row_count = _UNSET_COUNT def _handle_DQL(self, sql, params): sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index f7607b79bd..51732bc1b0 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -37,11 +37,13 @@ def _make_connection(self, *args, **kwargs): return Connection(*args, **kwargs) - def _transaction_mock(self): + def _transaction_mock(self, mock_response=[]): from google.rpc.code_pb2 import OK transaction = mock.Mock(committed=False, rolled_back=False) - transaction.batch_update = mock.Mock(return_value=[mock.Mock(code=OK), []]) + transaction.batch_update = mock.Mock( + return_value=[mock.Mock(code=OK), mock_response] + ) return transaction def test_property_connection(self): @@ -62,10 +64,12 @@ def test_property_description(self): self.assertIsInstance(cursor.description[0], ColumnInfo) def test_property_rowcount(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) - assert cursor.rowcount == -1 + self.assertEqual(cursor.rowcount, _UNSET_COUNT) def test_callproc(self): from google.cloud.spanner_dbapi.exceptions import InterfaceError @@ -93,25 +97,58 @@ def test_close(self, mock_client): cursor.execute("SELECT * FROM database") def test_do_execute_update(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() transaction = mock.MagicMock() def run_helper(ret_value): transaction.execute_update.return_value = ret_value - cursor._do_execute_update( + res = cursor._do_execute_update( transaction=transaction, sql="SELECT * WHERE true", params={}, ) - return cursor.fetchall() + return res expected = "good" - self.assertEqual(run_helper(expected), [expected]) + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, _UNSET_COUNT) expected = 1234 - self.assertEqual(run_helper(expected), [expected]) + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, expected) + + def test_do_batch_update(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.param_types import INT64 + from google.cloud.spanner_v1.types.spanner import Session + + sql = "DELETE FROM table WHERE col1 = %s" + + connection = connect("test-instance", "test-database") + + connection.autocommit = True + transaction = self._transaction_mock(mock_response=[1, 1, 1]) + cursor = connection.cursor() + + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + cursor.executemany(sql, [(1,), (2,), (3,)]) + + transaction.batch_update.assert_called_once_with( + [ + ("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}), + ("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}), + ("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}), + ] + ) + self.assertEqual(cursor._row_count, 3) def test_execute_programming_error(self): from google.cloud.spanner_dbapi.exceptions import ProgrammingError @@ -704,6 +741,7 @@ def test_setoutputsize(self): def test_handle_dql(self): from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.database.snapshot.return_value.__enter__.return_value = ( @@ -715,6 +753,7 @@ def test_handle_dql(self): cursor._handle_DQL("sql", params=None) self.assertEqual(cursor._result_set, ["0"]) self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) def test_context(self): connection = self._make_connection(self.INSTANCE, self.DATABASE)