Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added retry and timeout params to partition read in database and snapshot class #278

Merged
merged 7 commits into from Mar 24, 2021
23 changes: 23 additions & 0 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -26,6 +26,7 @@
from google.api_core.retry import if_exception_type
from google.cloud.exceptions import NotFound
from google.api_core.exceptions import Aborted
from google.api_core import gapic_v1
import six

# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -862,6 +863,9 @@ def generate_read_batches(
index="",
partition_size_bytes=None,
max_partitions=None,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
"""Start a partitioned batch read operation.

Expand Down Expand Up @@ -893,6 +897,12 @@ def generate_read_batches(
service uses this as a hint, the actual number of partitions may
differ.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: iterable of dict
:returns:
mappings of information used peform actual partitioned reads via
Expand All @@ -905,6 +915,8 @@ def generate_read_batches(
index=index,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

read_info = {
Expand Down Expand Up @@ -940,6 +952,9 @@ def generate_query_batches(
partition_size_bytes=None,
max_partitions=None,
query_options=None,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
"""Start a partitioned query operation.

Expand Down Expand Up @@ -983,6 +998,12 @@ def generate_query_batches(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.QueryOptions`

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: iterable of dict
:returns:
mappings of information used peform actual partitioned reads via
Expand All @@ -994,6 +1015,8 @@ def generate_query_batches(
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {"sql": sql}
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/session.py
Expand Up @@ -258,6 +258,12 @@ def execute_sql(
or :class:`dict`
:param query_options: (Optional) Options that are provided for query plan stability.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
Expand Down
38 changes: 33 additions & 5 deletions google/cloud/spanner_v1/snapshot.py
Expand Up @@ -27,7 +27,7 @@

from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import ServiceUnavailable
import google.api_core.gapic_v1.method
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
Expand Down Expand Up @@ -186,8 +186,8 @@ def execute_sql(
query_mode=None,
query_options=None,
partition=None,
retry=google.api_core.gapic_v1.method.DEFAULT,
timeout=google.api_core.gapic_v1.method.DEFAULT,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
"""Perform an ``ExecuteStreamingSql`` API request.

Expand Down Expand Up @@ -224,6 +224,12 @@ def execute_sql(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:raises ValueError:
for reuse of single-use snapshots, or if a transaction ID is
already pending for multiple-use snapshots.
Expand Down Expand Up @@ -296,6 +302,9 @@ def partition_read(
index="",
partition_size_bytes=None,
max_partitions=None,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
"""Perform a ``ParitionRead`` API request for rows in a table.

Expand Down Expand Up @@ -323,6 +332,12 @@ def partition_read(
service uses this as a hint, the actual number of partitions may
differ.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: iterable of bytes
:returns: a sequence of partition tokens

Expand Down Expand Up @@ -357,7 +372,9 @@ def partition_read(
with trace_call(
"CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
):
response = api.partition_read(request=request, metadata=metadata,)
response = api.partition_read(
request=request, metadata=metadata, retry=retry, timeout=timeout,
)

return [partition.partition_token for partition in response.partitions]

Expand All @@ -368,6 +385,9 @@ def partition_query(
param_types=None,
partition_size_bytes=None,
max_partitions=None,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
):
"""Perform a ``ParitionQuery`` API request.

Expand All @@ -394,6 +414,12 @@ def partition_query(
service uses this as a hint, the actual number of partitions may
differ.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: iterable of bytes
:returns: a sequence of partition tokens

Expand Down Expand Up @@ -438,7 +464,9 @@ def partition_query(
self._session,
trace_attributes,
):
response = api.partition_query(request=request, metadata=metadata,)
response = api.partition_query(
request=request, metadata=metadata, retry=retry, timeout=timeout,
)

return [partition.partition_token for partition in response.partitions]

Expand Down
96 changes: 96 additions & 0 deletions tests/unit/test_database.py
Expand Up @@ -16,6 +16,7 @@
import unittest

import mock
from google.api_core import gapic_v1

from google.cloud.spanner_v1.param_types import INT64

Expand Down Expand Up @@ -1768,6 +1769,49 @@ def test_generate_read_batches_w_max_partitions(self):
index="",
partition_size_bytes=None,
max_partitions=max_partitions,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_generate_read_batches_w_retry_and_timeout_params(self):
max_partitions = len(self.TOKENS)
keyset = self._make_keyset()
database = self._make_database()
batch_txn = self._make_one(database)
snapshot = batch_txn._snapshot = self._make_snapshot()
snapshot.partition_read.return_value = self.TOKENS

batches = list(
batch_txn.generate_read_batches(
self.TABLE,
self.COLUMNS,
keyset,
max_partitions=max_partitions,
retry={},
timeout=2.0,
)
)

expected_read = {
"table": self.TABLE,
"columns": self.COLUMNS,
"keyset": {"all": True},
"index": "",
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
self.assertEqual(batch["partition"], token)
self.assertEqual(batch["read"], expected_read)

snapshot.partition_read.assert_called_once_with(
table=self.TABLE,
columns=self.COLUMNS,
keyset=keyset,
index="",
partition_size_bytes=None,
max_partitions=max_partitions,
retry={},
timeout=2.0,
)

def test_generate_read_batches_w_index_w_partition_size_bytes(self):
Expand Down Expand Up @@ -1806,6 +1850,8 @@ def test_generate_read_batches_w_index_w_partition_size_bytes(self):
index=self.INDEX,
partition_size_bytes=size,
max_partitions=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_process_read_batch(self):
Expand Down Expand Up @@ -1863,6 +1909,8 @@ def test_generate_query_batches_w_max_partitions(self):
param_types=None,
partition_size_bytes=None,
max_partitions=max_partitions,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_generate_query_batches_w_params_w_partition_size_bytes(self):
Expand Down Expand Up @@ -1902,6 +1950,54 @@ def test_generate_query_batches_w_params_w_partition_size_bytes(self):
param_types=param_types,
partition_size_bytes=size,
max_partitions=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
)

def test_generate_query_batches_w_retry_and_timeout_params(self):
sql = (
"SELECT first_name, last_name, email FROM citizens " "WHERE age <= @max_age"
)
params = {"max_age": 30}
param_types = {"max_age": "INT64"}
size = 1 << 20
client = _Client(self.PROJECT_ID)
instance = _Instance(self.INSTANCE_NAME, client=client)
database = _Database(self.DATABASE_NAME, instance=instance)
batch_txn = self._make_one(database)
snapshot = batch_txn._snapshot = self._make_snapshot()
snapshot.partition_query.return_value = self.TOKENS

batches = list(
batch_txn.generate_query_batches(
sql,
params=params,
param_types=param_types,
partition_size_bytes=size,
retry={},
timeout=2.0,
)
)

expected_query = {
"sql": sql,
"params": params,
"param_types": param_types,
"query_options": client._query_options,
}
self.assertEqual(len(batches), len(self.TOKENS))
for batch, token in zip(batches, self.TOKENS):
self.assertEqual(batch["partition"], token)
self.assertEqual(batch["query"], expected_query)

snapshot.partition_query.assert_called_once_with(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=size,
max_partitions=None,
retry={},
timeout=2.0,
)

def test_process_query_batch(self):
Expand Down