Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support to log commit stats #205

Merged
merged 9 commits into from Feb 23, 2021
22 changes: 15 additions & 7 deletions google/cloud/spanner_v1/batch.py
Expand Up @@ -14,6 +14,7 @@

"""Context manager for Cloud Spanner batched writes."""

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
from google.cloud.spanner_v1 import TransactionOptions

Expand Down Expand Up @@ -123,6 +124,7 @@ class Batch(_BatchBase):
"""

committed = None
commit_stats = None
"""Timestamp at which the batch was successfully committed."""

def _check_state(self):
Expand All @@ -136,9 +138,13 @@ def _check_state(self):
if self.committed is not None:
raise ValueError("Batch already committed")

def commit(self):
def commit(self, return_commit_stats=False):
skuruppu marked this conversation as resolved.
Show resolved Hide resolved
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
"""
Expand All @@ -148,14 +154,16 @@ def commit(self):
metadata = _metadata_with_prefix(database.name)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
single_use_transaction=txn_options,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed

def __enter__(self):
Expand Down
40 changes: 38 additions & 2 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -17,6 +17,7 @@
import copy
import functools
import grpc
import logging
import re
import threading

Expand Down Expand Up @@ -99,14 +100,18 @@ class Database(object):

_spanner_api = None

def __init__(self, database_id, instance, ddl_statements=(), pool=None):
def __init__(
self, database_id, instance, ddl_statements=(), pool=None, logger=None
):
self.database_id = database_id
self._instance = instance
self._ddl_statements = _check_ddl_statements(ddl_statements)
self._local = threading.local()
self._state = None
self._create_time = None
self._restore_info = None
self.log_commit_stats = False
self._logger = logger

if pool is None:
pool = BurstyPool()
Expand Down Expand Up @@ -216,6 +221,31 @@ def ddl_statements(self):
"""
return self._ddl_statements

@property
def logger(self):
"""Logger used by the database.

The default logger will log commit stats at the log level INFO using
`sys.stderr`.

:rtype: :class:`logging.Logger` or `None`
:returns: the logger
"""
if self._logger is None:
self._logger = logging.getLogger(self.name)
self._logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
larkee marked this conversation as resolved.
Show resolved Hide resolved
)
ch.setFormatter(formatter)

self._logger.addHandler(ch)
return self._logger

@property
def spanner_api(self):
"""Helper for session-related API calls."""
Expand Down Expand Up @@ -624,8 +654,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
try:
if exc_type is None:
self._batch.commit()
self._batch.commit(return_commit_stats=self._database.log_commit_stats)
finally:
if self._database.log_commit_stats:
self._database.logger.info(
"Transaction mutation count: {}".format(
self._batch.commit_stats.mutation_count
)
)
self._database._pool.put(self._session)


Expand Down
6 changes: 4 additions & 2 deletions google/cloud/spanner_v1/instance.py
Expand Up @@ -357,7 +357,7 @@ def delete(self):

api.delete_instance(name=self.name, metadata=metadata)

def database(self, database_id, ddl_statements=(), pool=None):
def database(self, database_id, ddl_statements=(), pool=None, logger=None):
"""Factory to create a database within this instance.

:type database_id: str
Expand All @@ -374,7 +374,9 @@ def database(self, database_id, ddl_statements=(), pool=None):
:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: a database owned by this instance.
"""
return Database(database_id, self, ddl_statements=ddl_statements, pool=pool)
return Database(
database_id, self, ddl_statements=ddl_statements, pool=pool, logger=logger
)

def list_databases(self, page_size=None):
"""List databases for the instance.
Expand Down
8 changes: 7 additions & 1 deletion google/cloud/spanner_v1/session.py
Expand Up @@ -349,14 +349,20 @@ def run_in_transaction(self, func, *args, **kw):
raise

try:
txn.commit()
txn.commit(return_commit_stats=self._database.log_commit_stats)
except Aborted as exc:
del self._transaction
_delay_until_retry(exc, deadline, attempts)
except GoogleAPICallError:
del self._transaction
raise
else:
if self._database.log_commit_stats:
self._database.logger.info(
"Transaction mutation count: {}".format(
txn.commit_stats.mutation_count
)
)
return return_value


Expand Down
23 changes: 16 additions & 7 deletions google/cloud/spanner_v1/transaction.py
Expand Up @@ -21,6 +21,7 @@
_merge_query_options,
_metadata_with_prefix,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import TransactionSelector
Expand All @@ -42,6 +43,7 @@ class Transaction(_SnapshotBase, _BatchBase):
committed = None
"""Timestamp at which the transaction was successfully committed."""
rolled_back = False
commit_stats = None
_multi_use = True
_execute_sql_count = 0

Expand Down Expand Up @@ -119,9 +121,13 @@ def rollback(self):
self.rolled_back = True
del self._session._transaction

