Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aborted transactions internal retry #544

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions google/cloud/spanner_dbapi/checksum.py
@@ -0,0 +1,72 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

"""API to calculate checksums of SQL statements results."""

import hashlib
import pickle


class ResultsChecksum:
"""Cumulative checksum.

Used to calculate a total checksum of all the results
returned by operations executed within transaction.
Includes methods for checksums comparison.
These checksums are used while retrying an aborted
transaction to check if the results of a retried transaction
are equal to the results of the original transaction.
"""

def __init__(self):
self.checksum = hashlib.sha256()
self.count = 0 # counter of consumed results

def __len__(self):
"""Return the number of consumed results.

:rtype: :class:`int`
:returns: The number of results.
"""
return self.count

def __eq__(self, other):
"""Check if checksums are equal.

:type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:param other: Another checksum to compare with this one.
"""
return self.checksum.digest() == other.checksum.digest()

def consume_result(self, result):
"""Add the given result into the checksum.

:type result: Union[int, list]
:param result: Streamed row or row count from an UPDATE operation.
"""
self.checksum.update(pickle.dumps(result))
self.count += 1


def _compare_checksums(original, retried):
"""Compare the given checksums.

Raise an error if the given checksums have consumed
the same number of results, but are not equal.

:type original: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param original: results checksum of the original transaction.

:type retried: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum`
:param retried: results checksum of the retried transaction.

:raises: :exc:`RuntimeError` in case if checksums are not equal.
"""
if original is not None:
if len(retried) == len(original) and retried != original:
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
"The underlying data being changed while retrying an aborted transaction."
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
42 changes: 42 additions & 0 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -11,6 +11,7 @@

from google.cloud import spanner_v1

from .checksum import ResultsChecksum
from .cursor import Cursor
from .exceptions import InterfaceError

Expand Down Expand Up @@ -39,6 +40,9 @@ def __init__(self, instance, database):
self._ddl_statements = []
self._transaction = None
self._session = None
# SQL statements, which were executed
# within the current transaction
self._statements = []

self.is_closed = False
self._autocommit = False
Expand Down Expand Up @@ -178,6 +182,42 @@ def run_prior_DDL_statements(self):

return self.__handle_update_ddl(ddl_statements)

def run_statement(self, sql, params, param_types):
"""Run single SQL statement in begun transaction.

This method is never used in autocommit mode. In
!autocommit mode however it remembers every executed
SQL statement with its parameters.

:type sql: :class:`str`
:param sql: SQL statement to execute.

:type params: :class:`dict`
:param params: Params to be used by the given statement.

:type param_types: :class:`dict`
:param param_types: Statement parameters types description.

:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
:returns: Streamed result set of the statement and a
checksum of this statement results.
"""
transaction = self.transaction_checkout()

statement = {
"sql": sql,
"params": params,
"param_types": param_types,
"checksum": ResultsChecksum(),
}
self._statements.append(statement)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

