Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dbapi): autocommit enabling fails if no transactions begun #177

Merged
merged 35 commits into from Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
73f6330
fix(dbapi): autocommit enabling fails if no transactions begun
Nov 30, 2020
632bc20
remove unused import
Nov 30, 2020
9f09bfb
Merge branch 'master' into autocommit_change
c24t Dec 1, 2020
1f8e864
don't calculate checksums in autocommit mode
Dec 3, 2020
0841fef
Merge branch 'autocommit_change' of https://github.com/q-logic/python…
Dec 3, 2020
fcc20ae
try using dummy WHERE clause
Dec 3, 2020
73f91ef
revert where clause
Dec 3, 2020
df9c35f
unveil error
Dec 4, 2020
5f04934
fix where clauses
Dec 4, 2020
6a0ff47
add print
Dec 7, 2020
bb75c34
don't log
Dec 8, 2020
e9e5260
Merge branch 'master' into autocommit_change
c24t Dec 10, 2020
1527a0e
print failed exceptions
Dec 14, 2020
6294e97
Merge branch 'autocommit_change' of https://github.com/q-logic/python…
Dec 14, 2020
dd4e74e
don't print
Dec 14, 2020
092c334
separate insert statements
Dec 15, 2020
5d997fd
don't return
Dec 15, 2020
9e07ec8
re-run
Dec 15, 2020
914930f
don't pyformat insert args
Dec 15, 2020
824d3ff
args
Dec 15, 2020
18e815f
re-run
Dec 15, 2020
7ddd0b5
fix
Dec 15, 2020
96f5df1
Merge branch 'master' into autocommit_change
c24t Dec 15, 2020
705dd1d
Merge branch 'master' into autocommit_change
c24t Dec 17, 2020
51da009
Merge branch 'master' into autocommit_change
larkee Dec 17, 2020
e40ccc6
fix error in transactions.tests.NonAutocommitTests.test_orm_query_wit…
Dec 21, 2020
2b4d830
Merge branch 'autocommit_change' of https://github.com/q-logic/python…
Dec 21, 2020
2124092
fix "already committed" error
Dec 22, 2020
d36becb
Merge branch 'master' into autocommit_change
c24t Dec 22, 2020
6a4e248
fix for AttributeError: 'tuple' object has no attribute 'items'
Dec 23, 2020
9406e68
Merge branch 'autocommit_change' of https://github.com/q-logic/python…
Dec 23, 2020
dbd75ce
fix
Dec 23, 2020
f9306b2
fix KeyError: 'type'
Dec 24, 2020
5fa4dff
Merge branch 'master' into autocommit_change
AVaksman Dec 29, 2020
e50aa83
Merge branch 'master' into autocommit_change
larkee Dec 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 38 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -22,6 +22,9 @@
from google.cloud import spanner_v1 as spanner
from google.cloud.spanner_v1.session import _get_retry_delay

from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous
from google.cloud.spanner_dbapi._helpers import _execute_insert_homogenous
from google.cloud.spanner_dbapi._helpers import parse_insert
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
Expand Down Expand Up @@ -82,7 +85,7 @@ def autocommit(self, value):
:type value: bool
:param value: New autocommit mode state.
"""
if value and not self._autocommit:
if value and not self._autocommit and self.inside_transaction:
self.commit()

self._autocommit = value
Expand All @@ -96,6 +99,19 @@ def database(self):
"""
return self._database

@property
def inside_transaction(self):
"""Flag: transaction is started.

Returns:
bool: True if transaction begun, False otherwise.
"""
return (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
)

@property
def instance(self):
"""Instance to which this connection relates.
Expand Down Expand Up @@ -191,11 +207,7 @@ def transaction_checkout(self):
:returns: A Cloud Spanner transaction object, ready to use.
"""
if not self.autocommit:
if (
not self._transaction
or self._transaction.committed
or self._transaction.rolled_back
):
if not self.inside_transaction:
self._transaction = self._session_checkout().transaction()
self._transaction.begin()

Expand All @@ -216,11 +228,7 @@ def close(self):
The connection will be unusable from this point forward. If the
connection has an active transaction, it will be rolled back.
"""
if (
self._transaction
and not self._transaction.committed
and not self._transaction.rolled_back
):
if self.inside_transaction:
self._transaction.rollback()

if self._own_pool:
Expand All @@ -235,7 +243,7 @@ def commit(self):
"""
if self._autocommit:
warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2)
elif self._transaction:
elif self.inside_transaction:
try:
self._transaction.commit()
self._release_session()
Expand Down Expand Up @@ -291,6 +299,24 @@ def run_statement(self, statement, retried=False):
if not retried:
self._statements.append(statement)

