Source code for discord_tools.flags

"""
The MIT License (MIT)

Copyright (c) 2024-present Developer Anonymous

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations

from dataclasses import dataclass
import re
from typing import TYPE_CHECKING, Any, TypeVar, Union

from discord.utils import maybe_coroutine, MISSING
from discord.ext.commands import (
    FlagConverter,
    Flag as BaseFlag,
    MissingFlagArgument,
    TooManyArguments,
    Context,
    Bot,
    AutoShardedBot,
    MissingRequiredFlag,
    TooManyFlags,
)
from discord.ext.commands.flags import convert_flag

BotT = TypeVar("BotT", bound=Union[Bot, AutoShardedBot], covariant=True)

if TYPE_CHECKING:
    from typing_extensions import Self

__all__ = (
    "Flag",
    "flag",
    "ImplicitBoolFlagConverter",
)


[docs] @dataclass class Flag(BaseFlag): """Represents a flag parameter for a :class:`FlagConverter`. The :func:`flag` function helps with the creation of these flag objects, but it is not necessary to do so. These cannot be constructed manually. .. versionadded:: 1.0 Attributes ---------- name: :class:`str` The name of the flag. aliases: List[:class:`str`] The aliases of the flag name. attribute: :class:`str` The attribute in the class that corresponds to this flag. default: Any The default value of the flag, if available. annotation: Any The underlying evaluated annotation of the flag. max_args: :class:`int` The maximum number of arguments the flag can accept. A negative value indicates an unlimited amount of arguments. override: :class:`bool` Whether multiple given values override the previous one. description: :class:`str` The description of the flag. Shown for hybrid commands when they're used as application commands. positional: :class:`bool` Whether the flag is positional or not. There can only be one positional flag. implicit: :class:`bool` Whether the flag is implicit, this means that only being present makes the flag value ``True``, and if not present, ``False``. .. warning:: This can only be used on subclasses of :class:`ImplicitBoolFlagConverter`. .. note:: Settings this to ``True`` will change the ``default`` and ``annotation`` to ``False`` and ``bool``, respectively. """ implicit: bool = MISSING def __post_init__(self): if self.implicit is True: self.annotation = bool self.default = False
[docs] def flag( *, name: str = MISSING, aliases: list[str] = MISSING, default: Any = MISSING, max_args: int = MISSING, override: bool = MISSING, converter: Any = MISSING, description: str = MISSING, positional: bool = MISSING, implicit: bool = MISSING, ) -> Any: """Override default functionality and parameters of the underlying :class:`FlagConverter` class attributes. .. versionadded:: 1.0 Parameters ------------ name: :class:`str` The flag name. If not given, defaults to the attribute name. aliases: List[:class:`str`] Aliases to the flag name. If not given no aliases are set. default: Any The default parameter. This could be either a value or a callable that takes :class:`Context` as its sole parameter. If not given then it defaults to the default value given to the attribute. max_args: :class:`int` The maximum number of arguments the flag can accept. A negative value indicates an unlimited amount of arguments. The default value depends on the annotation given. override: :class:`bool` Whether multiple given values overrides the previous value. The default value depends on the annotation given. converter: Any The converter to use for this flag. This replaces the annotation at runtime which is transparent to type checkers. description: :class:`str` The description of the flag. Shown for hybrid commands when they're used as application commands. positional: :class:`bool` Whether the flag is positional or not. There can only be one positional flag. implicit: :class:`bool` Whether the flag is implicit or not. This means that only being present makes the flag value be ``True``, and if not present, `False``. .. warning:: This can only be used on subclasses of :class:`ImplicitBoolFlagConverter`. .. note:: Settings this to ``True`` will change the ``default`` and ``converter`` values to ``False`` and ``bool``, respectively. """ kwgs = { "name": name, "aliases": aliases, "default": default, "max_args": max_args, "override": override, "annotation": converter, "description": description, "positional": positional, } # Try to default to discord.py's Flag object if implicit is MISSING: return BaseFlag(**kwgs) return Flag(**kwgs, implicit=implicit)
[docs] class ImplicitBoolFlagConverter(FlagConverter): """A custom :class:`discord.ext.commands.FlagConverter` subclass that allows boolean flags to not have a value. .. versionadded:: 1.0 For example: .. code-block:: python3 from discord_tools import ImplicitBoolFlagConverter, flag class MyFlags(ImplicitBoolFlagConverter): some_flag = flag(implicit=True) """ if TYPE_CHECKING: __commands_flags__: dict[str, Flag | BaseFlag] @classmethod def parse_flags( cls, argument: str, *, ignore_extra: bool = True ) -> dict[str, list[str]]: result: dict[str, list[str]] = {} flags = cls.__commands_flags__ aliases = cls.__commands_flag_aliases__ positional_flag = cls.__commands_flag_positional__ last_position = 0 last_flag: Flag | BaseFlag | None = None case_insensitive = cls.__commands_flag_case_insensitive__ regex_flags = 0 if case_insensitive: flags = {key.casefold(): value for key, value in flags.items()} aliases = { key.casefold(): value.casefold() for key, value in aliases.items() } regex_flags = re.IGNORECASE prefix = cls.__commands_flag_prefix__ delimiter = cls.__commands_flag_delimiter__ keys = [re.escape(k) for k in flags] keys.extend(re.escape(a) for a in aliases) keys = sorted(keys, key=len, reverse=True) joined = "|".join(keys) pattern = re.compile( f"(({re.escape(prefix)})(?P<flag>{joined})({re.escape(delimiter)}?))", flags=regex_flags, ) if positional_flag is not None: match = pattern.search(argument) if match is not None: begin, end = match.span(0) value = argument[:begin].strip() else: value = argument.strip() last_position = len(argument) if value: name = ( positional_flag.name.casefold() if case_insensitive else positional_flag.name ) result[name] = [value] for match in pattern.finditer(argument): begin, end = match.span(0) key = match.group("flag") if case_insensitive: key = key.casefold() if key in aliases: key = aliases[key] flag = flags.get(key) if last_position and last_flag is not None: value = argument[last_position : begin - 1].lstrip() is_implicit = getattr(last_flag, "implicit", False) delim = match.group("delimiter") if not delim and not is_implicit: continue # ignore if not value and not is_implicit: raise MissingFlagArgument(last_flag) elif is_implicit: value = "1" name = last_flag.name.casefold() if case_insensitive else last_flag.name try: values = result[name] except KeyError: result[name] = [value] else: values.append(value) last_position = end last_flag = flag value = argument[last_position:].strip() if last_flag is not None: is_implicit = getattr(last_flag, "implicit", False) if not value and not is_implicit: raise MissingFlagArgument(last_flag) elif is_implicit: value = "1" name = last_flag.name.casefold() if case_insensitive else last_flag.name try: values = result[name] except KeyError: result[name] = [value] else: values.append(value) elif value and not ignore_extra: raise TooManyArguments(f"Too many arguments passed to {cls.__name__}") return result @classmethod async def convert(cls, ctx: Context[BotT], argument: str) -> Self: ignore_extra = True if ( ctx.command is not None and ctx.current_parameter is not None and ctx.current_parameter.kind == ctx.current_parameter.KEYWORD_ONLY ): ignore_extra = ctx.command.ignore_extra arguments = cls.parse_flags(argument, ignore_extra=ignore_extra) flags = cls.__commands_flags__ self = cls.__new__(cls) for name, flag in flags.items(): try: values = arguments[name] except KeyError: if flag.required: raise MissingRequiredFlag(flag) else: if callable(flag.default): default = await maybe_coroutine(flag.default, ctx) setattr(self, flag.attribute, default) else: setattr(self, flag.attribute, flag.default) continue if flag.max_args > 0 and len(values) > flag.max_args: if flag.override: value = values[-flag.max_args :] else: raise TooManyFlags(flag, values) if flag.max_args == 1: value = await convert_flag(ctx, values[0], flag) setattr(self, flag.attribute, value) continue values = [await convert_flag(ctx, value, flag) for value in values] if flag.cast_to_dict: values = dict(values) setattr(self, flag.attribute, values) return self