Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement aborted transactions retry mechanism #1

Closed
wants to merge 13 commits into from
17 changes: 14 additions & 3 deletions google/cloud/spanner_v1/session.py
Expand Up @@ -278,9 +278,13 @@ def batch(self):

return Batch(self)

def transaction(self):
def transaction(self, original_results_checksum=None):
"""Create a transaction to perform a set of reads with shared staleness.

:type original_results_checksum: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param original_results_checksum: original transaction results
checksum.

:rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction`
:returns: a transaction bound to this session
:raises ValueError: if the session has not yet been created.
Expand All @@ -292,7 +296,9 @@ def transaction(self):
self._transaction.rolled_back = True
del self._transaction
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

txn = self._transaction = Transaction(self)
txn = self._transaction = Transaction(
self, original_results_checksum=original_results_checksum
)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
return txn

def run_in_transaction(self, func, *args, **kw):
Expand All @@ -319,11 +325,12 @@ def run_in_transaction(self, func, *args, **kw):
reraises any non-ABORT execptions raised by ``func``.
"""
deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS)
original_results_checksum = None
attempts = 0

while True:
if self._transaction is None:
txn = self.transaction()
txn = self.transaction(original_results_checksum)
else:
txn = self._transaction
if txn._transaction_id is None:
Expand All @@ -333,6 +340,8 @@ def run_in_transaction(self, func, *args, **kw):
attempts += 1
return_value = func(txn, *args, **kw)
except Aborted as exc:
if attempts == 0:
original_results_checksum = self._transaction.results_checksum
del self._transaction
_delay_until_retry(exc, deadline, attempts)
continue
Expand All @@ -346,6 +355,8 @@ def run_in_transaction(self, func, *args, **kw):
try:
txn.commit()
except Aborted as exc:
if attempts == 0:
original_results_checksum = self._transaction.results_checksum
del self._transaction
_delay_until_retry(exc, deadline, attempts)
except GoogleAPICallError:
Expand Down
34 changes: 30 additions & 4 deletions google/cloud/spanner_v1/snapshot.py
Expand Up @@ -171,9 +171,22 @@ def read(self, table, columns, keyset, index="", limit=0, partition=None):
self._read_request_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self)
return StreamedResultSet(
iterator,
source=self,
results_checksum=getattr(self, "results_checksum", None),
original_results_checksum=getattr(
self, "_original_results_checksum", None
),
)
else:
return StreamedResultSet(iterator)
return StreamedResultSet(
iterator,
results_checksum=getattr(self, "results_checksum", None),
original_results_checksum=getattr(
self, "_original_results_checksum", None
),
)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def execute_sql(
self,
Expand Down Expand Up @@ -278,9 +291,22 @@ def execute_sql(
self._execute_sql_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self)
return StreamedResultSet(
iterator,
source=self,
results_checksum=getattr(self, "results_checksum", None),
original_results_checksum=getattr(
self, "_original_results_checksum", None
),
)
else:
return StreamedResultSet(iterator)
return StreamedResultSet(
iterator,
results_checksum=getattr(self, "results_checksum", None),
original_results_checksum=getattr(
self, "_original_results_checksum", None
),
)

def partition_read(
self,
Expand Down
34 changes: 32 additions & 2 deletions google/cloud/spanner_v1/streamed.py
Expand Up @@ -14,6 +14,7 @@

"""Wrapper for streaming results."""

from google.api_core.exceptions import Aborted
from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
from google.cloud import exceptions
Expand All @@ -37,16 +38,32 @@ class StreamedResultSet(object):

:type source: :class:`~google.cloud.spanner_v1.snapshot.Snapshot`
:param source: Snapshot from which the result set was fetched.

:type results_checksum: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param results_checksum: A checksum to which streamed rows from this
result set must be added.

:type original_results_checksum: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param original_results_checksum: Results checksum of the original
transaction.
"""

def __init__(self, response_iterator, source=None):
def __init__(
self,
response_iterator,
source=None,
results_checksum=None,
original_results_checksum=None,
):
self._response_iterator = response_iterator
self._rows = [] # Fully-processed rows
self._metadata = None # Until set from first PRS
self._stats = None # Until set from last PRS
self._current_row = [] # Accumulated values for incomplete row
self._pending_chunk = None # Incomplete value
self._source = source # Source snapshot
self._results_checksum = results_checksum
self._original_results_checksum = original_results_checksum

@property
def fields(self):
Expand Down Expand Up @@ -143,7 +160,20 @@ def __iter__(self):
return
iter_rows, self._rows[:] = self._rows[:], ()
while iter_rows:
yield iter_rows.pop(0)
row = iter_rows.pop(0)
if self._results_checksum is not None:
self._results_checksum.consume_result(row)

if self._original_results_checksum is not None:
if self._results_checksum != self._original_results_checksum:
if (
not self._results_checksum
< self._original_results_checksum
):
raise Aborted(
"The underlying data being changed while retrying."
)
yield row

def one(self):
"""Return exactly one result, or raise an exception.
Expand Down
92 changes: 91 additions & 1 deletion google/cloud/spanner_v1/transaction.py
Expand Up @@ -14,8 +14,12 @@

"""Spanner read-write transaction support."""

import hashlib
import pickle

from google.protobuf.struct_pb2 import Struct

from google.api_core.exceptions import Aborted
from google.cloud._helpers import _pb_timestamp_to_datetime
from google.cloud.spanner_v1._helpers import (
_make_value_pb,
Expand All @@ -35,6 +39,10 @@ class Transaction(_SnapshotBase, _BatchBase):
:type session: :class:`~google.cloud.spanner_v1.session.Session`
:param session: the session used to perform the commit

:type original_results_checksum: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param original_results_checksum: results checksum of the
original transaction.

:raises ValueError: if session has an existing transaction
"""

Expand All @@ -44,11 +52,21 @@ class Transaction(_SnapshotBase, _BatchBase):
_multi_use = True
_execute_sql_count = 0

def __init__(self, session):
def __init__(self, session, original_results_checksum=None):
if session._transaction is not None:
raise ValueError("Session has existing transaction.")

super(Transaction, self).__init__(session)
self._results_checksum = ResultsChecksum()
self._original_results_checksum = original_results_checksum

@property
def results_checksum(self):
"""
Cumulative checksum of all the results returned
by all the operations runned within this transaction.
"""
return self._results_checksum

def _check_state(self):
"""Helper for :meth:`commit` et al.
Expand Down Expand Up @@ -232,6 +250,13 @@ def execute_update(
seqno=seqno,
metadata=metadata,
)
self._results_checksum.consume_result(response.stats.row_count_exact)

if self._original_results_checksum is not None:
if self._results_checksum != self._original_results_checksum:
if not self._results_checksum < self._original_results_checksum:
raise Aborted("The underlying data being changed while retrying.")

return response.stats.row_count_exact

def batch_update(self, statements):
Expand Down Expand Up @@ -292,6 +317,13 @@ def batch_update(self, statements):
row_counts = [
result_set.stats.row_count_exact for result_set in response.result_sets
]
self._results_checksum.consume_result(row_counts)

if self._original_results_checksum is not None:
if self._results_checksum != self._original_results_checksum:
if not self._results_checksum < self._original_results_checksum:
raise Aborted("The underlying data being changed while retrying.")

return response.status, row_counts

def __enter__(self):
Expand All @@ -305,3 +337,61 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()
else:
self.rollback()


class ResultsChecksum:
"""Cumulative checksum.

Used to calculate a total checksum of all the results
returned by operations executed within transaction.
Includes methods for checksums comparison.

These checksums are used while retrying an aborted
transaction to check if the results of a retried transaction
are equal to the results of the original transaction.
"""

def __init__(self):
self.checksum = hashlib.sha256()
self.count = 0 # counter of consumed results

def __eq__(self, other):
"""Check if checksums are equal.

Args:
other (ResultsChecksum):
Another checksum to compare with this one.
"""
same_count = self.count == other.count
same_checksum = self.checksum.digest() == other.checksum.digest()
return same_count and same_checksum

def __ne__(self, other):
"""Check if checksums aren't equal.

Args:
other (ResultsChecksum):
Another checksum to compare with this one.
"""
same_count = self.count != other.count
same_checksum = self.checksum.digest() != other.checksum.digest()
return same_count or same_checksum
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def __lt__(self, other):
"""Check if this checksum have less results than the given one.

Args:
other (ResultsChecksum):
Another checksum to compare with this one.
"""
return self.count < other.count

def consume_result(self, result):
"""Add the given result into the checksum.

Args:
result (Union[int, list]):
Streamed row or row count from an UPDATE operation.
"""
self.checksum.update(pickle.dumps(result))
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
self.count += 1
63 changes: 63 additions & 0 deletions tests/unit/test_streamed.py
Expand Up @@ -35,6 +35,8 @@ def test_ctor_defaults(self):
self.assertEqual(list(streamed), [])
self.assertIsNone(streamed.metadata)
self.assertIsNone(streamed.stats)
self.assertIsNone(streamed._results_checksum)
self.assertIsNone(streamed._original_results_checksum)

def test_ctor_w_source(self):
iterator = _MockCancellableIterator()
Expand All @@ -46,6 +48,23 @@ def test_ctor_w_source(self):
self.assertIsNone(streamed.metadata)
self.assertIsNone(streamed.stats)

def test_ctor_w_checksums(self):
from google.cloud.spanner_v1.transaction import ResultsChecksum

checksum = ResultsChecksum()
orig_checksum = ResultsChecksum()
iterator = _MockCancellableIterator()
streamed = self._make_one(
iterator, results_checksum=checksum, original_results_checksum=orig_checksum
)

self.assertIs(streamed._response_iterator, iterator)
self.assertEqual(list(streamed), [])
self.assertEqual(streamed._results_checksum, checksum)
self.assertEqual(streamed._original_results_checksum, orig_checksum)
self.assertIsNone(streamed.metadata)
self.assertIsNone(streamed.stats)

def test_fields_unset(self):
iterator = _MockCancellableIterator()
streamed = self._make_one(iterator)
Expand Down Expand Up @@ -745,6 +764,50 @@ def test___iter___empty(self):
found = list(streamed)
self.assertEqual(found, [])

def test___iter___checksum(self):
from google.cloud.spanner_v1.transaction import ResultsChecksum

BARE = [u"Phred Phlyntstone", 42]
VALUES = [self._make_value(bare) for bare in BARE]
FIELDS = [
self._make_scalar_field("full_name", "STRING"),
self._make_scalar_field("age", "INT64"),
]

etalon_cs = ResultsChecksum()
etalon_cs.consume_result(BARE)

metadata = self._make_result_set_metadata(FIELDS)
result_set = self._make_partial_result_set(VALUES, metadata=metadata)
iterator = _MockCancellableIterator(result_set)
streamed = self._make_one(iterator, results_checksum=ResultsChecksum())
found = list(streamed)

self.assertEqual(found, [BARE])
self.assertTrue(streamed._results_checksum == etalon_cs)

def test___iter___checksum_mismatch(self):
from google.api_core.exceptions import Aborted
from google.cloud.spanner_v1.transaction import ResultsChecksum

BARE = [u"Phred Phlyntstone", 42]
VALUES = [self._make_value(bare) for bare in BARE]
FIELDS = [
self._make_scalar_field("full_name", "STRING"),
self._make_scalar_field("age", "INT64"),
]

metadata = self._make_result_set_metadata(FIELDS)
result_set = self._make_partial_result_set(VALUES, metadata=metadata)
iterator = _MockCancellableIterator(result_set)

streamed = self._make_one(iterator, results_checksum=ResultsChecksum())
streamed._original_results_checksum = ResultsChecksum()
streamed._original_results_checksum.consume_result(2)

with self.assertRaises(Aborted):
list(streamed)

def test___iter___one_result_set_partial(self):
FIELDS = [
self._make_scalar_field("full_name", "STRING"),
Expand Down