Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cursor must detect if the parent connection is closed #463

Merged
merged 17 commits into from Sep 1, 2020
Merged
54 changes: 30 additions & 24 deletions spanner_dbapi/connection.py
Expand Up @@ -11,26 +11,31 @@
from .cursor import Cursor
from .exceptions import InterfaceError

ColumnDetails = namedtuple('column_details', ['null_ok', 'spanner_type'])
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])


class Connection:
def __init__(self, db_handle):
self._dbhandle = db_handle
self._closed = False
self._ddl_statements = []

self.is_closed = False

def cursor(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

return Cursor(self)

def __raise_if_already_closed(self):
"""
Raise an exception if attempting to use an already closed connection.
def _raise_if_already_closed(self):
"""Raise an exception if this connection is closed.

Helper to check the connection state before
running a SQL/DDL/DML query.

:raises: :class:`InterfaceError` if this connection is closed.
"""
if self._closed:
raise InterfaceError('connection already closed')
if self.is_closed:
raise InterfaceError("connection is already closed")

def __handle_update_ddl(self, ddl_statements):
"""
Expand All @@ -41,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements):
Returns:
google.api_core.operation.Operation.result()
"""
self.__raise_if_already_closed()
self._raise_if_already_closed()
# Synchronously wait on the operation's completion.
return self._dbhandle.update_ddl(ddl_statements).result()

def read_snapshot(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()
return self._dbhandle.snapshot()

def in_transaction(self, fn, *args, **kwargs):
self.__raise_if_already_closed()
self._raise_if_already_closed()
return self._dbhandle.run_in_transaction(fn, *args, **kwargs)

def append_ddl_statement(self, ddl_statement):
self.__raise_if_already_closed()
self._raise_if_already_closed()
self._ddl_statements.append(ddl_statement)

def run_prior_DDL_statements(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

if not self._ddl_statements:
return
Expand All @@ -69,14 +74,16 @@ def run_prior_DDL_statements(self):
return self.__handle_update_ddl(ddl_statements)

def list_tables(self):
return self.run_sql_in_snapshot("""
return self.run_sql_in_snapshot(
"""
SELECT
t.table_name
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
""")
"""
)

def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
Expand All @@ -89,38 +96,37 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None):

def get_table_column_schema(self, table_name):
rows = self.run_sql_in_snapshot(
'''SELECT
"""SELECT
COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA = ''
AND
TABLE_NAME = @table_name''',
params={'table_name': table_name},
param_types={'table_name': spanner.param_types.STRING},
TABLE_NAME = @table_name""",
params={"table_name": table_name},
param_types={"table_name": spanner.param_types.STRING},
)

column_details = {}
for column_name, is_nullable, spanner_type in rows:
column_details[column_name] = ColumnDetails(
null_ok=is_nullable == 'YES',
spanner_type=spanner_type,
null_ok=is_nullable == "YES", spanner_type=spanner_type
)
return column_details

def close(self):
self.rollback()
self.__dbhandle = None
self._closed = True
self.is_closed = True

def commit(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

self.run_prior_DDL_statements()

def rollback(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

# TODO: to be added.

Expand Down
138 changes: 88 additions & 50 deletions spanner_dbapi/cursor.py
Expand Up @@ -8,11 +8,19 @@
from google.cloud.spanner_v1 import param_types

from .exceptions import (
IntegrityError, InterfaceError, OperationalError, ProgrammingError,
IntegrityError,
InterfaceError,
OperationalError,
ProgrammingError,
)
from .parse_utils import (
STMT_DDL, STMT_INSERT, STMT_NON_UPDATING, classify_stmt,
ensure_where_clause, get_param_types, parse_insert,
STMT_DDL,
STMT_INSERT,
STMT_NON_UPDATING,
classify_stmt,
ensure_where_clause,
get_param_types,
parse_insert,
sql_pyformat_args_to_spanner,
)
from .utils import PeekIterator
Expand Down Expand Up @@ -44,12 +52,9 @@ def __init__(self, connection):
self._res = None
self._row_count = _UNSET_COUNT
self._connection = connection
self._closed = False
self._is_closed = False

# arraysize is a readable and writable property mandated
# by PEP-0249 https://www.python.org/dev/peps/pep-0249/#arraysize
# It determines the results of .fetchmany
self.arraysize = 1
self.arraysize = 1 # the number of rows to fetch at a time with fetchmany()

def execute(self, sql, args=None):
"""
Expand All @@ -64,7 +69,7 @@ def execute(self, sql, args=None):
self._raise_if_already_closed()

if not self._connection:
raise ProgrammingError('Cursor is not connected to the database')
raise ProgrammingError("Cursor is not connected to the database")

self._res = None

Expand All @@ -86,23 +91,22 @@ def execute(self, sql, args=None):
else:
self.__handle_update(sql, args or None)
except (grpc_exceptions.AlreadyExists, grpc_exceptions.FailedPrecondition) as e:
raise IntegrityError(e.details if hasattr(e, 'details') else e)
raise IntegrityError(e.details if hasattr(e, "details") else e)
except grpc_exceptions.InvalidArgument as e:
raise ProgrammingError(e.details if hasattr(e, 'details') else e)
raise ProgrammingError(e.details if hasattr(e, "details") else e)
except grpc_exceptions.InternalServerError as e:
raise OperationalError(e.details if hasattr(e, 'details') else e)
raise OperationalError(e.details if hasattr(e, "details") else e)

def __handle_update(self, sql, params):
self._connection.in_transaction(
self.__do_execute_update,
sql, params,
)
self._connection.in_transaction(self.__do_execute_update, sql, params)

def __do_execute_update(self, transaction, sql, params, param_types=None):
sql = ensure_where_clause(sql)
sql, params = sql_pyformat_args_to_spanner(sql, params)

res = transaction.execute_update(sql, params=params, param_types=get_param_types(params))
res = transaction.execute_update(
sql, params=params, param_types=get_param_types(params)
)
self._itr = None
if type(res) == int:
self._row_count = res
Expand All @@ -125,20 +129,18 @@ def __handle_insert(self, sql, params):
# transaction.execute_sql(sql, params, param_types)
# which invokes more RPCs and is more costly.

if parts.get('homogenous'):
if parts.get("homogenous"):
# The common case of multiple values being passed in
# non-complex pyformat args and need to be uploaded in one RPC.
return self._connection.in_transaction(
self.__do_execute_insert_homogenous,
parts,
self.__do_execute_insert_homogenous, parts
)
else:
# All the other cases that are esoteric and need
# transaction.execute_sql
sql_params_list = parts.get('sql_params_list')
sql_params_list = parts.get("sql_params_list")
return self._connection.in_transaction(
self.__do_execute_insert_heterogenous,
sql_params_list,
self.__do_execute_insert_heterogenous, sql_params_list
)

def __do_execute_insert_heterogenous(self, transaction, sql_params_list):
Expand All @@ -152,17 +154,19 @@ def __do_execute_insert_heterogenous(self, transaction, sql_params_list):

def __do_execute_insert_homogenous(self, transaction, parts):
# Perform an insert in one shot.
table = parts.get('table')
columns = parts.get('columns')
values = parts.get('values')
table = parts.get("table")
columns = parts.get("columns")
values = parts.get("values")
return transaction.insert(table, columns, values)

def __handle_DQL(self, sql, params):
with self._connection.read_snapshot() as snapshot:
# Reference
# https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql
sql, params = sql_pyformat_args_to_spanner(sql, params)
res = snapshot.execute_sql(sql, params=params, param_types=get_param_types(params))
res = snapshot.execute_sql(
sql, params=params, param_types=get_param_types(params)
)
if type(res) == int:
self._row_count = res
self._itr = None
Expand Down Expand Up @@ -216,32 +220,48 @@ def description(self):
def rowcount(self):
return self._row_count

def _raise_if_already_closed(self):
@property
def is_closed(self):
"""The cursor close indicator.

Returns:
bool:
True if this cursor or it's parent connection
is closed, False otherwise.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
"""
Raise an exception if attempting to use an already closed connection.
return self._is_closed or self._connection.is_closed

def _raise_if_already_closed(self):
"""Raise an exception if this cursor is closed.

Helper to check this cursor's state before running a
SQL/DDL/DML query. If the parent connection is
already closed it also raises an error.

:raises: :class:`InterfaceError` if this cursor is closed.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
"""
if self._closed:
raise InterfaceError('cursor already closed')
if self.is_closed:
raise InterfaceError("cursor is already closed")

def close(self):
self.__clear()
self._closed = True
self._is_closed = True

def executemany(self, operation, seq_of_params):
if not self._connection:
raise ProgrammingError('Cursor is not connected to the database')
raise ProgrammingError("Cursor is not connected to the database")

for params in seq_of_params:
self.execute(operation, params)

def __next__(self):
if self._itr is None:
raise ProgrammingError('no results to return')
raise ProgrammingError("no results to return")
return next(self._itr)

def __iter__(self):
if self._itr is None:
raise ProgrammingError('no results to return')
raise ProgrammingError("no results to return")
return self._itr

def fetchone(self):
Expand Down Expand Up @@ -289,10 +309,10 @@ def lastrowid(self):
return None

def setinputsizes(sizes):
raise ProgrammingError('Unimplemented')
raise ProgrammingError("Unimplemented")

def setoutputsize(size, column=None):
raise ProgrammingError('Unimplemented')
raise ProgrammingError("Unimplemented")

def _run_prior_DDL_statements(self):
return self._connection.run_prior_DDL_statements()
Expand All @@ -308,8 +328,16 @@ def get_table_column_schema(self, table_name):


class Column:
def __init__(self, name, type_code, display_size=None, internal_size=None,
precision=None, scale=None, null_ok=False):
def __init__(
self,
name,
type_code,
display_size=None,
internal_size=None,
precision=None,
scale=None,
null_ok=False,
):
self.name = name
self.type_code = type_code
self.display_size = display_size
Expand Down Expand Up @@ -338,14 +366,24 @@ def __getitem__(self, index):
return self.null_ok

def __str__(self):
rstr = ', '.join([field for field in [
"name='%s'" % self.name,
"type_code=%d" % self.type_code,
None if not self.display_size else "display_size=%d" % self.display_size,
None if not self.internal_size else "internal_size=%d" % self.internal_size,
None if not self.precision else "precision='%s'" % self.precision,
None if not self.scale else "scale='%s'" % self.scale,
None if not self.null_ok else "null_ok='%s'" % self.null_ok,
] if field])

return 'Column(%s)' % rstr
rstr = ", ".join(
[
field
for field in [
"name='%s'" % self.name,
"type_code=%d" % self.type_code,
None
if not self.display_size
else "display_size=%d" % self.display_size,
None
if not self.internal_size
else "internal_size=%d" % self.internal_size,
None if not self.precision else "precision='%s'" % self.precision,
None if not self.scale else "scale='%s'" % self.scale,
None if not self.null_ok else "null_ok='%s'" % self.null_ok,
]
if field
]
)

return "Column(%s)" % rstr