Skip to content

Commit

Permalink
Add support for Polls
Browse files Browse the repository at this point in the history
Co-authored-by: owocado <24418520+owocado@users.noreply.github.com>
Co-authored-by: Josh <8677174+bijij@users.noreply.github.com>
Co-authored-by: Trevor Flahardy <75498301+trevorflahardy@users.noreply.github.com>
  • Loading branch information
4 people committed May 10, 2024
1 parent a1206df commit e43bd86
Show file tree
Hide file tree
Showing 19 changed files with 1,097 additions and 1 deletion.
1 change: 1 addition & 0 deletions discord/__init__.py
Expand Up @@ -69,6 +69,7 @@
from .components import *
from .threads import *
from .automod import *
from .poll import *


class VersionInfo(NamedTuple):
Expand Down
14 changes: 14 additions & 0 deletions discord/abc.py
Expand Up @@ -92,6 +92,7 @@
VoiceChannel,
StageChannel,
)
from .poll import Poll
from .threads import Thread
from .ui.view import View
from .types.channel import (
Expand Down Expand Up @@ -1350,6 +1351,7 @@ async def send(
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1370,6 +1372,7 @@ async def send(
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1390,6 +1393,7 @@ async def send(
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1410,6 +1414,7 @@ async def send(
view: View = ...,
suppress_embeds: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -1431,6 +1436,7 @@ async def send(
view: Optional[View] = None,
suppress_embeds: bool = False,
silent: bool = False,
poll: Optional[Poll] = None,
) -> Message:
"""|coro|
Expand Down Expand Up @@ -1516,6 +1522,10 @@ async def send(
in the UI, but will not actually send a notification.
.. versionadded:: 2.2
poll: :class:`~discord.Poll`
The poll to send with this message.
.. versionadded:: 2.4
Raises
--------
Expand Down Expand Up @@ -1582,13 +1592,17 @@ async def send(
stickers=sticker_ids,
view=view,
flags=flags,
poll=poll,
) as params:
data = await state.http.send_message(channel.id, params=params)

ret = state.create_message(channel=channel, data=data)
if view and not view.is_finished():
state.store_view(view, ret.id)

if poll:
poll._update(ret)

if delete_after is not None:
await ret.delete(delay=delete_after)
return ret
Expand Down
26 changes: 26 additions & 0 deletions discord/client.py
Expand Up @@ -107,6 +107,7 @@
RawThreadMembersUpdate,
RawThreadUpdateEvent,
RawTypingEvent,
RawPollVoteActionEvent,
)
from .reaction import Reaction
from .role import Role
Expand All @@ -116,6 +117,7 @@
from .ui.item import Item
from .voice_client import VoiceProtocol
from .audit_logs import AuditLogEntry
from .poll import PollAnswer


# fmt: off
Expand Down Expand Up @@ -1815,6 +1817,30 @@ async def wait_for(
) -> Tuple[Member, VoiceState, VoiceState]:
...

# Polls

@overload
async def wait_for(
self,
event: Literal['poll_vote_add', 'poll_vote_remove'],
/,
*,
check: Optional[Callable[[Union[User, Member], PollAnswer], bool]] = None,
timeout: Optional[float] = None,
) -> Tuple[Union[User, Member], PollAnswer]:
...

@overload
async def wait_for(
self,
event: Literal['raw_poll_vote_add', 'raw_poll_vote_remove'],
/,
*,
check: Optional[Callable[[RawPollVoteActionEvent], bool]] = None,
timeout: Optional[float] = None,
) -> RawPollVoteActionEvent:
...

# Commands

@overload
Expand Down
5 changes: 5 additions & 0 deletions discord/enums.py
Expand Up @@ -73,6 +73,7 @@
'SKUType',
'EntitlementType',
'EntitlementOwnerType',
'PollLayoutType',
)


Expand Down Expand Up @@ -818,6 +819,10 @@ class EntitlementOwnerType(Enum):
user = 2


class PollLayoutType(Enum):
default = 1


def create_unknown_value(cls: Type[E], val: Any) -> E:
value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below
name = f'unknown_{val}'
Expand Down
17 changes: 17 additions & 0 deletions discord/ext/commands/context.py
Expand Up @@ -50,6 +50,7 @@
from discord.message import MessageReference, PartialMessage
from discord.ui import View
from discord.types.interactions import ApplicationCommandInteractionData
from discord.poll import Poll

from .cog import Cog
from .core import Command
Expand Down Expand Up @@ -641,6 +642,7 @@ async def reply(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -662,6 +664,7 @@ async def reply(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -683,6 +686,7 @@ async def reply(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -704,6 +708,7 @@ async def reply(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand Down Expand Up @@ -826,6 +831,7 @@ async def send(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -847,6 +853,7 @@ async def send(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -868,6 +875,7 @@ async def send(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -889,6 +897,7 @@ async def send(
suppress_embeds: bool = ...,
ephemeral: bool = ...,
silent: bool = ...,
poll: Poll = ...,
) -> Message:
...

Expand All @@ -911,6 +920,7 @@ async def send(
suppress_embeds: bool = False,
ephemeral: bool = False,
silent: bool = False,
poll: Poll = MISSING,
) -> Message:
"""|coro|
Expand Down Expand Up @@ -1000,6 +1010,11 @@ async def send(
.. versionadded:: 2.2
poll: :class:`~discord.Poll`
The poll to send with this message.
.. versionadded:: 2.4
Raises
--------
~discord.HTTPException
Expand Down Expand Up @@ -1037,6 +1052,7 @@ async def send(
view=view,
suppress_embeds=suppress_embeds,
silent=silent,
poll=poll,
) # type: ignore # The overloads don't support Optional but the implementation does

# Convert the kwargs from None to MISSING to appease the remaining implementations
Expand All @@ -1052,6 +1068,7 @@ async def send(
'suppress_embeds': suppress_embeds,
'ephemeral': ephemeral,
'silent': silent,
'poll': poll,
}

if self.interaction.response.is_done():
Expand Down
51 changes: 51 additions & 0 deletions discord/flags.py
Expand Up @@ -1257,6 +1257,57 @@ def auto_moderation_execution(self):
"""
return 1 << 21

@alias_flag_value
def polls(self):
""":class:`bool`: Whether guild and direct messages poll related events are enabled.
This is a shortcut to set or get both :attr:`guild_polls` and :attr:`dm_polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (both guilds and DMs)
- :func:`on_poll_vote_remove` (both guilds and DMs)
- :func:`on_raw_poll_vote_add` (both guilds and DMs)
- :func:`on_raw_poll_vote_remove` (both guilds and DMs)
.. versionadded:: 2.4
"""
return (1 << 24) | (1 << 25)

@flag_value
def guild_polls(self):
""":class:`bool`: Whether guild poll related events are enabled.
See also :attr:`dm_polls` and :attr:`polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for guilds)
- :func:`on_poll_vote_remove` (only for guilds)
- :func:`on_raw_poll_vote_add` (only for guilds)
- :func:`on_raw_poll_vote_remove` (only for guilds)
.. versionadded:: 2.4
"""
return 1 << 24

@flag_value
def dm_polls(self):
""":class:`bool`: Whether direct messages poll related events are enabled.
See also :attr:`guild_polls` and :attr:`polls`.
This corresponds to the following events:
- :func:`on_poll_vote_add` (only for DMs)
- :func:`on_poll_vote_remove` (only for DMs)
- :func:`on_raw_poll_vote_add` (only for DMs)
- :func:`on_raw_poll_vote_remove` (only for DMs)
.. versionadded:: 2.4
"""
return 1 << 25


@fill_with_flags()
class MemberCacheFlags(BaseFlags):
Expand Down
43 changes: 43 additions & 0 deletions discord/http.py
Expand Up @@ -68,6 +68,7 @@
from .embeds import Embed
from .message import Attachment
from .flags import MessageFlags
from .poll import Poll

from .types import (
appinfo,
Expand All @@ -91,6 +92,7 @@
sticker,
welcome_screen,
sku,
poll,
)
from .types.snowflake import Snowflake, SnowflakeList

Expand Down Expand Up @@ -154,6 +156,7 @@ def handle_message_parameters(
thread_name: str = MISSING,
channel_payload: Dict[str, Any] = MISSING,
applied_tags: Optional[SnowflakeList] = MISSING,
poll: Optional[Poll] = MISSING,
) -> MultipartParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
Expand Down Expand Up @@ -256,6 +259,9 @@ def handle_message_parameters(
}
payload.update(channel_payload)

if poll not in (MISSING, None):
payload['poll'] = poll._to_dict() # type: ignore

multipart = []
if files:
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)})
Expand Down Expand Up @@ -2513,6 +2519,43 @@ def edit_application_info(self, *, reason: Optional[str], payload: Any) -> Respo
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(Route('PATCH', '/applications/@me'), json=payload, reason=reason)

def get_poll_answer_voters(
self,
channel_id: Snowflake,
message_id: Snowflake,
answer_id: Snowflake,
after: Optional[Snowflake] = None,
limit: Optional[int] = None,
) -> Response[poll.PollAnswerVoters]:
params = {}

if after:
params['after'] = int(after)

if limit is not None:
params['limit'] = limit

return self.request(
Route(
'GET',
'/channels/{channel_id}/polls/{message_id}/answers/{answer_id}',
channel_id=channel_id,
message_id=message_id,
answer_id=answer_id,
),
params=params,
)

def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]:
return self.request(
Route(
'POST',
'/channels/{channel_id}/polls/{message_id}/expire',
channel_id=channel_id,
message_id=message_id,
)
)

async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str:
try:
data = await self.request(Route('GET', '/gateway'))
Expand Down

0 comments on commit e43bd86

Please sign in to comment.