if statement.is_insert:
parts = parse_insert(statement.sql, statement.params)

if parts.get("homogenous"):
_execute_insert_homogenous(transaction, parts)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)
else:
_execute_insert_heterogenous(
transaction, parts.get("sql_params_list"),
)
return (
iter(()),
ResultsChecksum() if retried else statement.checksum,
)

return (
transaction.execute_sql(
statement.sql, statement.params, param_types=statement.param_types,
Expand Down
29 changes: 21 additions & 8 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -42,7 +42,7 @@
_UNSET_COUNT = -1

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
Statement = namedtuple("Statement", "sql, params, param_types, checksum")
Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert")


class Cursor(object):
Expand Down Expand Up @@ -95,9 +95,9 @@ def description(self):
for field in row_type.fields:
column_info = ColumnInfo(
name=field.name,
type_code=field.type.code,
type_code=field.type_.code,
# Size of the SQL type of the column.
display_size=code_to_display_size.get(field.type.code),
display_size=code_to_display_size.get(field.type_.code),
# Client perceived size of the column.
internal_size=field.ByteSize(),
)
Expand Down Expand Up @@ -172,10 +172,20 @@ def execute(self, sql, args=None):
self.connection.run_prior_DDL_statements()

if not self.connection.autocommit:
sql, params = sql_pyformat_args_to_spanner(sql, args)
if classification == parse_utils.STMT_UPDATING:
sql = parse_utils.ensure_where_clause(sql)

if classification != parse_utils.STMT_INSERT:
sql, args = sql_pyformat_args_to_spanner(sql, args or None)

statement = Statement(
sql, params, get_param_types(params), ResultsChecksum(),
sql,
args,
get_param_types(args or None)
if classification != parse_utils.STMT_INSERT
else {},
ResultsChecksum(),
classification == parse_utils.STMT_INSERT,
)
(self._result_set, self._checksum,) = self.connection.run_statement(
statement
Expand Down Expand Up @@ -233,7 +243,8 @@ def fetchone(self):

try:
res = next(self)
self._checksum.consume_result(res)
if not self.connection.autocommit:
self._checksum.consume_result(res)
return res
except StopIteration:
return
Expand All @@ -250,7 +261,8 @@ def fetchall(self):
res = []
try:
for row in self:
self._checksum.consume_result(row)
if not self.connection.autocommit:
self._checksum.consume_result(row)
res.append(row)
except Aborted:
self._connection.retry_transaction()
Expand Down Expand Up @@ -278,7 +290,8 @@ def fetchmany(self, size=None):
for i in range(size):
try:
res = next(self)
self._checksum.consume_result(res)
if not self.connection.autocommit:
self._checksum.consume_result(res)
items.append(res)
except StopIteration:
break
Expand Down
8 changes: 2 additions & 6 deletions google/cloud/spanner_dbapi/parse_utils.py
Expand Up @@ -523,19 +523,15 @@ def get_param_types(params):
def ensure_where_clause(sql):
"""
Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements.
Raise an error, if the given sql doesn't include it.
Add a dummy WHERE clause if non detected.

:type sql: `str`
:param sql: SQL code to check.

:raises: :class:`ProgrammingError` if the given sql doesn't include a WHERE clause.
"""
if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]):
return sql

raise ProgrammingError(
"Cloud Spanner requires a WHERE clause when executing DELETE or UPDATE query"
)
return sql + " WHERE 1=1"


def escape_name(name):
Expand Down
69 changes: 51 additions & 18 deletions tests/unit/spanner_dbapi/test_connection.py
Expand Up @@ -15,7 +15,6 @@
"""Cloud Spanner DB-API Connection class unit tests."""

import mock
import sys
import unittest
import warnings

Expand Down Expand Up @@ -51,25 +50,57 @@ def _make_connection(self):
database = instance.database(self.DATABASE)
return Connection(instance, database)

@unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated")
def test_property_autocommit_setter(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(self.INSTANCE, self.DATABASE)
def test_autocommit_setter_transaction_not_started(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection.autocommit = True
mock_commit.assert_called_once_with()
self.assertEqual(connection._autocommit, True)
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection.autocommit = False
mock_commit.assert_not_called()
self.assertEqual(connection._autocommit, False)
self.assertFalse(connection._autocommit)

def test_autocommit_setter_transaction_started(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=False, rolled_back=False)

connection.autocommit = True
mock_commit.assert_called_once()
self.assertTrue(connection._autocommit)

def test_autocommit_setter_transaction_started_commited_rolled_back(self):
connection = self._make_connection()

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=True, rolled_back=False)

connection.autocommit = True
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

connection.autocommit = False

with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.commit"
) as mock_commit:
connection._transaction = mock.Mock(committed=False, rolled_back=True)

connection.autocommit = True
mock_commit.assert_not_called()
self.assertTrue(connection._autocommit)

def test_property_database(self):
from google.cloud.spanner_v1.database import Database
Expand Down Expand Up @@ -166,7 +197,9 @@ def test_commit(self, mock_warn):
connection.commit()
mock_release.assert_not_called()

connection._transaction = mock_transaction = mock.MagicMock()
connection._transaction = mock_transaction = mock.MagicMock(
rolled_back=False, committed=False
)
mock_transaction.commit = mock_commit = mock.MagicMock()

with mock.patch(
Expand Down Expand Up @@ -316,7 +349,7 @@ def test_run_statement_remember_statements(self):

connection = self._make_connection()

statement = Statement(sql, params, param_types, ResultsChecksum(),)
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
Expand All @@ -338,7 +371,7 @@ def test_run_statement_dont_remember_retried_statements(self):

connection = self._make_connection()

statement = Statement(sql, params, param_types, ResultsChecksum(),)
statement = Statement(sql, params, param_types, ResultsChecksum(), False)
with mock.patch(
"google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
):
Expand All @@ -352,7 +385,7 @@ def test_clear_statements_on_commit(self):
cleared, when the transaction is commited.
"""
connection = self._make_connection()
connection._transaction = mock.Mock()
connection._transaction = mock.Mock(rolled_back=False, committed=False)
connection._statements = [{}, {}]

self.assertEqual(len(connection._statements), 2)
Expand Down Expand Up @@ -390,7 +423,7 @@ def test_retry_transaction(self):
checksum.consume_result(row)
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum,)
statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -423,7 +456,7 @@ def test_retry_transaction_checksum_mismatch(self):
checksum.consume_result(row)
retried_checkum = ResultsChecksum()

statement = Statement("SELECT 1", [], {}, checksum,)
statement = Statement("SELECT 1", [], {}, checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -453,9 +486,9 @@ def test_commit_retry_aborted_statements(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)
connection._transaction = mock.Mock()
connection._transaction = mock.Mock(rolled_back=False, committed=False)

with mock.patch.object(
connection._transaction, "commit", side_effect=(Aborted("Aborted"), None),
Expand Down Expand Up @@ -507,7 +540,7 @@ def test_retry_aborted_retry(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

metadata_mock = mock.Mock()
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -126,7 +126,7 @@ def test_execute_attribute_error(self):
cursor = self._make_one(connection)

with self.assertRaises(AttributeError):
cursor.execute(sql="")
cursor.execute(sql="SELECT 1")

def test_execute_autocommit_off(self):
from google.cloud.spanner_dbapi.utils import PeekIterator
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_fetchone_retry_aborted_statements(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down Expand Up @@ -570,7 +570,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self):
cursor._checksum = ResultsChecksum()
cursor._checksum.consume_result(row)

statement = Statement("SELECT 1", [], {}, cursor._checksum,)
statement = Statement("SELECT 1", [], {}, cursor._checksum, False)
connection._statements.append(statement)

with mock.patch(
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Expand Up @@ -391,7 +391,6 @@ def test_get_param_types_none(self):

@unittest.skipIf(skip_condition, skip_message)
def test_ensure_where_clause(self):
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause

cases = (
Expand All @@ -409,8 +408,7 @@ def test_ensure_where_clause(self):

for sql in err_cases:
with self.subTest(sql=sql):
with self.assertRaises(ProgrammingError):
ensure_where_clause(sql)
self.assertEqual(ensure_where_clause(sql), sql + " WHERE 1=1")

@unittest.skipIf(skip_condition, skip_message)
def test_escape_name(self):
Expand Down