diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py new file mode 100644 index 0000000000..3cae7cfb62 --- /dev/null +++ b/google/cloud/spanner_dbapi/checksum.py @@ -0,0 +1,72 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""API to calculate checksums of SQL statements results.""" + +import hashlib +import pickle + +from google.cloud.spanner_dbapi.exceptions import RetryAborted + + +class ResultsChecksum: + """Cumulative checksum. + + Used to calculate a total checksum of all the results + returned by operations executed within transaction. + Includes methods for checksums comparison. + These checksums are used while retrying an aborted + transaction to check if the results of a retried transaction + are equal to the results of the original transaction. + """ + + def __init__(self): + self.checksum = hashlib.sha256() + self.count = 0 # counter of consumed results + + def __len__(self): + """Return the number of consumed results. + + :rtype: :class:`int` + :returns: The number of results. + """ + return self.count + + def __eq__(self, other): + """Check if checksums are equal. + + :type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param other: Another checksum to compare with this one. + """ + return self.checksum.digest() == other.checksum.digest() + + def consume_result(self, result): + """Add the given result into the checksum. + + :type result: Union[int, list] + :param result: Streamed row or row count from an UPDATE operation. + """ + self.checksum.update(pickle.dumps(result)) + self.count += 1 + + +def _compare_checksums(original, retried): + """Compare the given checksums. + + Raise an error if the given checksums are not equal. + + :type original: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param original: results checksum of the original transaction. + + :type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param retried: results checksum of the retried transaction. + + :raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal. + """ + if retried != original: + raise RetryAborted( + "The transaction was aborted and could not be retried due to a concurrent modification." + ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index beb05a3173..5c1be8f724 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -6,11 +6,16 @@ """DB-API Connection for the Google Cloud Spanner.""" +import time import warnings +from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_dbapi.checksum import _compare_checksums +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT @@ -18,6 +23,7 @@ AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" +MAX_INTERNAL_RETRIES = 50 class Connection: @@ -40,6 +46,9 @@ def __init__(self, instance, database): self._transaction = None self._session = None + # SQL statements, which were executed + # within the current transaction + self._statements = [] self.is_closed = False self._autocommit = False @@ -110,6 +119,60 @@ def _release_session(self): self.database._pool.put(self._session) self._session = None + def retry_transaction(self): + """Retry the aborted transaction. + + All the statements executed in the original transaction + will be re-executed in new one. Results checksums of the + original statements and the retried ones will be compared. + + :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` + If results checksum of the retried statement is + not equal to the checksum of the original one. + """ + attempt = 0 + while True: + self._transaction = None + attempt += 1 + if attempt > MAX_INTERNAL_RETRIES: + raise + + try: + self._rerun_previous_statements() + break + except Aborted as exc: + delay = _get_retry_delay(exc.errors[0], attempt) + if delay: + time.sleep(delay) + + def _rerun_previous_statements(self): + """ + Helper to run all the remembered statements + from the last transaction. + """ + for statement in self._statements: + res_iter, retried_checksum = self.run_statement( + statement, retried=True + ) + # executing all the completed statements + if statement != self._statements[-1]: + for res in res_iter: + retried_checksum.consume_result(res) + + _compare_checksums(statement.checksum, retried_checksum) + # executing the failed statement + else: + # streaming up to the failed result or + # to the end of the streaming iterator + while len(retried_checksum) < len(statement.checksum): + try: + res = next(iter(res_iter)) + retried_checksum.consume_result(res) + except StopIteration: + break + + _compare_checksums(statement.checksum, retried_checksum) + def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -167,8 +230,13 @@ def commit(self): if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: - self._transaction.commit() - self._release_session() + try: + self._transaction.commit() + self._release_session() + self._statements = [] + except Aborted: + self.retry_transaction() + self.commit() def rollback(self): """Rolls back any pending transaction. @@ -181,6 +249,7 @@ def rollback(self): elif self._transaction: self._transaction.rollback() self._release_session() + self._statements = [] def cursor(self): """Factory to create a DB-API Cursor.""" @@ -197,6 +266,34 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() + def run_statement(self, statement, retried=False): + """Run single SQL statement in begun transaction. + + This method is never used in autocommit mode. In + !autocommit mode however it remembers every executed + SQL statement with its parameters. + + :type statement: :class:`dict` + :param statement: SQL statement to execute. + + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, + :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :returns: Streamed result set of the statement and a + checksum of this statement results. + """ + transaction = self.transaction_checkout() + if not retried: + self._statements.append(statement) + + return ( + transaction.execute_sql( + statement.sql, + statement.params, + param_types=statement.param_types, + ), + ResultsChecksum() if retried else statement.checksum, + ) + def __enter__(self): return self @@ -250,11 +347,11 @@ def connect( """ client_info = ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION ) client = spanner.Client( - project=project, credentials=credentials, client_info=client_info, + project=project, credentials=credentials, client_info=client_info ) instance = client.instance(instance_id) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 6997752a42..20d241f2d7 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -6,6 +6,7 @@ """Database cursor for Google Cloud Spanner DB-API.""" +from google.api_core.exceptions import Aborted from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition from google.api_core.exceptions import InternalServerError @@ -14,7 +15,7 @@ from collections import namedtuple from google.cloud import spanner_v1 as spanner - +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError @@ -26,11 +27,13 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_dbapi.utils import PeekIterator _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +Statement = namedtuple("Statement", "sql, params, param_types, checksum") class Cursor(object): @@ -46,6 +49,8 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False + # the currently running SQL statement results checksum + self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -158,15 +163,15 @@ def execute(self, sql, args=None): self.connection.run_prior_DDL_statements() if not self.connection.autocommit: - transaction = self.connection.transaction_checkout() - - sql, params = parse_utils.sql_pyformat_args_to_spanner( - sql, args - ) + sql, params = sql_pyformat_args_to_spanner(sql, args) - self._result_set = transaction.execute_sql( - sql, params, param_types=get_param_types(params) + statement = Statement( + sql, params, get_param_types(params), ResultsChecksum(), ) + ( + self._result_set, + self._checksum, + ) = self.connection.run_statement(statement) self._itr = PeekIterator(self._result_set) return @@ -207,9 +212,31 @@ def fetchone(self): self._raise_if_closed() try: - return next(self) + res = next(self) + self._checksum.consume_result(res) + return res except StopIteration: - return None + return + except Aborted: + self.connection.retry_transaction() + return self.fetchone() + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + res = [] + try: + for row in self: + self._checksum.consume_result(row) + res.append(row) + except Aborted: + self._connection.retry_transaction() + return self.fetchall() + + return res def fetchmany(self, size=None): """Fetch the next set of rows of a query result, returning a sequence @@ -230,20 +257,17 @@ def fetchmany(self, size=None): items = [] for i in range(size): try: - items.append(tuple(self.__next__())) + res = next(self) + self._checksum.consume_result(res) + items.append(res) except StopIteration: break + except Aborted: + self._connection.retry_transaction() + return self.fetchmany(size) return items - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as - a sequence of sequences. - """ - self._raise_if_closed() - - return list(self.__iter__()) - def nextset(self): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py index b21be2c949..2b021f6b98 100644 --- a/google/cloud/spanner_dbapi/exceptions.py +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -92,3 +92,13 @@ class NotSupportedError(DatabaseError): """ pass + + +class RetryAborted(OperationalError): + """ + Error for case of no aborted transaction retry + is available, because of underlying data being + changed during a retry. + """ + + pass diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 24260de12e..0ae01225ef 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -11,7 +11,9 @@ # import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions +from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi import Connection, InterfaceError +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance @@ -77,3 +79,124 @@ def test_instance_property(self): with self.assertRaises(AttributeError): connection.instance = None + + def test_run_statement(self): + """Check that Connection remembers executed statements.""" + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = { + "sql": sql, + "params": params, + "param_types": param_types, + "checksum": ResultsChecksum(), + } + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement) + + self.assertEqual(connection._statements[0]["sql"], sql) + self.assertEqual(connection._statements[0]["params"], params) + self.assertEqual(connection._statements[0]["param_types"], param_types) + self.assertIsInstance( + connection._statements[0]["checksum"], ResultsChecksum + ) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(Aborted): + connection.retry_transaction() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..4baa409070 --- /dev/null +++ b/tests/spanner_dbapi/test_cursor.py @@ -0,0 +1,262 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest +from unittest import mock + +from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi import connect, InterfaceError +from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.cursor import ColumnInfo + + +class TestCursor(unittest.TestCase): + def test_close(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_connection_closed(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + connection.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_executemany_on_closed_cursor(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + with self.assertRaises(InterfaceError): + cursor.executemany( + """SELECT * FROM table1 WHERE "col1" = @a1""", () + ) + + def test_executemany(self): + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" + params_seq = ((1,), (2,)) + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.execute" + ) as execute_mock: + cursor.executemany(operation, params_seq) + + execute_mock.assert_has_calls( + (mock.call(operation, (1,)), mock.call(operation, (2,))) + ) + + def test_fetchone_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_called_once() + + def test_fetchone_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=Aborted("Aborted"), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(Aborted): + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + +class TestColumns(unittest.TestCase): + def test_ctor(self): + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + ( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ), + ) + + def test___get_item__(self): + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + ) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f3ee345e15..158092b31c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -4,8 +4,10 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import unittest +import hashlib +import pickle import os +import unittest from google.api_core import exceptions @@ -289,6 +291,34 @@ def test_rollback_on_connection_closing(self): cursor.close() conn.close() + def test_results_checksum(self): + """Test that results checksum is calculated properly.""" + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES + (1, 'first-name', 'last-name', 'test.email@domen.ru'), + (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') + """ + ) + self.assertEqual(len(conn._statements), 1) + conn.commit() + + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(len(conn._statements), 1) + conn.commit() + + checksum = hashlib.sha256() + checksum.update(pickle.dumps(got_rows[0])) + checksum.update(pickle.dumps(got_rows[1])) + + self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest()) + def clear_table(transaction): """Clear the test table.""" diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py new file mode 100644 index 0000000000..3e7780bd6e --- /dev/null +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -0,0 +1,63 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + + +class Test_compare_checksums(unittest.TestCase): + def test_equal(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + + self.assertIsNone(_compare_checksums(original, retried)) + + def test_less_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) + + def test_more_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + retried.consume_result(2) + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) + + def test_mismatch(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(2) + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py similarity index 76% rename from tests/spanner_dbapi/test_connect.py rename to tests/unit/spanner_dbapi/test_connect.py index fb4d89c373..5d545d7d90 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -10,9 +10,6 @@ from unittest import mock import google.auth.credentials -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi import connect, Connection -from google.cloud.spanner_v1.pool import FixedSizePool def _make_credentials(): @@ -26,37 +23,31 @@ class _CredentialsWithScopes( class Test_connect(unittest.TestCase): def test_connect(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() - CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - - with mock.patch( - "google.cloud.spanner_dbapi.spanner_v1.Client" - ) as client_mock: - with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, - ) as client_info_mock: - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - user_agent=USER_AGENT, - ) + with mock.patch("google.cloud.spanner_v1.Client") as client_mock: + connection = connect( + "test-instance", + "test-database", + PROJECT, + CREDENTIALS, + user_agent=USER_AGENT, + ) - self.assertIsInstance(connection, Connection) - client_info_mock.assert_called_once_with(USER_AGENT) + self.assertIsInstance(connection, Connection) client_mock.assert_called_once_with( - project=PROJECT, - credentials=CREDENTIALS, - client_info=CLIENT_INFO, + project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY ) def test_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=False, @@ -68,6 +59,8 @@ def test_instance_not_found(self): exists_mock.assert_called_once_with() def test_database_not_found(self): + from google.cloud.spanner_dbapi import connect + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -83,6 +76,9 @@ def test_database_not_found(self): exists_mock.assert_called_once_with() def test_connect_instance_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + INSTANCE = "test-instance" with mock.patch( @@ -95,6 +91,9 @@ def test_connect_instance_id(self): self.assertIsInstance(connection, Connection) def test_connect_database_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + DATABASE = "test-database" with mock.patch( @@ -111,6 +110,8 @@ def test_connect_database_id(self): self.assertIsInstance(connection, Connection) def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", @@ -121,6 +122,9 @@ def test_default_sessions_pool(self): self.assertIsNotNone(connection.database._pool) def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + database_id = "test-database" pool = FixedSizePool() diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 99aa0aa47b..79415bca55 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -85,7 +85,7 @@ def test__session_checkout(self): from google.cloud.spanner_dbapi import Connection with mock.patch( - "google.cloud.spanner_v1.database.Database", + "google.cloud.spanner_v1.database.Database" ) as mock_database: mock_database._pool = mock.MagicMock() mock_database._pool.get = mock.MagicMock( @@ -105,7 +105,7 @@ def test__release_session(self): from google.cloud.spanner_dbapi import Connection with mock.patch( - "google.cloud.spanner_v1.database.Database", + "google.cloud.spanner_v1.database.Database" ) as mock_database: mock_database._pool = mock.MagicMock() mock_database._pool.put = mock.MagicMock() @@ -225,7 +225,7 @@ def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError with mock.patch( - "google.cloud.spanner_v1.database.Database", autospec=True, + "google.cloud.spanner_v1.database.Database", autospec=True ) as mock_database: connection = Connection(self.INSTANCE, mock_database) @@ -335,3 +335,241 @@ def test_global_pool(self): ) as pool_clear_mock: connection.close() assert not pool_clear_mock.called + + def test_run_statement_remember_statements(self): + """Check that Connection remembers executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement) + + self.assertEqual(connection._statements[0].sql, sql) + self.assertEqual(connection._statements[0].params, params) + self.assertEqual(connection._statements[0].param_types, param_types) + self.assertIsInstance( + connection._statements[0].checksum, ResultsChecksum + ) + + def test_run_statement_dont_remember_retried_statements(self): + """Check that Connection doesn't remember re-executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement, retried=True) + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(RetryAborted): + connection.retry_transaction() + + def test_commit_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + connection._transaction = mock.Mock() + + with mock.patch.object( + connection._transaction, + "commit", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + connection.commit() + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_drop_transaction(self): + """ + Check that before retrying an aborted transaction + connection drops the original aborted transaction. + """ + connection = self._make_connection() + transaction_mock = mock.Mock() + connection._transaction = transaction_mock + + # as we didn't set any statements, the method + # will only drop the transaction object + connection.retry_transaction() + self.assertIsNone(connection._transaction) + + def test_retry_aborted_retry(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + + with mock.patch.object( + connection, + "run_statement", + side_effect=( + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ), + ) as retry_mock: + + connection.retry_transaction() + + retry_mock.assert_has_calls( + ( + mock.call(statement, retried=True), + mock.call(statement, retried=True), + ) + ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..f7dd712ddd 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -94,7 +94,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None, + transaction=transaction, sql="sql", params=None ) return res @@ -286,17 +286,25 @@ def test_executemany(self): ) def test_fetchone(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [1, 2, 3] cursor._itr = iter(lst) + for i in range(len(lst)): self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) def test_fetchmany(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) @@ -306,8 +314,11 @@ def test_fetchmany(self): self.assertEqual(result, lst[1:]) def test_fetchall(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) @@ -442,9 +453,7 @@ def test_get_table_column_schema(self): spanner_type = "spanner_type" rows = [(column_name, is_nullable, spanner_type)] expected = { - column_name: ColumnDetails( - null_ok=True, spanner_type=spanner_type, - ) + column_name: ColumnDetails(null_ok=True, spanner_type=spanner_type) } with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", @@ -458,3 +467,114 @@ def test_get_table_column_schema(self): param_types={"table_name": param_types.STRING}, ) self.assertEqual(result, expected) + + def test_fetchone_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_called_with() + + def test_fetchone_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(RetryAborted): + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True)