Skip to content

Commit

Permalink
Fix type issues for pyright 1.1.96
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrattli committed Dec 26, 2020
1 parent f1899c5 commit 746d6b6
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 124 deletions.
58 changes: 29 additions & 29 deletions aioreactive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .observers import AsyncAnonymousObserver, AsyncAwaitableObserver, AsyncIteratorObserver, AsyncNotificationObserver
from .subject import AsyncSingleSubject, AsyncSubject
from .subscription import run
from .types import AsyncObserver, Stream
from .types import AsyncObserver, Projection

TSource = TypeVar("TSource")
TResult = TypeVar("TResult")
Expand Down Expand Up @@ -407,7 +407,7 @@ def as_chained(source: AsyncObservable[TSource]) -> AsyncRx[TSource]:
return AsyncRx(source)


def choose(chooser: Callable[[TSource], Option[TSource]]) -> Stream[TSource, TSource]:
def choose(chooser: Callable[[TSource], Option[TSource]]) -> Projection[TSource, TSource]:
"""Choose.
Applies the given function to each element of the stream and returns
Expand All @@ -427,7 +427,7 @@ def choose(chooser: Callable[[TSource], Option[TSource]]) -> Stream[TSource, TSo
return choose(chooser)


def choose_async(chooser: Callable[[TSource], Awaitable[Option[TSource]]]) -> Stream[TSource, TSource]:
def choose_async(chooser: Callable[[TSource], Awaitable[Option[TSource]]]) -> Projection[TSource, TSource]:
"""Choose async.
Applies the given async function to each element of the stream and
Expand All @@ -447,13 +447,13 @@ def choose_async(chooser: Callable[[TSource], Awaitable[Option[TSource]]]) -> St
return choose_async(chooser)


def combine_latest(other: AsyncObservable[TOther]) -> Stream[TSource, Tuple[TSource, TOther]]:
def combine_latest(other: AsyncObservable[TOther]) -> Projection[TSource, Tuple[TSource, TOther]]:
from .combine import combine_latest

return pipe(combine_latest(other))


def debounce(seconds: float) -> Stream[TSource, TSource]:
def debounce(seconds: float) -> Projection[TSource, TSource]:
"""Debounce source stream.
Ignores values from a source stream which are followed by another
Expand All @@ -475,13 +475,13 @@ def debounce(seconds: float) -> Stream[TSource, TSource]:
return debounce(seconds)


def catch(handler: Callable[[Exception], AsyncObservable[TSource]]) -> Stream[TSource, TSource]:
def catch(handler: Callable[[Exception], AsyncObservable[TSource]]) -> Projection[TSource, TSource]:
from .transform import catch

return catch(handler)


def concat(other: AsyncObservable[TSource]) -> Stream[TSource, TSource]:
def concat(other: AsyncObservable[TSource]) -> Projection[TSource, TSource]:
"""Concatenates an observable sequence with another observable
sequence."""

Expand Down Expand Up @@ -509,7 +509,7 @@ def defer(factory: Callable[[], AsyncObservable[TSource]]) -> AsyncObservable[TS
return defer(factory)


def delay(seconds: float) -> Stream[TSource, TSource]:
def delay(seconds: float) -> Projection[TSource, TSource]:
from .timeshift import delay

return delay(seconds)
Expand Down Expand Up @@ -547,7 +547,7 @@ def filter(predicate: Callable[[TSource], bool]) -> Callable[[AsyncObservable[TS
return _filter(predicate)


def filteri(predicate: Callable[[TSource, int], bool]) -> Stream[TSource, TSource]:
def filteri(predicate: Callable[[TSource, int], bool]) -> Projection[TSource, TSource]:
"""Filter with index.
Filters the elements of an observable sequence based on a predicate
Expand Down Expand Up @@ -594,19 +594,19 @@ def from_iterable(iterable: Iterable[TSource]) -> AsyncObservable[TSource]:
return of_seq(iterable)


def flat_map(mapper: Callable[[TSource], AsyncObservable[TResult]]) -> Stream[TSource, TResult]:
def flat_map(mapper: Callable[[TSource], AsyncObservable[TResult]]) -> Projection[TSource, TResult]:
from .transform import flat_map

return flat_map(mapper)


