diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 8d46b84cef..ba9fea3858 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -21,6 +21,7 @@ from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner from google.cloud.spanner_v1.session import _get_retry_delay +from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous @@ -50,15 +51,31 @@ class Connection: :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to which the connection is linked. + + :type read_only: bool + :param read_only: + Flag to indicate that the connection may only execute queries and no update or DDL statements. + If True, the connection will use a single use read-only transaction with strong timestamp + bound for each new statement, and will immediately see any changes that have been committed by + any other transaction. + If autocommit is false, the connection will automatically start a new multi use read-only transaction + with strong timestamp bound when the first statement is executed. This read-only transaction will be + used for all subsequent statements until either commit() or rollback() is called on the connection. The + read-only transaction will read from a consistent snapshot of the database at the time that the + transaction started. This means that the transaction will not see any changes that have been + committed by other transactions since the start of the read-only transaction. Commit or rolling back + the read-only transaction is semantically the same, and only indicates that the read-only transaction + should end a that a new one should be started when the next statement is executed. """ - def __init__(self, instance, database): + def __init__(self, instance, database, read_only=False): self._instance = instance self._database = database self._ddl_statements = [] self._transaction = None self._session = None + self._snapshot = None # SQL statements, which were executed # within the current transaction self._statements = [] @@ -69,6 +86,7 @@ def __init__(self, instance, database): # this connection should be cleared on the # connection close self._own_pool = True + self._read_only = read_only @property def autocommit(self): @@ -123,6 +141,30 @@ def instance(self): """ return self._instance + @property + def read_only(self): + """Flag: the connection can be used only for database reads. + + Returns: + bool: + True if the connection may only be used for database reads. + """ + return self._read_only + + @read_only.setter + def read_only(self, value): + """`read_only` flag setter. + + Args: + value (bool): True for ReadOnly mode, False for ReadWrite. + """ + if self.inside_transaction: + raise ValueError( + "Connection read/write mode can't be changed while a transaction is in progress. " + "Commit or rollback the current transaction and try again." + ) + self._read_only = value + def _session_checkout(self): """Get a Cloud Spanner session from the pool. @@ -231,6 +273,22 @@ def transaction_checkout(self): return self._transaction + def snapshot_checkout(self): + """Get a Cloud Spanner snapshot. + + Initiate a new multi-use snapshot, if there is no snapshot in + this connection yet. Return the existing one otherwise. + + :rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot` + :returns: A Cloud Spanner snapshot object, ready to use. + """ + if self.read_only and not self.autocommit: + if not self._snapshot: + self._snapshot = Snapshot(self._session_checkout(), multi_use=True) + self._snapshot.begin() + + return self._snapshot + def _raise_if_closed(self): """Helper to check the connection state before running a query. Raises an exception if this connection is closed. @@ -259,6 +317,8 @@ def commit(self): This method is non-operational in autocommit mode. """ + self._snapshot = None + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) return @@ -266,7 +326,9 @@ def commit(self): self.run_prior_DDL_statements() if self.inside_transaction: try: - self._transaction.commit() + if not self.read_only: + self._transaction.commit() + self._release_session() self._statements = [] except Aborted: @@ -279,10 +341,14 @@ def rollback(self): This is a no-op if there is no active transaction or if the connection is in autocommit mode. """ + self._snapshot = None + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: - self._transaction.rollback() + if not self.read_only: + self._transaction.rollback() + self._release_session() self._statements = [] diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index cf15b99a55..64df68b362 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -186,6 +186,10 @@ def execute(self, sql, args=None): # Classify whether this is a read-only SQL statement. try: + if self.connection.read_only: + self._handle_DQL(sql, args or None) + return + classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: ddl_statements = [] @@ -325,14 +329,15 @@ def fetchone(self): try: res = next(self) - if not self.connection.autocommit: + if not self.connection.autocommit and not self.connection.read_only: self._checksum.consume_result(res) return res except StopIteration: return except Aborted: - self.connection.retry_transaction() - return self.fetchone() + if not self.connection.read_only: + self.connection.retry_transaction() + return self.fetchone() def fetchall(self): """Fetch all (remaining) rows of a query result, returning them as @@ -343,12 +348,13 @@ def fetchall(self): res = [] try: for row in self: - if not self.connection.autocommit: + if not self.connection.autocommit and not self.connection.read_only: self._checksum.consume_result(row) res.append(row) except Aborted: - self.connection.retry_transaction() - return self.fetchall() + if not self.connection.read_only: + self.connection.retry_transaction() + return self.fetchall() return res @@ -372,14 +378,15 @@ def fetchmany(self, size=None): for i in range(size): try: res = next(self) - if not self.connection.autocommit: + if not self.connection.autocommit and not self.connection.read_only: self._checksum.consume_result(res) items.append(res) except StopIteration: break except Aborted: - self.connection.retry_transaction() - return self.fetchmany(size) + if not self.connection.read_only: + self.connection.retry_transaction() + return self.fetchmany(size) return items @@ -395,38 +402,39 @@ def setoutputsize(self, size, column=None): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() + def _handle_DQL_with_snapshot(self, snapshot, sql, params): + # Reference + # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) + # Immediately using: + # iter(response) + # here, because this Spanner API doesn't provide + # easy mechanisms to detect when only a single item + # is returned or many, yet mixing results that + # are for .fetchone() with those that would result in + # many items returns a RuntimeError if .fetchone() is + # invoked and vice versa. + self._result_set = res + # Read the first element so that the StreamedResultSet can + # return the metadata after a DQL statement. See issue #155. + self._itr = PeekIterator(self._result_set) + # Unfortunately, Spanner doesn't seem to send back + # information about the number of rows available. + self._row_count = _UNSET_COUNT + def _handle_DQL(self, sql, params): - with self.connection.database.snapshot() as snapshot: - # Reference - # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql - sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) - res = snapshot.execute_sql( - sql, params=params, param_types=get_param_types(params) + if self.connection.read_only and not self.connection.autocommit: + # initiate or use the existing multi-use snapshot + self._handle_DQL_with_snapshot( + self.connection.snapshot_checkout(), sql, params ) - if type(res) == int: - self._row_count = res - self._itr = None - else: - # Immediately using: - # iter(response) - # here, because this Spanner API doesn't provide - # easy mechanisms to detect when only a single item - # is returned or many, yet mixing results that - # are for .fetchone() with those that would result in - # many items returns a RuntimeError if .fetchone() is - # invoked and vice versa. - self._result_set = res - # Read the first element so that the StreamedResultSet can - # return the metadata after a DQL statement. See issue #155. - while True: - try: - self._itr = PeekIterator(self._result_set) - break - except Aborted: - self.connection.retry_transaction() - # Unfortunately, Spanner doesn't seem to send back - # information about the number of rows available. - self._row_count = _UNSET_COUNT + else: + # execute with single-use snapshot + with self.connection.database.snapshot() as snapshot: + self._handle_DQL_with_snapshot(snapshot, sql, params) def __enter__(self): return self diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 2d1b4097dc..4c3989a7a4 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -19,9 +19,11 @@ from google.cloud import spanner_v1 from google.cloud.spanner_dbapi.connection import connect, Connection +from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_v1 import JsonObject from . import _helpers + DATABASE_NAME = "dbapi-txn" DDL_STATEMENTS = ( @@ -406,3 +408,24 @@ def test_user_agent(shared_instance, dbapi_database): conn.instance._client._client_info.user_agent == "dbapi/" + pkg_resources.get_distribution("google-cloud-spanner").version ) + + +def test_read_only(shared_instance, dbapi_database): + """ + Check that connection set to `read_only=True` uses + ReadOnly transactions. + """ + conn = Connection(shared_instance, dbapi_database, read_only=True) + cur = conn.cursor() + + with pytest.raises(ProgrammingError): + cur.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + + cur.execute("SELECT * FROM contacts") + conn.commit() diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index abdd3357dd..34e50255f9 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -39,14 +39,14 @@ def _get_client_info(self): return ClientInfo(user_agent=USER_AGENT) - def _make_connection(self): + def _make_connection(self, **kwargs): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1.instance import Instance # We don't need a real Client object to test the constructor instance = Instance(INSTANCE, client=None) database = instance.database(DATABASE) - return Connection(instance, database) + return Connection(instance, database, **kwargs) @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") def test_autocommit_setter_transaction_not_started(self, mock_commit): @@ -105,6 +105,42 @@ def test_property_instance(self): self.assertIsInstance(connection.instance, Instance) self.assertEqual(connection.instance, connection._instance) + def test_read_only_connection(self): + connection = self._make_connection(read_only=True) + self.assertTrue(connection.read_only) + + connection._transaction = mock.Mock(committed=False, rolled_back=False) + with self.assertRaisesRegex( + ValueError, + "Connection read/write mode can't be changed while a transaction is in progress. " + "Commit or rollback the current transaction and try again.", + ): + connection.read_only = False + + connection._transaction = None + connection.read_only = False + self.assertFalse(connection.read_only) + + def test_read_only_not_retried(self): + """ + Testing the unlikely case of a read-only transaction + failed with Aborted exception. In this case the + transaction should not be automatically retried. + """ + from google.api_core.exceptions import Aborted + + connection = self._make_connection(read_only=True) + connection.retry_transaction = mock.Mock() + + cursor = connection.cursor() + cursor._itr = mock.Mock(__next__=mock.Mock(side_effect=Aborted("Aborted"),)) + + cursor.fetchone() + cursor.fetchall() + cursor.fetchmany(5) + + connection.retry_transaction.assert_not_called() + @staticmethod def _make_pool(): from google.cloud.spanner_v1.pool import AbstractSessionPool @@ -160,6 +196,32 @@ def test_transaction_checkout(self): connection._autocommit = True self.assertIsNone(connection.transaction_checkout()) + def test_snapshot_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(INSTANCE, DATABASE, read_only=True) + connection.autocommit = False + + session_checkout = mock.MagicMock(autospec=True) + connection._session_checkout = session_checkout + + snapshot = connection.snapshot_checkout() + session_checkout.assert_called_once() + + self.assertEqual(snapshot, connection.snapshot_checkout()) + + connection.commit() + self.assertIsNone(connection._snapshot) + + connection.snapshot_checkout() + self.assertIsNotNone(connection._snapshot) + + connection.rollback() + self.assertIsNone(connection._snapshot) + + connection.autocommit = True + self.assertIsNone(connection.snapshot_checkout()) + @mock.patch("google.cloud.spanner_v1.Client") def test_close(self, mock_client): from google.cloud.spanner_dbapi import connect diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 038f419351..1a79c64e1b 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -714,14 +714,9 @@ def test_handle_dql(self): ) = mock.MagicMock() cursor = self._make_one(connection) - mock_snapshot.execute_sql.return_value = int(0) + mock_snapshot.execute_sql.return_value = ["0"] cursor._handle_DQL("sql", params=None) - self.assertEqual(cursor._row_count, 0) - self.assertIsNone(cursor._itr) - - mock_snapshot.execute_sql.return_value = "0" - cursor._handle_DQL("sql", params=None) - self.assertEqual(cursor._result_set, "0") + self.assertEqual(cursor._result_set, ["0"]) self.assertIsInstance(cursor._itr, utils.PeekIterator) self.assertEqual(cursor._row_count, _UNSET_COUNT) @@ -838,37 +833,6 @@ def test_peek_iterator_aborted(self, mock_client): retry_mock.assert_called_with() - @mock.patch("google.cloud.spanner_v1.Client") - def test_peek_iterator_aborted_autocommit(self, mock_client): - """ - Checking that an Aborted exception is retried in case it happened while - streaming the first element with a PeekIterator in autocommit mode. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.connection import connect - - connection = connect("test-instance", "test-database") - - connection.autocommit = True - cursor = connection.cursor() - with mock.patch( - "google.cloud.spanner_dbapi.utils.PeekIterator.__init__", - side_effect=(Aborted("Aborted"), None), - ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" - ) as retry_mock: - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=((1, 2, 3), None), - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.snapshot" - ): - cursor.execute("SELECT * FROM table_name") - - retry_mock.assert_called_with() - @mock.patch("google.cloud.spanner_v1.Client") def test_fetchone_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction."""