From 2f2cd8631817c9f3d898c60e38778ae533c3f803 Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 22 Oct 2020 21:28:50 +0300 Subject: [PATCH] feat: support transactions management (#535) Add transaction management, including utils for handling spanner sessions and connections. Co-authored-by: MF2199 <38331387+mf2199@users.noreply.github.com> Co-authored-by: Chris Kleinknecht --- google/cloud/spanner_dbapi/__init__.py | 18 +- google/cloud/spanner_dbapi/connection.py | 122 +++++++++- google/cloud/spanner_dbapi/cursor.py | 12 + tests/spanner_dbapi/test_connect.py | 27 ++- tests/spanner_dbapi/test_connection.py | 19 +- tests/system/test_system.py | 293 ++++++++++++++++++++++- 6 files changed, 463 insertions(+), 28 deletions(-) diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 014d82d3cc..0bb37492db 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -49,7 +49,12 @@ def connect( - instance_id, database_id, project=None, credentials=None, user_agent=None + instance_id, + database_id, + project=None, + credentials=None, + pool=None, + user_agent=None, ): """ Create a connection to Cloud Spanner database. @@ -71,6 +76,13 @@ def connect( If none are specified, the client will attempt to ascertain the credentials from the environment. + :type pool: Concrete subclass of + :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. + :param pool: (Optional). Session pool to be used by database. + + :type user_agent: :class:`str` + :param user_agent: (Optional) User agent to be used with this connection requests. + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Cloud Spanner resource. @@ -87,9 +99,7 @@ def connect( if not instance.exists(): raise ValueError("instance '%s' does not exist." % instance_id) - database = instance.database( - database_id, pool=spanner_v1.pool.BurstyPool() - ) + database = instance.database(database_id, pool=pool) if not database.exists(): raise ValueError("database '%s' does not exist." % database_id) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 70ec5a0365..8907e65c03 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -14,11 +14,7 @@ from .cursor import Cursor from .exceptions import InterfaceError -AUTOCOMMIT_MODE_WARNING = ( - "This method is non-operational, as Cloud Spanner" - "DB API always works in `autocommit` mode." - "See https://github.com/googleapis/python-spanner-django#transaction-management-isnt-supported" -) +AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) @@ -37,11 +33,98 @@ class Connection: """ def __init__(self, instance, database): - self.instance = instance - self.database = database - self.is_closed = False + self._instance = instance + self._database = database self._ddl_statements = [] + self._transaction = None + self._session = None + + self.is_closed = False + self._autocommit = False + + @property + def autocommit(self): + """Autocommit mode flag for this connection. + + :rtype: bool + :returns: Autocommit mode flag value. + """ + return self._autocommit + + @autocommit.setter + def autocommit(self, value): + """Change this connection autocommit mode. + + :type value: bool + :param value: New autocommit mode state. + """ + if value and not self._autocommit: + self.commit() + + self._autocommit = value + + @property + def database(self): + """Database to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.database.Database` + :returns: The related database object. + """ + return self._database + + @property + def instance(self): + """Instance to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` + :returns: The related instance object. + """ + return self._instance + + def _session_checkout(self): + """Get a Cloud Spanner session from the pool. + + If there is already a session associated with + this connection, it'll be used instead. + + :rtype: :class:`google.cloud.spanner_v1.session.Session` + :returns: Cloud Spanner session object ready to use. + """ + if not self._session: + self._session = self.database._pool.get() + + return self._session + + def _release_session(self): + """Release the currently used Spanner session. + + The session will be returned into the sessions pool. + """ + self.database._pool.put(self._session) + self._session = None + + def transaction_checkout(self): + """Get a Cloud Spanner transaction. + + Begin a new transaction, if there is no transaction in + this connection yet. Return the begun one otherwise. + + The method is non operational in autocommit mode. + + :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` + :returns: A Cloud Spanner transaction object, ready to use. + """ + if not self.autocommit: + if ( + not self._transaction + or self._transaction.committed + or self._transaction.rolled_back + ): + self._transaction = self._session_checkout().transaction() + self._transaction.begin() + + return self._transaction def cursor(self): self._raise_if_closed() @@ -142,18 +225,33 @@ def get_table_column_schema(self, table_name): def close(self): """Close this connection. - The connection will be unusable from this point forward. + The connection will be unusable from this point forward. If the + connection has an active transaction, it will be rolled back. """ - self.__dbhandle = None + if ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ): + self._transaction.rollback() + self.is_closed = True def commit(self): """Commit all the pending transactions.""" - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + if self.autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.commit() + self._release_session() def rollback(self): """Rollback all the pending transactions.""" - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + if self.autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.rollback() + self._release_session() def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 73764b4c26..95eae50e1a 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -91,6 +91,7 @@ def execute(self, sql, args=None): # Classify whether this is a read-only SQL statement. try: classification = classify_stmt(sql) + if classification == STMT_DDL: self._connection.append_ddl_statement(sql) return @@ -99,6 +100,17 @@ def execute(self, sql, args=None): # any prior DDL statements were run. self._run_prior_DDL_statements() + if not self._connection.autocommit: + transaction = self._connection.transaction_checkout() + + sql, params = sql_pyformat_args_to_spanner(sql, args) + + self._res = transaction.execute_sql( + sql, params, param_types=get_param_types(params) + ) + self._itr = PeekIterator(self._res) + return + if classification == STMT_NON_UPDATING: self.__handle_DQL(sql, args or None) elif classification == STMT_INSERT: diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index 260d3a0993..fb4d89c373 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -12,6 +12,7 @@ import google.auth.credentials from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud.spanner_dbapi import connect, Connection +from google.cloud.spanner_v1.pool import FixedSizePool def _make_credentials(): @@ -43,7 +44,7 @@ def test_connect(self): "test-database", PROJECT, CREDENTIALS, - USER_AGENT, + user_agent=USER_AGENT, ) self.assertIsInstance(connection, Connection) @@ -108,3 +109,27 @@ def test_connect_database_id(self): database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) self.assertIsInstance(connection, Connection) + + def test_default_sessions_pool(self): + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 1b285a933d..24260de12e 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -49,8 +49,9 @@ def test_close(self): connection.cursor() @mock.patch("warnings.warn") - def test_transaction_management_warnings(self, warn_mock): + def test_transaction_autocommit_warnings(self, warn_mock): connection = self._make_connection() + connection.autocommit = True connection.commit() warn_mock.assert_called_with( @@ -60,3 +61,19 @@ def test_transaction_management_warnings(self, warn_mock): warn_mock.assert_called_with( AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 ) + + def test_database_property(self): + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + with self.assertRaises(AttributeError): + connection.database = None + + def test_instance_property(self): + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + with self.assertRaises(AttributeError): + connection.instance = None diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 5710ba6ce6..f3ee345e15 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -5,18 +5,291 @@ # https://developers.google.com/open-source/licenses/bsd import unittest +import os +from google.api_core import exceptions -class TestSpannerDjangoDBAPI(unittest.TestCase): - def setUp(self): - # TODO: Implement this method - pass +from google.cloud.spanner import Client +from google.cloud.spanner import BurstyPool +from google.cloud.spanner_dbapi.connection import Connection + +from test_utils.retry import RetryErrors +from test_utils.system import unique_resource_id + + +CREATE_INSTANCE = ( + os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None +) +USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None + +if CREATE_INSTANCE: + INSTANCE_ID = "google-cloud" + unique_resource_id("-") +else: + INSTANCE_ID = os.environ.get( + "GOOGLE_CLOUD_TESTS_SPANNER_INSTANCE", "google-cloud-python-systest" + ) +EXISTING_INSTANCES = [] + +DDL_STATEMENTS = ( + """CREATE TABLE contacts ( + contact_id INT64, + first_name STRING(1024), + last_name STRING(1024), + email STRING(1024) + ) + PRIMARY KEY (contact_id)""", +) + + +class Config(object): + """Run-time configuration to be modified at set-up. + + This is a mutable stand-in to allow test set-up to modify + global state. + """ + + CLIENT = None + INSTANCE_CONFIG = None + INSTANCE = None + + +def _list_instances(): + return list(Config.CLIENT.list_instances()) + + +def setUpModule(): + if USE_EMULATOR: + from google.auth.credentials import AnonymousCredentials + + emulator_project = os.getenv("GCLOUD_PROJECT", "emulator-test-project") + Config.CLIENT = Client( + project=emulator_project, credentials=AnonymousCredentials() + ) + else: + Config.CLIENT = Client() + + retry = RetryErrors(exceptions.ServiceUnavailable) + + configs = list(retry(Config.CLIENT.list_instance_configs)()) + + instances = retry(_list_instances)() + EXISTING_INSTANCES[:] = instances + + if CREATE_INSTANCE: + if not USE_EMULATOR: + # Defend against back-end returning configs for regions we aren't + # actually allowed to use. + configs = [config for config in configs if "-us-" in config.name] + + if not configs: + raise ValueError("List instance configs failed in module set up.") + + Config.INSTANCE_CONFIG = configs[0] + config_name = configs[0].name + + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + created_op = Config.INSTANCE.create() + created_op.result(30) # block until completion + else: + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID) + Config.INSTANCE.reload() + + +def tearDownModule(): + """Delete the test instance, if it was created.""" + if CREATE_INSTANCE: + Config.INSTANCE.delete() + + +class TestTransactionsManagement(unittest.TestCase): + """Transactions management support tests.""" + + DATABASE_NAME = "db-api-transactions-management" + + @classmethod + def setUpClass(cls): + """Create a test database.""" + cls._db = Config.INSTANCE.database( + cls.DATABASE_NAME, + ddl_statements=DDL_STATEMENTS, + pool=BurstyPool(labels={"testcase": "database_api"}), + ) + cls._db.create().result(30) # raises on failure / timeout. + + @classmethod + def tearDownClass(cls): + """Delete the test database.""" + cls._db.drop() def tearDown(self): - # TODO: Implement this method - pass + """Clear the test table after every test.""" + self._db.run_in_transaction(clear_table) + + def test_commit(self): + """Test committing a transaction with several statements.""" + want_row = ( + 1, + "updated-first-name", + "last-name", + "test.email_updated@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # execute several DML statements within one transaction + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.commit() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback(self): + """Test rollbacking a transaction with several statements.""" + want_row = (2, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + # execute several DMLs with one transaction + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.rollback() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_autocommit_mode_change(self): + """Test auto committing a transaction on `autocommit` mode change.""" + want_row = ( + 2, + "updated-first-name", + "last-name", + "test.email@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.autocommit = True + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback_on_connection_closing(self): + """ + When closing a connection all the pending transactions + must be rollbacked. Testing if it's working this way. + """ + want_row = (1, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.close() + + # connect again, as the previous connection is no-op after closing + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + - def test_api(self): - # An dummy stub to avoid `exit code 5` errors - # TODO: Replace this with an actual system test method - self.assertTrue(True) +def clear_table(transaction): + """Clear the test table.""" + transaction.execute_update("DELETE FROM contacts WHERE true")