diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index ba9fea3858..e6d1d64db1 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -87,6 +87,7 @@ def __init__(self, instance, database, read_only=False): # connection close self._own_pool = True self._read_only = read_only + self._staleness = None @property def autocommit(self): @@ -165,6 +166,42 @@ def read_only(self, value): ) self._read_only = value + @property + def staleness(self): + """Current read staleness option value of this `Connection`. + + Returns: + dict: Staleness type and value. + """ + return self._staleness or {} + + @staleness.setter + def staleness(self, value): + """Read staleness option setter. + + Args: + value (dict): Staleness type and value. + """ + if self.inside_transaction: + raise ValueError( + "`staleness` option can't be changed while a transaction is in progress. " + "Commit or rollback the current transaction and try again." + ) + + possible_opts = ( + "read_timestamp", + "min_read_timestamp", + "max_staleness", + "exact_staleness", + ) + if value is not None and sum([opt in value for opt in possible_opts]) != 1: + raise ValueError( + "Expected one of the following staleness options: " + "read_timestamp, min_read_timestamp, max_staleness, exact_staleness." + ) + + self._staleness = value + def _session_checkout(self): """Get a Cloud Spanner session from the pool. @@ -284,7 +321,9 @@ def snapshot_checkout(self): """ if self.read_only and not self.autocommit: if not self._snapshot: - self._snapshot = Snapshot(self._session_checkout(), multi_use=True) + self._snapshot = Snapshot( + self._session_checkout(), multi_use=True, **self.staleness + ) self._snapshot.begin() return self._snapshot diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 27303a09a6..e9e4862281 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -426,7 +426,9 @@ def _handle_DQL(self, sql, params): ) else: # execute with single-use snapshot - with self.connection.database.snapshot() as snapshot: + with self.connection.database.snapshot( + **self.connection.staleness + ) as snapshot: self._handle_DQL_with_snapshot(snapshot, sql, params) def __enter__(self): diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 4c3989a7a4..d0ad26e79f 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import hashlib import pickle import pkg_resources import pytest from google.cloud import spanner_v1 -from google.cloud.spanner_dbapi.connection import connect, Connection +from google.cloud._helpers import UTC +from google.cloud.spanner_dbapi.connection import connect +from google.cloud.spanner_dbapi.connection import Connection from google.cloud.spanner_dbapi.exceptions import ProgrammingError from google.cloud.spanner_v1 import JsonObject from . import _helpers @@ -429,3 +432,32 @@ def test_read_only(shared_instance, dbapi_database): cur.execute("SELECT * FROM contacts") conn.commit() + + +def test_staleness(shared_instance, dbapi_database): + """Check the DB API `staleness` option.""" + conn = Connection(shared_instance, dbapi_database) + cursor = conn.cursor() + + before_insert = datetime.datetime.utcnow().replace(tzinfo=UTC) + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com') + """ + ) + conn.commit() + + conn.read_only = True + conn.staleness = {"read_timestamp": before_insert} + cursor.execute("SELECT * FROM contacts") + conn.commit() + assert len(cursor.fetchall()) == 0 + + conn.staleness = None + cursor.execute("SELECT * FROM contacts") + conn.commit() + assert len(cursor.fetchall()) == 1 + + conn.close() diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 34e50255f9..0eea3eaf5b 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -14,6 +14,7 @@ """Cloud Spanner DB-API Connection class unit tests.""" +import datetime import mock import unittest import warnings @@ -688,9 +689,6 @@ def test_retry_transaction_w_empty_response(self): run_mock.assert_called_with(statement, retried=True) def test_validate_ok(self): - def exit_func(self, exc_type, exc_value, traceback): - pass - connection = self._make_connection() # mock snapshot context manager @@ -699,7 +697,7 @@ def exit_func(self, exc_type, exc_value, traceback): snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) - snapshot_ctx.__exit__ = exit_func + snapshot_ctx.__exit__ = exit_ctx_func snapshot_method = mock.Mock(return_value=snapshot_ctx) connection.database.snapshot = snapshot_method @@ -710,9 +708,6 @@ def exit_func(self, exc_type, exc_value, traceback): def test_validate_fail(self): from google.cloud.spanner_dbapi.exceptions import OperationalError - def exit_func(self, exc_type, exc_value, traceback): - pass - connection = self._make_connection() # mock snapshot context manager @@ -721,7 +716,7 @@ def exit_func(self, exc_type, exc_value, traceback): snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) - snapshot_ctx.__exit__ = exit_func + snapshot_ctx.__exit__ = exit_ctx_func snapshot_method = mock.Mock(return_value=snapshot_ctx) connection.database.snapshot = snapshot_method @@ -734,9 +729,6 @@ def exit_func(self, exc_type, exc_value, traceback): def test_validate_error(self): from google.cloud.exceptions import NotFound - def exit_func(self, exc_type, exc_value, traceback): - pass - connection = self._make_connection() # mock snapshot context manager @@ -745,7 +737,7 @@ def exit_func(self, exc_type, exc_value, traceback): snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) - snapshot_ctx.__exit__ = exit_func + snapshot_ctx.__exit__ = exit_ctx_func snapshot_method = mock.Mock(return_value=snapshot_ctx) connection.database.snapshot = snapshot_method @@ -763,3 +755,117 @@ def test_validate_closed(self): with self.assertRaises(InterfaceError): connection.validate() + + def test_staleness_invalid_value(self): + """Check that `staleness` property accepts only correct values.""" + connection = self._make_connection() + + # incorrect staleness type + with self.assertRaises(ValueError): + connection.staleness = {"something": 4} + + # no expected staleness types + with self.assertRaises(ValueError): + connection.staleness = {} + + def test_staleness_inside_transaction(self): + """ + Check that it's impossible to change the `staleness` + option if a transaction is in progress. + """ + connection = self._make_connection() + connection._transaction = mock.Mock(committed=False, rolled_back=False) + + with self.assertRaises(ValueError): + connection.staleness = {"read_timestamp": datetime.datetime(2021, 9, 21)} + + def test_staleness_multi_use(self): + """ + Check that `staleness` option is correctly + sent to the `Snapshot()` constructor. + + READ_ONLY, NOT AUTOCOMMIT + """ + timestamp = datetime.datetime(2021, 9, 20) + + connection = self._make_connection() + connection._session = "session" + connection.read_only = True + connection.staleness = {"read_timestamp": timestamp} + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Snapshot" + ) as snapshot_mock: + connection.snapshot_checkout() + + snapshot_mock.assert_called_with( + "session", multi_use=True, read_timestamp=timestamp + ) + + def test_staleness_single_use_autocommit(self): + """ + Check that `staleness` option is correctly + sent to the snapshot context manager. + + NOT READ_ONLY, AUTOCOMMIT + """ + timestamp = datetime.datetime(2021, 9, 20) + + connection = self._make_connection() + connection._session_checkout = mock.MagicMock(autospec=True) + + connection.autocommit = True + connection.staleness = {"read_timestamp": timestamp} + + # mock snapshot context manager + snapshot_obj = mock.Mock() + snapshot_obj.execute_sql = mock.Mock(return_value=[1]) + + snapshot_ctx = mock.Mock() + snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) + snapshot_ctx.__exit__ = exit_ctx_func + snapshot_method = mock.Mock(return_value=snapshot_ctx) + + connection.database.snapshot = snapshot_method + + cursor = connection.cursor() + cursor.execute("SELECT 1") + + connection.database.snapshot.assert_called_with(read_timestamp=timestamp) + + def test_staleness_single_use_readonly_autocommit(self): + """ + Check that `staleness` option is correctly sent to the + snapshot context manager while in `autocommit` mode. + + READ_ONLY, AUTOCOMMIT + """ + timestamp = datetime.datetime(2021, 9, 20) + + connection = self._make_connection() + connection.autocommit = True + connection.read_only = True + connection._session_checkout = mock.MagicMock(autospec=True) + + connection.staleness = {"read_timestamp": timestamp} + + # mock snapshot context manager + snapshot_obj = mock.Mock() + snapshot_obj.execute_sql = mock.Mock(return_value=[1]) + + snapshot_ctx = mock.Mock() + snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) + snapshot_ctx.__exit__ = exit_ctx_func + snapshot_method = mock.Mock(return_value=snapshot_ctx) + + connection.database.snapshot = snapshot_method + + cursor = connection.cursor() + cursor.execute("SELECT 1") + + connection.database.snapshot.assert_called_with(read_timestamp=timestamp) + + +def exit_ctx_func(self, exc_type, exc_value, traceback): + """Context __exit__ method mock.""" + pass