Skip to content

Commit

Permalink
Ensure Client.close() has finished in __aexit__
Browse files Browse the repository at this point in the history
This wraps the closing behavior in a task. Subsequent callers of
.close() now await that same close finishing rather than short
circuiting. This prevents a user-called close outside of __aexit__ from
not finishing before no longer having a running event loop.
  • Loading branch information
mikeshardmind committed Mar 27, 2024
1 parent ab287e7 commit fce0fc3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
33 changes: 19 additions & 14 deletions discord/client.py
Expand Up @@ -284,7 +284,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._closing_task: Optional[asyncio.Task[None]] = None
self._ready: asyncio.Event = MISSING
self._application: Optional[AppInfo] = None
self._connection._get_websocket = self._get_websocket
Expand All @@ -304,7 +304,10 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
# This avoids double-calling a user-provided .close()
if self._closing_task:
await self._closing_task
else:
await self.close()

# internals
Expand Down Expand Up @@ -724,22 +727,24 @@ async def close(self) -> None:
Closes the connection to Discord.
"""
if self._closed:
return
if self._closing_task:
return await self._closing_task

self._closed = True
async def _close():
await self._connection.close()

await self._connection.close()
if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)

if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
await self.http.close()

await self.http.close()
if self._ready is not MISSING:
self._ready.clear()

if self._ready is not MISSING:
self._ready.clear()
self.loop = MISSING

self.loop = MISSING
self._closing_task = asyncio.create_task(_close())
await self._closing_task

def clear(self) -> None:
"""Clears the internal state of the bot.
Expand All @@ -748,7 +753,7 @@ def clear(self) -> None:
and :meth:`is_ready` both return ``False`` along with the bot's internal
cache cleared.
"""
self._closed = False
self._closing_task = None
self._ready.clear()
self._connection.clear()
self.http.clear()
Expand Down Expand Up @@ -868,7 +873,7 @@ async def runner():

def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
return self._closing_task is not None

@property
def activity(self) -> Optional[ActivityTypes]:
Expand Down
21 changes: 12 additions & 9 deletions discord/shard.py
Expand Up @@ -470,18 +470,21 @@ async def close(self) -> None:
Closes the connection to Discord.
"""
if self.is_closed():
return
if self._closing_task:
return await self._closing_task

async def _close():
await self._connection.close()

self._closed = True
await self._connection.close()
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)

to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))

await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
self._closing_task = asyncio.create_task(_close())
await self._closing_task

async def change_presence(
self,
Expand Down

0 comments on commit fce0fc3

Please sign in to comment.