Skip to content

Commit

Permalink
Apply types on Pool class. (#418)
Browse files Browse the repository at this point in the history
* Add types for pool.

* Modify few make commands to run mypy check.
  • Loading branch information
jettify committed Apr 8, 2023
1 parent d58c7fd commit eea5dfb
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 32 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ jobs:
run: |
make lint
make checkfmt
make mypy
- name: Test
run: |
make ci
make cov
make run_examples
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ run_examples:
python examples/example_simple.py
python examples/example_complex_queries.py

ci: cov run_examples
ci: lint checkfmt mypy cov run_examples

checkbuild:
python setup.py sdist bdist_wheel
Expand Down
6 changes: 4 additions & 2 deletions aioodbc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import warnings
from concurrent.futures.thread import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional

from pyodbc import dataSources as _dataSources
Expand Down Expand Up @@ -37,5 +37,7 @@ async def dataSources(
msg = "Explicit loop is deprecated, and has no effect."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
loop = asyncio.get_event_loop()
sources = await loop.run_in_executor(executor, _dataSources)
sources: Dict[str, str] = await loop.run_in_executor(
executor, _dataSources
)
return sources
76 changes: 48 additions & 28 deletions aioodbc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import collections
import warnings
from typing import Any, Deque, Dict, Optional, Set

from pyodbc import ProgrammingError

Expand All @@ -14,12 +15,18 @@
__all__ = ["create_pool", "Pool"]


class Pool(asyncio.AbstractServer):
class Pool:
"""Connection pool"""

def __init__(
self, minsize, maxsize, echo, pool_recycle, loop=None, **kwargs
):
self,
minsize: int,
maxsize: int,
echo: bool,
pool_recycle: int,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Dict[Any, Any],
) -> None:
if minsize < 0:
raise ValueError("minsize should be zero or greater")
if maxsize < minsize:
Expand All @@ -30,50 +37,51 @@ def __init__(
warnings.warn(msg, DeprecationWarning, stacklevel=2)

self._minsize = minsize
self._maxsize = maxsize
self._loop = asyncio.get_event_loop()
self._conn_kwargs = kwargs
self._acquiring = 0
self._recycle = pool_recycle
self._free = collections.deque(maxlen=maxsize)
self._free: Deque[Connection] = collections.deque(maxlen=maxsize)
self._cond = asyncio.Condition()
self._used = set()
self._used: Set[Connection] = set()
self._closing = False
self._closed = False
self._echo = echo

@property
def echo(self):
def echo(self) -> bool:
return self._echo

@property
def minsize(self):
def minsize(self) -> int:
return self._minsize

@property
def maxsize(self):
return self._free.maxlen
def maxsize(self) -> int:
return self._free.maxlen or self._maxsize

@property
def size(self):
def size(self) -> int:
return self.freesize + len(self._used) + self._acquiring

@property
def freesize(self):
def freesize(self) -> int:
return len(self._free)

@property
def closed(self):
def closed(self) -> bool:
return self._closed

async def clear(self):
async def clear(self) -> None:
"""Close all free connections in pool."""
async with self._cond:
while self._free:
conn = self._free.popleft()
await conn.close()
self._cond.notify()

def close(self):
def close(self) -> None:
"""Close pool.
Mark all pool connections to be closed on getting back to pool.
Expand All @@ -83,7 +91,7 @@ def close(self):
return
self._closing = True

async def wait_closed(self):
async def wait_closed(self) -> None:
"""Wait for closing all pool's connections."""

if self._closed:
Expand All @@ -103,12 +111,12 @@ async def wait_closed(self):

self._closed = True

def acquire(self):
def acquire(self) -> _ContextManager[Connection]:
"""Acquire free connection from the pool."""
coro = self._acquire()
return _ContextManager[Connection](coro, self.release)

async def _acquire(self):
async def _acquire(self) -> Connection:
if self._closing:
raise RuntimeError("Cannot acquire connection after closing pool")
async with self._cond:
Expand All @@ -123,7 +131,7 @@ async def _acquire(self):
else:
await self._cond.wait()

async def _fill_free_pool(self, override_min):
async def _fill_free_pool(self, override_min: bool) -> None:
n, free = 0, len(self._free)
while n < free:
conn = self._free[-1]
Expand Down Expand Up @@ -168,11 +176,11 @@ async def _fill_free_pool(self, override_min):
finally:
self._acquiring -= 1

async def _wakeup(self):
async def _wakeup(self) -> None:
async with self._cond:
self._cond.notify()

async def release(self, conn):
async def release(self, conn: Connection) -> None:
"""Release free connection back to the connection pool."""
assert conn in self._used, (conn, self._used)
self._used.remove(conn)
Expand All @@ -183,10 +191,12 @@ async def release(self, conn):
self._free.append(conn)
await self._wakeup()

async def __aenter__(self):
async def __aenter__(self) -> "Pool":
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(
self, exc_type: None, exc_val: None, exc_tb: None
) -> None:
self.close()
await self.wait_closed()

Expand All @@ -197,14 +207,19 @@ async def _destroy_pool(pool: "Pool") -> None:


async def _create_pool(
minsize=10, maxsize=10, echo=False, pool_recycle=-1, **kwargs
):
minsize: int = 10,
maxsize: int = 10,
echo: bool = False,
pool_recycle: int = -1,
**kwargs: Dict[Any, Any],
) -> Pool:
pool = Pool(
minsize=minsize,
maxsize=maxsize,
echo=echo,
pool_recycle=pool_recycle,
**kwargs
loop=None,
**kwargs,
)
if minsize > 0:
async with pool._cond:
Expand All @@ -213,8 +228,13 @@ async def _create_pool(


def create_pool(
minsize=10, maxsize=10, echo=False, loop=None, pool_recycle=-1, **kwargs
):
minsize: int = 10,
maxsize: int = 10,
echo: bool = False,
loop: None = None,
pool_recycle: int = -1,
**kwargs: Dict[Any, Any],
) -> _ContextManager[Pool]:
if loop is not None:
msg = "Explicit loop is deprecated, and has no effect."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
Expand All @@ -225,7 +245,7 @@ def create_pool(
maxsize=maxsize,
echo=echo,
pool_recycle=pool_recycle,
**kwargs
**kwargs,
),
_destroy_pool,
)

0 comments on commit eea5dfb

Please sign in to comment.