From adeacee3cc07260fa9fcd496b3187402f02bf157 Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Tue, 24 Mar 2020 17:26:10 +1300 Subject: [PATCH] =?UTF-8?q?fix:=20increment=20seqno=20before=20execute=20c?= =?UTF-8?q?alls=20to=20prevent=20InvalidArgument=20=E2=80=A6=20(#19)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: increment seqno before execute calls to prevent InvalidArgument errors after a previous error * make assignments atomic * add and update tests * revert snapshot.py change * formatting Co-authored-by: larkee --- google/cloud/spanner_v1/transaction.py | 18 ++++++++---- tests/unit/test_snapshot.py | 2 ++ tests/unit/test_transaction.py | 38 ++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 5a161fd8a6..27c260212e 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -201,6 +201,11 @@ def execute_update( transaction = self._make_txn_selector() api = database.spanner_api + seqno, self._execute_sql_count = ( + self._execute_sql_count, + self._execute_sql_count + 1, + ) + # Query-level options have higher precedence than client-level and # environment-level options default_query_options = database._instance._client._query_options @@ -214,11 +219,9 @@ def execute_update( param_types=param_types, query_mode=query_mode, query_options=query_options, - seqno=self._execute_sql_count, + seqno=seqno, metadata=metadata, ) - - self._execute_sql_count += 1 return response.stats.row_count_exact def batch_update(self, statements): @@ -259,15 +262,18 @@ def batch_update(self, statements): transaction = self._make_txn_selector() api = database.spanner_api + seqno, self._execute_sql_count = ( + self._execute_sql_count, + self._execute_sql_count + 1, + ) + response = api.execute_batch_dml( session=self._session.name, transaction=transaction, statements=parsed, - seqno=self._execute_sql_count, + seqno=seqno, metadata=metadata, ) - - self._execute_sql_count += 1 row_counts = [ result_set.stats.row_count_exact for result_set in response.result_sets ] diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index e29b19d5f1..40ba1c6c5a 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -311,6 +311,8 @@ def test_execute_sql_other_error(self): with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) + self.assertEqual(derived._execute_sql_count, 1) + def test_execute_sql_w_params_wo_param_types(self): database = _Database() session = _Session(database) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index dcb6cb95d3..6ae24aedab 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -413,6 +413,19 @@ def test_execute_update_new_transaction(self): def test_execute_update_w_count(self): self._execute_update_helper(count=1) + def test_execute_update_error(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + database.spanner_api.execute_sql.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(RuntimeError): + transaction.execute_update(DML_QUERY) + + self.assertEqual(transaction._execute_sql_count, 1) + def test_execute_update_w_query_options(self): from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest @@ -513,6 +526,31 @@ def test_batch_update_wo_errors(self): def test_batch_update_w_errors(self): self._batch_update_helper(error_after=2, count=1) + def test_batch_update_error(self): + database = _Database() + api = database.spanner_api = self._make_spanner_api() + api.execute_batch_dml.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)" + insert_params = {"pkey": 12345, "desc": "DESCRIPTION"} + insert_param_types = {"pkey": "INT64", "desc": "STRING"} + update_dml = 'UPDATE table SET desc = desc + "-amended"' + delete_dml = "DELETE FROM table WHERE desc IS NULL" + + dml_statements = [ + (insert_dml, insert_params, insert_param_types), + update_dml, + delete_dml, + ] + + with self.assertRaises(RuntimeError): + transaction.batch_update(dml_statements) + + self.assertEqual(transaction._execute_sql_count, 1) + def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse