Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Client.close() has finished in __aexit__ #9769

Merged
merged 1 commit into from May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 19 additions & 14 deletions discord/client.py
Expand Up @@ -284,7 +284,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._closing_task: Optional[asyncio.Task[None]] = None
self._ready: asyncio.Event = MISSING
self._application: Optional[AppInfo] = None
self._connection._get_websocket = self._get_websocket
Expand All @@ -304,7 +304,10 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
# This avoids double-calling a user-provided .close()
if self._closing_task:
await self._closing_task
else:
await self.close()

# internals
Expand Down Expand Up @@ -724,22 +727,24 @@ async def close(self) -> None:

Closes the connection to Discord.
"""
if self._closed:
return
if self._closing_task:
return await self._closing_task

self._closed = True
async def _close():
await self._connection.close()

await self._connection.close()
if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)

if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
await self.http.close()

await self.http.close()
if self._ready is not MISSING:
self._ready.clear()

if self._ready is not MISSING:
self._ready.clear()
self.loop = MISSING

self.loop = MISSING
self._closing_task = asyncio.create_task(_close())
await self._closing_task

def clear(self) -> None:
"""Clears the internal state of the bot.
Expand All @@ -748,7 +753,7 @@ def clear(self) -> None:
and :meth:`is_ready` both return ``False`` along with the bot's internal
cache cleared.
"""
self._closed = False
self._closing_task = None
self._ready.clear()
self._connection.clear()
self.http.clear()
Expand Down Expand Up @@ -868,7 +873,7 @@ async def runner():

def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
return self._closing_task is not None

@property
def activity(self) -> Optional[ActivityTypes]:
Expand Down
21 changes: 12 additions & 9 deletions discord/shard.py
Expand Up @@ -470,18 +470,21 @@ async def close(self) -> None:
Closes the connection to Discord.
"""
if self.is_closed():
return
if self._closing_task:
return await self._closing_task

async def _close():
await self._connection.close()

self._closed = True
await self._connection.close()
to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)

to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
if to_close:
await asyncio.wait(to_close)
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))

await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
self._closing_task = asyncio.create_task(_close())
await self._closing_task

async def change_presence(
self,
Expand Down