Skip to content

Commit

Permalink
fix(dbapi): autocommit enabling fails if no transactions begun (#177)
Browse files Browse the repository at this point in the history
* fix(dbapi): autocommit enabling fails if no transactions begun

* remove unused import

* don't calculate checksums in autocommit mode

* try using dummy WHERE clause

* revert where clause

* unveil error

* fix where clauses

* add print

* don't log

* print failed exceptions

* don't print

* separate insert statements

* don't return

* re-run

* don't pyformat insert args

* args

* re-run

* fix

* fix error in transactions.tests.NonAutocommitTests.test_orm_query_without_autocommit

* fix "already committed" error

* fix for AttributeError: 'tuple' object has no attribute 'items'

* fix

* fix KeyError: 'type'

Co-authored-by: Chris Kleinknecht <libc@google.com>
Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
Co-authored-by: Alex <7764119+AVaksman@users.noreply.github.com>
  • Loading branch information
4 people committed Dec 30, 2020
1 parent 4ef793c commit e981adb
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 50 deletions.
50 changes: 38 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -22,6 +22,9 @@
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1.session import _get_retry_delay

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
from google.cloud.spanner_dbapi._helpers import parse_insert
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -82,7 +85,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit:
if value and not self._autocommit and self.inside_transaction:
self.commit()

self._autocommit = value
Expand All @@ -96,6 +99,19 @@ def database(self):
"""
return self._database

@property
def inside_transaction(self):
"""Flag: transaction is started.
Returns:
bool: True if transaction begun, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -191,11 +207,7 @@ def transaction_checkout(self):
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if (
not self._transaction
or self._transaction.committed
or self._transaction.rolled_back
):
if not self.inside_transaction:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -216,11 +228,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
):
if self.inside_transaction:
self._transaction.rollback()

if self._own_pool:
Expand All @@ -235,7 +243,7 @@ def commit(self):
"""
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
elif self.inside_transaction:
try:
self._transaction.commit()
self._release_session()
Expand Down Expand Up @@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False):
if not retried:
self._statements.append(statement)

if statement.is_insert:
parts = parse_insert(statement.sql, statement.params)

if parts.get("homogenous"):
_execute_insert_homogenous(transaction, parts)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)
else:
_execute_insert_heterogenous(
transaction, parts.get("sql_params_list"),
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
statement.sql, statement.params, param_types=statement.param_types,
Expand Down
29 changes: 21 additions & 8 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -42,7 +42,7 @@
_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum")
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")


class Cursor(object):
Expand Down Expand Up @@ -95,9 +95,9 @@ def description(self):
for field in row_type.fields:
column_info = ColumnInfo(
name=field.name,
type_code=field.type.code,
type_code=field.type_.code,
# Size of the SQL type of the column.
display_size=code_to_display_size.get(field.type.code),
display_size=code_to_display_size.get(field.type_.code),
# Client perceived size of the column.
internal_size=field.ByteSize(),
)
Expand Down Expand Up @@ -172,10 +172,20 @@ def execute(self, sql, args=None):
self.connection.run_prior_DDL_statements()

if not self.connection.autocommit:
sql, params = sql_pyformat_args_to_spanner(sql, args)
if classification == parse_utils.STMT_UPDATING:
sql = parse_utils.ensure_where_clause(sql)

if classification != parse_utils.STMT_INSERT:
sql, args = sql_pyformat_args_to_spanner(sql, args or None)

statement = Statement(
sql, params, get_param_types(params), ResultsChecksum(),
sql,
args,
get_param_types(args or None)
if classification != parse_utils.STMT_INSERT
else {},
ResultsChecksum(),
classification == parse_utils.STMT_INSERT,
)
(self._result_set, self._checksum,) = self.connection.run_statement(
statement
Expand Down Expand Up @@ -233,7 +243,8 @@ def fetchone(self):

try:
res = next(self)
self._checksum.consume_result(res)
if not self.connection.autocommit:
self._checksum.consume_result(res)
return res
except StopIteration:
return
Expand All @@ -250,7 +261,8 @@ def fetchall(self):
res = []
try:
for row in self:
self._checksum.consume_result(row)
if not self.connection.autocommit:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
self._connection.retry_transaction()
Expand Down Expand Up @@ -278,7 +290,8 @@ def fetchmany(self, size=None):
for i in range(size):
try:
res = next(self)
self._checksum.consume_result(res)
if not self.connection.autocommit:
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break
Expand Down
8 changes: 2 additions & 6 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -523,19 +523,15 @@ def get_param_types(params):
def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Raise an error, if the given sql doesn't include it.
Add a dummy WHERE clause if non detected.
:type sql: `str`
:param sql: SQL code to check.
:raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause.
"""
if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]):
return sql

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)
return sql + " WHERE 1=1"