def flat_mapi(mapper: Callable[[TSource, int], AsyncObservable[TResult]]) -> Stream[TSource, TResult]:
def flat_mapi(mapper: Callable[[TSource, int], AsyncObservable[TResult]]) -> Projection[TSource, TResult]:
from .transform import flat_mapi

return flat_mapi(mapper)


def flat_map_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[TResult]]]) -> Stream[TSource, TResult]:
def flat_map_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[TResult]]]) -> Projection[TSource, TResult]:
"""Flap map async.
Asynchronously projects each element of an observable sequence into
Expand All @@ -626,7 +626,7 @@ def flat_map_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[TResult
return flat_map_async(mapper)


def flat_map_latest_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[TResult]]]) -> Stream[TSource, TResult]:
def flat_map_latest_async(mapper: Callable[[TSource], Awaitable[AsyncObservable[TResult]]]) -> Projection[TSource, TResult]:
"""Flat map latest async.
Asynchronosly transforms the items emitted by an source sequence
Expand Down Expand Up @@ -669,13 +669,13 @@ def interval(seconds: float, period: int) -> AsyncObservable[int]:
return interval(seconds, period)


def map(fn: Callable[[TSource], TResult]) -> Stream[TSource, TResult]:
def map(fn: Callable[[TSource], TResult]) -> Projection[TSource, TResult]:
from .transform import map as _map

return _map(fn)


def map_async(mapper: Callable[[TSource], Awaitable[TResult]]) -> Stream[TSource, TResult]:
def map_async(mapper: Callable[[TSource], Awaitable[TResult]]) -> Projection[TSource, TResult]:
"""Map asynchrnously.
Returns an observable sequence whose elements are the result of
Expand All @@ -686,7 +686,7 @@ def map_async(mapper: Callable[[TSource], Awaitable[TResult]]) -> Stream[TSource
return map_async(mapper)


def mapi_async(mapper: Callable[[TSource, int], Awaitable[TResult]]) -> Stream[TSource, TResult]:
def mapi_async(mapper: Callable[[TSource, int], Awaitable[TResult]]) -> Projection[TSource, TResult]:
"""Returns an observable sequence whose elements are the result of
invoking the async mapper function by incorporating the element's
index on each element of the source."""
Expand All @@ -695,7 +695,7 @@ def mapi_async(mapper: Callable[[TSource, int], Awaitable[TResult]]) -> Stream[T
return mapi_async(mapper)


def mapi(mapper: Callable[[TSource, int], TResult]) -> Stream[TSource, TResult]:
def mapi(mapper: Callable[[TSource, int], TResult]) -> Projection[TSource, TResult]:
"""Returns an observable sequence whose elements are the result of
invoking the mapper function and incorporating the element's index
on each element of the source."""
Expand All @@ -704,7 +704,7 @@ def mapi(mapper: Callable[[TSource, int], TResult]) -> Stream[TSource, TResult]:
return mapi(mapper)


def merge(other: AsyncObservable[TSource]) -> Stream[TSource, TSource]:
def merge(other: AsyncObservable[TSource]) -> Projection[TSource, TSource]:
from .create import of_seq

def _(source: AsyncObservable[TSource]) -> AsyncObservable[TSource]:
Expand All @@ -713,7 +713,7 @@ def _(source: AsyncObservable[TSource]) -> AsyncObservable[TSource]:
return _


def merge_inner(max_concurrent: int = 0) -> Stream[AsyncObservable[TSource], TSource]:
def merge_inner(max_concurrent: int = 0) -> Projection[AsyncObservable[TSource], TSource]:
def _merge_inner(source: AsyncObservable[AsyncObservable[TSource]]) -> AsyncObservable[TSource]:
from .combine import merge_inner

Expand All @@ -740,7 +740,7 @@ def of_async(workflow: Awaitable[TSource]) -> AsyncObservable[TSource]:
return of_async(workflow)


def retry(retry_count: int) -> Stream[TSource, TSource]:
def retry(retry_count: int) -> Projection[TSource, TSource]:
from .transform import retry

return retry(retry_count)
Expand All @@ -763,7 +763,7 @@ def single(value: TSource) -> "AsyncObservable[TSource]":
return single(value)


def skip(count: int) -> Stream[TSource, TSource]:
def skip(count: int) -> Projection[TSource, TSource]:
"""Skip items in the stream.
Bypasses a specified number of elements in an observable sequence
Expand All @@ -780,7 +780,7 @@ def skip(count: int) -> Stream[TSource, TSource]:
return skip(count)


