From 38abc343dcfe500671834828ff389843cee7019e Mon Sep 17 00:00:00 2001 From: larkee Date: Tue, 18 Feb 2020 16:31:59 +1100 Subject: [PATCH 1/5] fix: increment seqno before execute calls to prevent InvalidArgument errors after a previous error --- google/cloud/spanner_v1/snapshot.py | 5 ++--- google/cloud/spanner_v1/transaction.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ec7008fb75..d52327bc94 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -225,12 +225,11 @@ def execute_sql( retry=retry, timeout=timeout, ) - - iterator = _restart_on_unavailable(restart) - self._read_request_count += 1 self._execute_sql_count += 1 + iterator = _restart_on_unavailable(restart) + if self._multi_use: return StreamedResultSet(iterator, source=self) else: diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 29a2e5f786..7f27e74988 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -191,6 +191,9 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): transaction = self._make_txn_selector() api = database.spanner_api + seqno = self._execute_sql_count + self._execute_sql_count += 1 + response = api.execute_sql( self._session.name, dml, @@ -198,11 +201,9 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): params=params_pb, param_types=param_types, query_mode=query_mode, - seqno=self._execute_sql_count, + seqno=seqno, metadata=metadata, ) - - self._execute_sql_count += 1 return response.stats.row_count_exact def batch_update(self, statements): @@ -243,15 +244,16 @@ def batch_update(self, statements): transaction = self._make_txn_selector() api = database.spanner_api + seqno = self._execute_sql_count + self._execute_sql_count += 1 + response = api.execute_batch_dml( session=self._session.name, transaction=transaction, statements=parsed, - seqno=self._execute_sql_count, + seqno=seqno, metadata=metadata, ) - - self._execute_sql_count += 1 row_counts = [ result_set.stats.row_count_exact for result_set in response.result_sets ] From fa3364c13c5106ce578cc53ad7ded2382e35ed2b Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 27 Feb 2020 10:37:41 +1100 Subject: [PATCH 2/5] make assignments atomic --- google/cloud/spanner_v1/transaction.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 22f8cad8b6..861a4c7baf 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -191,8 +191,10 @@ def execute_update(self, dml, params=None, param_types=None, query_mode=None): transaction = self._make_txn_selector() api = database.spanner_api - seqno = self._execute_sql_count - self._execute_sql_count += 1 + seqno, self._execute_sql_count = ( + self._execute_sql_count, + self._execute_sql_count + 1, + ) response = api.execute_sql( self._session.name, @@ -244,8 +246,10 @@ def batch_update(self, statements): transaction = self._make_txn_selector() api = database.spanner_api - seqno = self._execute_sql_count - self._execute_sql_count += 1 + seqno, self._execute_sql_count = ( + self._execute_sql_count, + self._execute_sql_count + 1, + ) response = api.execute_batch_dml( session=self._session.name, From 82588e4664546a068409f5e2fcefd225fe43ed76 Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 27 Feb 2020 10:50:49 +1100 Subject: [PATCH 3/5] add and update tests --- tests/unit/test_snapshot.py | 2 ++ tests/unit/test_transaction.py | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 883ab73258..e202489942 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -311,6 +311,8 @@ def test_execute_sql_other_error(self): with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) + self.assertEqual(derived._execute_sql_count, 1) + def test_execute_sql_w_params_wo_param_types(self): database = _Database() session = _Session(database) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 9ef13c2ab6..e7c99031c4 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -399,6 +399,19 @@ def test_execute_update_new_transaction(self): def test_execute_update_w_count(self): self._execute_update_helper(count=1) + def test_execute_update_error(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + database.spanner_api.execute_sql.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(RuntimeError): + transaction.execute_update(DML_QUERY) + + self.assertEqual(transaction._execute_sql_count, 1) + def test_batch_update_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -492,6 +505,32 @@ def test_batch_update_wo_errors(self): def test_batch_update_w_errors(self): self._batch_update_helper(error_after=2, count=1) + def test_batch_update_error(self): + database = _Database() + api = database.spanner_api = self._make_spanner_api() + api.execute_batch_dml.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)" + insert_params = {"pkey": 12345, "desc": "DESCRIPTION"} + insert_param_types = {"pkey": "INT64", "desc": "STRING"} + update_dml = 'UPDATE table SET desc = desc + "-amended"' + delete_dml = "DELETE FROM table WHERE desc IS NULL" + + dml_statements = [ + (insert_dml, insert_params, insert_param_types), + update_dml, + delete_dml, + ] + + with self.assertRaises(RuntimeError): + transaction.batch_update(dml_statements) + + self.assertEqual(transaction._execute_sql_count, 1) + + def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse From c67c8ba7e43ce5e2f9d90a694ca903d8933d58ee Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 27 Feb 2020 10:51:40 +1100 Subject: [PATCH 4/5] revert snapshot.py change --- google/cloud/spanner_v1/snapshot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index d52327bc94..ec7008fb75 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -225,11 +225,12 @@ def execute_sql( retry=retry, timeout=timeout, ) - self._read_request_count += 1 - self._execute_sql_count += 1 iterator = _restart_on_unavailable(restart) + self._read_request_count += 1 + self._execute_sql_count += 1 + if self._multi_use: return StreamedResultSet(iterator, source=self) else: From 1bd89836fb0d8e246ed38e834f49bbeac865f6ad Mon Sep 17 00:00:00 2001 From: larkee Date: Thu, 27 Feb 2020 11:26:33 +1100 Subject: [PATCH 5/5] formatting --- tests/unit/test_transaction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e7c99031c4..1da588a860 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -530,7 +530,6 @@ def test_batch_update_error(self): self.assertEqual(transaction._execute_sql_count, 1) - def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse