diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py new file mode 100644 index 0000000000..7a2a1d75b9 --- /dev/null +++ b/google/cloud/spanner_dbapi/checksum.py @@ -0,0 +1,80 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""API to calculate checksums of SQL statements results.""" + +import hashlib +import pickle + +from google.cloud.spanner_dbapi.exceptions import RetryAborted + + +class ResultsChecksum: + """Cumulative checksum. + + Used to calculate a total checksum of all the results + returned by operations executed within transaction. + Includes methods for checksums comparison. + These checksums are used while retrying an aborted + transaction to check if the results of a retried transaction + are equal to the results of the original transaction. + """ + + def __init__(self): + self.checksum = hashlib.sha256() + self.count = 0 # counter of consumed results + + def __len__(self): + """Return the number of consumed results. + + :rtype: :class:`int` + :returns: The number of results. + """ + return self.count + + def __eq__(self, other): + """Check if checksums are equal. + + :type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param other: Another checksum to compare with this one. + """ + return self.checksum.digest() == other.checksum.digest() + + def consume_result(self, result): + """Add the given result into the checksum. + + :type result: Union[int, list] + :param result: Streamed row or row count from an UPDATE operation. + """ + self.checksum.update(pickle.dumps(result)) + self.count += 1 + + +def _compare_checksums(original, retried): + """Compare the given checksums. + + Raise an error if the given checksums are not equal. + + :type original: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param original: results checksum of the original transaction. + + :type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param retried: results checksum of the retried transaction. + + :raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal. + """ + if retried != original: + raise RetryAborted( + "The transaction was aborted and could not be retried due to a concurrent modification." + ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index befc760ea5..a397028287 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -14,11 +14,16 @@ """DB-API Connection for the Google Cloud Spanner.""" +import time import warnings +from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_dbapi.checksum import _compare_checksums +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT @@ -26,6 +31,7 @@ AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" +MAX_INTERNAL_RETRIES = 50 class Connection: @@ -48,9 +54,16 @@ def __init__(self, instance, database): self._transaction = None self._session = None + # SQL statements, which were executed + # within the current transaction + self._statements = [] self.is_closed = False self._autocommit = False + # indicator to know if the session pool used by + # this connection should be cleared on the + # connection close + self._own_pool = True @property def autocommit(self): @@ -114,6 +127,58 @@ def _release_session(self): self.database._pool.put(self._session) self._session = None + def retry_transaction(self): + """Retry the aborted transaction. + + All the statements executed in the original transaction + will be re-executed in new one. Results checksums of the + original statements and the retried ones will be compared. + + :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` + If results checksum of the retried statement is + not equal to the checksum of the original one. + """ + attempt = 0 + while True: + self._transaction = None + attempt += 1 + if attempt > MAX_INTERNAL_RETRIES: + raise + + try: + self._rerun_previous_statements() + break + except Aborted as exc: + delay = _get_retry_delay(exc.errors[0], attempt) + if delay: + time.sleep(delay) + + def _rerun_previous_statements(self): + """ + Helper to run all the remembered statements + from the last transaction. + """ + for statement in self._statements: + res_iter, retried_checksum = self.run_statement(statement, retried=True) + # executing all the completed statements + if statement != self._statements[-1]: + for res in res_iter: + retried_checksum.consume_result(res) + + _compare_checksums(statement.checksum, retried_checksum) + # executing the failed statement + else: + # streaming up to the failed result or + # to the end of the streaming iterator + while len(retried_checksum) < len(statement.checksum): + try: + res = next(iter(res_iter)) + retried_checksum.consume_result(res) + except StopIteration: + break + + _compare_checksums(statement.checksum, retried_checksum) + def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -158,6 +223,9 @@ def close(self): ): self._transaction.rollback() + if self._own_pool: + self.database._pool.clear() + self.is_closed = True def commit(self): @@ -168,8 +236,13 @@ def commit(self): if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: - self._transaction.commit() - self._release_session() + try: + self._transaction.commit() + self._release_session() + self._statements = [] + except Aborted: + self.retry_transaction() + self.commit() def rollback(self): """Rolls back any pending transaction. @@ -182,6 +255,7 @@ def rollback(self): elif self._transaction: self._transaction.rollback() self._release_session() + self._statements = [] def cursor(self): """Factory to create a DB-API Cursor.""" @@ -198,6 +272,32 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() + def run_statement(self, statement, retried=False): + """Run single SQL statement in begun transaction. + + This method is never used in autocommit mode. In + !autocommit mode however it remembers every executed + SQL statement with its parameters. + + :type statement: :class:`dict` + :param statement: SQL statement to execute. + + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, + :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :returns: Streamed result set of the statement and a + checksum of this statement results. + """ + transaction = self.transaction_checkout() + if not retried: + self._statements.append(statement) + + return ( + transaction.execute_sql( + statement.sql, statement.params, param_types=statement.param_types, + ), + ResultsChecksum() if retried else statement.checksum, + ) + def __enter__(self): return self @@ -207,7 +307,12 @@ def __exit__(self, etype, value, traceback): def connect( - instance_id, database_id, project=None, credentials=None, pool=None, user_agent=None + instance_id, + database_id, + project=None, + credentials=None, + pool=None, + user_agent=None, ): """Creates a connection to a Google Cloud Spanner database. @@ -261,4 +366,8 @@ def connect( if not database.exists(): raise ValueError("database '%s' does not exist." % database_id) - return Connection(instance, database) + conn = Connection(instance, database) + if pool is not None: + conn._own_pool = False + + return conn diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index ceaccccdf3..e2667f0599 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -14,6 +14,7 @@ """Database cursor for Google Cloud Spanner DB-API.""" +from google.api_core.exceptions import Aborted from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition from google.api_core.exceptions import InternalServerError @@ -22,7 +23,7 @@ from collections import namedtuple from google.cloud import spanner_v1 as spanner - +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError @@ -34,11 +35,13 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_dbapi.utils import PeekIterator _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +Statement = namedtuple("Statement", "sql, params, param_types, checksum") class Cursor(object): @@ -54,6 +57,8 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False + # the currently running SQL statement results checksum + self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -166,12 +171,13 @@ def execute(self, sql, args=None): self.connection.run_prior_DDL_statements() if not self.connection.autocommit: - transaction = self.connection.transaction_checkout() - - sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, args) + sql, params = sql_pyformat_args_to_spanner(sql, args) - self._result_set = transaction.execute_sql( - sql, params, param_types=get_param_types(params) + statement = Statement( + sql, params, get_param_types(params), ResultsChecksum(), + ) + (self._result_set, self._checksum,) = self.connection.run_statement( + statement ) self._itr = PeekIterator(self._result_set) return @@ -213,9 +219,31 @@ def fetchone(self): self._raise_if_closed() try: - return next(self) + res = next(self) + self._checksum.consume_result(res) + return res except StopIteration: - return None + return + except Aborted: + self.connection.retry_transaction() + return self.fetchone() + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + res = [] + try: + for row in self: + self._checksum.consume_result(row) + res.append(row) + except Aborted: + self._connection.retry_transaction() + return self.fetchall() + + return res def fetchmany(self, size=None): """Fetch the next set of rows of a query result, returning a sequence @@ -236,20 +264,17 @@ def fetchmany(self, size=None): items = [] for i in range(size): try: - items.append(tuple(self.__next__())) + res = next(self) + self._checksum.consume_result(res) + items.append(res) except StopIteration: break + except Aborted: + self._connection.retry_transaction() + return self.fetchmany(size) return items - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as - a sequence of sequences. - """ - self._raise_if_closed() - - return list(self.__iter__()) - def nextset(self): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py index 1a9fdd3625..f5f85a752a 100644 --- a/google/cloud/spanner_dbapi/exceptions.py +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -100,3 +100,13 @@ class NotSupportedError(DatabaseError): """ pass + + +class RetryAborted(OperationalError): + """ + Error for case of no aborted transaction retry + is available, because of underlying data being + changed during a retry. + """ + + pass diff --git a/tests/system/test_system_dbapi.py b/tests/system/test_system_dbapi.py new file mode 100644 index 0000000000..be8e9f2a26 --- /dev/null +++ b/tests/system/test_system_dbapi.py @@ -0,0 +1,311 @@ +# Copyright 2016 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import os +import pickle +import unittest + +from google.api_core import exceptions + +from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1 import BurstyPool + +from google.cloud.spanner_dbapi.connection import Connection + +from test_utils.retry import RetryErrors + +from .test_system import ( + CREATE_INSTANCE, + EXISTING_INSTANCES, + INSTANCE_ID, + USE_EMULATOR, + _list_instances, + Config, +) + + +def setUpModule(): + if USE_EMULATOR: + from google.auth.credentials import AnonymousCredentials + + emulator_project = os.getenv("GCLOUD_PROJECT", "emulator-test-project") + Config.CLIENT = Client( + project=emulator_project, credentials=AnonymousCredentials() + ) + else: + Config.CLIENT = Client() + retry = RetryErrors(exceptions.ServiceUnavailable) + + configs = list(retry(Config.CLIENT.list_instance_configs)()) + + instances = retry(_list_instances)() + EXISTING_INSTANCES[:] = instances + + if CREATE_INSTANCE: + if not USE_EMULATOR: + # Defend against back-end returning configs for regions we aren't + # actually allowed to use. + configs = [config for config in configs if "-us-" in config.name] + + if not configs: + raise ValueError("List instance configs failed in module set up.") + + Config.INSTANCE_CONFIG = configs[0] + config_name = configs[0].name + + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + created_op = Config.INSTANCE.create() + created_op.result(30) # block until completion + + else: + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID) + Config.INSTANCE.reload() + + +def tearDownModule(): + if CREATE_INSTANCE: + Config.INSTANCE.delete() + + +class TestTransactionsManagement(unittest.TestCase): + """Transactions management support tests.""" + + DATABASE_NAME = "db-api-transactions-management" + + DDL_STATEMENTS = ( + """CREATE TABLE contacts ( + contact_id INT64, + first_name STRING(1024), + last_name STRING(1024), + email STRING(1024) + ) + PRIMARY KEY (contact_id)""", + ) + + @classmethod + def setUpClass(cls): + """Create a test database.""" + cls._db = Config.INSTANCE.database( + cls.DATABASE_NAME, + ddl_statements=cls.DDL_STATEMENTS, + pool=BurstyPool(labels={"testcase": "database_api"}), + ) + cls._db.create().result(30) # raises on failure / timeout. + + @classmethod + def tearDownClass(cls): + """Delete the test database.""" + cls._db.drop() + + def tearDown(self): + """Clear the test table after every test.""" + self._db.run_in_transaction(clear_table) + + def test_commit(self): + """Test committing a transaction with several statements.""" + want_row = ( + 1, + "updated-first-name", + "last-name", + "test.email_updated@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # execute several DML statements within one transaction + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.commit() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback(self): + """Test rollbacking a transaction with several statements.""" + want_row = (2, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + # execute several DMLs with one transaction + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.rollback() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_autocommit_mode_change(self): + """Test auto committing a transaction on `autocommit` mode change.""" + want_row = ( + 2, + "updated-first-name", + "last-name", + "test.email@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.autocommit = True + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback_on_connection_closing(self): + """ + When closing a connection all the pending transactions + must be rollbacked. Testing if it's working this way. + """ + want_row = (1, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.close() + + # connect again, as the previous connection is no-op after closing + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_results_checksum(self): + """Test that results checksum is calculated properly.""" + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES + (1, 'first-name', 'last-name', 'test.email@domen.ru'), + (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') + """ + ) + self.assertEqual(len(conn._statements), 1) + conn.commit() + + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(len(conn._statements), 1) + conn.commit() + + checksum = hashlib.sha256() + checksum.update(pickle.dumps(got_rows[0])) + checksum.update(pickle.dumps(got_rows[1])) + + self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest()) + + +def clear_table(transaction): + """Clear the test table.""" + transaction.execute_update("DELETE FROM contacts WHERE true") diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py new file mode 100644 index 0000000000..a90d0da370 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -0,0 +1,71 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + + +class Test_compare_checksums(unittest.TestCase): + def test_equal(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + + self.assertIsNone(_compare_checksums(original, retried)) + + def test_less_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) + + def test_more_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + retried.consume_result(2) + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) + + def test_mismatch(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import RetryAborted + + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(2) + + with self.assertRaises(RetryAborted): + _compare_checksums(original, retried) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..771b9d4a7f --- /dev/null +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -0,0 +1,141 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""connect() module function unit tests.""" + +import unittest +from unittest import mock + +import google.auth.credentials + + +def _make_credentials(): + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Test_connect(unittest.TestCase): + def test_connect(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + with mock.patch("google.cloud.spanner_v1.Client") as client_mock: + connection = connect( + "test-instance", + "test-database", + PROJECT, + CREDENTIALS, + user_agent=USER_AGENT, + ) + + self.assertIsInstance(connection, Connection) + + client_mock.assert_called_once_with( + project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY + ) + + def test_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_database_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_connect_instance_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = connect(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + connection = connect("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) + + def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 8cd3bced16..213eb24d84 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -234,9 +234,7 @@ def test_run_prior_DDL_statements(self): connection.run_prior_DDL_statements() def test_context(self): - from google.cloud.spanner_dbapi import Connection - - connection = Connection(self.INSTANCE, self.DATABASE) + connection = self._make_connection() with connection as conn: self.assertEqual(conn, connection) @@ -306,3 +304,229 @@ def test_sessions_pool(self): ): connect("test-instance", database_id, pool=pool) database_mock.assert_called_once_with(database_id, pool=pool) + + def test_run_statement_remember_statements(self): + """Check that Connection remembers executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement) + + self.assertEqual(connection._statements[0].sql, sql) + self.assertEqual(connection._statements[0].params, params) + self.assertEqual(connection._statements[0].param_types, param_types) + self.assertIsInstance(connection._statements[0].checksum, ResultsChecksum) + + def test_run_statement_dont_remember_retried_statements(self): + """Check that Connection doesn't remember re-executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement, retried=True) + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch("google.cloud.spanner_v1.transaction.Transaction.commit"): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch("google.cloud.spanner_v1.transaction.Transaction.commit"): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(RetryAborted): + connection.retry_transaction() + + def test_commit_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + connection._transaction = mock.Mock() + + with mock.patch.object( + connection._transaction, "commit", side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + connection.commit() + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_drop_transaction(self): + """ + Check that before retrying an aborted transaction + connection drops the original aborted transaction. + """ + connection = self._make_connection() + transaction_mock = mock.Mock() + connection._transaction = transaction_mock + + # as we didn't set any statements, the method + # will only drop the transaction object + connection.retry_transaction() + self.assertIsNone(connection._transaction) + + def test_retry_aborted_retry(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + + with mock.patch.object( + connection, + "run_statement", + side_effect=( + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ), + ) as retry_mock: + + connection.retry_transaction() + + retry_mock.assert_has_calls( + ( + mock.call(statement, retried=True), + mock.call(statement, retried=True), + ) + ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 871214a360..43fc077abe 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -285,8 +285,11 @@ def test_executemany(self): sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [1, 2, 3] cursor._itr = iter(lst) for i in range(len(lst)): @@ -294,8 +297,11 @@ def test_fetchone(self): self.assertIsNone(cursor.fetchone()) def test_fetchmany(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) @@ -305,8 +311,11 @@ def test_fetchmany(self): self.assertEqual(result, lst[1:]) def test_fetchall(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) @@ -453,3 +462,108 @@ def test_get_table_column_schema(self): param_types={"table_name": param_types.STRING}, ) self.assertEqual(result, expected) + + def test_fetchone_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_called_with() + + def test_fetchone_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(RetryAborted): + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True)