diff --git a/spanner_dbapi/connection.py b/spanner_dbapi/connection.py index 1abc68519f..869586e363 100644 --- a/spanner_dbapi/connection.py +++ b/spanner_dbapi/connection.py @@ -17,20 +17,25 @@ class Connection: def __init__(self, db_handle): self._dbhandle = db_handle - self._closed = False self._ddl_statements = [] + self.is_closed = False + def cursor(self): - self.__raise_if_already_closed() + self._raise_if_closed() return Cursor(self) - def __raise_if_already_closed(self): - """ - Raise an exception if attempting to use an already closed connection. + def _raise_if_closed(self): + """Raise an exception if this connection is closed. + + Helper to check the connection state before + running a SQL/DDL/DML query. + + :raises: :class:`InterfaceError` if this connection is closed. """ - if self._closed: - raise InterfaceError("connection already closed") + if self.is_closed: + raise InterfaceError("connection is already closed") def __handle_update_ddl(self, ddl_statements): """ @@ -41,24 +46,24 @@ def __handle_update_ddl(self, ddl_statements): Returns: google.api_core.operation.Operation.result() """ - self.__raise_if_already_closed() + self._raise_if_closed() # Synchronously wait on the operation's completion. return self._dbhandle.update_ddl(ddl_statements).result() def read_snapshot(self): - self.__raise_if_already_closed() + self._raise_if_closed() return self._dbhandle.snapshot() def in_transaction(self, fn, *args, **kwargs): - self.__raise_if_already_closed() + self._raise_if_closed() return self._dbhandle.run_in_transaction(fn, *args, **kwargs) def append_ddl_statement(self, ddl_statement): - self.__raise_if_already_closed() + self._raise_if_closed() self._ddl_statements.append(ddl_statement) def run_prior_DDL_statements(self): - self.__raise_if_already_closed() + self._raise_if_closed() if not self._ddl_statements: return @@ -113,17 +118,21 @@ def get_table_column_schema(self, table_name): return column_details def close(self): + """Close this connection. + + The connection will be unusable from this point forward. + """ self.rollback() self.__dbhandle = None - self._closed = True + self.is_closed = True def commit(self): - self.__raise_if_already_closed() + self._raise_if_closed() self.run_prior_DDL_statements() def rollback(self): - self.__raise_if_already_closed() + self._raise_if_closed() # TODO: to be added. diff --git a/spanner_dbapi/cursor.py b/spanner_dbapi/cursor.py index 4dcc69df77..10e5184ed2 100644 --- a/spanner_dbapi/cursor.py +++ b/spanner_dbapi/cursor.py @@ -4,7 +4,14 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import google.api_core.exceptions as grpc_exceptions +"""Database cursor API.""" + +from google.api_core.exceptions import ( + AlreadyExists, + FailedPrecondition, + InternalServerError, + InvalidArgument, +) from google.cloud.spanner_v1 import param_types from .exceptions import ( @@ -47,16 +54,21 @@ class Cursor: + """ + Database cursor to manage the context of a fetch operation. + + :type connection: :class:`spanner_dbapi.connection.Connection` + :param connection: Parent connection object for this Cursor. + """ + def __init__(self, connection): self._itr = None self._res = None self._row_count = _UNSET_COUNT self._connection = connection - self._closed = False + self._is_closed = False - # arraysize is a readable and writable property mandated - # by PEP-0249 https://www.python.org/dev/peps/pep-0249/#arraysize - # It determines the results of .fetchmany + # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 def execute(self, sql, args=None): @@ -69,7 +81,7 @@ def execute(self, sql, args=None): Returns: None """ - self._raise_if_already_closed() + self._raise_if_closed() if not self._connection: raise ProgrammingError("Cursor is not connected to the database") @@ -93,14 +105,11 @@ def execute(self, sql, args=None): self.__handle_insert(sql, args or None) else: self.__handle_update(sql, args or None) - except ( - grpc_exceptions.AlreadyExists, - grpc_exceptions.FailedPrecondition, - ) as e: + except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) - except grpc_exceptions.InvalidArgument as e: + except InvalidArgument as e: raise ProgrammingError(e.details if hasattr(e, "details") else e) - except grpc_exceptions.InternalServerError as e: + except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) def __handle_update(self, sql, params): @@ -228,16 +237,35 @@ def description(self): def rowcount(self): return self._row_count - def _raise_if_already_closed(self): + @property + def is_closed(self): + """The cursor close indicator. + + :rtype: :class:`bool` + :returns: True if this cursor or it's parent connection is closed, False + otherwise. """ - Raise an exception if attempting to use an already closed connection. + return self._is_closed or self._connection.is_closed + + def _raise_if_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. """ - if self._closed: - raise InterfaceError("cursor already closed") + if self.is_closed: + raise InterfaceError("cursor is already closed") def close(self): + """Close this cursor. + + The cursor will be unusable from this point forward. + """ self.__clear() - self._closed = True + self._is_closed = True def executemany(self, operation, seq_of_params): if not self._connection: @@ -257,7 +285,7 @@ def __iter__(self): return self._itr def fetchone(self): - self._raise_if_already_closed() + self._raise_if_closed() try: return next(self) @@ -265,7 +293,7 @@ def fetchone(self): return None def fetchall(self): - self._raise_if_already_closed() + self._raise_if_closed() return list(self.__iter__()) @@ -282,7 +310,7 @@ def fetchmany(self, size=None): Error if the previous call to .execute*() did not produce any result set or if no call was issued yet. """ - self._raise_if_already_closed() + self._raise_if_closed() if size is None: size = self.arraysize diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..ab72f799df --- /dev/null +++ b/tests/spanner_dbapi/test_connection.py @@ -0,0 +1,32 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Connection() class unit tests.""" + +import unittest +from unittest import mock + +from spanner_dbapi import connect, InterfaceError + + +class TestConnection(unittest.TestCase): + def test_close(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..6bf6bb27e4 --- /dev/null +++ b/tests/spanner_dbapi/test_cursor.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest +from unittest import mock + +from spanner_dbapi import connect, InterfaceError + + +class TestCursor(unittest.TestCase): + def test_close(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_connection_closed(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + connection.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database")