Skip to content

Commit

Permalink
Update type annotations in cursor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RA80533 committed May 10, 2023
1 parent 8f2353e commit 3644670
Showing 1 changed file with 140 additions and 42 deletions.
182 changes: 140 additions & 42 deletions aioodbc/cursor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar

import pyodbc

from .log import logger
Expand All @@ -8,6 +10,9 @@
__all__ = ["Cursor"]


_T = TypeVar("_T")


class Cursor:
"""Cursors represent a database cursor (and map to ODBC HSTMTs), which
is used to manage the context of a fetch operation.
Expand All @@ -17,13 +22,23 @@ class Cursor:
the other cursors.
"""

def __init__(self, pyodbc_cursor: pyodbc.Cursor, connection, echo=False):
def __init__(
self,
pyodbc_cursor: pyodbc.Cursor,
connection: Connection,
echo: bool = False,
) -> None:
self._conn = connection
self._impl: pyodbc.Cursor = pyodbc_cursor
self._loop = connection.loop
self._echo: bool = echo

async def _run_operation(self, func, *args, **kwargs):
async def _run_operation(
self,
func: Callable[..., _T],
*args: Any,
**kwargs: Any,
) -> _T:
# execute func in thread pool of attached to cursor connection
if not self._conn:
raise pyodbc.OperationalError("Cursor is closed.")
Expand All @@ -37,17 +52,17 @@ async def _run_operation(self, func, *args, **kwargs):
raise

@property
def echo(self):
def echo(self) -> bool:
"""Return echo mode status."""
return self._echo

@property
def connection(self):
def connection(self) -> Optional[Connection]:
"""Cursors database connection"""
return self._conn

@property
def autocommit(self):
def autocommit(self) -> bool:
"""Show autocommit mode for current database session. True if
connection is in autocommit mode; False otherwse. The default
is False.
Expand All @@ -56,12 +71,12 @@ def autocommit(self):
return self._conn.autocommit

@autocommit.setter
def autocommit(self, value):
def autocommit(self, value: bool) -> None:
assert self._conn is not None # mypy
self._conn.autocommit = value

@property
def rowcount(self):
def rowcount(self) -> int:
"""The number of rows modified by the previous DDL statement.
This is -1 if no SQL has been executed or if the number of rows is
Expand All @@ -73,7 +88,7 @@ def rowcount(self):
return self._impl.rowcount

@property
def description(self):
def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]:
"""This read-only attribute is a list of 7-item tuples, each
containing (name, type_code, display_size, internal_size, precision,
scale, null_ok).
Expand All @@ -91,23 +106,23 @@ def description(self):
return self._impl.description

@property
def closed(self):
def closed(self) -> bool:
"""Read only property indicates if cursor has been closed"""
return self._conn is None

@property
def arraysize(self):
def arraysize(self) -> int:
"""This read/write attribute specifies the number of rows to fetch
at a time with .fetchmany() . It defaults to 1 meaning to fetch a
single row at a time.
"""
return self._impl.arraysize

@arraysize.setter
def arraysize(self, size):
def arraysize(self, size: int) -> None:
self._impl.arraysize = size

async def close(self):
async def close(self) -> None:
"""Close the cursor now (rather than whenever __del__ is called).
The cursor will be unusable from this point forward; an Error
Expand All @@ -119,7 +134,7 @@ async def close(self):
await self._run_operation(self._impl.close)
self._conn = None

async def execute(self, sql, *params):
async def execute(self, sql: str, *params: Any) -> Cursor:
"""Executes the given operation substituting any markers with
the given parameters.
Expand All @@ -136,7 +151,7 @@ async def execute(self, sql, *params):
await self._run_operation(self._impl.execute, sql, *params)
return self

def executemany(self, sql, *params):
def executemany(self, sql: str, *params: Any) -> Coroutine[Any, Any, None]:
"""Prepare a database query or command and then execute it against
all parameter sequences found in the sequence seq_of_params.
Expand All @@ -157,7 +172,7 @@ async def setoutputsize(self, *args, **kwargs):
"""Does nothing, required by DB API."""
return None

def fetchone(self):
def fetchone(self) -> Coroutine[Any, Any, Optional[pyodbc.Row]]:
"""Returns the next row or None when no more data is available.
A ProgrammingError exception is raised if no SQL has been executed
Expand All @@ -167,7 +182,7 @@ def fetchone(self):
fut = self._run_operation(self._impl.fetchone)
return fut

def fetchall(self):
def fetchall(self) -> Coroutine[Any, Any, List[pyodbc.Row]]:
"""Returns a list of all remaining rows.
Since this reads all rows into memory, it should not be used if
Expand All @@ -181,7 +196,7 @@ def fetchall(self):
fut = self._run_operation(self._impl.fetchall)
return fut

def fetchmany(self, size=0):
def fetchmany(self, size: int = 0) -> Coroutine[Any, Any, List[pyodbc.Row]]:
"""Returns a list of remaining rows, containing no more than size
rows, used to process results in chunks. The list will be empty when
there are no more rows.
Expand All @@ -200,7 +215,7 @@ def fetchmany(self, size=0):
fut = self._run_operation(self._impl.fetchmany, size)
return fut

def nextset(self):
def nextset(self) -> Coroutine[Any, Any, bool]:
"""This method will make the cursor skip to the next available
set, discarding any remaining rows from the current set.
Expand All @@ -214,7 +229,13 @@ def nextset(self):
fut = self._run_operation(self._impl.nextset)
return fut

