Skip to content

Commit

Permalink
feat: add RPC priority support (#324)
Browse files Browse the repository at this point in the history
* feat: add RPC priority support

* Review changes

* Review changes

* Update google/cloud/spanner_v1/database.py

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>

* Update google/cloud/spanner_v1/database.py

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>

* Update session.py

* update import

Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
zoercai and larkee committed Jun 22, 2021
1 parent c1ee8c2 commit 51533b8
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 20 deletions.
2 changes: 2 additions & 0 deletions google/cloud/spanner_v1/__init__.py
Expand Up @@ -28,6 +28,7 @@
from .types.query_plan import PlanNode
from .types.query_plan import QueryPlan
from .types.result_set import PartialResultSet
from .types import RequestOptions
from .types.result_set import ResultSet
from .types.result_set import ResultSetMetadata
from .types.result_set import ResultSetStats
Expand Down Expand Up @@ -119,6 +120,7 @@
"PlanNode",
"QueryPlan",
"ReadRequest",
"RequestOptions",
"ResultSet",
"ResultSetMetadata",
"ResultSetStats",
Expand Down
15 changes: 14 additions & 1 deletion google/cloud/spanner_v1/batch.py
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions

# pylint: enable=ungrouped-imports

Expand Down Expand Up @@ -138,13 +139,20 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self, return_commit_stats=False):
def commit(self, return_commit_stats=False, request_options=None):
"""Commit mutations to the database.
:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:rtype: datetime
:returns: timestamp of the committed changes.
"""
Expand All @@ -154,11 +162,16 @@ def commit(self, return_commit_stats=False):
metadata = _metadata_with_prefix(database.name)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}

if type(request_options) == dict:
request_options = RequestOptions(request_options)

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(request=request, metadata=metadata,)
Expand Down
46 changes: 40 additions & 6 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -58,10 +58,10 @@
TransactionOptions,
)
from google.cloud.spanner_v1.table import Table
from google.cloud.spanner_v1 import RequestOptions

# pylint: enable=ungrouped-imports


SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data"


Expand Down Expand Up @@ -454,7 +454,12 @@ def drop(self):
api.drop_database(database=self.name, metadata=metadata)

def execute_partitioned_dml(
self, dml, params=None, param_types=None, query_options=None
self,
dml,
params=None,
param_types=None,
query_options=None,
request_options=None,
):
"""Execute a partitionable DML statement.
Expand All @@ -478,12 +483,22 @@ def execute_partitioned_dml(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:rtype: int
:returns: Count of rows affected by the DML statement.
"""
query_options = _merge_query_options(
self._instance._client._query_options, query_options
)
if type(request_options) == dict:
request_options = RequestOptions(request_options)

if params is not None:
from google.cloud.spanner_v1.transaction import Transaction

