Skip to content

Commit

Permalink
Add asyncio_fn to async_proxy() and DeduplicateDecorator (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkang-quora committed Nov 15, 2023
1 parent fdcd08d commit 2c696bc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
18 changes: 9 additions & 9 deletions asynq/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ def __get__(self, owner, cls):


class AsyncProxyDecorator(AsyncDecorator):
def __init__(self, fn):
def __init__(self, fn, asyncio_fn=None):
# we don't need the task class but still need to pass it to the superclass
AsyncDecorator.__init__(self, fn, None)
AsyncDecorator.__init__(self, fn, None, asyncio_fn=asyncio_fn)

def _call_pure(self, args, kwargs):
return self.fn(*args, **kwargs)


class AsyncAndSyncPairProxyDecorator(AsyncProxyDecorator):
def __init__(self, fn, sync_fn):
AsyncProxyDecorator.__init__(self, fn)
def __init__(self, fn, sync_fn, asyncio_fn=None):
AsyncProxyDecorator.__init__(self, fn, asyncio_fn=asyncio_fn)
self.sync_fn = sync_fn

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -297,19 +297,19 @@ def decorate(fn):
return decorate


def async_proxy(pure=False, sync_fn=None):
def async_proxy(pure=False, sync_fn=None, asyncio_fn=None):
if sync_fn is not None:
assert pure is False, "sync_fn=? cannot be used together with pure=True"

def decorate(fn):
if pure:
return fn
if sync_fn is None:
return qcore.decorators.decorate(AsyncProxyDecorator)(fn)
return qcore.decorators.decorate(AsyncProxyDecorator, asyncio_fn)(fn)
else:
return qcore.decorators.decorate(AsyncAndSyncPairProxyDecorator, sync_fn)(
fn
)
return qcore.decorators.decorate(
AsyncAndSyncPairProxyDecorator, sync_fn, asyncio_fn
)(fn)

return decorate

Expand Down
37 changes: 36 additions & 1 deletion asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from qcore.asserts import assert_eq

import asynq
from asynq.tools import AsyncTimer
from asynq import ConstFuture
from asynq.tools import AsyncTimer, deduplicate


def test_asyncio():
Expand Down Expand Up @@ -118,3 +119,37 @@ def i():

assert i() == 100
assert asyncio.run(i.asyncio()) == 100


def test_proxy():
async def k(x):
return x + 999

@asynq.async_proxy(asyncio_fn=k)
def j(x):
return ConstFuture(x + 888)

assert j(-100) == 788
assert j.asynq(-200).value() == 688
assert asyncio.run(j.asyncio(-300)) == 699


def test_deduplicate():
@deduplicate()
@asynq.asynq()
def l():
return 3

async def n():
return 4

@deduplicate()
@asynq.asynq(asyncio_fn=n)
def m():
return 3

assert l() == 3
assert asyncio.run(l.asyncio()) == 3

assert m() == 3
assert asyncio.run(m.asyncio()) == 4
27 changes: 18 additions & 9 deletions asynq/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
"""

from .contexts import AsyncContext
from .decorators import asynq, async_call, AsyncDecorator, AsyncDecoratorBinder

from qcore import get_original_fn, utime
from qcore.caching import get_args_tuple, get_kwargs_defaults, LRUCache
from qcore.events import EventHook
from qcore.errors import reraise, prepare_for_reraise
from qcore.decorators import decorate
import functools
import inspect
import itertools
import weakref
import threading
import time
import weakref
from typing import Any, Awaitable

from qcore import get_original_fn, utime
from qcore.caching import LRUCache, get_args_tuple, get_kwargs_defaults
from qcore.decorators import decorate
from qcore.errors import prepare_for_reraise, reraise
from qcore.events import EventHook

from .asynq_to_async import is_asyncio_mode
from .contexts import AsyncContext
from .decorators import AsyncDecorator, AsyncDecoratorBinder, async_call, asynq


@asynq()
Expand Down Expand Up @@ -346,7 +349,13 @@ def __init__(self, fn, task_cls, keygetter):
def cache_key(self, args, kwargs):
return self.keygetter(args, kwargs), threading.current_thread(), id(self.fn)

def asyncio(self, *args, **kwargs) -> Awaitable[Any]:
return self.fn.asyncio(*args, **kwargs)

def asynq(self, *args, **kwargs):
if is_asyncio_mode():
return self.fn.asyncio(*args, **kwargs)

cache_key = self.cache_key(args, kwargs)

try:
Expand Down

0 comments on commit 2c696bc

Please sign in to comment.