Skip to content

Commit

Permalink
Ensure Client.close() has finished in __aexit__
Browse files Browse the repository at this point in the history
This wraps the closing behavior in a task. Subsequent callers of
.close() now await that same close finishing rather than short
circuiting. This prevents a user-called close outside of __aexit__ from
not finishing before no longer having a running event loop.
  • Loading branch information
mikeshardmind committed May 5, 2024
1 parent 8fd1fd8 commit 88f62d8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
33 changes: 19 additions & 14 deletions discord/client.py
Expand Up @@ -287,7 +287,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._closing_task: Optional[asyncio.Task[None]] = None
self._ready: asyncio.Event = MISSING
self._application: Optional[AppInfo] = None
self._connection._get_websocket = self._get_websocket
Expand All @@ -307,7 +307,10 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
# This avoids double-calling a user-provided .close()
if self._closing_task:
await self._closing_task
else:
await self.close()

# internals
Expand Down Expand Up @@ -726,22 +729,24 @@ async def close(self) -> None:
Closes the connection to Discord.
"""
if self._closed:
return
if self._closing_task:
return await self._closing_task

self._closed = True
async def _close():
await self._connection.close()

await self._connection.close()
if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)

if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
await self.http.close()

await self.http.close()
if self._ready is not MISSING:
self._ready.clear()

if self._ready is not MISSING:
self._ready.clear()
self.loop = MISSING

self.loop = MISSING
self._closing_task = asyncio.create_task(_close())
await self._closing_task

def clear(self) -> None:
"""Clears the internal state of the bot.
Expand All @@ -750,7 +755,7 @@ def clear(self) -> None:
and :meth:`is_ready` both return ``False`` along with the bot's internal
cache cleared.
"""
self._closed = False
self._closing_task = None
self._ready.clear()
self._connection.clear()
self.http.clear()
Expand Down Expand Up @@ -870,7 +875,7 @@ async def runner():

def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
return self._closing_task is not None

@property
def activity(self) -> Optional[ActivityTypes]:
Expand Down
21 changes: 12 additions & 9 deletions discord/shard.py
Expand Up @@ -481,18 +481,21 @@ async def close(self) -> None:
Closes the connection to Discord.
"""
if self.is_closed():
return
if self._closing_task:
return await self._closing_task

async def _close():
await self._connection.close()

self._closed = True
await self._connection.close()
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)

to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))

await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
self._closing_task = asyncio.create_task(_close())
await self._closing_task

async def change_presence(
self,
Expand Down

0 comments on commit 88f62d8

Please sign in to comment.