def commit(self):
def commit(self, return_commit_stats=False):
"""Commit mutations to the database.

:type return_commit_stats: bool
:param return_commit_stats:
If true, the response will return commit stats which can be accessed though commit_stats.

:rtype: datetime
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
Expand All @@ -132,14 +138,17 @@ def commit(self):
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
trace_attributes = {"num_mutations": len(self._mutations)}
request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
metadata=metadata,
)
response = api.commit(request=request, metadata=metadata,)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
skuruppu marked this conversation as resolved.
Show resolved Hide resolved
del self._session._transaction
return self.committed

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_batch.py
Expand Up @@ -339,17 +339,17 @@ def __init__(self, **kwargs):
self.__dict__.update(**kwargs)

def commit(
self,
session,
mutations,
transaction_id="",
single_use_transaction=None,
metadata=None,
self, request=None, metadata=None,
):
from google.api_core.exceptions import Unknown

assert transaction_id == ""
self._committed = (session, mutations, single_use_transaction, metadata)
assert request.transaction_id == b""
self._committed = (
request.session,
request.mutations,
request.single_use_transaction,
metadata,
)
if self._rpc_error:
raise Unknown("error")
return self._commit_response
86 changes: 85 additions & 1 deletion tests/unit/test_database.py
Expand Up @@ -104,6 +104,8 @@ def test_ctor_defaults(self):
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertIsInstance(database._pool, BurstyPool)
self.assertFalse(database.log_commit_stats)
self.assertIsNone(database._logger)
# BurstyPool does not create sessions during 'bind()'.
self.assertTrue(database._pool._sessions.empty())

Expand Down Expand Up @@ -145,6 +147,18 @@ def test_ctor_w_ddl_statements_ok(self):
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), DDL_STATEMENTS)

def test_ctor_w_explicit_logger(self):
from logging import Logger

instance = _Instance(self.INSTANCE_NAME)
logger = mock.create_autospec(Logger, instance=True)
database = self._make_one(self.DATABASE_ID, instance, logger=logger)
self.assertEqual(database.database_id, self.DATABASE_ID)
self.assertIs(database._instance, instance)
self.assertEqual(list(database.ddl_statements), [])
self.assertFalse(database.log_commit_stats)
self.assertEqual(database._logger, logger)

def test_from_pb_bad_database_name(self):
from google.cloud.spanner_admin_database_v1 import Database

Expand Down Expand Up @@ -249,6 +263,24 @@ def test_restore_info(self):
)
self.assertEqual(database.restore_info, restore_info)

def test_logger_property_default(self):
import logging

instance = _Instance(self.INSTANCE_NAME)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
logger = logging.getLogger(database.name)
self.assertEqual(database.logger, logger)

def test_logger_property_custom(self):
import logging

instance = _Instance(self.INSTANCE_NAME)
pool = _Pool()
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
logger = database._logger = mock.create_autospec(logging.Logger, instance=True)
self.assertEqual(database.logger, logger)

def test_spanner_api_property_w_scopeless_creds(self):

client = _Client()
Expand Down Expand Up @@ -1263,6 +1295,7 @@ def test_ctor(self):

def test_context_mgr_success(self):
import datetime
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -1290,13 +1323,59 @@ def test_context_mgr_success(self):

expected_txn_options = TransactionOptions(read_write={})

request = CommitRequest(
session=self.SESSION_NAME,
mutations=[],
single_use_transaction=expected_txn_options,
)
api.commit.assert_called_once_with(
request=request, metadata=[("google-cloud-resource-prefix", database.name)],
)

def test_context_mgr_w_commit_stats(self):
import datetime
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.cloud.spanner_v1.batch import Batch

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
commit_stats = CommitResponse.CommitStats(mutation_count=4)
response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats)
database = _Database(self.DATABASE_NAME)
database.log_commit_stats = True
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
pool.put(session)
checkout = self._make_one(database)

with checkout as batch:
self.assertIsNone(pool._session)
self.assertIsInstance(batch, Batch)
self.assertIs(batch._session, session)

self.assertIs(pool._session, session)
self.assertEqual(batch.committed, now)

expected_txn_options = TransactionOptions(read_write={})

request = CommitRequest(
session=self.SESSION_NAME,
mutations=[],
single_use_transaction=expected_txn_options,
metadata=[("google-cloud-resource-prefix", database.name)],
return_commit_stats=True,
)
api.commit.assert_called_once_with(
request=request, metadata=[("google-cloud-resource-prefix", database.name)],
)

database.logger.info.assert_called_once_with("Transaction mutation count: 4")

def test_context_mgr_failure(self):
from google.cloud.spanner_v1.batch import Batch

Expand Down Expand Up @@ -1883,10 +1962,15 @@ def __init__(self, name):


class _Database(object):
log_commit_stats = False

def __init__(self, name, instance=None):
self.name = name
self.database_id = name.rsplit("/", 1)[1]
self._instance = instance
from logging import Logger

self.logger = mock.create_autospec(Logger, instance=True)


class _Pool(object):
Expand Down