Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cursor must detect if the parent connection is closed #463

Merged
merged 17 commits into from Sep 1, 2020
Merged
149 changes: 94 additions & 55 deletions spanner_dbapi/__init__.py
Expand Up @@ -4,83 +4,122 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

from google.cloud import spanner_v1 as spanner
"""Connection-based DB API for Cloud Spanner."""

from google.cloud import spanner_v1

from .connection import Connection
# These need to be included in the top-level package for PEP-0249 DB API v2.
from .exceptions import (
DatabaseError, DataError, Error, IntegrityError, InterfaceError,
InternalError, NotSupportedError, OperationalError, ProgrammingError,
DatabaseError,
DataError,
Error,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
)
from .parse_utils import get_param_types
from .types import (
BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time,
TimeFromTicks, Timestamp, TimestampFromTicks,
BINARY,
DATETIME,
NUMBER,
ROWID,
STRING,
Binary,
Date,
DateFromTicks,
Time,
TimeFromTicks,
Timestamp,
TimestampFromTicks,
)
from .version import google_client_info

# Globals that MUST be defined ###
apilevel = "2.0" # Implements the Python Database API specification 2.0 version.
# We accept arguments in the format '%s' aka ANSI C print codes.
# as per https://www.python.org/dev/peps/pep-0249/#paramstyle
paramstyle = 'format'
# Threads may share the module but not connections. This is a paranoid threadsafety level,
# but it is necessary for starters to use when debugging failures. Eventually once transactions
# are working properly, we'll update the threadsafety level.
apilevel = "2.0" # supports DP-API 2.0 level.
paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s.

# Threads may share the module, but not connections. This is a paranoid threadsafety
# level, but it is necessary for starters to use when debugging failures.
# Eventually once transactions are working properly, we'll update the
# threadsafety level.
threadsafety = 1


def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None):
def connect(instance_id, database_id, project=None, credentials=None, user_agent=None):
"""
Connect to Cloud Spanner.
Create a connection to Cloud Spanner database.

Args:
project: The id of a project that already exists.
instance: The id of an instance that already exists.
database: The name of a database that already exists.
credentials_uri: An optional string specifying where to retrieve the service
account JSON for the credentials to connect to Cloud Spanner.
:type instance_id: :class:`str`
:param instance_id: ID of the instance to connect to.

Returns:
The Connection object associated to the Cloud Spanner instance.
:type database_id: :class:`str`
:param database_id: The name of the database to connect to.

Raises:
Error if it encounters any unexpected inputs.
"""
if not project:
raise Error("'project' is required.")
if not instance:
raise Error("'instance' is required.")
if not database:
raise Error("'database' is required.")
:type project: :class:`str`
:param project: (Optional) The ID of the project which owns the
instances, tables and data. If not provided, will
attempt to determine from the environment.

client_kwargs = {
'project': project,
'client_info': google_client_info(user_agent),
}
if credentials_uri:
client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs)
else:
client = spanner.Client(**client_kwargs)
:type credentials: :class:`google.auth.credentials.Credentials`
:param credentials: (Optional) The authorization credentials to attach to requests.
These credentials identify this application to the service.
If none are specified, the client will attempt to ascertain
the credentials from the environment.

:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Cloud Spanner resource.

