Skip to content

Commit

Permalink
test(db_api): increase coverage of db_api (#231)
Browse files Browse the repository at this point in the history
* pref: increase coverage of db_api

* fix: lint

* fix: added missing unit tetst
  • Loading branch information
HemangChothani committed Mar 10, 2021
1 parent a2b53a3 commit 489ac0a
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
152 changes: 152 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Expand Up @@ -183,6 +183,10 @@ def test_close(self):
mock_transaction.rollback = mock_rollback = mock.MagicMock()
connection.close()
mock_rollback.assert_called_once_with()
connection._transaction = mock.MagicMock()
connection._own_pool = False
connection.close()
self.assertTrue(connection.is_closed)

@mock.patch.object(warnings, "warn")
def test_commit(self, mock_warn):
Expand Down Expand Up @@ -379,6 +383,25 @@ def test_run_statement_dont_remember_retried_statements(self):

self.assertEqual(len(connection._statements), 0)

def test_run_statement_w_heterogenous_insert_statements(self):
"""Check that Connection executed heterogenous insert statements."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement

sql = "INSERT INTO T (f1, f2) VALUES (1, 2)"
params = None
param_types = None

connection = self._make_connection()

statement = Statement(sql, params, param_types, ResultsChecksum(), True)
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
connection.run_statement(statement, retried=True)

self.assertEqual(len(connection._statements), 0)

def test_run_statement_w_homogeneous_insert_statements(self):
"""Check that Connection executed homogeneous insert statements."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
Expand Down Expand Up @@ -582,3 +605,132 @@ def test_retry_aborted_retry(self):
mock.call(statement, retried=True),
)
)

def test_retry_transaction_raise_max_internal_retries(self):
"""Check retrying raise an error of max internal retries."""
from google.cloud.spanner_dbapi import connection as conn
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement

conn.MAX_INTERNAL_RETRIES = 0
row = ["field1", "field2"]
connection = self._make_connection()

checksum = ResultsChecksum()
checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with self.assertRaises(Exception):
connection.retry_transaction()

conn.MAX_INTERNAL_RETRIES = 50

def test_retry_aborted_retry_without_delay(self):
"""
Check that in case of a retried transaction failed,
the connection will retry it once again.
"""
from google.api_core.exceptions import Aborted
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.connection import connect
from google.cloud.spanner_dbapi.cursor import Statement

row = ["field1", "field2"]

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True,
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")

cursor = connection.cursor()
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

metadata_mock = mock.Mock()
metadata_mock.trailing_metadata.return_value = {}

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
side_effect=(
Aborted("Aborted", errors=[metadata_mock]),
([row], ResultsChecksum()),
),
) as retry_mock:
with mock.patch(
"google.cloud.spanner_dbapi.connection._get_retry_delay",
return_value=False,
):
connection.retry_transaction()

retry_mock.assert_has_calls(
(
mock.call(statement, retried=True),
mock.call(statement, retried=True),
)
)

def test_retry_transaction_w_multiple_statement(self):
"""Check retrying an aborted transaction."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement

row = ["field1", "field2"]
connection = self._make_connection()

checksum = ResultsChecksum()
checksum.consume_result(row)
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum, False)
statement1 = Statement("SELECT 2", [], {}, checksum, False)
connection._statements.append(statement)
connection._statements.append(statement1)

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
return_value=([row], retried_checkum),
) as run_mock:
with mock.patch(
"google.cloud.spanner_dbapi.connection._compare_checksums"
) as compare_mock:
connection.retry_transaction()

compare_mock.assert_called_with(checksum, retried_checkum)

run_mock.assert_called_with(statement1, retried=True)

def test_retry_transaction_w_empty_response(self):
"""Check retrying an aborted transaction."""
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Statement

row = []
connection = self._make_connection()

checksum = ResultsChecksum()
checksum.count = 1
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
return_value=(row, retried_checkum),
) as run_mock:
with mock.patch(
"google.cloud.spanner_dbapi.connection._compare_checksums"
) as compare_mock:
connection.retry_transaction()

compare_mock.assert_called_with(checksum, retried_checkum)

run_mock.assert_called_with(statement, retried=True)
25 changes: 25 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -140,6 +140,31 @@ def test_execute_autocommit_off(self):
self.assertIsInstance(cursor._result_set, mock.MagicMock)
self.assertIsInstance(cursor._itr, PeekIterator)

def test_execute_insert_statement_autocommit_off(self):
from google.cloud.spanner_dbapi import parse_utils
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.utils import PeekIterator

connection = self._make_connection(self.INSTANCE, mock.MagicMock())
cursor = self._make_one(connection)
cursor.connection._autocommit = False
cursor.connection.transaction_checkout = mock.MagicMock(autospec=True)

cursor._checksum = ResultsChecksum()
with mock.patch(
"google.cloud.spanner_dbapi.parse_utils.classify_stmt",
return_value=parse_utils.STMT_INSERT,
):
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.run_statement",
return_value=(mock.MagicMock(), ResultsChecksum()),
):
cursor.execute(
sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)"
)
self.assertIsInstance(cursor._result_set, mock.MagicMock)
self.assertIsInstance(cursor._itr, PeekIterator)

def test_execute_statement(self):
from google.cloud.spanner_dbapi import parse_utils

Expand Down
16 changes: 16 additions & 0 deletions tests/unit/spanner_dbapi/test_utils.py
Expand Up @@ -85,3 +85,19 @@ def test_backtick_unicode(self):
with self.subTest(sql=sql):
got = backtick_unicode(sql)
self.assertEqual(got, want)

@unittest.skipIf(skip_condition, skip_message)
def test_StreamedManyResultSets(self):
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets

cases = [
("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]),
("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]),
]

for name, data_in, expected in cases:
with self.subTest(name=name):
stream_result = StreamedManyResultSets()
stream_result._iterators.append(data_in)
actual = list(stream_result)
self.assertEqual(actual, expected)

0 comments on commit 489ac0a

Please sign in to comment.