Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch bug where Client.close() doesn't complete fully when called by user-code #9768

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 17 additions & 17 deletions discord/client.py
Expand Up @@ -287,7 +287,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._closure: Optional[asyncio.Task[None]] = None
self._ready: asyncio.Event = MISSING
self._application: Optional[AppInfo] = None
self._connection._get_websocket = self._get_websocket
Expand All @@ -307,8 +307,7 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
await self.close()
await self.close()

# internals

Expand Down Expand Up @@ -653,7 +652,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
'initial': True,
'shard_id': self.shard_id,
}
while not self.is_closed():
while not self._closure:
try:
coro = DiscordWebSocket.from_client(self, **ws_params)
self.ws = await asyncio.wait_for(coro, timeout=60.0)
Expand Down Expand Up @@ -683,7 +682,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
return
raise

if self.is_closed():
if self._closure is not None:
return

# If we get connection reset by peer then try to RESUME
Expand Down Expand Up @@ -721,16 +720,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
session=self.ws.session_id,
)

async def close(self) -> None:
"""|coro|

Closes the connection to Discord.
"""
if self._closed:
return

self._closed = True

async def _close(self) -> None:
await self._connection.close()

if self.ws is not None and self.ws.open:
Expand All @@ -743,14 +733,24 @@ async def close(self) -> None:

self.loop = MISSING

async def close(self) -> None:
"""|coro|

Closes the connection to Discord.
"""
if self._closure is None:
self._closure = self.loop.create_task(self._close())

await self._closure

def clear(self) -> None:
"""Clears the internal state of the bot.

After this, the bot can be considered "re-opened", i.e. :meth:`is_closed`
and :meth:`is_ready` both return ``False`` along with the bot's internal
cache cleared.
"""
self._closed = False
self._closure = None
self._ready.clear()
self._connection.clear()
self.http.clear()
Expand Down Expand Up @@ -870,7 +870,7 @@ async def runner():

def is_closed(self) -> bool:
""":class:`bool`: Indicates if the websocket connection is closed."""
return self._closed
return self._closure is not None and self._closure.done()

@property
def activity(self) -> Optional[ActivityTypes]:
Expand Down
14 changes: 3 additions & 11 deletions discord/shard.py
Expand Up @@ -429,7 +429,7 @@ async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool
ret.launch()

async def launch_shards(self) -> None:
if self.is_closed():
if self._closure is not None:
return

if self.shard_count is None:
Expand All @@ -456,7 +456,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
self._reconnect = reconnect
await self.launch_shards()

while not self.is_closed():
while not self._closure:
item = await self.__queue.get()
if item.type == EventType.close:
await self.close()
Expand All @@ -476,15 +476,7 @@ async def connect(self, *, reconnect: bool = True) -> None:
elif item.type == EventType.clean_close:
return

async def close(self) -> None:
"""|coro|

Closes the connection to Discord.
"""
if self.is_closed():
return

self._closed = True
async def _close(self) -> None:
await self._connection.close()

to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()]
Expand Down