Skip to content

Commit

Permalink
Fix leave_context_asyncio to handle cancelled asyncio task
Browse files Browse the repository at this point in the history
  • Loading branch information
dkang-quora committed Apr 17, 2024
1 parent c61a61e commit e1a965c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
4 changes: 2 additions & 2 deletions asynq/contexts.py
Expand Up @@ -93,7 +93,7 @@ def leave_context_asyncio(context):
debug.write("@async: -context: %s" % debug.str(context))

task = asyncio.current_task()
del getattr(task, ASYNCIO_CONTEXT_FIELD)[id(context)] # type: ignore
getattr(task, ASYNCIO_CONTEXT_FIELD, {}).pop(id(context), None) # type: ignore


def pause_contexts_asyncio(task):
Expand All @@ -104,7 +104,7 @@ def pause_contexts_asyncio(task):


def resume_contexts_asyncio(task):
if not getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False):
if not getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, True):
setattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, True)
for ctx in getattr(task, ASYNCIO_CONTEXT_FIELD, {}).values():
ctx.resume()
Expand Down
21 changes: 13 additions & 8 deletions asynq/decorators.py
Expand Up @@ -15,7 +15,7 @@

import asyncio
import inspect
from typing import Any, Coroutine
from typing import Any, Callable, Coroutine, Generic, TypeVar

import qcore.decorators
import qcore.helpers as core_helpers
Expand All @@ -27,6 +27,8 @@

__traceback_hide__ = True

_T = TypeVar("_T")


def lazy(fn):
"""Converts a function into a lazy one - i.e. its call
Expand Down Expand Up @@ -147,10 +149,11 @@ def is_pure_async_fn(self):
return True


class PureAsyncDecorator(qcore.decorators.DecoratorBase):
class PureAsyncDecorator(qcore.decorators.DecoratorBase, Generic[_T]):
binder_cls = PureAsyncDecoratorBinder
fn: Callable[..., _T]

def __init__(self, fn, task_cls, kwargs={}, asyncio_fn=None):
def __init__(self, fn: Callable[..., _T], task_cls, kwargs={}, asyncio_fn=None):
qcore.decorators.DecoratorBase.__init__(self, fn)
self.task_cls = task_cls
self.needs_wrapper = core_inspection.is_cython_or_generator(fn)
Expand Down Expand Up @@ -188,21 +191,23 @@ def _call_pure(self, args, kwargs):
return self.task_cls(result, self.fn, args, kwargs, **self.kwargs)


class AsyncDecoratorBinder(qcore.decorators.DecoratorBinder):
def asynq(self, *args, **kwargs):
class AsyncDecoratorBinder(qcore.decorators.DecoratorBinder, Generic[_T]):
decorator: "AsyncDecorator[_T]"

def asynq(self, *args, **kwargs) -> async_task.AsyncTask[_T]:
if self.instance is None:
return self.decorator.asynq(*args, **kwargs)
else:
return self.decorator.asynq(self.instance, *args, **kwargs)

def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]:
def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, _T]:
if self.instance is None:
return self.decorator.asyncio(*args, **kwargs)
else:
return self.decorator.asyncio(self.instance, *args, **kwargs)


class AsyncDecorator(PureAsyncDecorator):
class AsyncDecorator(PureAsyncDecorator[_T]):
binder_cls = AsyncDecoratorBinder

def __init__(self, fn, cls, kwargs={}, asyncio_fn=None):
Expand All @@ -211,7 +216,7 @@ def __init__(self, fn, cls, kwargs={}, asyncio_fn=None):
def is_pure_async_fn(self):
return False

def asynq(self, *args, **kwargs):
def asynq(self, *args: Any, **kwargs: Any) -> async_task.AsyncTask[_T]:
return self._call_pure(args, kwargs)

def name(self):
Expand Down

0 comments on commit e1a965c

Please sign in to comment.