def tables(self, **kw):
def tables(
self,
table: Optional[str] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
tableType: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Creates a result set of tables in the database that match the
given criteria.
Expand All @@ -223,10 +244,22 @@ def tables(self, **kw):
:param schema: the schmea name
:param tableType: one of TABLE, VIEW, SYSTEM TABLE ...
"""
fut = self._run_operation(self._impl.tables, **kw)
fut = self._run_operation(
self._impl.tables,
table=table,
catalog=catalog,
schema=schema,
tableType=tableType,
)
return fut

def columns(self, **kw):
def columns(
self,
table: Optional[str] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
column: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Creates a results set of column names in specified tables by
executing the ODBC SQLColumns function. Each row fetched has the
following columns.
Expand All @@ -236,10 +269,23 @@ def columns(self, **kw):
:param schema: the schmea name
:param column: string search pattern for column names.
"""
fut = self._run_operation(self._impl.columns, **kw)
fut = self._run_operation(
self._impl.columns,
table=table,
catalog=catalog,
schema=schema,
column=column,
)
return fut

def statistics(self, table, catalog=None, schema=None, unique=False, quick=True):
def statistics(
self,
table: str,
catalog: Optional[str] = None,
schema: Optional[str] = None,
unique: bool = False,
quick: bool = True,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Creates a results set of statistics about a single table and
the indexes associated with the table by executing SQLStatistics.
Expand All @@ -262,8 +308,12 @@ def statistics(self, table, catalog=None, schema=None, unique=False, quick=True)
return fut

def rowIdColumns(
self, table, catalog=None, schema=None, nullable=True # nopep8
):
self,
table: str,
catalog: Optional[str] = None,
schema: Optional[str] = None,
nullable: bool = True,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Executes SQLSpecialColumns with SQL_BEST_ROWID which creates a
result set of columns that uniquely identify a row
"""
Expand All @@ -277,8 +327,12 @@ def rowIdColumns(
return fut

def rowVerColumns(
self, table, catalog=None, schema=None, nullable=True # nopep8
):
self,
table: str,
catalog: Optional[str] = None,
schema: Optional[str] = None,
nullable: bool = True,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Executes SQLSpecialColumns with SQL_ROWVER which creates a
result set of columns that are automatically updated when any
value in the row is updated.
Expand All @@ -292,68 +346,112 @@ def rowVerColumns(
)
return fut

def primaryKeys(self, table, catalog=None, schema=None): # nopep8
def primaryKeys(
self,
table: str,
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Creates a result set of column names that make up the primary key
for a table by executing the SQLPrimaryKeys function."""
fut = self._run_operation(
self._impl.primaryKeys, table, catalog=catalog, schema=schema
)
return fut

def foreignKeys(self, *a, **kw): # nopep8
def foreignKeys(
self,
table: Optional[str] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
foreignTable: Optional[str] = None,
foreignCatalog: Optional[str] = None,
foreignSchema: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Executes the SQLForeignKeys function and creates a result set
of column names that are foreign keys in the specified table (columns
in the specified table that refer to primary keys in other tables)
or foreign keys in other tables that refer to the primary key in
the specified table.
"""
fut = self._run_operation(self._impl.foreignKeys, *a, **kw)
fut = self._run_operation(
self._impl.foreignKeys,
table=table,
catalog=catalog,
schema=schema,
foreignTable=foreignTable,
foreignCatalog=foreignCatalog,
foreignSchema=foreignSchema,
)
return fut

def getTypeInfo(self, sql_type): # nopep8
def getTypeInfo(
self,
sql_type: Optional[int] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Executes SQLGetTypeInfo a creates a result set with information
about the specified data type or all data types supported by the
ODBC driver if not specified.
"""
fut = self._run_operation(self._impl.getTypeInfo, sql_type)
return fut

def procedures(self, *a, **kw):
def procedures(
self,
procedure: Optional[str] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
"""Executes SQLProcedures and creates a result set of information
about the procedures in the data source.
"""
fut = self._run_operation(self._impl.procedures, *a, **kw)
fut = self._run_operation(
self._impl.procedures,
procedure=procedure,
catalog=catalog,
schema=schema,
)
return fut

def procedureColumns(self, *a, **kw): # nopep8
fut = self._run_operation(self._impl.procedureColumns, *a, **kw)
def procedureColumns(
self,
procedure: Optional[str] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Coroutine[Any, Any, pyodbc.Cursor]:
fut = self._run_operation(
self._impl.procedureColumns,
procedure=procedure,
catalog=catalog,
schema=schema,
)
return fut

def skip(self, count):
def skip(self, count: int) -> Coroutine[Any, Any, None]:
fut = self._run_operation(self._impl.skip, count)
return fut

def commit(self):
def commit(self) -> Coroutine[Any, Any, None]:
fut = self._run_operation(self._impl.commit)
return fut

def rollback(self):
def rollback(self) -> Coroutine[Any, Any, None]:
fut = self._run_operation(self._impl.rollback)
return fut

def __aiter__(self):
def __aiter__(self) -> Cursor:
return self

async def __anext__(self):
async def __anext__(self) -> pyodbc.Row:
ret = await self.fetchone()
if ret is not None:
return ret
else:
raise StopAsyncIteration

async def __aenter__(self):
async def __aenter__(self) -> Cursor:
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()
return

0 comments on commit 3644670

Please sign in to comment.