def escape_name(name):
Expand Down
69 changes: 51 additions & 18 deletions tests/unit/spanner_dbapi/test_connection.py
Expand Up @@ -15,7 +15,6 @@
"""Cloud Spanner DB-API Connection class unit tests."""

import mock
import sys
import unittest
import warnings

Expand Down Expand Up @@ -51,25 +50,57 @@ def _make_connection(self):
database = instance.database(self.DATABASE)
return Connection(instance, database)

@unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated")
def test_property_autocommit_setter(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(self.INSTANCE, self.DATABASE)
def test_autocommit_setter_transaction_not_started(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection.autocommit = True
mock_commit.assert_called_once_with()
self.assertEqual(connection._autocommit, True)
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection.autocommit = False
mock_commit.assert_not_called()
self.assertEqual(connection._autocommit, False)
self.assertFalse(connection._autocommit)

def test_autocommit_setter_transaction_started(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=False, rolled_back=False)

connection.autocommit = True
mock_commit.assert_called_once()
self.assertTrue(connection._autocommit)

def test_autocommit_setter_transaction_started_commited_rolled_back(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=True, rolled_back=False)

connection.autocommit = True
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

connection.autocommit = False

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=False, rolled_back=True)

connection.autocommit = True
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

def test_property_database(self):
from google.cloud.spanner_v1.database import Database
Expand Down Expand Up @@ -166,7 +197,9 @@ def test_commit(self, mock_warn):
connection.commit()
mock_release.assert_not_called()

connection._transaction = mock_transaction = mock.MagicMock()
connection._transaction = mock_transaction = mock.MagicMock(
rolled_back=False, committed=False
)
mock_transaction.commit = mock_commit = mock.MagicMock()

with mock.patch(
Expand Down Expand Up @@ -316,7 +349,7 @@ def test_run_statement_remember_statements(self):

connection = self._make_connection()

statement = Statement(sql, params, param_types, ResultsChecksum(),)
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
Expand All @@ -338,7 +371,7 @@ def test_run_statement_dont_remember_retried_statements(self):

connection = self._make_connection()

statement = Statement(sql, params, param_types, ResultsChecksum(),)
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
Expand All @@ -352,7 +385,7 @@ def test_clear_statements_on_commit(self):
cleared, when the transaction is commited.
"""
connection = self._make_connection()
connection._transaction = mock.Mock()
connection._transaction = mock.Mock(rolled_back=False, committed=False)
connection._statements = [{}, {}]

self.assertEqual(len(connection._statements), 2)
Expand Down Expand Up @@ -390,7 +423,7 @@ def test_retry_transaction(self):
checksum.consume_result(row)
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum,)
statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -423,7 +456,7 @@ def test_retry_transaction_checksum_mismatch(self):
checksum.consume_result(row)
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum,)
statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -453,9 +486,9 @@ def test_commit_retry_aborted_statements(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)
connection._transaction = mock.Mock()
connection._transaction = mock.Mock(rolled_back=False, committed=False)

with mock.patch.object(
connection._transaction, "commit", side_effect=(Aborted("Aborted"), None),
Expand Down Expand Up @@ -507,7 +540,7 @@ def test_retry_aborted_retry(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

metadata_mock = mock.Mock()
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -126,7 +126,7 @@ def test_execute_attribute_error(self):
cursor = self._make_one(connection)

with self.assertRaises(AttributeError):
cursor.execute(sql="")
cursor.execute(sql="SELECT 1")

def test_execute_autocommit_off(self):
from google.cloud.spanner_dbapi.utils import PeekIterator
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_fetchone_retry_aborted_statements(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -570,7 +570,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -391,7 +391,6 @@ def test_get_param_types_none(self):

@unittest.skipIf(skip_condition, skip_message)
def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = (
Expand All @@ -409,8 +408,7 @@ def test_ensure_where_clause(self):

for sql in err_cases:
with self.subTest(sql=sql):
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)
self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1")

@unittest.skipIf(skip_condition, skip_message)
def test_escape_name(self):
Expand Down

0 comments on commit e981adb

Please sign in to comment.