diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 622f3d7b07..1e76bf218f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -26,6 +26,7 @@ from google.api_core.retry import if_exception_type from google.cloud.exceptions import NotFound from google.api_core.exceptions import Aborted +from google.api_core import gapic_v1 import six # pylint: disable=ungrouped-imports @@ -915,6 +916,9 @@ def generate_read_batches( index="", partition_size_bytes=None, max_partitions=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Start a partitioned batch read operation. @@ -946,6 +950,12 @@ def generate_read_batches( service uses this as a hint, the actual number of partitions may differ. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: iterable of dict :returns: mappings of information used perform actual partitioned reads via @@ -958,6 +968,8 @@ def generate_read_batches( index=index, partition_size_bytes=partition_size_bytes, max_partitions=max_partitions, + retry=retry, + timeout=timeout, ) read_info = { @@ -969,7 +981,9 @@ def generate_read_batches( for partition in partitions: yield {"partition": partition, "read": read_info.copy()} - def process_read_batch(self, batch): + def process_read_batch( + self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + ): """Process a single, partitioned read. :type batch: mapping @@ -977,13 +991,22 @@ def process_read_batch(self, batch): one of the mappings returned from an earlier call to :meth:`generate_read_batches`. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ kwargs = copy.deepcopy(batch["read"]) keyset_dict = kwargs.pop("keyset") kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read(partition=batch["partition"], **kwargs) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -993,6 +1016,9 @@ def generate_query_batches( partition_size_bytes=None, max_partitions=None, query_options=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Start a partitioned query operation. @@ -1036,6 +1062,12 @@ def generate_query_batches( If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: iterable of dict :returns: mappings of information used perform actual partitioned reads via @@ -1047,6 +1079,8 @@ def generate_query_batches( param_types=param_types, partition_size_bytes=partition_size_bytes, max_partitions=max_partitions, + retry=retry, + timeout=timeout, ) query_info = {"sql": sql} @@ -1064,7 +1098,9 @@ def generate_query_batches( for partition in partitions: yield {"partition": partition, "query": query_info} - def process_query_batch(self, batch): + def process_query_batch( + self, batch, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + ): """Process a single, partitioned query. :type batch: mapping @@ -1072,11 +1108,17 @@ def process_query_batch(self, batch): one of the mappings returned from an earlier call to :meth:`generate_query_batches`. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"] + partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout ) def process(self, batch): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 853c5c5c18..1321308ace 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -258,6 +258,12 @@ def execute_sql( or :class:`dict` :param query_options: (Optional) Options that are provided for query plan stability. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 37638df6fa..1b3ae8097d 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -27,7 +27,7 @@ from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable -import google.api_core.gapic_v1.method +from google.api_core import gapic_v1 from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _merge_query_options from google.cloud.spanner_v1._helpers import _metadata_with_prefix @@ -109,7 +109,18 @@ def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """ raise NotImplementedError - def read(self, table, columns, keyset, index="", limit=0, partition=None): + def read( + self, + table, + columns, + keyset, + index="", + limit=0, + partition=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -134,6 +145,12 @@ def read(self, table, columns, keyset, index="", limit=0, partition=None): from :meth:`partition_read`. Incompatible with ``limit``. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -163,7 +180,11 @@ def read(self, table, columns, keyset, index="", limit=0, partition=None): partition_token=partition, ) restart = functools.partial( - api.streaming_read, request=request, metadata=metadata, + api.streaming_read, + request=request, + metadata=metadata, + retry=retry, + timeout=timeout, ) trace_attributes = {"table_id": table, "columns": columns} @@ -186,8 +207,8 @@ def execute_sql( query_mode=None, query_options=None, partition=None, - retry=google.api_core.gapic_v1.method.DEFAULT, - timeout=google.api_core.gapic_v1.method.DEFAULT, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -224,6 +245,12 @@ def execute_sql( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :raises ValueError: for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. @@ -296,6 +323,9 @@ def partition_read( index="", partition_size_bytes=None, max_partitions=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Perform a ``PartitionRead`` API request for rows in a table. @@ -323,6 +353,12 @@ def partition_read( service uses this as a hint, the actual number of partitions may differ. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: iterable of bytes :returns: a sequence of partition tokens @@ -357,7 +393,9 @@ def partition_read( with trace_call( "CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes ): - response = api.partition_read(request=request, metadata=metadata,) + response = api.partition_read( + request=request, metadata=metadata, retry=retry, timeout=timeout, + ) return [partition.partition_token for partition in response.partitions] @@ -368,6 +406,9 @@ def partition_query( param_types=None, partition_size_bytes=None, max_partitions=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Perform a ``PartitionQuery`` API request. @@ -394,6 +435,12 @@ def partition_query( service uses this as a hint, the actual number of partitions may differ. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: iterable of bytes :returns: a sequence of partition tokens @@ -438,7 +485,9 @@ def partition_query( self._session, trace_attributes, ): - response = api.partition_query(request=request, metadata=metadata,) + response = api.partition_query( + request=request, metadata=metadata, retry=retry, timeout=timeout, + ) return [partition.partition_token for partition in response.partitions] diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index c6ff5d3e74..c71bab2581 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -16,8 +16,10 @@ import unittest import mock +from google.api_core import gapic_v1 from google.cloud.spanner_v1.param_types import INT64 +from google.api_core.retry import Retry DML_WO_PARAM = """ DELETE FROM citizens @@ -1949,6 +1951,49 @@ def test_generate_read_batches_w_max_partitions(self): index="", partition_size_bytes=None, max_partitions=max_partitions, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + + def test_generate_read_batches_w_retry_and_timeout_params(self): + max_partitions = len(self.TOKENS) + keyset = self._make_keyset() + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_read.return_value = self.TOKENS + retry = Retry(deadline=60) + batches = list( + batch_txn.generate_read_batches( + self.TABLE, + self.COLUMNS, + keyset, + max_partitions=max_partitions, + retry=retry, + timeout=2.0, + ) + ) + + expected_read = { + "table": self.TABLE, + "columns": self.COLUMNS, + "keyset": {"all": True}, + "index": "", + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch["partition"], token) + self.assertEqual(batch["read"], expected_read) + + snapshot.partition_read.assert_called_once_with( + table=self.TABLE, + columns=self.COLUMNS, + keyset=keyset, + index="", + partition_size_bytes=None, + max_partitions=max_partitions, + retry=retry, + timeout=2.0, ) def test_generate_read_batches_w_index_w_partition_size_bytes(self): @@ -1987,6 +2032,8 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self): index=self.INDEX, partition_size_bytes=size, max_partitions=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ) def test_process_read_batch(self): @@ -2016,6 +2063,39 @@ def test_process_read_batch(self): keyset=keyset, index=self.INDEX, partition=token, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + + def test_process_read_batch_w_retry_timeout(self): + keyset = self._make_keyset() + token = b"TOKEN" + batch = { + "partition": token, + "read": { + "table": self.TABLE, + "columns": self.COLUMNS, + "keyset": {"all": True}, + "index": self.INDEX, + }, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.read.return_value = object() + retry = Retry(deadline=60) + found = batch_txn.process_read_batch(batch, retry=retry, timeout=2.0) + + self.assertIs(found, expected) + + snapshot.read.assert_called_once_with( + table=self.TABLE, + columns=self.COLUMNS, + keyset=keyset, + index=self.INDEX, + partition=token, + retry=retry, + timeout=2.0, ) def test_generate_query_batches_w_max_partitions(self): @@ -2044,6 +2124,8 @@ def test_generate_query_batches_w_max_partitions(self): param_types=None, partition_size_bytes=None, max_partitions=max_partitions, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ) def test_generate_query_batches_w_params_w_partition_size_bytes(self): @@ -2083,6 +2165,54 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self): param_types=param_types, partition_size_bytes=size, max_partitions=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + + def test_generate_query_batches_w_retry_and_timeout_params(self): + sql = ( + "SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age" + ) + params = {"max_age": 30} + param_types = {"max_age": "INT64"} + size = 1 << 20 + client = _Client(self.PROJECT_ID) + instance = _Instance(self.INSTANCE_NAME, client=client) + database = _Database(self.DATABASE_NAME, instance=instance) + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + snapshot.partition_query.return_value = self.TOKENS + retry = Retry(deadline=60) + batches = list( + batch_txn.generate_query_batches( + sql, + params=params, + param_types=param_types, + partition_size_bytes=size, + retry=retry, + timeout=2.0, + ) + ) + + expected_query = { + "sql": sql, + "params": params, + "param_types": param_types, + "query_options": client._query_options, + } + self.assertEqual(len(batches), len(self.TOKENS)) + for batch, token in zip(batches, self.TOKENS): + self.assertEqual(batch["partition"], token) + self.assertEqual(batch["query"], expected_query) + + snapshot.partition_query.assert_called_once_with( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=size, + max_partitions=None, + retry=retry, + timeout=2.0, ) def test_process_query_batch(self): @@ -2106,7 +2236,41 @@ def test_process_query_batch(self): self.assertIs(found, expected) snapshot.execute_sql.assert_called_once_with( - sql=sql, params=params, param_types=param_types, partition=token + sql=sql, + params=params, + param_types=param_types, + partition=token, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ) + + def test_process_query_batch_w_retry_timeout(self): + sql = ( + "SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age" + ) + params = {"max_age": 30} + param_types = {"max_age": "INT64"} + token = b"TOKEN" + batch = { + "partition": token, + "query": {"sql": sql, "params": params, "param_types": param_types}, + } + database = self._make_database() + batch_txn = self._make_one(database) + snapshot = batch_txn._snapshot = self._make_snapshot() + expected = snapshot.execute_sql.return_value = object() + retry = Retry(deadline=60) + found = batch_txn.process_query_batch(batch, retry=retry, timeout=2.0) + + self.assertIs(found, expected) + + snapshot.execute_sql.assert_called_once_with( + sql=sql, + params=params, + param_types=param_types, + partition=token, + retry=retry, + timeout=2.0, ) def test_close_wo_session(self): @@ -2160,6 +2324,8 @@ def test_process_w_read_batch(self): keyset=keyset, index=self.INDEX, partition=token, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ) def test_process_w_query_batch(self): @@ -2183,7 +2349,12 @@ def test_process_w_query_batch(self): self.assertIs(found, expected) snapshot.execute_sql.assert_called_once_with( - sql=sql, params=params, param_types=param_types, partition=token + sql=sql, + params=params, + param_types=param_types, + partition=token, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 2305937204..cc9a67cb4d 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -13,7 +13,7 @@ # limitations under the License. -import google.api_core.gapic_v1.method +from google.api_core import gapic_v1 import mock from tests._helpers import ( OpenTelemetryBase, @@ -21,6 +21,7 @@ HAS_OPENTELEMETRY_INSTALLED, ) from google.cloud.spanner_v1.param_types import INT64 +from google.api_core.retry import Retry TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -375,7 +376,15 @@ def test_read_other_error(self): ), ) - def _read_helper(self, multi_use, first=True, count=0, partition=None): + def _read_helper( + self, + multi_use, + first=True, + count=0, + partition=None, + timeout=gapic_v1.method.DEFAULT, + retry=gapic_v1.method.DEFAULT, + ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( PartialResultSet, @@ -426,11 +435,23 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): if partition is not None: # 'limit' and 'partition' incompatible result_set = derived.read( - TABLE_NAME, COLUMNS, keyset, index=INDEX, partition=partition + TABLE_NAME, + COLUMNS, + keyset, + index=INDEX, + partition=partition, + retry=retry, + timeout=timeout, ) else: result_set = derived.read( - TABLE_NAME, COLUMNS, keyset, index=INDEX, limit=LIMIT + TABLE_NAME, + COLUMNS, + keyset, + index=INDEX, + limit=LIMIT, + retry=retry, + timeout=timeout, ) self.assertEqual(derived._read_request_count, count + 1) @@ -474,6 +495,8 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): api.streaming_read.assert_called_once_with( request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], + retry=retry, + timeout=timeout, ) self.assertSpanAttributes( @@ -504,6 +527,17 @@ def test_read_w_multi_use_w_first_w_count_gt_0(self): with self.assertRaises(ValueError): self._read_helper(multi_use=True, first=True, count=1) + def test_read_w_timeout_param(self): + self._read_helper(multi_use=True, first=False, timeout=2.0) + + def test_read_w_retry_param(self): + self._read_helper(multi_use=True, first=False, retry=Retry(deadline=60)) + + def test_read_w_timeout_and_retry_params(self): + self._read_helper( + multi_use=True, first=False, retry=Retry(deadline=60), timeout=2.0 + ) + def test_execute_sql_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -540,8 +574,8 @@ def _execute_sql_helper( partition=None, sql_count=0, query_options=None, - timeout=google.api_core.gapic_v1.method.DEFAULT, - retry=google.api_core.gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + retry=gapic_v1.method.DEFAULT, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -698,7 +732,14 @@ def test_execute_sql_w_query_options(self): ) def _partition_read_helper( - self, multi_use, w_txn, size=None, max_partitions=None, index=None + self, + multi_use, + w_txn, + size=None, + max_partitions=None, + index=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1 import Partition @@ -736,6 +777,8 @@ def _partition_read_helper( index=index, partition_size_bytes=size, max_partitions=max_partitions, + retry=retry, + timeout=timeout, ) ) @@ -759,6 +802,8 @@ def _partition_read_helper( api.partition_read.assert_called_once_with( request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], + retry=retry, + timeout=timeout, ) self.assertSpanAttributes( @@ -809,7 +854,28 @@ def test_partition_read_ok_w_size(self): def test_partition_read_ok_w_max_partitions(self): self._partition_read_helper(multi_use=True, w_txn=True, max_partitions=4) - def _partition_query_helper(self, multi_use, w_txn, size=None, max_partitions=None): + def test_partition_read_ok_w_timeout_param(self): + self._partition_read_helper(multi_use=True, w_txn=True, timeout=2.0) + + def test_partition_read_ok_w_retry_param(self): + self._partition_read_helper( + multi_use=True, w_txn=True, retry=Retry(deadline=60) + ) + + def test_partition_read_ok_w_timeout_and_retry_params(self): + self._partition_read_helper( + multi_use=True, w_txn=True, retry=Retry(deadline=60), timeout=2.0 + ) + + def _partition_query_helper( + self, + multi_use, + w_txn, + size=None, + max_partitions=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionOptions @@ -845,6 +911,8 @@ def _partition_query_helper(self, multi_use, w_txn, size=None, max_partitions=No PARAM_TYPES, partition_size_bytes=size, max_partitions=max_partitions, + retry=retry, + timeout=timeout, ) ) @@ -871,6 +939,8 @@ def _partition_query_helper(self, multi_use, w_txn, size=None, max_partitions=No api.partition_query.assert_called_once_with( request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], + retry=retry, + timeout=timeout, ) self.assertSpanAttributes( @@ -926,6 +996,19 @@ def test_partition_query_ok_w_size(self): def test_partition_query_ok_w_max_partitions(self): self._partition_query_helper(multi_use=True, w_txn=True, max_partitions=4) + def test_partition_query_ok_w_timeout_param(self): + self._partition_query_helper(multi_use=True, w_txn=True, timeout=2.0) + + def test_partition_query_ok_w_retry_param(self): + self._partition_query_helper( + multi_use=True, w_txn=True, retry=Retry(deadline=30) + ) + + def test_partition_query_ok_w_timeout_and_retry_params(self): + self._partition_query_helper( + multi_use=True, w_txn=True, retry=Retry(deadline=60), timeout=2.0 + ) + class TestSnapshot(OpenTelemetryBase): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 3302f68d2d..923a6ec47d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -17,6 +17,7 @@ from tests._helpers import OpenTelemetryBase, StatusCanonicalCode from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode +from google.api_core.retry import Retry from google.api_core import gapic_v1 TABLE_NAME = "citizens" @@ -492,10 +493,10 @@ def test_execute_update_w_timeout_param(self): self._execute_update_helper(timeout=2.0) def test_execute_update_w_retry_param(self): - self._execute_update_helper(retry=gapic_v1.method.DEFAULT) + self._execute_update_helper(retry=Retry(deadline=60)) def test_execute_update_w_timeout_and_retry_params(self): - self._execute_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0) + self._execute_update_helper(retry=Retry(deadline=60), timeout=2.0) def test_execute_update_error(self): database = _Database()