Skip to content

Commit

Permalink
avoid an unclosed auto_aiter
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Apr 7, 2024
1 parent 3fd91e4 commit f0e5558
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/jinja2/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .utils import _PassArg
from .utils import pass_eval_context

if t.TYPE_CHECKING:
import typing_extensions as te

V = t.TypeVar("V")


Expand Down Expand Up @@ -67,15 +70,35 @@ async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V":
return t.cast("V", value)


async def auto_aiter(
class _IterableToAsyncIterableIterator(t.Generic[V]):
def __init__(self, iterator: "t.Iterator[V]"):
self._iterator = iterator

def __aiter__(self) -> "te.Self":
return self

async def __anext__(self) -> V:
try:
return next(self._iterator)
except StopIteration as e:
raise StopAsyncIteration(e.value)


class _IterableToAsyncIterable(t.Generic[V]):
def __init__(self, iterable: "t.Iterable[V]"):
self._iterable = iterable

def __aiter__(self) -> "_IterableToAsyncIterableIterator[V]"
return _IterableToAsyncIterableIterator(iter(self._iterable))


def auto_aiter(
iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]",
) -> "t.AsyncIterator[V]":
) -> "t.AsyncIterable[V]":
if hasattr(iterable, "__aiter__"):
async for item in t.cast("t.AsyncIterable[V]", iterable):
yield item
return iterable
else:
for item in iterable:
yield item
return _IterableToAsyncIterable(iterable)


async def auto_to_list(
Expand Down

0 comments on commit f0e5558

Please sign in to comment.