Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat(dbapi): add aborted transactions retry support (#168)
  • Loading branch information
Ilya Gurov committed Nov 23, 2020
1 parent e801a2e commit d59d502
Show file tree
Hide file tree
Showing 9 changed files with 1,109 additions and 24 deletions.
80 changes: 80 additions & 0 deletions google/cloud/spanner_dbapi/checksum.py
@@ -0,0 +1,80 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""API to calculate checksums of SQL statements results."""

import hashlib
import pickle

from google.cloud.spanner_dbapi.exceptions import RetryAborted


class ResultsChecksum:
"""Cumulative checksum.
Used to calculate a total checksum of all the results
returned by operations executed within transaction.
Includes methods for checksums comparison.
These checksums are used while retrying an aborted
transaction to check if the results of a retried transaction
are equal to the results of the original transaction.
"""

def __init__(self):
self.checksum = hashlib.sha256()
self.count = 0 # counter of consumed results

def __len__(self):
"""Return the number of consumed results.
:rtype: :class:`int`
:returns: The number of results.
"""
return self.count

def __eq__(self, other):
"""Check if checksums are equal.
:type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:param other: Another checksum to compare with this one.
"""
return self.checksum.digest() == other.checksum.digest()

def consume_result(self, result):
"""Add the given result into the checksum.
:type result: Union[int, list]
:param result: Streamed row or row count from an UPDATE operation.
"""
self.checksum.update(pickle.dumps(result))
self.count += 1


def _compare_checksums(original, retried):
"""Compare the given checksums.
Raise an error if the given checksums are not equal.
:type original: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:param original: results checksum of the original transaction.
:type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:param retried: results checksum of the retried transaction.
:raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal.
"""
if retried != original:
raise RetryAborted(
"The transaction was aborted and could not be retried due to a concurrent modification."
)
117 changes: 113 additions & 4 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -14,18 +14,24 @@

"""DB-API Connection for the Google Cloud Spanner."""

import time
import warnings

from google.api_core.exceptions import Aborted
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1.session import _get_retry_delay

from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION


AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
MAX_INTERNAL_RETRIES = 50


class Connection:
Expand All @@ -48,9 +54,16 @@ def __init__(self, instance, database):

self._transaction = None
self._session = None
# SQL statements, which were executed
# within the current transaction
self._statements = []

self.is_closed = False
self._autocommit = False
# indicator to know if the session pool used by
# this connection should be cleared on the
# connection close
self._own_pool = True

@property
def autocommit(self):
Expand Down Expand Up @@ -114,6 +127,58 @@ def _release_session(self):
self.database._pool.put(self._session)
self._session = None

def retry_transaction(self):
"""Retry the aborted transaction.
All the statements executed in the original transaction
will be re-executed in new one. Results checksums of the
original statements and the retried ones will be compared.
:raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted`
If results checksum of the retried statement is
not equal to the checksum of the original one.
"""
attempt = 0
while True:
self._transaction = None
attempt += 1
if attempt > MAX_INTERNAL_RETRIES:
raise

try:
self._rerun_previous_statements()
break
except Aborted as exc:
delay = _get_retry_delay(exc.errors[0], attempt)
if delay:
time.sleep(delay)

def _rerun_previous_statements(self):
"""
Helper to run all the remembered statements
from the last transaction.
"""
for statement in self._statements:
res_iter, retried_checksum = self.run_statement(statement, retried=True)
# executing all the completed statements
if statement != self._statements[-1]:
for res in res_iter:
retried_checksum.consume_result(res)

_compare_checksums(statement.checksum, retried_checksum)
# executing the failed statement
else:
# streaming up to the failed result or
# to the end of the streaming iterator
while len(retried_checksum) < len(statement.checksum):
try:
res = next(iter(res_iter))
retried_checksum.consume_result(res)
except StopIteration:
break

_compare_checksums(statement.checksum, retried_checksum)

def transaction_checkout(self):
"""Get a Cloud Spanner transaction.
Expand Down Expand Up @@ -158,6 +223,9 @@ def close(self):
):
self._transaction.rollback()

if self._own_pool:
self.database._pool.clear()

self.is_closed = True

def commit(self):
Expand All @@ -168,8 +236,13 @@ def commit(self):
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
self._transaction.commit()
self._release_session()
try:
self._transaction.commit()
self._release_session()
self._statements = []
except Aborted:
self.retry_transaction()
self.commit()

def rollback(self):
"""Rolls back any pending transaction.
Expand All @@ -182,6 +255,7 @@ def rollback(self):
elif self._transaction:
self._transaction.rollback()
self._release_session()
self._statements = []

def cursor(self):
"""Factory to create a DB-API Cursor."""
Expand All @@ -198,6 +272,32 @@ def run_prior_DDL_statements(self):

return self.database.update_ddl(ddl_statements).result()

def run_statement(self, statement, retried=False):
"""Run single SQL statement in begun transaction.
This method is never used in autocommit mode. In
!autocommit mode however it remembers every executed
SQL statement with its parameters.
:type statement: :class:`dict`
:param statement: SQL statement to execute.
:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:returns: Streamed result set of the statement and a
checksum of this statement results.
"""
transaction = self.transaction_checkout()
if not retried:
self._statements.append(statement)

return (
transaction.execute_sql(
statement.sql, statement.params, param_types=statement.param_types,
),
ResultsChecksum() if retried else statement.checksum,
)

def __enter__(self):
return self

Expand All @@ -207,7 +307,12 @@ def __exit__(self, etype, value, traceback):


def connect(
instance_id, database_id, project=None, credentials=None, pool=None, user_agent=None
instance_id,
database_id,
project=None,
credentials=None,
pool=None,
user_agent=None,
):
"""Creates a connection to a Google Cloud Spanner database.
Expand Down Expand Up @@ -261,4 +366,8 @@ def connect(
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

return Connection(instance, database)
conn = Connection(instance, database)
if pool is not None:
conn._own_pool = False

return conn
59 changes: 42 additions & 17 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -14,6 +14,7 @@

"""Database cursor for Google Cloud Spanner DB-API."""

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import AlreadyExists
from google.api_core.exceptions import FailedPrecondition
from google.api_core.exceptions import InternalServerError
Expand All @@ -22,7 +23,7 @@
from collections import namedtuple

from google.cloud import spanner_v1 as spanner

from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.exceptions import IntegrityError
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import OperationalError
Expand All @@ -34,11 +35,13 @@

from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.parse_utils import get_param_types
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
from google.cloud.spanner_dbapi.utils import PeekIterator

_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum")


class Cursor(object):
Expand All @@ -54,6 +57,8 @@ def __init__(self, connection):
self._row_count = _UNSET_COUNT
self.connection = connection
self._is_closed = False
# the currently running SQL statement results checksum
self._checksum = None

# the number of rows to fetch at a time with fetchmany()
self.arraysize = 1
Expand Down Expand Up @@ -166,12 +171,13 @@ def execute(self, sql, args=None):
self.connection.run_prior_DDL_statements()

if not self.connection.autocommit:
transaction = self.connection.transaction_checkout()

sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, args)
sql, params = sql_pyformat_args_to_spanner(sql, args)

self._result_set = transaction.execute_sql(
sql, params, param_types=get_param_types(params)
statement = Statement(
sql, params, get_param_types(params), ResultsChecksum(),
)
(self._result_set, self._checksum,) = self.connection.run_statement(
statement
)
self._itr = PeekIterator(self._result_set)
return
Expand Down Expand Up @@ -213,9 +219,31 @@ def fetchone(self):
self._raise_if_closed()

try:
return next(self)
res = next(self)
self._checksum.consume_result(res)
return res
except StopIteration:
return None
return
except Aborted:
self.connection.retry_transaction()
return self.fetchone()

def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as
a sequence of sequences.
"""
self._raise_if_closed()

res = []
try:
for row in self:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
self._connection.retry_transaction()
return self.fetchall()

return res

def fetchmany(self, size=None):
"""Fetch the next set of rows of a query result, returning a sequence
Expand All @@ -236,20 +264,17 @@ def fetchmany(self, size=None):
items = []
for i in range(size):
try:
items.append(tuple(self.__next__()))
res = next(self)
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break
except Aborted:
self._connection.retry_transaction()
return self.fetchmany(size)

return items

def fetchall(self):
"""Fetch all (remaining) rows of a query result, returning them as
a sequence of sequences.
"""
self._raise_if_closed()

return list(self.__iter__())

def nextset(self):
"""A no-op, raising an error if the cursor or connection is closed."""
self._raise_if_closed()
Expand Down

0 comments on commit d59d502

Please sign in to comment.