Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: increment seqno before execute calls to prevent InvalidArgument … #19

Merged
merged 9 commits into from Mar 24, 2020
18 changes: 12 additions & 6 deletions google/cloud/spanner_v1/transaction.py
Expand Up @@ -201,6 +201,11 @@ def execute_update(
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = database._instance._client._query_options
Expand All @@ -214,11 +219,9 @@ def execute_update(
param_types=param_types,
query_mode=query_mode,
query_options=query_options,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
return response.stats.row_count_exact

def batch_update(self, statements):
Expand Down Expand Up @@ -259,15 +262,18 @@ def batch_update(self, statements):
transaction = self._make_txn_selector()
api = database.spanner_api

seqno, self._execute_sql_count = (
self._execute_sql_count,
self._execute_sql_count + 1,
)

response = api.execute_batch_dml(
session=self._session.name,
transaction=transaction,
statements=parsed,
seqno=self._execute_sql_count,
seqno=seqno,
metadata=metadata,
)

self._execute_sql_count += 1
row_counts = [
result_set.stats.row_count_exact for result_set in response.result_sets
]
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_snapshot.py
Expand Up @@ -311,6 +311,8 @@ def test_execute_sql_other_error(self):
with self.assertRaises(RuntimeError):
list(derived.execute_sql(SQL_QUERY))

self.assertEqual(derived._execute_sql_count, 1)

def test_execute_sql_w_params_wo_param_types(self):
database = _Database()
session = _Session(database)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_transaction.py
Expand Up @@ -413,6 +413,19 @@ def test_execute_update_new_transaction(self):
def test_execute_update_w_count(self):
self._execute_update_helper(count=1)

def test_execute_update_error(self):
database = _Database()
database.spanner_api = self._make_spanner_api()
database.spanner_api.execute_sql.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

with self.assertRaises(RuntimeError):
transaction.execute_update(DML_QUERY)

self.assertEqual(transaction._execute_sql_count, 1)

def test_execute_update_w_query_options(self):
from google.cloud.spanner_v1.proto.spanner_pb2 import ExecuteSqlRequest

Expand Down Expand Up @@ -513,6 +526,31 @@ def test_batch_update_wo_errors(self):
def test_batch_update_w_errors(self):
self._batch_update_helper(error_after=2, count=1)

def test_batch_update_error(self):
database = _Database()
api = database.spanner_api = self._make_spanner_api()
api.execute_batch_dml.side_effect = RuntimeError()
session = _Session(database)
transaction = self._make_one(session)
transaction._transaction_id = self.TRANSACTION_ID

insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)"
insert_params = {"pkey": 12345, "desc": "DESCRIPTION"}
insert_param_types = {"pkey": "INT64", "desc": "STRING"}
update_dml = 'UPDATE table SET desc = desc + "-amended"'
delete_dml = "DELETE FROM table WHERE desc IS NULL"

dml_statements = [
(insert_dml, insert_params, insert_param_types),
update_dml,
delete_dml,
]

with self.assertRaises(RuntimeError):
transaction.batch_update(dml_statements)

self.assertEqual(transaction._execute_sql_count, 1)

def test_context_mgr_success(self):
import datetime
from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse
Expand Down