Expand Down Expand Up @@ -517,6 +532,7 @@ def execute_pdml():
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql, metadata=metadata,
Expand Down Expand Up @@ -561,16 +577,23 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self):
def batch(self, request_options=None):
"""Return an object which wraps a batch.
The wrapper *must* be used as a context manager, with the batch
as the value returned by the wrapper.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for the commit request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self)
return BatchCheckout(self, request_options)

def batch_snapshot(self, read_timestamp=None, exact_staleness=None):
"""Return an object which wraps a batch read / query.
Expand Down Expand Up @@ -756,11 +779,19 @@ class BatchCheckout(object):
:type database: :class:`~google.cloud.spanner_v1.database.Database`
:param database: database to use
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for the commit request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
"""

def __init__(self, database):
def __init__(self, database, request_options=None):
self._database = database
self._session = self._batch = None
self._request_options = request_options

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -772,7 +803,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
try:
if exc_type is None:
self._batch.commit(return_commit_stats=self._database.log_commit_stats)
self._batch.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
self._database.logger.info(
Expand Down
20 changes: 18 additions & 2 deletions google/cloud/spanner_v1/session.py
Expand Up @@ -230,6 +230,7 @@ def execute_sql(
param_types=None,
query_mode=None,
query_options=None,
request_options=None,
retry=google.api_core.gapic_v1.method.DEFAULT,
timeout=google.api_core.gapic_v1.method.DEFAULT,
):
Expand Down Expand Up @@ -258,6 +259,13 @@ def execute_sql(
or :class:`dict`
:param query_options: (Optional) Options that are provided for query plan stability.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.
Expand All @@ -273,6 +281,7 @@ def execute_sql(
param_types,
query_mode,
query_options=query_options,
request_options=request_options,
retry=retry,
timeout=timeout,
)
Expand Down Expand Up @@ -319,9 +328,12 @@ def run_in_transaction(self, func, *args, **kw):
:type kw: dict
:param kw: (Optional) keyword arguments to be passed to ``func``.
If passed, "timeout_secs" will be removed and used to
If passed:
"timeout_secs" will be removed and used to
override the default retry timeout which defines maximum timestamp
to continue retrying the transaction.
"commit_request_options" will be removed and used to set the
request options for the commit request.
:rtype: Any
:returns: The return value of ``func``.
Expand All @@ -330,6 +342,7 @@ def run_in_transaction(self, func, *args, **kw):
reraises any non-ABORT exceptions raised by ``func``.
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
commit_request_options = kw.pop("commit_request_options", None)
attempts = 0

while True:
Expand All @@ -355,7 +368,10 @@ def run_in_transaction(self, func, *args, **kw):
raise

try:
txn.commit(return_commit_stats=self._database.log_commit_stats)
txn.commit(
return_commit_stats=self._database.log_commit_stats,
request_options=commit_request_options,
)
except Aborted as exc:
del self._transaction
_delay_until_retry(exc, deadline, attempts)
Expand Down
25 changes: 25 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Expand Up @@ -34,6 +34,7 @@
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
from google.cloud.spanner_v1 import RequestOptions

_STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES = (
"RST_STREAM",
Expand Down Expand Up @@ -124,6 +125,7 @@ def read(
index="",
limit=0,
partition=None,
request_options=None,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -152,6 +154,13 @@ def read(
from :meth:`partition_read`. Incompatible with
``limit``.
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.
Expand All @@ -176,6 +185,9 @@ def read(
metadata = _metadata_with_prefix(database.name)
transaction = self._make_txn_selector()

if type(request_options) == dict:
request_options = RequestOptions(request_options)

request = ReadRequest(
session=self._session.name,
table=table,
Expand All @@ -185,6 +197,7 @@ def read(
index=index,
limit=limit,
partition_token=partition,
request_options=request_options,
)
restart = functools.partial(
api.streaming_read,
Expand Down Expand Up @@ -217,6 +230,7 @@ def execute_sql(
param_types=None,
query_mode=None,
query_options=None,
request_options=None,
partition=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -249,6 +263,13 @@ def execute_sql(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`
:type request_options:
:class:`google.cloud.spanner_v1.types.RequestOptions`
:param request_options:
(Optional) Common options for this request.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type partition: bytes
:param partition: (Optional) one of the partition tokens returned
from :meth:`partition_query`.
Expand Down Expand Up @@ -291,6 +312,9 @@ def execute_sql(
default_query_options = database._instance._client._query_options
query_options = _merge_query_options(default_query_options, query_options)

if type(request_options) == dict:
request_options = RequestOptions(request_options)

request = ExecuteSqlRequest(
session=self._session.name,
sql=sql,
Expand All @@ -301,6 +325,7 @@ def execute_sql(
partition_token=partition,
seqno=self._execute_sql_count,
query_options=query_options,
request_options=request_options,
)
restart = functools.partial(
api.execute_streaming_sql,
Expand Down

0 comments on commit 51533b8

Please sign in to comment.