Skip to content

Commit

Permalink
Refactor for changes to pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrattli committed Dec 9, 2020
1 parent 8058d0f commit 6cd5b39
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 169 deletions.
50 changes: 25 additions & 25 deletions aioreactive/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
MailboxProcessor,
Nothing,
Option,
TailCallResult,
Some,
TailCall,
TailCallResult,
match,
pipe,
tailrec_async,
Expand Down Expand Up @@ -58,7 +58,7 @@ async def subscribe_async(aobv: AsyncObserver[TSource]) -> AsyncDisposable:
key=Key(0),
)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None:
def obv(key: Key) -> AsyncObserver[TSource]:
async def asend(value: TSource) -> None:
await safe_obv.asend(value)
Expand All @@ -71,19 +71,19 @@ async def aclose() -> None:

return AsyncAnonymousObserver(asend, athrow, aclose)

async def update(msg: Msg, model: Model[TSource]) -> Model[TSource]:
async def update(msg: Msg[TSource], model: Model[TSource]) -> Model[TSource]:
# log.debug("update: %s, model: %s", msg, model)
with match(msg) as m:
for xs in InnerObservableMsg.case(m):
with match(msg) as case:
for xs in case(InnerObservableMsg[TSource]):
if max_concurrent == 0 or len(model.subscriptions) < max_concurrent:
inner = await xs.subscribe_async(obv(model.key))
return model.replace(
subscriptions=model.subscriptions.add(model.key, inner),
key=Key(model.key + 1),
)

return model.replace(queue=model.queue.append(xs))
for key in InnerCompletedMsg.case(m):
lst = FrozenList.singleton(xs)
return model.replace(queue=model.queue.append(lst))
for key in case(InnerCompletedMsg):
subscriptions = model.subscriptions.remove(key)
if len(model.queue):
xs = model.queue[0]
Expand All @@ -100,14 +100,14 @@ async def update(msg: Msg, model: Model[TSource]) -> Model[TSource]:
if model.is_stopped:
await safe_obv.aclose()
return model.replace(subscriptions=map.empty)
while CompletedMsg.case(m):
while case(CompletedMsg):
if not model.subscriptions:
log.debug("merge_inner: closing!")
await safe_obv.aclose()

return model.replace(is_stopped=True)

while m.default():
while case.default():
for dispose in model.subscriptions.values():
await dispose.dispose_async()