return (
transaction.execute_sql(sql, params, param_types=param_types),
statement["checksum"],
)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def list_tables(self):
return self.run_sql_in_snapshot(
"""
Expand Down Expand Up @@ -244,6 +284,7 @@ def commit(self):
elif self._transaction:
self._transaction.commit()
self._release_session()
self._statements = []

def rollback(self):
"""Rollback all the pending transactions."""
Expand All @@ -252,6 +293,7 @@ def rollback(self):
elif self._transaction:
self._transaction.rollback()
self._release_session()
self._statements = []
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
return self
Expand Down
19 changes: 13 additions & 6 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -67,6 +67,8 @@ def __init__(self, connection):
self._row_count = _UNSET_COUNT
self._connection = connection
self._is_closed = False
# the currently running SQL statement results checksum
self._checksum = None
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

# the number of rows to fetch at a time with fetchmany()
self.arraysize = 1
Expand Down Expand Up @@ -101,11 +103,9 @@ def execute(self, sql, args=None):
self._run_prior_DDL_statements()

if not self._connection.autocommit:
transaction = self._connection.transaction_checkout()

sql, params = sql_pyformat_args_to_spanner(sql, args)

self._res = transaction.execute_sql(
self._res, self._checksum = self._connection.run_statement(
sql, params, param_types=get_param_types(params)
)
self._itr = PeekIterator(self._res)
Expand Down Expand Up @@ -305,14 +305,19 @@ def fetchone(self):
self._raise_if_closed()

try:
return next(self)
res = next(self)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
self._checksum.consume_result(res)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
return res
except StopIteration:
return None

def fetchall(self):
self._raise_if_closed()

return list(self.__iter__())
res = list(self.__iter__())
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
for row in res:
self._checksum.consume_result(row)
return res

def fetchmany(self, size=None):
"""
Expand All @@ -335,7 +340,9 @@ def fetchmany(self, size=None):
items = []
for i in range(size):
try:
items.append(tuple(self.__next__()))
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
res = next(self)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break

Expand Down
44 changes: 44 additions & 0 deletions tests/spanner_dbapi/test_checksum.py
@@ -0,0 +1,44 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import unittest

from google.cloud.spanner_dbapi.checksum import (
_compare_checksums,
ResultsChecksum,
)


class Test_compare_checksums(unittest.TestCase):
def test_no_original_checksum(self):
self.assertIsNone(_compare_checksums(None, ResultsChecksum()))
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def test_equal(self):
original = ResultsChecksum()
original.consume_result(5)

retried = ResultsChecksum()
retried.consume_result(5)

self.assertIsNone(_compare_checksums(original, retried))

def test_less_results(self):
original = ResultsChecksum()
original.consume_result(5)

retried = ResultsChecksum()

self.assertIsNone(_compare_checksums(original, retried))
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def test_mismatch(self):
original = ResultsChecksum()
original.consume_result(5)

retried = ResultsChecksum()
retried.consume_result(2)

with self.assertRaises(RuntimeError):
_compare_checksums(original, retried)
57 changes: 57 additions & 0 deletions tests/spanner_dbapi/test_connection.py
Expand Up @@ -12,6 +12,7 @@
# import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions

from google.cloud.spanner_dbapi import Connection, InterfaceError
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these tests are still here? They were copied into unit directory, so I suppose they should be erased from this directory?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Looks like this (and test_connect) weren't moved in #532?

https://github.com/q-logic/python-spanner-django/blob/41abaebb6f2e0b1cf16704aa1e394acc5a47e68b/tests/spanner_dbapi/test_connection.py

The test files weren't exactly copied, #532 changed them and added some new tests. E.g. the version on master now doesn't include test_transaction_autocommit_warnings.

@mf2199 can you confirm that you meant to change/remove these tests before removing tests/spanner_dbapi in this PR?

Expand Down Expand Up @@ -77,3 +78,59 @@ def test_instance_property(self):

with self.assertRaises(AttributeError):
connection.instance = None

def test_run_statement(self):
"""Check that Connection remembers executed statements."""
statement = """SELECT 23 FROM table WHERE id = @a1"""
params = {"a1": "value"}
param_types = {"a1": str}

connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
connection.run_statement(statement, params, param_types)

self.assertEqual(connection._statements[0]["sql"], statement)
self.assertEqual(connection._statements[0]["params"], params)
self.assertEqual(connection._statements[0]["param_types"], param_types)
self.assertIsInstance(
connection._statements[0]["checksum"], ResultsChecksum
)

def test_clear_statements_on_commit(self):
"""
Check that all the saved statements are
cleared, when the transaction is commited.
"""
connection = self._make_connection()
connection._transaction = mock.Mock()
connection._statements = [{}, {}]

self.assertEqual(len(connection._statements), 2)

with mock.patch(
"google.cloud.spanner_v1.transaction.Transaction.commit"
):
connection.commit()

self.assertEqual(len(connection._statements), 0)

def test_clear_statements_on_rollback(self):
"""
Check that all the saved statements are
cleared, when the transaction is roll backed.
"""
connection = self._make_connection()
connection._transaction = mock.Mock()
connection._statements = [{}, {}]

self.assertEqual(len(connection._statements), 2)

with mock.patch(
"google.cloud.spanner_v1.transaction.Transaction.commit"
):
connection.rollback()

self.assertEqual(len(connection._statements), 0)
32 changes: 31 additions & 1 deletion tests/system/test_system.py
Expand Up @@ -4,8 +4,10 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import unittest
import hashlib
import pickle
import os
import unittest

from google.api_core import exceptions

Expand Down Expand Up @@ -289,6 +291,34 @@ def test_rollback_on_connection_closing(self):
cursor.close()
conn.close()

def test_results_checksum(self):
"""Test that results checksum is calculated properly."""
conn = Connection(Config.INSTANCE, self._db)
cursor = conn.cursor()

cursor.execute(
"""
INSERT INTO contacts (contact_id, first_name, last_name, email)
VALUES
(1, 'first-name', 'last-name', 'test.email@domen.ru'),
(2, 'first-name2', 'last-name2', 'test.email2@domen.ru')
"""
)
self.assertEqual(len(conn._statements), 1)
conn.commit()

cursor.execute("SELECT * FROM contacts")
got_rows = cursor.fetchall()

self.assertEqual(len(conn._statements), 1)
conn.commit()

checksum = hashlib.sha256()
checksum.update(pickle.dumps(got_rows[0]))
checksum.update(pickle.dumps(got_rows[1]))

self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest())


def clear_table(transaction):
"""Clear the test table."""
Expand Down