def skip_last(count: int) -> Stream[TSource, TSource]:
def skip_last(count: int) -> Projection[TSource, TSource]:
"""Bypasses a specified number of elements at the end of an
observable sequence.
Expand All @@ -802,7 +802,7 @@ def skip_last(count: int) -> Stream[TSource, TSource]:
return skip_last(count)


def starfilter(predicate: Callable[..., bool]) -> Stream[TSource, Tuple[Any, ...]]:
def starfilter(predicate: Callable[..., bool]) -> Projection[TSource, Tuple[Any, ...]]:
"""Filter and spread the arguments to the predicate.
Filters the elements of an observable sequence based on a predicate.
Expand All @@ -815,7 +815,7 @@ def starfilter(predicate: Callable[..., bool]) -> Stream[TSource, Tuple[Any, ...
return starfilter(predicate)


def starmap(mapper: Callable[..., TResult]) -> Stream[TSource, TResult]:
def starmap(mapper: Callable[..., TResult]) -> Projection[TSource, TResult]:
"""Map and spread the arguments to the mapper.
Returns an observable sequence whose elements are the result of
Expand All @@ -826,13 +826,13 @@ def starmap(mapper: Callable[..., TResult]) -> Stream[TSource, TResult]:
return starmap(mapper)


def switch_latest() -> Stream[AsyncObservable[TSource], TSource]:
def switch_latest() -> Projection[AsyncObservable[TSource], TSource]:
from .transform import switch_latest

return switch_latest


def take(count: int) -> Stream[TSource, TSource]:
def take(count: int) -> Projection[TSource, TSource]:
"""Take the first elements from the stream.
Returns a specified number of contiguous elements from the start of
Expand All @@ -850,7 +850,7 @@ def take(count: int) -> Stream[TSource, TSource]:
return take(count)


def take_last(count: int) -> Stream[TSource, TSource]:
def take_last(count: int) -> Projection[TSource, TSource]:
"""Take last elements from stream.
Returns a specified number of contiguous elements from the end of an
Expand All @@ -868,7 +868,7 @@ def take_last(count: int) -> Stream[TSource, TSource]:
return take_last(count)


def take_until(other: AsyncObservable[TResult]) -> Stream[TSource, TSource]:
def take_until(other: AsyncObservable[TResult]) -> Projection[TSource, TSource]:
"""Take elements until other.
Returns the values from the source observable sequence until the
Expand Down Expand Up @@ -900,7 +900,7 @@ def to_async_iterable(source: AsyncObservable[TSource]) -> AsyncIterable[TSource
return to_async_iterable(source)


def with_latest_from(other: AsyncObservable[TOther]) -> Stream[TSource, Tuple[TSource, TOther]]:
def with_latest_from(other: AsyncObservable[TOther]) -> Projection[TSource, Tuple[TSource, TOther]]:
from .combine import with_latest_from

return with_latest_from(other)
Expand Down
10 changes: 4 additions & 6 deletions aioreactive/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .notification import Notification, OnError, OnNext
from .observables import AsyncAnonymousObservable
from .observers import AsyncAnonymousObserver, AsyncNotificationObserver, auto_detach_observer
from .types import AsyncObservable, AsyncObserver, Stream
from .types import AsyncObservable, AsyncObserver, Projection, Zipper

TSource = TypeVar("TSource")
TResult = TypeVar("TResult")
Expand Down Expand Up @@ -160,7 +160,7 @@ def concat_seq(sources: Iterable[AsyncObservable[TSource]]) -> AsyncObservable[T
)


def combine_latest(other: AsyncObservable[TOther]) -> Stream[TSource, Tuple[TSource, TOther]]:
def combine_latest(other: AsyncObservable[TOther]) -> Projection[TSource, Tuple[TSource, TOther]]:
"""Combine latest values.
Merges the specified observable sequences into one observable
Expand Down Expand Up @@ -241,7 +241,7 @@ async def obv_fn2(n: Notification[TOther]) -> None:
return _combine_latest


def with_latest_from(other: AsyncObservable[TOther]) -> Stream[TSource, Tuple[TSource, TOther]]:
def with_latest_from(other: AsyncObservable[TOther]) -> Zipper[TOther]:
"""[summary]
Merges the specified observable sequences into one observable
Expand Down Expand Up @@ -318,9 +318,7 @@ async def obv_fn2(n: Notification[TOther]) -> None:
return _with_latest_from


def zip_seq(
sequence: Iterable[TOther],
) -> Callable[[AsyncObservable[TSource]], AsyncObservable[Tuple[TSource, TOther]]]:
def zip_seq(sequence: Iterable[TOther]) -> Zipper[TOther]:
def _zip_seq(source: AsyncObservable[TSource]) -> AsyncObservable[Tuple[TSource, TOther]]:
async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncDisposable:
safe_obv, auto_detach = auto_detach_observer(aobv)
Expand Down
24 changes: 10 additions & 14 deletions aioreactive/create.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from asyncio import Future
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, cast
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar

from expression.core import TailCallResult, aiotools, tailrec_async
from expression.core.fn import TailCall
Expand Down Expand Up @@ -35,14 +35,12 @@ def create(subscribe: Callable[[AsyncObserver[TSource]], Awaitable[AsyncDisposab
return AsyncAnonymousObservable(subscribe)


def of_async_worker(
worker: Callable[[AsyncObserver[TSource], CancellationToken], Awaitable[None]]
) -> AsyncObservable[TSource]:
def of_async_worker(worker: Callable[[AsyncObserver[Any], CancellationToken], Awaitable[None]]) -> AsyncObservable[Any]:
"""Create async observable from async worker function"""

log.debug("of_async_worker()")

async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
async def subscribe_async(aobv: AsyncObserver[Any]) -> AsyncDisposable:
log.debug("of_async_worker:subscribe_async()")
disposable, token = canceller()
safe_obv = safe_observer(aobv, disposable)
Expand All @@ -66,8 +64,7 @@ async def worker(obv: AsyncObserver[TSource], _: CancellationToken) -> None:
finally:
await obv.aclose()

ret = of_async_worker(worker)
return cast(AsyncObservable[TSource], ret) # NOTE: pyright issue
return of_async_worker(worker)


def of_async_iterable(iterable: AsyncIterable[TSource]) -> AsyncObservable[TSource]:
Expand Down Expand Up @@ -113,20 +110,20 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
return AsyncAnonymousObservable(subscribe_async)


def empty() -> AsyncObservable[TSource]:
def empty() -> AsyncObservable[Any]:
"""Returns an observable sequence with no elements."""

async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
async def subscribe_async(aobv: AsyncObserver[Any]) -> AsyncDisposable:
await aobv.aclose()
return AsyncDisposable.empty()

return AsyncAnonymousObservable(subscribe_async)


def never() -> AsyncObservable[TSource]:
def never() -> AsyncObservable[Any]:
"""Returns an empty observable sequence that never completes."""

async def subscribe_async(_: AsyncObserver[TSource]) -> AsyncDisposable:
async def subscribe_async(_: AsyncObserver[Any]) -> AsyncDisposable:
return AsyncDisposable.empty()

return AsyncAnonymousObservable(subscribe_async)
Expand All @@ -136,7 +133,7 @@ def fail(error: Exception) -> AsyncObservable[Any]:
"""Returns the observable sequence that terminates exceptionally
with the specified exception."""

async def worker(obv: AsyncObserver[TSource], _: CancellationToken) -> None:
async def worker(obv: AsyncObserver[Any], _: CancellationToken) -> None:
await obv.athrow(error)

return of_async_worker(worker)
Expand All @@ -161,8 +158,7 @@ async def worker(obv: AsyncObserver[TSource], token: CancellationToken) -> None:

await obv.aclose()

ret = of_async_worker(worker)
return cast(AsyncObservable[TSource], ret) # NOTE: pyright issue
return of_async_worker(worker)


def defer(factory: Callable[[], AsyncObservable[TSource]]) -> AsyncObservable[TSource]:
Expand Down

0 comments on commit 746d6b6

Please sign in to comment.