Expand Down Expand Up @@ -179,7 +179,7 @@ def _combine_latest(source: AsyncObservable[TSource]) -> AsyncObservable[Tuple[T
async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncDisposable:
safe_obv, auto_detach = auto_detach_observer(aobv)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None:
@tailrec_async
async def message_loop(
source_value: Option[TSource], other_value: Option[TOther]
Expand All @@ -188,24 +188,24 @@ async def message_loop(

async def get_value(n: Notification[Any]) -> Option[Any]:
with match(n) as m:
for value in OnNext.case(m):
for value in case(OnNext[TSource]):
return Some(value)

for err in OnError.case(m):
for err in case(OnError):
await safe_obv.athrow(err)

while m.default():
await safe_obv.aclose()
return Nothing

m = match(cn)
for value in SourceMsg.case(m):
source_value = await get_value(value)
break
with match(cn) as case:
for value in case(SourceMsg[TSource]):
source_value = await get_value(value)
break

for value in OtherMsg.case(m):
other_value = await get_value(value)
break
for value in case(OtherMsg[TOther]):
other_value = await get_value(value)
break

def binder(s: TSource) -> Option[Tuple[TSource, TOther]]:
def mapper(o: TOther) -> Tuple[TSource, TOther]:
Expand Down Expand Up @@ -260,20 +260,20 @@ def _with_latest_from(source: AsyncObservable[TSource]) -> AsyncObservable[Tuple
async def subscribe_async(aobv: AsyncObserver[Tuple[TSource, TOther]]) -> AsyncDisposable:
safe_obv, auto_detach = auto_detach_observer(aobv)

async def worker(inbox: MailboxProcessor[Msg]) -> None:
async def worker(inbox: MailboxProcessor[Msg[TSource]]) -> None:
@tailrec_async
async def message_loop(latest: Option[TOther]) -> TailCallResult[NoReturn]:
cn = await inbox.receive()

async def get_value(n: Notification[Any]) -> Option[Any]:
with match(n) as m:
for value in OnNext.case(m):
with match(n) as case:
for value in case(OnNext[TSource]):
return Some(value)

for err in OnError.case(m):
for err in case(OnError[TSource]):
await safe_obv.athrow(err)

while m.default():
while case.default():
await safe_obv.aclose()
return Nothing

Expand Down
4 changes: 2 additions & 2 deletions aioreactive/create.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from asyncio import Future
from typing import AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar

from expression.core import TailCallResult, aiotools, tailrec_async
from expression.core.fn import TailCall
Expand Down Expand Up @@ -131,7 +131,7 @@ async def subscribe_async(_: AsyncObserver[TSource]) -> AsyncDisposable:
return AsyncAnonymousObservable(subscribe_async)


def fail(error: Exception) -> AsyncObservable[TSource]:
def fail(error: Exception) -> AsyncObservable[Any]:
"""Returns the observable sequence that terminates exceptionally
with the specified exception."""

Expand Down
10 changes: 5 additions & 5 deletions aioreactive/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from expression.core import (
MailboxProcessor,
Option,
TailCallResult,
TailCall,
TailCallResult,
aiotools,
compose,
fst,
Expand Down Expand Up @@ -126,19 +126,19 @@ async def message_loop(latest: Notification[TSource]) -> TailCallResult[NoReturn
n = await inbox.receive()

async def get_latest() -> Notification[TSource]:
with match(n) as m:
for x in OnNext.case(m):
with match(n) as case:
for x in case(OnNext[TSource]):
if n == latest:
break
try:
await safe_obv.asend(x)
except Exception as ex:
await safe_obv.athrow(ex)
break
for err in OnError.case(m):
for err in case(OnError[TSource]):
await safe_obv.athrow(err)
break
while m.case(OnCompleted):
while case(OnCompleted):
await safe_obv.aclose()
break

Expand Down
133 changes: 67 additions & 66 deletions aioreactive/msg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Internal messages used by mailbox processors. Do not import or use.
"""
from abc import abstractclassmethod
from abc import ABC
from dataclasses import dataclass
from typing import Any, Generic, Iterable, NewType, TypeVar
from typing import Any, Iterable, NewType, Type, TypeVar, get_origin

from expression.core import Matcher
from expression.core import SupportsMatch
from expression.system import AsyncDisposable

from .notification import Notification
Expand All @@ -16,123 +16,124 @@
Key = NewType("Key", int)


class Msg:
"""Message base class.
Contains overloads for pattern matching to avoid any type casting
later.
"""

@abstractclassmethod
def case(cls, matcher: Matcher) -> Any:
raise NotImplementedError
class Msg(SupportsMatch[TSource], ABC):
"""Message base class."""


@dataclass
class SourceMsg(Msg, Generic[TSource]):
class SourceMsg(Msg[Notification[TSource]], SupportsMatch[TSource]):
value: Notification[TSource]

@classmethod
def case(cls, matcher: Matcher) -> Iterable[Notification[TSource]]:
"""Helper to cast the match result to correct type."""
return matcher.case(cls)

def __match__(self, pattern: Any) -> Iterable[Notification[TSource]]:
if isinstance(self, pattern):
return [self.value]
origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [self.value]
except TypeError:
pass
return []


@dataclass
class OtherMsg(Msg, Generic[TOther]):
class OtherMsg(Msg[Notification[TOther]], SupportsMatch[TOther]):
value: Notification[TOther]

@classmethod
def case(cls, matcher: Matcher) -> Iterable[Notification[TOther]]:
"""Helper to cast the match result to correct type."""

return matcher.case(cls)

def __match__(self, pattern: Any) -> Iterable[Notification[TOther]]:
if isinstance(self, pattern):
return [self.value]
origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [self.value]
except TypeError:
pass
return []


@dataclass
class DisposableMsg(Msg):
class DisposableMsg(Msg[AsyncDisposable], SupportsMatch[AsyncDisposable]):
"""Message containing a diposable."""

disposable: AsyncDisposable

@classmethod
def case(cls, matcher: Matcher) -> Iterable[Notification[AsyncDisposable]]:
"""Helper to cast the match result to correct type."""

return matcher.case(cls)

def __match__(self, pattern: Any) -> Iterable[AsyncDisposable]:
if isinstance(self, pattern):
return [self.disposable]
try:
if isinstance(self, pattern):
return [self.disposable]
except TypeError:
pass
return []


@dataclass
class InnerObservableMsg(Msg, Generic[TSource]):
class InnerObservableMsg(Msg[AsyncObservable[TSource]], SupportsMatch[AsyncObservable[TSource]]):
"""Message containing an inner observable."""

inner_observable: AsyncObservable[TSource]

@classmethod
def case(cls, matcher: Matcher) -> Iterable[AsyncObservable[TSource]]:
"""Helper to cast the match result to correct type."""

return matcher.case(cls)

def __match__(self, pattern: Any) -> Iterable[AsyncObservable[TSource]]:
if isinstance(self, pattern):
return [self.inner_observable]
origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [self.inner_observable]
except TypeError:
pass
return []


@dataclass
class InnerCompletedMsg(Msg):
class InnerCompletedMsg(Msg[TSource]):
"""Message notifying that the inner observable completed."""

key: Key

@classmethod
def case(cls, matcher: Matcher) -> Iterable[Key]:
"""Helper to cast the match result to correct type."""

return matcher.case(cls)

def __match__(self, pattern: Any) -> Iterable[Key]:
if isinstance(self, pattern):
return [self.key]
origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [self.key]
except TypeError:
pass
return []


class CompletedMsg(Msg):
class CompletedMsg_(Msg[Any]):
"""Message notifying that the observable sequence completed."""

@classmethod
def case(cls, matcher: Matcher) -> Iterable[bool]:
"""Helper to cast the match result to correct type."""
def __match__(self, pattern: Any) -> Iterable[bool]:
if self is pattern:
return [True]

return matcher.case(cls)
origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [True]
except TypeError:
pass

return []


CompletedMsg_ = CompletedMsg() # Singleton
CompletedMsg = CompletedMsg_() # Singleton


class DisposeMsg(Msg):
class DisposeMsg_(Msg[None]):
"""Message notifying that the operator got disposed."""

pass
def __match__(self, pattern: Any) -> Iterable[bool]:

if self is pattern:
return [True]

origin: Any = get_origin(pattern)
try:
if isinstance(self, origin or pattern):
return [True]
except TypeError:
pass

return []


DisposeMsg_ = DisposeMsg() # Singleton
DisposeMsg = DisposeMsg_() # Singleton


__all__ = ["Msg", "DisposeMsg", "CompletedMsg", "InnerCompletedMsg", "InnerObservableMsg", "DisposableMsg"]

0 comments on commit 6cd5b39

Please sign in to comment.