:raises: :class:`ProgrammingError` in case of given instance/database
doesn't exist.
"""
client = spanner_v1.Client(
project=project,
credentials=credentials,
client_info=google_client_info(user_agent),
)

client_instance = client.instance(instance)
if not client_instance.exists():
raise ProgrammingError("instance '%s' does not exist." % instance)
instance = client.instance(instance_id)
if not instance.exists():
raise ProgrammingError("instance '%s' does not exist." % instance_id)

db = client_instance.database(database, pool=spanner.pool.BurstyPool())
if not db.exists():
raise ProgrammingError("database '%s' does not exist." % database)
database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool())
if not database.exists():
raise ProgrammingError("database '%s' does not exist." % database_id)

return Connection(db)
return Connection(database)


__all__ = [
'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError',
'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError',
'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety',
'get_param_types',
'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp',
'TimestampFromTicks',
'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr',
"DatabaseError",
"DataError",
"Error",
"IntegrityError",
"InterfaceError",
"InternalError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
"Warning",
"DEFAULT_USER_AGENT",
"apilevel",
"connect",
"paramstyle",
"threadsafety",
"get_param_types",
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"STRING",
"NUMBER",
"DATETIME",
"ROWID",
"TimestampStr",
]
54 changes: 30 additions & 24 deletions spanner_dbapi/connection.py
Expand Up @@ -11,26 +11,31 @@
from .cursor import Cursor
from .exceptions import InterfaceError

ColumnDetails = namedtuple('column_details', ['null_ok', 'spanner_type'])
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])


class Connection:
def __init__(self, db_handle):
self._dbhandle = db_handle
self._closed = False
self._ddl_statements = []

self.is_closed = False

def cursor(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

return Cursor(self)

def __raise_if_already_closed(self):
"""
Raise an exception if attempting to use an already closed connection.
def _raise_if_already_closed(self):
"""Raise an exception if this connection is closed.

Helper to check the connection state before
running a SQL/DDL/DML query.

:raises: :class:`InterfaceError` if this connection is closed.
"""
if self._closed:
raise InterfaceError('connection already closed')
if self.is_closed:
raise InterfaceError("connection is already closed")

def __handle_update_ddl(self, ddl_statements):
"""
Expand All @@ -41,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements):
Returns:
google.api_core.operation.Operation.result()
"""
self.__raise_if_already_closed()
self._raise_if_already_closed()
# Synchronously wait on the operation's completion.
return self._dbhandle.update_ddl(ddl_statements).result()

def read_snapshot(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()
return self._dbhandle.snapshot()

def in_transaction(self, fn, *args, **kwargs):
self.__raise_if_already_closed()
self._raise_if_already_closed()
return self._dbhandle.run_in_transaction(fn, *args, **kwargs)

def append_ddl_statement(self, ddl_statement):
self.__raise_if_already_closed()
self._raise_if_already_closed()
self._ddl_statements.append(ddl_statement)

def run_prior_DDL_statements(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

if not self._ddl_statements:
return
Expand All @@ -69,14 +74,16 @@ def run_prior_DDL_statements(self):
return self.__handle_update_ddl(ddl_statements)

def list_tables(self):
return self.run_sql_in_snapshot("""
return self.run_sql_in_snapshot(
"""
SELECT
t.table_name
FROM
information_schema.tables AS t
WHERE
t.table_catalog = '' and t.table_schema = ''
""")
"""
)

def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
Expand All @@ -89,38 +96,37 @@ def run_sql_in_snapshot(self, sql, params=None, param_types=None):

def get_table_column_schema(self, table_name):
rows = self.run_sql_in_snapshot(
'''SELECT
"""SELECT
COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE
FROM
INFORMATION_SCHEMA.COLUMNS
WHERE
TABLE_SCHEMA = ''
AND
TABLE_NAME = @table_name''',
params={'table_name': table_name},
param_types={'table_name': spanner.param_types.STRING},
TABLE_NAME = @table_name""",
params={"table_name": table_name},
param_types={"table_name": spanner.param_types.STRING},
)

column_details = {}
for column_name, is_nullable, spanner_type in rows:
column_details[column_name] = ColumnDetails(
null_ok=is_nullable == 'YES',
spanner_type=spanner_type,
null_ok=is_nullable == "YES", spanner_type=spanner_type
)
return column_details

def close(self):
self.rollback()
self.__dbhandle = None
self._closed = True
self.is_closed = True

def commit(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

self.run_prior_DDL_statements()

def rollback(self):
self.__raise_if_already_closed()
self._raise_if_already_closed()

# TODO: to be added.

Expand Down