Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flags positional (consume rest) #9805

Merged
merged 14 commits into from May 5, 2024
37 changes: 37 additions & 0 deletions discord/ext/commands/flags.py
Expand Up @@ -79,6 +79,10 @@ class Flag:
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
positional: :class:`bool`
Whether the flag is positional or not. There can only be one positional flag.
.. versionadded:: 2.4
"""

name: str = MISSING
Expand All @@ -89,6 +93,7 @@ class Flag:
max_args: int = MISSING
override: bool = MISSING
description: str = MISSING
positional: bool = MISSING
cast_to_dict: bool = False

@property
Expand All @@ -109,6 +114,7 @@ def flag(
override: bool = MISSING,
converter: Any = MISSING,
description: str = MISSING,
positional: bool = MISSING,
) -> Any:
"""Override default functionality and parameters of the underlying :class:`FlagConverter`
class attributes.
Expand Down Expand Up @@ -136,6 +142,10 @@ class attributes.
description: :class:`str`
The description of the flag. Shown for hybrid commands when they're
used as application commands.
positional: :class:`bool`
Whether the flag is positional or not. There can only be one positional flag.
.. versionadded:: 2.4
"""
return Flag(
name=name,
Expand All @@ -145,6 +155,7 @@ class attributes.
override=override,
annotation=converter,
description=description,
positional=positional,
)


Expand All @@ -171,6 +182,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
positional: Optional[Flag] = None
for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag):
Expand All @@ -183,6 +195,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.name is MISSING:
flag.name = name

if flag.positional:
if positional is not None:
raise TypeError(f"{flag.name!r} positional flag conflicts with {positional.name!r} flag.")
positional = flag

annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)

if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
Expand Down Expand Up @@ -270,6 +287,7 @@ class FlagsMeta(type):
__commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str
__commands_flag_prefix__: str
__commands_flag_positional__: Optional[Flag]

def __new__(
cls,
Expand Down Expand Up @@ -324,9 +342,13 @@ def __new__(
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')

positional: Optional[Flag] = None
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
if flag.positional:
positional = flag
attrs['__commands_flag_positional__'] = positional

forbidden = set(delimiter).union(prefix)
for flag_name in flags:
Expand Down Expand Up @@ -500,10 +522,25 @@ def parse_flags(cls, argument: str, *, ignore_extra: bool = True) -> Dict[str, L
result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
positional_flag = cls.__commands_flag_positional__
last_position = 0
last_flag: Optional[Flag] = None

case_insensitive = cls.__commands_flag_case_insensitive__

if positional_flag is not None:
match = cls.__commands_flag_regex__.search(argument)
if match is not None:
begin, end = match.span(0)
value = argument[:begin].strip()
else:
value = argument.strip()
last_position = len(argument)

if value:
name = positional_flag.name.casefold() if case_insensitive else positional_flag.name
result[name] = [value]

for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
Expand Down