diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 926408c928..110e0f9b9b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -32,6 +32,8 @@ from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT from google.cloud.spanner_dbapi.version import PY_VERSION +from google.rpc.code_pb2 import ABORTED + AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" MAX_INTERNAL_RETRIES = 50 @@ -175,25 +177,41 @@ def _rerun_previous_statements(self): from the last transaction. """ for statement in self._statements: - res_iter, retried_checksum = self.run_statement(statement, retried=True) - # executing all the completed statements - if statement != self._statements[-1]: - for res in res_iter: - retried_checksum.consume_result(res) - - _compare_checksums(statement.checksum, retried_checksum) - # executing the failed statement + if isinstance(statement, list): + statements, checksum = statement + + transaction = self.transaction_checkout() + status, res = transaction.batch_update(statements) + + if status.code == ABORTED: + self.connection._transaction = None + raise Aborted(status.details) + + retried_checksum = ResultsChecksum() + retried_checksum.consume_result(res) + retried_checksum.consume_result(status.code) + + _compare_checksums(checksum, retried_checksum) else: - # streaming up to the failed result or - # to the end of the streaming iterator - while len(retried_checksum) < len(statement.checksum): - try: - res = next(iter(res_iter)) + res_iter, retried_checksum = self.run_statement(statement, retried=True) + # executing all the completed statements + if statement != self._statements[-1]: + for res in res_iter: retried_checksum.consume_result(res) - except StopIteration: - break - _compare_checksums(statement.checksum, retried_checksum) + _compare_checksums(statement.checksum, retried_checksum) + # executing the failed statement + else: + # streaming up to the failed result or + # to the end of the streaming iterator + while len(retried_checksum) < len(statement.checksum): + try: + res = next(iter(res_iter)) + retried_checksum.consume_result(res) + except StopIteration: + break + + _compare_checksums(statement.checksum, retried_checksum) def transaction_checkout(self): """Get a Cloud Spanner transaction. diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index c5de13b370..dccbf04dc8 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -41,6 +41,8 @@ from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets +from google.rpc.code_pb2 import ABORTED, OK + _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) @@ -156,6 +158,15 @@ def _do_execute_update(self, transaction, sql, params): return result + def _do_batch_update(self, transaction, statements, many_result_set): + status, res = transaction.batch_update(statements) + many_result_set.add_iter(res) + + if status.code == ABORTED: + raise Aborted(status.details) + elif status.code != OK: + raise OperationalError(status.details) + def execute(self, sql, args=None): """Prepares and executes a Spanner database operation. @@ -258,9 +269,50 @@ def executemany(self, operation, seq_of_params): many_result_set = StreamedManyResultSets() - for params in seq_of_params: - self.execute(operation, params) - many_result_set.add_iter(self._itr) + if classification in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING): + statements = [] + + for params in seq_of_params: + sql, params = parse_utils.sql_pyformat_args_to_spanner( + operation, params + ) + statements.append((sql, params, get_param_types(params))) + + if self.connection.autocommit: + self.connection.database.run_in_transaction( + self._do_batch_update, statements, many_result_set + ) + else: + retried = False + while True: + try: + transaction = self.connection.transaction_checkout() + + res_checksum = ResultsChecksum() + if not retried: + self.connection._statements.append( + (statements, res_checksum) + ) + + status, res = transaction.batch_update(statements) + many_result_set.add_iter(res) + res_checksum.consume_result(res) + res_checksum.consume_result(status.code) + + if status.code == ABORTED: + self.connection._transaction = None + raise Aborted(status.details) + elif status.code != OK: + raise OperationalError(status.details) + break + except Aborted: + self.connection.retry_transaction() + retried = True + + else: + for params in seq_of_params: + self.execute(operation, params) + many_result_set.add_iter(self._itr) self._result_set = many_result_set self._itr = many_result_set diff --git a/tests/system/test_system_dbapi.py b/tests/system/test_system_dbapi.py index 6ca1029ae1..28636a561c 100644 --- a/tests/system/test_system_dbapi.py +++ b/tests/system/test_system_dbapi.py @@ -343,20 +343,20 @@ def test_execute_many(self): conn = Connection(Config.INSTANCE, self._db) cursor = conn.cursor() - cursor.execute( + cursor.executemany( """ INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (1, 'first-name', 'last-name', 'test.email@example.com'), - (2, 'first-name2', 'last-name2', 'test.email2@example.com') - """ +VALUES (%s, %s, %s, %s) + """, + [ + (1, "first-name", "last-name", "test.email@example.com"), + (2, "first-name2", "last-name2", "test.email2@example.com"), + ], ) conn.commit() cursor.executemany( - """ -SELECT * FROM contacts WHERE contact_id = @a1 -""", - ({"a1": 1}, {"a1": 2}), + """SELECT * FROM contacts WHERE contact_id = @a1""", ({"a1": 1}, {"a1": 2}), ) res = cursor.fetchall() conn.commit() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index d1a20c2ed2..d7c181ff0b 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -37,6 +37,13 @@ def _make_connection(self, *args, **kwargs): return Connection(*args, **kwargs) + def _transaction_mock(self): + from google.rpc.code_pb2 import OK + + transaction = mock.Mock(committed=False, rolled_back=False) + transaction.batch_update = mock.Mock(return_value=[mock.Mock(code=OK), []]) + return transaction + def test_property_connection(self): connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) @@ -318,6 +325,297 @@ def test_executemany(self, mock_client): (mock.call(operation, (1,)), mock.call(operation, (2,))) ) + def test_executemany_delete_batch_autocommit(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.param_types import INT64 + from google.cloud.spanner_v1.types.spanner import Session + + sql = "DELETE FROM table WHERE col1 = %s" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + connection.autocommit = True + transaction = self._transaction_mock() + cursor = connection.cursor() + + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + cursor.executemany(sql, [(1,), (2,), (3,)]) + + transaction.batch_update.assert_called_once_with( + [ + ("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}), + ("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}), + ("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}), + ] + ) + + def test_executemany_update_batch_autocommit(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.param_types import INT64, STRING + from google.cloud.spanner_v1.types.spanner import Session + + sql = "UPDATE table SET col1 = %s WHERE col2 = %s" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + connection.autocommit = True + transaction = self._transaction_mock() + cursor = connection.cursor() + + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + cursor.executemany(sql, [(1, "a"), (2, "b"), (3, "c")]) + + transaction.batch_update.assert_called_once_with( + [ + ( + "UPDATE table SET col1 = @a0 WHERE col2 = @a1", + {"a0": 1, "a1": "a"}, + {"a0": INT64, "a1": STRING}, + ), + ( + "UPDATE table SET col1 = @a0 WHERE col2 = @a1", + {"a0": 2, "a1": "b"}, + {"a0": INT64, "a1": STRING}, + ), + ( + "UPDATE table SET col1 = @a0 WHERE col2 = @a1", + {"a0": 3, "a1": "c"}, + {"a0": INT64, "a1": STRING}, + ), + ] + ) + + def test_executemany_insert_batch_non_autocommit(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.param_types import INT64 + from google.cloud.spanner_v1.types.spanner import Session + + sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + transaction = self._transaction_mock() + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) + + transaction.batch_update.assert_called_once_with( + [ + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ] + ) + + def test_executemany_insert_batch_autocommit(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.param_types import INT64 + from google.cloud.spanner_v1.types.spanner import Session + + sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + connection.autocommit = True + + transaction = self._transaction_mock() + transaction.commit = mock.Mock() + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) + + transaction.batch_update.assert_called_once_with( + [ + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ] + ) + transaction.commit.assert_called_once() + + def test_executemany_insert_batch_failed(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi.exceptions import OperationalError + from google.cloud.spanner_v1.types.spanner import Session + from google.rpc.code_pb2 import UNKNOWN + + sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" + err_details = "Details here" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + connection.autocommit = True + cursor = connection.cursor() + + transaction = mock.Mock(committed=False, rolled_back=False) + transaction.batch_update = mock.Mock( + return_value=(mock.Mock(code=UNKNOWN, details=err_details), []) + ) + + with mock.patch( + "google.cloud.spanner_v1.services.spanner.client.SpannerClient.create_session", + return_value=Session(), + ): + with mock.patch( + "google.cloud.spanner_v1.session.Session.transaction", + return_value=transaction, + ): + with self.assertRaisesRegex(OperationalError, err_details): + cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) + + def test_executemany_insert_batch_aborted(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_v1.param_types import INT64 + from google.rpc.code_pb2 import ABORTED + + sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" + err_details = "Aborted details here" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + transaction1 = mock.Mock(committed=False, rolled_back=False) + transaction1.batch_update = mock.Mock( + side_effect=[(mock.Mock(code=ABORTED, details=err_details), [])] + ) + + transaction2 = self._transaction_mock() + + connection.transaction_checkout = mock.Mock( + side_effect=[transaction1, transaction2] + ) + connection.retry_transaction = mock.Mock() + + cursor = connection.cursor() + cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) + + transaction1.batch_update.assert_called_with( + [ + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ] + ) + transaction2.batch_update.assert_called_with( + [ + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ] + ) + connection.retry_transaction.assert_called_once() + + self.assertEqual( + connection._statements[0][0], + [ + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ( + """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", + {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, + {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, + ), + ], + ) + self.assertIsInstance(connection._statements[0][1], ResultsChecksum) + @unittest.skipIf( sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" )