Skip to content

Commit

Permalink
Allow setting AppCommandContext and AppInstallationType on the tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Rapptz committed Mar 28, 2024
1 parent 31c74bd commit ff07ad0
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 41 deletions.
1 change: 1 addition & 0 deletions discord/app_commands/__init__.py
Expand Up @@ -16,5 +16,6 @@
from .namespace import *
from .transformers import *
from .translator import *
from .installs import *
from . import checks as checks
from .checks import Cooldown as Cooldown
66 changes: 34 additions & 32 deletions discord/app_commands/commands.py
Expand Up @@ -49,7 +49,7 @@
from copy import copy as shallow_copy

from ..enums import AppCommandOptionType, AppCommandType, ChannelType, Locale
from ..flags import AppCommandContext, AppInstallationType
from .installs import AppCommandContext, AppInstallationType
from .models import Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CheckFailure, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
Expand All @@ -66,6 +66,8 @@
from ..abc import Snowflake
from .namespace import Namespace
from .models import ChoiceT
from .tree import CommandTree
from .._types import ClientT

# Generally, these two libraries are supposed to be separate from each other.
# However, for type hinting purposes it's unfortunately necessary for one to
Expand Down Expand Up @@ -744,8 +746,8 @@ def _copy_with(

return copy

async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
name_localizations: Dict[str, str] = {}
description_localizations: Dict[str, str] = {}

Expand All @@ -771,7 +773,7 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]
]
return base

def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
# If we have a parent then our type is a subcommand
# Otherwise, the type falls back to the specific command type (e.g. slash command or context menu)
option_type = AppCommandType.chat_input.value if self.parent is None else AppCommandOptionType.subcommand.value
Expand All @@ -786,8 +788,8 @@ def to_dict(self) -> Dict[str, Any]:
base['nsfw'] = self.nsfw
base['dm_permission'] = not self.guild_only
base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value
base['contexts'] = self.allowed_contexts.to_array() if self.allowed_contexts is not None else None
base['integration_types'] = self.allowed_installs.to_array() if self.allowed_installs is not None else None
base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts)
base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs)

return base

Expand Down Expand Up @@ -1277,8 +1279,8 @@ def qualified_name(self) -> str:
""":class:`str`: Returns the fully qualified command name."""
return self.name

async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
context = TranslationContext(location=TranslationContextLocation.command_name, data=self)
if self._locale_name:
name_localizations: Dict[str, str] = {}
Expand All @@ -1290,13 +1292,13 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]
base['name_localizations'] = name_localizations
return base

def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
return {
'name': self.name,
'type': self.type.value,
'dm_permission': not self.guild_only,
'contexts': self.allowed_contexts.to_array() if self.allowed_contexts is not None else None,
'integration_types': self.allowed_installs.to_array() if self.allowed_installs is not None else None,
'contexts': tree.allowed_contexts._merge_to_array(self.allowed_contexts),
'integration_types': tree.allowed_installs._merge_to_array(self.allowed_installs),
'default_member_permissions': None if self.default_permissions is None else self.default_permissions.value,
'nsfw': self.nsfw,
}
Expand Down Expand Up @@ -1711,8 +1713,8 @@ def _copy_with(

return copy

async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]:
base = self.to_dict()
async def get_translated_payload(self, tree: CommandTree[ClientT], translator: Translator) -> Dict[str, Any]:
base = self.to_dict(tree)
name_localizations: Dict[str, str] = {}
description_localizations: Dict[str, str] = {}

Expand All @@ -1732,26 +1734,26 @@ async def get_translated_payload(self, translator: Translator) -> Dict[str, Any]

base['name_localizations'] = name_localizations
base['description_localizations'] = description_localizations
base['options'] = [await child.get_translated_payload(translator) for child in self._children.values()]
base['options'] = [await child.get_translated_payload(tree, translator) for child in self._children.values()]
return base

def to_dict(self) -> Dict[str, Any]:
def to_dict(self, tree: CommandTree[ClientT]) -> Dict[str, Any]:
# If this has a parent command then it's part of a subcommand group
# Otherwise, it's just a regular command
option_type = 1 if self.parent is None else AppCommandOptionType.subcommand_group.value
base: Dict[str, Any] = {
'name': self.name,
'description': self.description,
'type': option_type,
'options': [child.to_dict() for child in self._children.values()],
'options': [child.to_dict(tree) for child in self._children.values()],
}

if self.parent is None:
base['nsfw'] = self.nsfw
base['dm_permission'] = not self.guild_only
base['default_member_permissions'] = None if self.default_permissions is None else self.default_permissions.value
base['contexts'] = self.allowed_contexts.to_array() if self.allowed_contexts is not None else None
base['integration_types'] = self.allowed_installs.to_array() if self.allowed_installs is not None else None
base['contexts'] = tree.allowed_contexts._merge_to_array(self.allowed_contexts)
base['integration_types'] = tree.allowed_installs._merge_to_array(self.allowed_installs)

return base

Expand Down Expand Up @@ -2501,12 +2503,12 @@ async def my_guild_only_command(interaction: discord.Interaction) -> None:
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = True
allowed_contexts = f.allowed_contexts or AppCommandContext.none()
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
f.__discord_app_commands_guild_only__ = True # type: ignore # Runtime attribute assignment

allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none()
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment

allowed_contexts.guild = True
Expand Down Expand Up @@ -2545,10 +2547,10 @@ async def my_private_channel_only_command(interaction: discord.Interaction) -> N
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext.none()
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none()
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment

allowed_contexts.private_channel = True
Expand Down Expand Up @@ -2587,10 +2589,10 @@ async def my_dm_only_command(interaction: discord.Interaction) -> None:
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext.none()
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none()
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment

allowed_contexts.dm_channel = True
Expand Down Expand Up @@ -2628,10 +2630,10 @@ async def my_command(interaction: discord.Interaction) -> None:
def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
f.guild_only = False
allowed_contexts = f.allowed_contexts or AppCommandContext.none()
allowed_contexts = f.allowed_contexts or AppCommandContext()
f.allowed_contexts = allowed_contexts
else:
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext.none()
allowed_contexts = getattr(f, '__discord_app_commands_contexts__', None) or AppCommandContext()
f.__discord_app_commands_contexts__ = allowed_contexts # type: ignore # Runtime attribute assignment

if guilds is not MISSING:
Expand Down Expand Up @@ -2668,10 +2670,10 @@ async def my_guild_install_command(interaction: discord.Interaction) -> None:

def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType.none()
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none()
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment

allowed_installs.guild = True
Expand Down Expand Up @@ -2706,10 +2708,10 @@ async def my_user_install_command(interaction: discord.Interaction) -> None:

def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType.none()
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none()
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment

allowed_installs.user = True
Expand Down Expand Up @@ -2748,10 +2750,10 @@ async def my_command(interaction: discord.Interaction) -> None:

def inner(f: T) -> T:
if isinstance(f, (Command, Group, ContextMenu)):
allowed_installs = f.allowed_installs or AppInstallationType.none()
allowed_installs = f.allowed_installs or AppInstallationType()
f.allowed_installs = allowed_installs
else:
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType.none()
allowed_installs = getattr(f, '__discord_app_commands_installation_types__', None) or AppInstallationType()
f.__discord_app_commands_installation_types__ = allowed_installs # type: ignore # Runtime attribute assignment

if guilds is not MISSING:
Expand Down

0 comments on commit ff07ad0

Please sign in to comment.