diff --git a/changes/1002.deprecation.md b/changes/1002.deprecation.md new file mode 100644 index 0000000000..240d7a839a --- /dev/null +++ b/changes/1002.deprecation.md @@ -0,0 +1 @@ +Deprecate `RESTClientImpl.build_action_row` in favour of `RESTClientImpl.build_message_action_row`. diff --git a/changes/1002.feature.md b/changes/1002.feature.md new file mode 100644 index 0000000000..888c73d716 --- /dev/null +++ b/changes/1002.feature.md @@ -0,0 +1,2 @@ +Implement modal interactions. +- Additionally, it is now guaranteed (typing-wise) that top level components will be an action row diff --git a/hikari/__init__.py b/hikari/__init__.py index 357cf84776..10b95e050a 100644 --- a/hikari/__init__.py +++ b/hikari/__init__.py @@ -71,6 +71,7 @@ from hikari.colors import * from hikari.colours import * from hikari.commands import * +from hikari.components import * from hikari.embeds import * from hikari.emojis import * from hikari.errors import * @@ -105,6 +106,7 @@ from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * from hikari.interactions.component_interactions import * +from hikari.interactions.modal_interactions import * from hikari.invites import * from hikari.iterators import * from hikari.locales import * diff --git a/hikari/__init__.pyi b/hikari/__init__.pyi index 252cf7f8ce..3d59a211ff 100644 --- a/hikari/__init__.pyi +++ b/hikari/__init__.pyi @@ -44,6 +44,7 @@ from hikari.channels import * from hikari.colors import * from hikari.colours import * from hikari.commands import * +from hikari.components import * from hikari.embeds import * from hikari.emojis import * from hikari.errors import * @@ -78,6 +79,7 @@ from hikari.intents import * from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * from hikari.interactions.component_interactions import * +from hikari.interactions.modal_interactions import * from hikari.invites import * from hikari.iterators import * from hikari.locales import * diff --git a/hikari/api/entity_factory.py b/hikari/api/entity_factory.py index 361efbc30c..2c21303fed 100644 --- a/hikari/api/entity_factory.py +++ b/hikari/api/entity_factory.py @@ -53,6 +53,7 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions + from hikari.interactions import modal_interactions from hikari.internal import data_binding @@ -1298,6 +1299,21 @@ def deserialize_command_interaction( The deserialized command interaction object. """ + @abc.abstractmethod + def deserialize_modal_interaction(self, payload: data_binding.JSONObject) -> modal_interactions.ModalInteraction: + """Parse a raw payload from Discord into a modal interaction object. + + Parameters + ---------- + payload : hikari.internal.data_binding.JSONObject + The JSON payload to deserialize. + + Returns + ------- + hikari.interactions.modal_interactions.ModalInteraction + The deserialized modal interaction object. + """ + @abc.abstractmethod def deserialize_interaction(self, payload: data_binding.JSONObject) -> base_interactions.PartialInteraction: """Parse a raw payload from Discord into a interaction object. diff --git a/hikari/api/interaction_server.py b/hikari/api/interaction_server.py index 8019d4cc47..7e122e4701 100644 --- a/hikari/api/interaction_server.py +++ b/hikari/api/interaction_server.py @@ -34,6 +34,7 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions + from hikari.interactions import modal_interactions _InteractionT_co = typing.TypeVar("_InteractionT_co", bound=base_interactions.PartialInteraction, covariant=True) _ResponseT_co = typing.TypeVar("_ResponseT_co", bound=special_endpoints.InteractionResponseBuilder, covariant=True) @@ -41,6 +42,10 @@ special_endpoints.InteractionDeferredBuilder, special_endpoints.InteractionMessageBuilder, ] + _ModalOrMessageResponseBuilder = typing.Union[ + _MessageResponseBuilderT, + special_endpoints.InteractionModalBuilder, + ] ListenerT = typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]] @@ -142,14 +147,14 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes) @abc.abstractmethod def get_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], / - ) -> typing.Optional[ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT]]: + ) -> typing.Optional[ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilder]]: ... @typing.overload @abc.abstractmethod def get_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], / - ) -> typing.Optional[ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT]]: + ) -> typing.Optional[ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilder]]: ... @typing.overload @@ -161,6 +166,13 @@ def get_listener( ]: ... + @typing.overload + @abc.abstractmethod + def get_listener( + self, interaction_type: typing.Type[modal_interactions.ModalInteraction], / + ) -> typing.Optional[ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT]]: + ... + @typing.overload @abc.abstractmethod def get_listener( @@ -191,7 +203,7 @@ def get_listener( def set_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], - listener: typing.Optional[ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT]], + listener: typing.Optional[ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilder]], /, *, replace: bool = False, @@ -203,7 +215,9 @@ def set_listener( def set_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], - listener: typing.Optional[ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT]], + listener: typing.Optional[ + ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilder] + ], /, *, replace: bool = False, @@ -224,6 +238,18 @@ def set_listener( ) -> None: ... + @typing.overload + @abc.abstractmethod + def set_listener( + self, + interaction_type: typing.Type[modal_interactions.ModalInteraction], + listener: typing.Optional[ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT]], + /, + *, + replace: bool = False, + ) -> None: + ... + @abc.abstractmethod def set_listener( self, diff --git a/hikari/api/rest.py b/hikari/api/rest.py index 00e6e043ad..fa45fb8c5a 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -7977,6 +7977,23 @@ def interaction_message_builder( The interaction message response builder object. """ + @abc.abstractmethod + def interaction_modal_builder(self, title: str, custom_id: str) -> special_endpoints.InteractionModalBuilder: + """Create a builder for a modal interaction response. + + Parameters + ---------- + title : builtins.str + The title that will show up in the modal. + custom_id : builtins.str + Developer set custom ID used for identifying interactions with this modal. + + Returns + ------- + hikari.api.special_endpoints.InteractionModalBuilder + The interaction modal response builder object. + """ + @abc.abstractmethod async def fetch_interaction_response( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], token: str @@ -8378,13 +8395,59 @@ async def create_autocomplete_response( If an internal error occurs on Discord while handling the request. """ + async def create_modal_response( + self, + interaction: snowflakes.SnowflakeishOr[base_interactions.PartialInteraction], + token: str, + *, + title: str, + custom_id: str, + component: undefined.UndefinedOr[special_endpoints.ComponentBuilder] = undefined.UNDEFINED, + components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, + ) -> None: + """Create a response by sending a modal. + + Parameters + ---------- + interaction : hikari.snowflakes.SnowflakeishOr[hikari.interactions.base_interactions.PartialInteraction] + Object or ID of the interaction this response is for. + token : builtins.str + The command interaction's token. + title : str + The title that will show up in the modal. + custom_id : str + Developer set custom ID used for identifying interactions with this modal. + + Other Parameters + ---------------- + component : hikari.undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] + A component builders to send in this modal. + components : hikari.undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] + A sequence of component builders to send in this modal. + + Raises + ------ + ValueError + If both `component` and `components` are specified or if none are specified. + """ + + @abc.abstractmethod + def build_message_action_row(self) -> special_endpoints.MessageActionRowBuilder: + """Build a message action row message component for use in message create and REST calls. + + Returns + ------- + hikari.api.special_endpoints.MessageActionRowBuilder + The initialised action row builder. + """ + @abc.abstractmethod - def build_action_row(self) -> special_endpoints.ActionRowBuilder: - """Build an action row message component for use in message create and REST calls. + def build_modal_action_row(self) -> special_endpoints.ModalActionRowBuilder: + """Build an action row modal component for use in interactions and REST calls. Returns ------- - hikari.api.special_endpoints.ActionRowBuilder + hikari.api.special_endpoints.ModalActionRowBuilder The initialised action row builder. """ diff --git a/hikari/api/special_endpoints.py b/hikari/api/special_endpoints.py index 87c3fc88ef..24bb97a841 100644 --- a/hikari/api/special_endpoints.py +++ b/hikari/api/special_endpoints.py @@ -24,7 +24,6 @@ from __future__ import annotations __all__: typing.Sequence[str] = ( - "ActionRowBuilder", "ButtonBuilder", "CommandBuilder", "SlashCommandBuilder", @@ -40,6 +39,10 @@ "LinkButtonBuilder", "SelectMenuBuilder", "SelectOptionBuilder", + "TextInputBuilder", + "InteractionModalBuilder", + "MessageActionRowBuilder", + "ModalActionRowBuilder", ) import abc @@ -53,6 +56,7 @@ from hikari import channels from hikari import colors from hikari import commands + from hikari import components as components_ from hikari import embeds as embeds_ from hikari import emojis from hikari import files @@ -185,13 +189,7 @@ class GuildBuilder(abc.ABC): @property @abc.abstractmethod def name(self) -> str: - """Name of the guild to create. - - Returns - ------- - builtins.str - The guild name. - """ + """Name of the guild to create.""" @property @abc.abstractmethod @@ -199,11 +197,6 @@ def default_message_notifications(self) -> undefined.UndefinedOr[guilds.GuildMes """Default message notification level that can be overwritten. If not overridden, this will use the Discord default level. - - Returns - ------- - hikari.undefined.UndefinedOr[hikari.guilds.GuildMessageNotificationsLevel] - The default message notification level, if overwritten. """ # noqa: D401 - Imperative mood @default_message_notifications.setter @@ -218,11 +211,6 @@ def explicit_content_filter_level(self) -> undefined.UndefinedOr[guilds.GuildExp """Explicit content filter level that can be overwritten. If not overridden, this will use the Discord default level. - - Returns - ------- - hikari.undefined.UndefinedOr[hikari.guilds.GuildExplicitContentFilterLevel] - The explicit content filter level, if overwritten. """ @explicit_content_filter_level.setter @@ -237,11 +225,6 @@ def icon(self) -> undefined.UndefinedOr[files.Resourceish]: """Guild icon to use that can be overwritten. If not overridden, the guild will not have an icon. - - Returns - ------- - hikari.undefined.UndefinedOr[hikari.files.Resourceish] - The guild icon to use, if overwritten. """ @icon.setter @@ -254,11 +237,6 @@ def verification_level(self) -> undefined.UndefinedOr[typing.Union[guilds.GuildV """Verification level required to join the guild that can be overwritten. If not overridden, the guild will use the default verification level for - - Returns - ------- - hikari.undefined.UndefinedOr[typing.Union[hikari.guilds.GuildVerificationLevel, builtins.int]] - The verification level required to join the guild, if overwritten. """ @verification_level.setter @@ -558,13 +536,7 @@ class InteractionResponseBuilder(abc.ABC): @property @abc.abstractmethod def type(self) -> typing.Union[int, base_interactions.ResponseType]: - """Return the type of this response. - - Returns - ------- - typing.Union[builtins.int, hikari.interactions.base_interactions.ResponseType] - The type of response this is. - """ + """Type of this response.""" @abc.abstractmethod def build( @@ -593,13 +565,7 @@ class InteractionDeferredBuilder(InteractionResponseBuilder, abc.ABC): @property @abc.abstractmethod def type(self) -> base_interactions.DeferredResponseTypesT: - """Return the type of this response. - - Returns - ------- - hikari.interactions.base_interactions.DeferredResponseTypesT - The type of response this is. - """ + """Type of this response.""" @property @abc.abstractmethod @@ -609,12 +575,6 @@ def flags(self) -> typing.Union[undefined.UndefinedType, int, messages.MessageFl !!! note As of writing the only message flag which can be set here is `hikari.messages.MessageFlag.EPHEMERAL`. - - Returns - ------- - typing.Union[hikari.undefined.UndefinedType, builtins.int, hikari.messages.MessageFlag] - The message flags this response should have if set else - `hikari.undefined.UNDEFINED`. """ @abc.abstractmethod @@ -644,7 +604,7 @@ class InteractionAutocompleteBuilder(InteractionResponseBuilder, abc.ABC): @property @abc.abstractmethod def choices(self) -> typing.Sequence[commands.CommandChoice]: - """Return autocomplete choices.""" + """Autocomplete choices.""" @abc.abstractmethod def set_choices(self: _T, choices: typing.Sequence[commands.CommandChoice], /) -> _T: @@ -672,13 +632,7 @@ class InteractionMessageBuilder(InteractionResponseBuilder, abc.ABC): @property @abc.abstractmethod def type(self) -> base_interactions.MessageResponseTypesT: - """Return the type of this response. - - Returns - ------- - hikari.interactions.base_interactions.MessageResponseTypesT - The type of response this is. - """ + """Type of this response.""" # Extendable fields @@ -702,13 +656,7 @@ def embeds(self) -> undefined.UndefinedOr[typing.Sequence[embeds_.Embed]]: @property @abc.abstractmethod def content(self) -> undefined.UndefinedOr[str]: - """Response's message content. - - Returns - ------- - hikari.undefined.UndefinedOr[builtins.str] - The response's message content, if set. - """ + """Response's message content.""" @property @abc.abstractmethod @@ -718,37 +666,17 @@ def flags(self) -> typing.Union[undefined.UndefinedType, int, messages.MessageFl !!! note As of writing the only message flag which can be set here is `hikari.messages.MessageFlag.EPHEMERAL`. - - Returns - ------- - typing.Union[hikari.undefined.UndefinedType, builtins.int, hikari.messages.MessageFlag] - The message flags this response should have if set else - `hikari.undefined.UNDEFINED`. """ @property @abc.abstractmethod def is_tts(self) -> undefined.UndefinedOr[bool]: - """Whether this response's content should be treated as text-to-speech. - - Returns - ------- - builtins.bool - Whether this response's content should be treated as text-to-speech. - If left as `hikari.undefined.UNDEFINED` then this will be disabled. - """ + """Whether this response's content should be treated as text-to-speech.""" @property @abc.abstractmethod def mentions_everyone(self) -> undefined.UndefinedOr[bool]: - """Whether @everyone and @here mentions should be enabled for this response. - - Returns - ------- - hikari.undefined.UndefinedOr[builtins.bool] - Whether @everyone mentions should be enabled for this response. - If left as `hikari.undefined.UNDEFINED` then they will be disabled. - """ + """Whether @everyone and @here mentions should be enabled for this response.""" @property @abc.abstractmethod @@ -947,6 +875,67 @@ def set_user_mentions( """ # noqa: E501 - Line too long +class InteractionModalBuilder(InteractionResponseBuilder, abc.ABC): + """Interface of an interaction modal response builder used within REST servers. + + This can be returned by the listener registered to + `hikari.api.interaction_server.InteractionServer` as a response to the interaction + create. + """ + + __slots__: typing.Sequence[str] = () + + @property + @abc.abstractmethod + def type(self) -> typing.Literal[base_interactions.ResponseType.MODAL]: + """Type of this response.""" + + @property + @abc.abstractmethod + def title(self) -> str: + """Title that will show up in the modal.""" + + @property + @abc.abstractmethod + def custom_id(self) -> str: + """Developer set custom ID used for identifying interactions with this modal.""" + + @property + @abc.abstractmethod + def components(self) -> undefined.UndefinedOr[typing.Sequence[ComponentBuilder]]: + """Sequence of component builders to send in this modal.""" + + @abc.abstractmethod + def set_title(self: _T, title: str, /) -> _T: + """Set the title that will show up in the modal. + + Parameters + ---------- + title : builtins.str + The title that will show up in the modal. + """ + + @abc.abstractmethod + def set_custom_id(self: _T, custom_id: str, /) -> _T: + """Set the developer set custom ID used for identifying interactions with this modal. + + Parameters + ---------- + custom_id : builtins.str + The developer set custom ID used for identifying interactions with this modal. + """ + + @abc.abstractmethod + def add_component(self: _T, component: ComponentBuilder, /) -> _T: + """Add a component to this modal. + + Parameters + ---------- + component : ComponentBuilder + The component builder to add to this modal. + """ + + class CommandBuilder(abc.ABC): """Interface of a command builder used when bulk creating commands over REST.""" @@ -970,24 +959,12 @@ def name(self) -> str: @property @abc.abstractmethod def type(self) -> commands.CommandType: - """Return the type of this command. - - Returns - ------- - hikari.commands.CommandType - The type of this command. - """ + """Type of this command.""" @property @abc.abstractmethod def id(self) -> undefined.UndefinedOr[snowflakes.Snowflake]: - """ID of this command. - - Returns - ------- - hikari.undefined.UndefinedOr[hikari.snowflakes.Snowflake] - The ID of this command if set. - """ + """ID of this command.""" @property @abc.abstractmethod @@ -1153,16 +1130,11 @@ class SlashCommandBuilder(CommandBuilder): @property @abc.abstractmethod def description(self) -> str: - """Return the description to set for this command. + """Description to set for this command. !!! warning This should be inclusively between 1-100 characters in length. - - Returns - ------- - builtins.str - The description to set for this command. - """ + """ # noqa: D401 - Imperative mood @property @abc.abstractmethod @@ -1193,13 +1165,7 @@ def set_description_localizations( @property @abc.abstractmethod def options(self) -> typing.Sequence[commands.CommandOption]: - """Sequence of up to 25 of the options set for this command. - - Returns - ------- - typing.Sequence[hikari.commands.CommandOption] - A sequence of up to 25 of the options set for this command. - """ + """Sequence of up to 25 of the options set for this command.""" @abc.abstractmethod def add_option(self: _T, option: commands.CommandOption) -> _T: @@ -1318,26 +1284,13 @@ class ButtonBuilder(ComponentBuilder, abc.ABC, typing.Generic[_ContainerT]): @property @abc.abstractmethod - def style(self) -> typing.Union[messages.ButtonStyle, int]: - """Button's style. - - Returns - ------- - typing.Union[builtins.int, hikari.messages.ButtonStyle] - The button's style. - """ + def style(self) -> typing.Union[components_.ButtonStyle, int]: + """Button's style.""" @property @abc.abstractmethod def emoji(self) -> typing.Union[snowflakes.Snowflakeish, emojis.Emoji, str, undefined.UndefinedType]: - """Emoji which should appear on this button. - - Returns - ------- - typing.Union[hikari.snowflakes.Snowflakeish, hikari.emojis.Emoji, builtins.str, hikari.undefined.UndefinedType] - Object or ID or raw string of the emoji which should be displayed - on this button if set. - """ + """Emoji which should appear on this button.""" @property @abc.abstractmethod @@ -1347,11 +1300,6 @@ def label(self) -> undefined.UndefinedOr[str]: !!! note The text label to that should appear on this button. This may be up to 80 characters long. - - Returns - ------- - hikari.undefined.UndefinedOr[builtins.str] - Text label which should appear on this button. """ @property @@ -1361,11 +1309,6 @@ def is_disabled(self) -> bool: !!! note Defaults to `builtins.False`. - - Returns - ------- - builtins.bool - Whether the button should be marked as disabled. """ @abc.abstractmethod @@ -1439,13 +1382,7 @@ class LinkButtonBuilder(ButtonBuilder[_ContainerT], abc.ABC): @property @abc.abstractmethod def url(self) -> str: - """Url this button should link to when pressed. - - Returns - ------- - builtins.str - Url this button should link to when pressed. - """ + """Url this button should link to when pressed.""" class InteractiveButtonBuilder(ButtonBuilder[_ContainerT], abc.ABC): @@ -1456,13 +1393,7 @@ class InteractiveButtonBuilder(ButtonBuilder[_ContainerT], abc.ABC): @property @abc.abstractmethod def custom_id(self) -> str: - """Developer set custom ID used for identifying interactions with this button. - - Returns - ------- - builtins.str - Developer set custom ID used for identifying interactions with this button. - """ + """Developer set custom ID used for identifying interactions with this button.""" class SelectOptionBuilder(ComponentBuilder, abc.ABC, typing.Generic[_SelectMenuBuilderT]): @@ -1473,47 +1404,22 @@ class SelectOptionBuilder(ComponentBuilder, abc.ABC, typing.Generic[_SelectMenuB @property @abc.abstractmethod def label(self) -> str: - """User-facing name of the option, max 100 characters. - - Returns - ------- - builtins.str - User-facing name of the option. - """ + """User-facing name of the option, max 100 characters.""" @property @abc.abstractmethod def value(self) -> str: - """Developer-defined value of the option, max 100 characters. - - Returns - ------- - builtins.str - Developer-defined value of the option. - """ + """Developer-defined value of the option, max 100 characters.""" @property @abc.abstractmethod def description(self) -> undefined.UndefinedOr[str]: - """Return the description of the option, max 100 characters. - - Returns - ------- - hikari.undefined.UndefinedOr[builtins.str] - Description of the option, if set. - """ + """Description of the option, max 100 characters.""" # noqa: D401 - Imperative mood @property @abc.abstractmethod def emoji(self) -> typing.Union[snowflakes.Snowflakeish, emojis.Emoji, str, undefined.UndefinedType]: - """Emoji which should appear on this option. - - Returns - ------- - typing.Union[hikari.snowflakes.Snowflakeish, hikari.emojis.Emoji, builtins.str, hikari.undefined.UndefinedType] - Object or ID or raw string of the emoji which should be displayed - on this option if set. - """ + """Emoji which should appear on this option.""" @property @abc.abstractmethod @@ -1521,11 +1427,6 @@ def is_default(self) -> bool: """Whether this option should be marked as selected by default. Defaults to `builtins.False`. - - Returns - ------- - builtins.bool - Whether this option should be marked as selected by default. """ @abc.abstractmethod @@ -1598,13 +1499,7 @@ class SelectMenuBuilder(ComponentBuilder, abc.ABC, typing.Generic[_ContainerT]): @property @abc.abstractmethod def custom_id(self) -> str: - """Developer set custom ID used for identifying interactions with this menu. - - Returns - ------- - builtins.str - Developer set custom ID used for identifying interactions with this menu. - """ + """Developer set custom ID used for identifying interactions with this menu.""" @property @abc.abstractmethod @@ -1613,34 +1508,17 @@ def is_disabled(self) -> bool: !!! note Defaults to `builtins.False`. - - Returns - ------- - builtins.bool - Whether the select menu should be marked as disabled. """ @property @abc.abstractmethod def options(self: _SelectMenuBuilderT) -> typing.Sequence[SelectOptionBuilder[_SelectMenuBuilderT]]: - """Sequence of the options set for this select menu. - - Returns - ------- - typing.Sequence[SelectOptionBuilder[Self]] - Sequence of the options set for this select menu. - """ + """Sequence of the options set for this select menu.""" @property @abc.abstractmethod def placeholder(self) -> undefined.UndefinedOr[str]: - """Return the placeholder text to display when no options are selected. - - Returns - ------- - hikari.undefined.UndefinedOr[builtins.str] - Placeholder text to display when no options are selected, if defined. - """ + """Placeholder text to display when no options are selected.""" # noqa: D401 - Imperative mood @property @abc.abstractmethod @@ -1650,11 +1528,6 @@ def min_values(self) -> int: Defaults to 1. Must be less than or equal to `SelectMenuBuilder.max_values` and greater than or equal to 0. - - Returns - ------- - builtins.str - Minimum number of options which must be chosen. """ @property @@ -1665,11 +1538,6 @@ def max_values(self) -> int: Defaults to 1. Must be greater than or equal to `SelectMenuBuilder.min_values` and less than or equal to 25. - - Returns - ------- - builtins.str - Maximum number of options which can be chosen. """ @abc.abstractmethod @@ -1772,22 +1640,197 @@ def add_to_container(self) -> _ContainerT: """ -class ActionRowBuilder(ComponentBuilder, abc.ABC): - """Builder class for action row components.""" +class TextInputBuilder(ComponentBuilder, abc.ABC, typing.Generic[_ContainerT]): + """Builder class for text inputs components.""" __slots__: typing.Sequence[str] = () @property @abc.abstractmethod - def components(self) -> typing.Sequence[ComponentBuilder]: - """Sequence of the component builders registered within this action row. + def custom_id(self) -> str: + """Developer set custom ID used for identifying this text input. + + !!! note + This custom_id is never used in component interaction events. + It is meant to be used purely for resolving components modal interactions. + """ + + @property + @abc.abstractmethod + def label(self) -> str: + """Label above this text input.""" + + @property + @abc.abstractmethod + def style(self) -> components_.TextInputStyle: + """Style to use for the text input.""" + + @property + @abc.abstractmethod + def placeholder(self) -> undefined.UndefinedOr[str]: + """Placeholder text for when the text input is empty.""" # noqa: D401 - Imperative mood + + @property + @abc.abstractmethod + def value(self) -> undefined.UndefinedOr[str]: + """Pre-filled text that will be sent if the user does not write anything.""" + + @property + @abc.abstractmethod + def required(self) -> undefined.UndefinedOr[bool]: + """Whether this text input is required to be filled-in.""" + + @property + @abc.abstractmethod + def min_length(self) -> undefined.UndefinedOr[int]: + """Minimum length the text should have.""" + + @property + @abc.abstractmethod + def max_length(self) -> undefined.UndefinedOr[int]: + """Maximum length the text should have.""" + + @abc.abstractmethod + def set_style(self: _T, style: typing.Union[components_.TextInputStyle, int], /) -> _T: + """Set the style to use for the text input. + + Parameters + ---------- + style : typing.Union[hikari.modal_interactions.TextInputStyle, int] + Style to use for the text input. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_custom_id(self: _T, custom_id: str, /) -> _T: + """Set the developer set custom ID used for identifying this text input. + + Parameters + ---------- + custom_id : builtins.str + Developer set custom ID used for identifying this text input. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_label(self: _T, label: str, /) -> _T: + """Set the label above this text input. + + Parameters + ---------- + label : builtins.str + Label above this text input. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_placeholder(self: _T, placeholder: str, /) -> _T: + """Set the placeholder text for when the text input is empty. + + Parameters + ---------- + placeholder : builtins.str: + Placeholder text that will disappear when the user types anything. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_value(self: _T, value: str, /) -> _T: + """Pre-filled text that will be sent if the user does not write anything. + + Parameters + ---------- + value : builtins.str + Pre-filled text that will be sent if the user does not write anything. Returns ------- - typing.Sequence[ComponentBuilder] - Sequence of the component builders registered within this action row. + TextInputBuilder + The builder object to enable chained calls. """ + @abc.abstractmethod + def set_required(self: _T, required: bool, /) -> _T: + """Set whether this text input is required to be filled-in. + + Parameters + ---------- + required : builtins.bool + Whether this text input is required to be filled-in. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_min_length(self: _T, min_length: int, /) -> _T: + """Set the minimum length the text should have. + + Parameters + ---------- + min_length : builtins.int + The minimum length the text should have. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def set_max_length(self: _T, max_length: int, /) -> _T: + """Set the maximum length the text should have. + + Parameters + ---------- + max_length : builtins.int + The maximum length the text should have. + + Returns + ------- + TextInputBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def add_to_container(self) -> _ContainerT: + """Finalise this builder by adding it to its parent container component. + + Returns + ------- + _ContainerT + The parent container component builder. + """ + + +class MessageActionRowBuilder(ComponentBuilder, abc.ABC): + """Builder class for action row components.""" + + __slots__: typing.Sequence[str] = () + + @property + @abc.abstractmethod + def components(self) -> typing.Sequence[ComponentBuilder]: + """Sequence of the component builders registered within this action row.""" + @abc.abstractmethod def add_component( self: _T, @@ -1815,25 +1858,27 @@ def add_component( @typing.overload @abc.abstractmethod def add_button( - self: _T, style: messages.InteractiveButtonTypesT, custom_id: str, / + self: _T, style: components_.InteractiveButtonTypesT, custom_id: str, / ) -> InteractiveButtonBuilder[_T]: ... @typing.overload @abc.abstractmethod - def add_button(self: _T, style: typing.Literal[messages.ButtonStyle.LINK, 5], url: str, /) -> LinkButtonBuilder[_T]: + def add_button( + self: _T, style: typing.Literal[components_.ButtonStyle.LINK, 5], url: str, / + ) -> LinkButtonBuilder[_T]: ... @typing.overload @abc.abstractmethod def add_button( - self: _T, style: typing.Union[int, messages.ButtonStyle], url_or_custom_id: str, / + self: _T, style: typing.Union[int, components_.ButtonStyle], url_or_custom_id: str, / ) -> typing.Union[LinkButtonBuilder[_T], InteractiveButtonBuilder[_T]]: ... @abc.abstractmethod def add_button( - self: _T, style: typing.Union[int, messages.ButtonStyle], url_or_custom_id: str, / + self: _T, style: typing.Union[int, components_.ButtonStyle], url_or_custom_id: str, / ) -> typing.Union[LinkButtonBuilder[_T], InteractiveButtonBuilder[_T]]: """Add a button component to this action row builder. @@ -1853,7 +1898,7 @@ def add_button( typing.Union[LinkButtonBuilder[Self], InteractiveButtonBuilder[Self]] Button builder object. `ButtonBuilder.add_to_container` should be called to finalise the - button. + component. """ @abc.abstractmethod @@ -1871,5 +1916,63 @@ def add_select_menu(self: _T, custom_id: str, /) -> SelectMenuBuilder[_T]: SelectMenuBuilder[Self] Select menu builder object. `SelectMenuBuilder.add_to_container` should be called to finalise the - button. + component. + """ + + +class ModalActionRowBuilder(ComponentBuilder, abc.ABC): + """Builder class for modal action row components.""" + + __slots__: typing.Sequence[str] = () + + @property + @abc.abstractmethod + def components(self) -> typing.Sequence[ComponentBuilder]: + """Sequence of the component builders registered within this action row.""" + + @abc.abstractmethod + def add_component( + self: _T, + component: ComponentBuilder, + /, + ) -> _T: + """Add a component to this action row builder. + + !!! warning + It is generally better to use `ActionRowBuilder.add_button` + and `ActionRowBuilder.add_select_menu` to add your + component to the builder. Those methods utilize this one. + + Parameters + ---------- + component : ComponentBuilder + The component builder to add to the action row. + + Returns + ------- + ActionRowBuilder + The builder object to enable chained calls. + """ + + @abc.abstractmethod + def add_text_input( + self: _T, + custom_id: str, + label: str, + ) -> TextInputBuilder[_T]: + """Add a text input component to this action row builder. + + Parameters + ---------- + custom_id : builtins.str + Developer set custom ID used for identifying this text input. + label : builtins.str + Label above this text input. + + Returns + ------- + TextInputBuilder[Self] + Text input builder object. + `TextInputBuilder.add_to_container` should be called to finalise the + component. """ diff --git a/hikari/components.py b/hikari/components.py new file mode 100644 index 0000000000..d21af5f69d --- /dev/null +++ b/hikari/components.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Application and entities that are used to describe components on Discord.""" + +from __future__ import annotations + +__all__: typing.Sequence[str] = ( + "ComponentType", + "PartialComponent", + "ActionRowComponent", + "ButtonStyle", + "ButtonComponent", + "SelectMenuOption", + "SelectMenuComponent", + "TextInputStyle", + "TextInputComponent", + "InteractiveButtonTypes", + "InteractiveButtonTypesT", + "MessageComponentTypesT", + "ModalComponentTypesT", + "MessageActionRowComponent", + "ModalActionRowComponent", +) + +import typing + +import attr + +from hikari import emojis +from hikari.internal import enums + + +@typing.final +class ComponentType(int, enums.Enum): + """Types of components found within Discord.""" + + ACTION_ROW = 1 + """A non-interactive container component for other types of components. + + !!! note + As this is a container component it can never be contained within another + component and therefore will always be top-level. + + !!! note + As of writing this can only contain one component type. + """ + + BUTTON = 2 + """A button component. + + !!! note + This cannot be top-level and must be within a container component such + as `ComponentType.ACTION_ROW`. + """ + + SELECT_MENU = 3 + """A select menu component. + + !!! note + This cannot be top-level and must be within a container component such + as `ComponentType.ACTION_ROW`. + """ + + TEXT_INPUT = 4 + """A text input component. + + !! note + This component may only be used inside a modal container. + + !!! note + This cannot be top-level and must be within a container component such + as `ComponentType.ACTION_ROW`. + """ + + +@typing.final +class ButtonStyle(int, enums.Enum): + """Enum of the available button styles. + + More information, such as how these look, can be found at + https://discord.com/developers/docs/interactions/message-components#buttons-button-styles + """ + + PRIMARY = 1 + """A blurple "call to action" button.""" + + SECONDARY = 2 + """A grey neutral button.""" + + SUCCESS = 3 + """A green button.""" + + DANGER = 4 + """A red button (usually indicates a destructive action).""" + + LINK = 5 + """A grey button which navigates to a URL. + + !!! warning + Unlike the other button styles, clicking this one will not trigger an + interaction and custom_id shouldn't be included for this style. + """ + + +@typing.final +class TextInputStyle(int, enums.Enum): + """A text input style.""" + + SHORT = 1 + """Intended for short single-line text.""" + + PARAGRAPH = 2 + """Intended for much longer inputs.""" + + +@attr.define(kw_only=True, weakref_slot=False) +class PartialComponent: + """Base class for all component entities.""" + + type: typing.Union[ComponentType, int] = attr.field() + """The type of component this is.""" + + +AllowedComponentsT = typing.TypeVar("AllowedComponentsT", bound="PartialComponent") + + +@attr.define(weakref_slot=False) +class ActionRowComponent(typing.Generic[AllowedComponentsT], PartialComponent): + """Represents a row of components.""" + + components: typing.Sequence[AllowedComponentsT] = attr.field() + """Sequence of the components contained within this row.""" + + @typing.overload + def __getitem__(self, index: int, /) -> PartialComponent: + ... + + @typing.overload + def __getitem__(self, slice_: slice, /) -> typing.Sequence[AllowedComponentsT]: + ... + + def __getitem__( + self, index_or_slice: typing.Union[int, slice], / + ) -> typing.Union[PartialComponent, typing.Sequence[AllowedComponentsT]]: + return self.components[index_or_slice] + + def __iter__(self) -> typing.Iterator[AllowedComponentsT]: + return iter(self.components) + + def __len__(self) -> int: + return len(self.components) + + +@attr.define(hash=True, kw_only=True, weakref_slot=False) +class ButtonComponent(PartialComponent): + """Represents a button component.""" + + style: typing.Union[ButtonStyle, int] = attr.field(eq=False) + """The button's style.""" + + label: typing.Optional[str] = attr.field(eq=False) + """Text label which appears on the button.""" + + emoji: typing.Optional[emojis.Emoji] = attr.field(eq=False) + """Custom or unicode emoji which appears on the button.""" + + custom_id: typing.Optional[str] = attr.field(hash=True) + """Developer defined identifier for this button (will be <= 100 characters). + + !!! note + This is required for the following button styles: + + * `ButtonStyle.PRIMARY` + * `ButtonStyle.SECONDARY` + * `ButtonStyle.SUCCESS` + * `ButtonStyle.DANGER` + """ + + url: typing.Optional[str] = attr.field(eq=False) + """Url for `ButtonStyle.LINK` style buttons.""" + + is_disabled: bool = attr.field(eq=False) + """Whether the button is disabled.""" + + +@attr.define(kw_only=True, weakref_slot=False) +class SelectMenuOption: + """Represents an option for a `SelectMenuComponent`.""" + + label: str = attr.field() + """User-facing name of the option, max 100 characters.""" + + value: str = attr.field() + """Dev-defined value of the option, max 100 characters.""" + + description: typing.Optional[str] = attr.field() + """Optional description of the option, max 100 characters.""" + + emoji: typing.Optional[emojis.Emoji] = attr.field(eq=False) + """Custom or unicode emoji which appears on the button.""" + + is_default: bool = attr.field() + """Whether this option will be selected by default.""" + + +@attr.define(hash=True, kw_only=True, weakref_slot=False) +class SelectMenuComponent(PartialComponent): + """Represents a select menu component.""" + + custom_id: str = attr.field(hash=True) + """Developer defined identifier for this menu (will be <= 100 characters).""" + + options: typing.Sequence[SelectMenuOption] = attr.field(eq=False) + """Sequence of up to 25 of the options set for this menu.""" + + placeholder: typing.Optional[str] = attr.field(eq=False) + """Custom placeholder text shown if nothing is selected, max 100 characters.""" + + min_values: int = attr.field(eq=False) + """The minimum amount of options which must be chosen for this menu. + + This will be greater than or equal to 0 and will be less than or equal to + `SelectMenuComponent.max_values`. + """ + + max_values: int = attr.field(eq=False) + """The minimum amount of options which can be chosen for this menu. + + This will be less than or equal to 25 and will be greater than or equal to + `SelectMenuComponent.min_values`. + """ + + is_disabled: bool = attr.field(eq=False) + """Whether the select menu is disabled.""" + + +@attr.define(kw_only=True, weakref_slot=False) +class TextInputComponent(PartialComponent): + """Represents a text input component.""" + + custom_id: str = attr.field(repr=True) + """Developer set custom ID used for identifying interactions with this modal.""" + + value: str = attr.field(repr=True) + """Value provided for this text input.""" + + +InteractiveButtonTypesT = typing.Union[ + typing.Literal[ButtonStyle.PRIMARY], + typing.Literal[1], + typing.Literal[ButtonStyle.SECONDARY], + typing.Literal[2], + typing.Literal[ButtonStyle.SUCCESS], + typing.Literal[3], + typing.Literal[ButtonStyle.DANGER], + typing.Literal[4], +] +"""Type hints of the `ButtonStyle` values which are valid for interactive buttons. + +The following values are valid for this: + +* `ButtonStyle.PRIMARY`/`1` +* `ButtonStyle.SECONDARY`/`2` +* `ButtonStyle.SUCCESS`/`3` +* `ButtonStyle.DANGER`/`4` +""" + +InteractiveButtonTypes: typing.AbstractSet[InteractiveButtonTypesT] = frozenset( + [ButtonStyle.PRIMARY, ButtonStyle.SECONDARY, ButtonStyle.SUCCESS, ButtonStyle.DANGER] +) +"""Set of the `ButtonType`s which are valid for interactive buttons. + +The following values are included in this: + +* `ButtonStyle.PRIMARY` +* `ButtonStyle.SECONDARY` +* `ButtonStyle.SUCCESS` +* `ButtonStyle.DANGER` +""" + +MessageComponentTypesT = typing.Union[ButtonComponent, SelectMenuComponent] +"""Type hint of the `PartialComponent`s that be contained in a `MessageActionRowComponent`. + +The following values are valid for this: + +* `ButtonComponent` +* `SelectMenuComponent` +""" +ModalComponentTypesT = TextInputComponent +"""Type hint of the `PartialComponent`s that be contained in a `ModalActionRowComponent`. + +The following values are valid for this: + +* `TextInputComponent` +""" + +MessageActionRowComponent = ActionRowComponent[MessageComponentTypesT] +"""A message action row component.""" +ModalActionRowComponent = ActionRowComponent[ModalComponentTypesT] +"""A modal action row component.""" diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index cb36a2fa4c..043c899a6b 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -37,6 +37,7 @@ from hikari import channels as channel_models from hikari import colors as color_models from hikari import commands +from hikari import components as component_models from hikari import embeds as embed_models from hikari import emojis as emoji_models from hikari import errors @@ -61,6 +62,7 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions +from hikari.interactions import modal_interactions from hikari.internal import attr_extensions from hikari.internal import data_binding from hikari.internal import time @@ -429,7 +431,8 @@ class EntityFactoryImpl(entity_factory.EntityFactory): "_audit_log_entry_converters", "_audit_log_event_mapping", "_command_mapping", - "_component_type_mapping", + "_message_component_type_mapping", + "_modal_component_type_mapping", "_dm_channel_type_mapping", "_guild_channel_type_mapping", "_thread_channel_type_mapping", @@ -500,10 +503,16 @@ def __init__(self, app: traits.RESTAware) -> None: commands.CommandType.USER: self.deserialize_context_menu_command, commands.CommandType.MESSAGE: self.deserialize_context_menu_command, } - self._component_type_mapping = { - message_models.ComponentType.ACTION_ROW: self._deserialize_action_row, - message_models.ComponentType.BUTTON: self._deserialize_button, - message_models.ComponentType.SELECT_MENU: self._deserialize_select_menu, + self._message_component_type_mapping: typing.Dict[ + int, typing.Callable[[data_binding.JSONObject], component_models.MessageComponentTypesT] + ] = { + component_models.ComponentType.BUTTON: self._deserialize_button, + component_models.ComponentType.SELECT_MENU: self._deserialize_select_menu, + } + self._modal_component_type_mapping: typing.Dict[ + int, typing.Callable[[data_binding.JSONObject], component_models.ModalComponentTypesT] + ] = { + component_models.ComponentType.TEXT_INPUT: self._deserialize_text_input, } self._dm_channel_type_mapping = { channel_models.ChannelType.DM: self.deserialize_dm, @@ -527,6 +536,7 @@ def __init__(self, app: traits.RESTAware) -> None: base_interactions.InteractionType.APPLICATION_COMMAND: self.deserialize_command_interaction, base_interactions.InteractionType.MESSAGE_COMPONENT: self.deserialize_component_interaction, base_interactions.InteractionType.AUTOCOMPLETE: self.deserialize_autocomplete_interaction, + base_interactions.InteractionType.MODAL_SUBMIT: self.deserialize_modal_interaction, } self._scheduled_event_type_mapping = { scheduled_events_models.ScheduledEventType.STAGE_INSTANCE: self.deserialize_scheduled_stage_event, @@ -2374,7 +2384,7 @@ def deserialize_command_interaction( options=options, resolved=resolved, target_id=target_id, - app_permissions=permission_models.Permissions(app_perms) if app_perms is not None else None, + app_permissions=permission_models.Permissions(app_perms) if app_perms else None, ) def deserialize_autocomplete_interaction( @@ -2418,6 +2428,48 @@ def deserialize_autocomplete_interaction( guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, ) + def deserialize_modal_interaction(self, payload: data_binding.JSONObject) -> modal_interactions.ModalInteraction: + data_payload = payload["data"] + + guild_id: typing.Optional[snowflakes.Snowflake] = None + if raw_guild_id := payload.get("guild_id"): + guild_id = snowflakes.Snowflake(raw_guild_id) + + member: typing.Optional[base_interactions.InteractionMember] + if member_payload := payload.get("member"): + assert guild_id is not None + member = self._deserialize_interaction_member(member_payload, guild_id=guild_id) + # See https://github.com/discord/discord-api-docs/pull/2568 + user = member.user + + else: + member = None + user = self.deserialize_user(payload["user"]) + + message: typing.Optional[message_models.Message] = None + if message_payload := payload.get("message"): + message = self.deserialize_message(message_payload) + + app_perms = payload.get("app_permissions") + return modal_interactions.ModalInteraction( + app=self._app, + application_id=snowflakes.Snowflake(payload["application_id"]), + id=snowflakes.Snowflake(payload["id"]), + type=base_interactions.InteractionType(payload["type"]), + guild_id=guild_id, + app_permissions=permission_models.Permissions(app_perms) if app_perms else None, + guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, + locale=locales.Locale(payload["locale"]), + channel_id=snowflakes.Snowflake(payload["channel_id"]), + member=member, + user=user, + token=payload["token"], + version=payload["version"], + custom_id=data_payload["custom_id"], + components=self._deserialize_components(data_payload["components"], self._modal_component_type_mapping), + message=message, + ) + def deserialize_interaction(self, payload: data_binding.JSONObject) -> base_interactions.PartialInteraction: interaction_type = base_interactions.InteractionType(payload["type"]) @@ -2495,11 +2547,11 @@ def deserialize_component_interaction( values=data_payload.get("values") or (), version=payload["version"], custom_id=data_payload["custom_id"], - component_type=message_models.ComponentType(data_payload["component_type"]), + component_type=component_models.ComponentType(data_payload["component_type"]), message=self.deserialize_message(payload["message"]), locale=locales.Locale(payload["locale"]), guild_locale=locales.Locale(payload["guild_locale"]) if "guild_locale" in payload else None, - app_permissions=permission_models.Permissions(app_perms) if app_perms is not None else None, + app_permissions=permission_models.Permissions(app_perms) if app_perms else None, ) ################## @@ -2551,21 +2603,65 @@ def deserialize_guild_sticker(self, payload: data_binding.JSONObject) -> sticker user=self.deserialize_user(payload["user"]) if "user" in payload else None, ) - ################## - # MESSAGE MODELS # - ################## + #################### + # COMPONENT MODELS # + #################### - def _deserialize_action_row(self, payload: data_binding.JSONObject) -> message_models.ActionRowComponent: - components = data_binding.cast_variants_array(self._deserialize_component, payload["components"]) - return message_models.ActionRowComponent( - type=message_models.ComponentType(payload["type"]), components=components - ) + @typing.overload + def _deserialize_components( + self, + payloads: data_binding.JSONArray, + mapping: typing.Dict[int, typing.Callable[[data_binding.JSONObject], component_models.MessageComponentTypesT]], + ) -> typing.List[component_models.MessageActionRowComponent]: + ... + + @typing.overload + def _deserialize_components( + self, + payloads: data_binding.JSONArray, + mapping: typing.Dict[int, typing.Callable[[data_binding.JSONObject], component_models.ModalComponentTypesT]], + ) -> typing.List[component_models.ModalActionRowComponent]: + ... + + def _deserialize_components( + self, + payloads: data_binding.JSONArray, + mapping: typing.Dict[int, typing.Callable[[data_binding.JSONObject], typing.Any]], + ) -> typing.List[component_models.ActionRowComponent[typing.Any]]: + top_level_components = [] + + for payload in payloads: + top_level_component_type = component_models.ComponentType(payload["type"]) + + if top_level_component_type != component_models.ComponentType.ACTION_ROW: + _LOGGER.debug("Unknown top-level message component type %s", top_level_component_type) + continue - def _deserialize_button(self, payload: data_binding.JSONObject) -> message_models.ButtonComponent: + components = [] + + for component_payload in payload["components"]: + component_type = component_models.ComponentType(component_payload["type"]) + + if (deserializer := mapping.get(component_type)) is None: + _LOGGER.debug("Unknown component type %s", component_type) + continue + + components.append(deserializer(component_payload)) + + if components: + # If we somehow get a top-level component full of unknown components, ignore the top-level + # component all-together + top_level_components.append( + component_models.ActionRowComponent(type=top_level_component_type, components=components) + ) + + return top_level_components + + def _deserialize_button(self, payload: data_binding.JSONObject) -> component_models.ButtonComponent: emoji_payload = payload.get("emoji") - return message_models.ButtonComponent( - type=message_models.ComponentType(payload["type"]), - style=message_models.ButtonStyle(payload["style"]), + return component_models.ButtonComponent( + type=component_models.ComponentType(payload["type"]), + style=component_models.ButtonStyle(payload["style"]), label=payload.get("label"), emoji=self.deserialize_emoji(emoji_payload) if emoji_payload else None, custom_id=payload.get("custom_id"), @@ -2573,15 +2669,15 @@ def _deserialize_button(self, payload: data_binding.JSONObject) -> message_model is_disabled=payload.get("disabled", False), ) - def _deserialize_select_menu(self, payload: data_binding.JSONObject) -> message_models.SelectMenuComponent: - options: typing.List[message_models.SelectMenuOption] = [] + def _deserialize_select_menu(self, payload: data_binding.JSONObject) -> component_models.SelectMenuComponent: + options: typing.List[component_models.SelectMenuOption] = [] for option_payload in payload["options"]: emoji = None if emoji_payload := option_payload.get("emoji"): emoji = self.deserialize_emoji(emoji_payload) options.append( - message_models.SelectMenuOption( + component_models.SelectMenuOption( label=option_payload["label"], value=option_payload["value"], description=option_payload.get("description"), @@ -2590,8 +2686,8 @@ def _deserialize_select_menu(self, payload: data_binding.JSONObject) -> message_ ) ) - return message_models.SelectMenuComponent( - type=message_models.ComponentType(payload["type"]), + return component_models.SelectMenuComponent( + type=component_models.ComponentType(payload["type"]), custom_id=payload["custom_id"], options=options, placeholder=payload.get("placeholder"), @@ -2600,14 +2696,16 @@ def _deserialize_select_menu(self, payload: data_binding.JSONObject) -> message_ is_disabled=payload.get("disabled", False), ) - def _deserialize_component(self, payload: data_binding.JSONObject) -> message_models.PartialComponent: - component_type = message_models.ComponentType(payload["type"]) - - if deserialize := self._component_type_mapping.get(component_type): - return deserialize(payload) + def _deserialize_text_input(self, payload: data_binding.JSONObject) -> component_models.TextInputComponent: + return component_models.TextInputComponent( + type=component_models.ComponentType(payload["type"]), + custom_id=payload["custom_id"], + value=payload["value"], + ) - _LOGGER.debug("Unknown component type %s", component_type) - raise errors.UnrecognisedEntityError(f"Unrecognised component type {component_type}") + ################## + # MESSAGE MODELS # + ################## def _deserialize_message_activity(self, payload: data_binding.JSONObject) -> message_models.MessageActivity: return message_models.MessageActivity( @@ -2745,9 +2843,9 @@ def deserialize_partial_message( # noqa: C901 - Too complex if interaction_payload := payload.get("interaction"): interaction = self._deserialize_message_interaction(interaction_payload) - components: undefined.UndefinedOr[typing.List[message_models.PartialComponent]] = undefined.UNDEFINED + components: undefined.UndefinedOr[typing.List[component_models.MessageActionRowComponent]] = undefined.UNDEFINED if component_payloads := payload.get("components"): - components = data_binding.cast_variants_array(self._deserialize_component, component_payloads) + components = self._deserialize_components(component_payloads, self._message_component_type_mapping) channel_mentions: undefined.UndefinedOr[ typing.Dict[snowflakes.Snowflake, channel_models.PartialChannel] @@ -2847,8 +2945,9 @@ def deserialize_message(self, payload: data_binding.JSONObject) -> message_model if interaction_payload := payload.get("interaction"): interaction = self._deserialize_message_interaction(interaction_payload) + components: typing.List[component_models.MessageActionRowComponent] if component_payloads := payload.get("components"): - components = data_binding.cast_variants_array(self._deserialize_component, component_payloads) + components = self._deserialize_components(component_payloads, self._message_component_type_mapping) else: components = [] diff --git a/hikari/impl/interaction_server.py b/hikari/impl/interaction_server.py index 60d1e6a0d5..f0e4708c98 100644 --- a/hikari/impl/interaction_server.py +++ b/hikari/impl/interaction_server.py @@ -56,12 +56,17 @@ from hikari.api import rest as rest_api from hikari.interactions import command_interactions from hikari.interactions import component_interactions + from hikari.interactions import modal_interactions _InteractionT_co = typing.TypeVar("_InteractionT_co", bound=base_interactions.PartialInteraction, covariant=True) _MessageResponseBuilderT = typing.Union[ special_endpoints.InteractionDeferredBuilder, special_endpoints.InteractionMessageBuilder, ] + _ModalOrMessageResponseBuilderT = typing.Union[ + _MessageResponseBuilderT, + special_endpoints.InteractionModalBuilder, + ] _LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.interaction_server") @@ -567,7 +572,7 @@ async def start( def get_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], / ) -> typing.Optional[ - interaction_server.ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT] + interaction_server.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ]: ... @@ -575,7 +580,7 @@ def get_listener( def get_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], / ) -> typing.Optional[ - interaction_server.ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT] + interaction_server.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ]: ... @@ -589,6 +594,12 @@ def get_listener( ]: ... + @typing.overload + def get_listener( + self, interaction_type: typing.Type[modal_interactions.ModalInteraction], / + ) -> typing.Optional[interaction_server.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT]]: + ... + @typing.overload def get_listener( self, interaction_type: typing.Type[_InteractionT_co], / @@ -605,7 +616,7 @@ def set_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], listener: typing.Optional[ - interaction_server.ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT] + interaction_server.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ], /, *, @@ -618,7 +629,7 @@ def set_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], listener: typing.Optional[ - interaction_server.ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT] + interaction_server.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ], /, *, @@ -641,6 +652,19 @@ def set_listener( ) -> None: ... + @typing.overload + def set_listener( + self, + interaction_type: typing.Type[modal_interactions.ModalInteraction], + listener: typing.Optional[ + interaction_server.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT] + ], + /, + *, + replace: bool = False, + ) -> None: + ... + def set_listener( self, interaction_type: typing.Type[_InteractionT_co], diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 059e1b11a6..074b5c8233 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -75,6 +75,7 @@ from hikari.impl import special_endpoints as special_endpoints_impl from hikari.interactions import base_interactions from hikari.internal import data_binding +from hikari.internal import deprecation from hikari.internal import mentions from hikari.internal import net from hikari.internal import routes @@ -3645,6 +3646,9 @@ def interaction_message_builder( ) -> special_endpoints.InteractionMessageBuilder: return special_endpoints_impl.InteractionMessageBuilder(type=type_) + def interaction_modal_builder(self, title: str, custom_id: str) -> special_endpoints.InteractionModalBuilder: + return special_endpoints_impl.InteractionModalBuilder(title=title, custom_id=custom_id) + async def fetch_interaction_response( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], token: str ) -> messages_.Message: @@ -3778,8 +3782,57 @@ async def create_autocomplete_response( body.put("data", data) await self._request(route, json=body, no_auth=True) - def build_action_row(self) -> special_endpoints.ActionRowBuilder: - return special_endpoints_impl.ActionRowBuilder() + async def create_modal_response( + self, + interaction: snowflakes.SnowflakeishOr[base_interactions.PartialInteraction], + token: str, + *, + title: str, + custom_id: str, + component: undefined.UndefinedOr[special_endpoints.ComponentBuilder] = undefined.UNDEFINED, + components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, + ) -> None: + if undefined.all_undefined(component, components) or not undefined.any_undefined(component, components): + raise ValueError("Must specify exactly only one of 'component' or 'components'") + + route = routes.POST_INTERACTION_RESPONSE.compile(interaction=interaction, token=token) + + body = data_binding.JSONObjectBuilder() + body.put("type", base_interactions.ResponseType.MODAL) + + data = data_binding.JSONObjectBuilder() + data.put("title", title) + data.put("custom_id", custom_id) + + if component: + components = (component,) + + data.put_array("components", components, conversion=lambda c: c.build()) + + body.put("data", data) + + await self._request(route, json=body, no_auth=True) + + def build_action_row(self) -> special_endpoints.MessageActionRowBuilder: + """Build a message action row message component for use in message create and REST calls. + + Returns + ------- + hikari.api.special_endpoints.MessageActionRowBuilder + The initialised action row builder. + """ + deprecation.warn_deprecated( + "build_action_row", + removal_version="2.0.0.dev115", + additional_info="Use 'build_message_action_row' parameter instead", + ) + return special_endpoints_impl.MessageActionRowBuilder() + + def build_message_action_row(self) -> special_endpoints.MessageActionRowBuilder: + return special_endpoints_impl.MessageActionRowBuilder() + + def build_modal_action_row(self) -> special_endpoints.ModalActionRowBuilder: + return special_endpoints_impl.ModalActionRowBuilder() async def fetch_scheduled_event( self, diff --git a/hikari/impl/rest_bot.py b/hikari/impl/rest_bot.py index abf63330a8..79c8e6566e 100644 --- a/hikari/impl/rest_bot.py +++ b/hikari/impl/rest_bot.py @@ -53,12 +53,17 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions + from hikari.interactions import modal_interactions _InteractionT_co = typing.TypeVar("_InteractionT_co", bound=base_interactions.PartialInteraction, covariant=True) _MessageResponseBuilderT = typing.Union[ special_endpoints.InteractionDeferredBuilder, special_endpoints.InteractionMessageBuilder, ] + _ModalOrMessageResponseBuilderT = typing.Union[ + _MessageResponseBuilderT, + special_endpoints.InteractionModalBuilder, + ] _LOGGER: typing.Final[logging.Logger] = logging.getLogger("hikari.rest_bot") @@ -648,7 +653,7 @@ async def start( def get_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], / ) -> typing.Optional[ - interaction_server_.ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT] + interaction_server_.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ]: ... @@ -656,7 +661,7 @@ def get_listener( def get_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], / ) -> typing.Optional[ - interaction_server_.ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT] + interaction_server_.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ]: ... @@ -670,6 +675,12 @@ def get_listener( ]: ... + @typing.overload + def get_listener( + self, interaction_type: typing.Type[modal_interactions.ModalInteraction], / + ) -> typing.Optional[interaction_server_.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT]]: + ... + @typing.overload def get_listener( self, interaction_type: typing.Type[_InteractionT_co], / @@ -686,7 +697,7 @@ def set_listener( self, interaction_type: typing.Type[command_interactions.CommandInteraction], listener: typing.Optional[ - interaction_server_.ListenerT[command_interactions.CommandInteraction, _MessageResponseBuilderT] + interaction_server_.ListenerT[command_interactions.CommandInteraction, _ModalOrMessageResponseBuilderT] ], /, *, @@ -699,7 +710,7 @@ def set_listener( self, interaction_type: typing.Type[component_interactions.ComponentInteraction], listener: typing.Optional[ - interaction_server_.ListenerT[component_interactions.ComponentInteraction, _MessageResponseBuilderT] + interaction_server_.ListenerT[component_interactions.ComponentInteraction, _ModalOrMessageResponseBuilderT] ], /, *, @@ -722,6 +733,19 @@ def set_listener( ) -> None: ... + @typing.overload + def set_listener( + self, + interaction_type: typing.Type[modal_interactions.ModalInteraction], + listener: typing.Optional[ + interaction_server_.ListenerT[modal_interactions.ModalInteraction, _MessageResponseBuilderT] + ], + /, + *, + replace: bool = False, + ) -> None: + ... + def set_listener( self, interaction_type: typing.Type[_InteractionT_co], diff --git a/hikari/impl/special_endpoints.py b/hikari/impl/special_endpoints.py index 0bc9ca8b7d..f499021f5c 100644 --- a/hikari/impl/special_endpoints.py +++ b/hikari/impl/special_endpoints.py @@ -27,7 +27,6 @@ from __future__ import annotations __all__: typing.Sequence[str] = ( - "ActionRowBuilder", "CommandBuilder", "SlashCommandBuilder", "ContextMenuCommandBuilder", @@ -39,6 +38,10 @@ "InteractiveButtonBuilder", "LinkButtonBuilder", "SelectMenuBuilder", + "TextInputBuilder", + "InteractionModalBuilder", + "MessageActionRowBuilder", + "ModalActionRowBuilder", ) import asyncio @@ -48,6 +51,7 @@ from hikari import channels from hikari import commands +from hikari import components as component_models from hikari import emojis from hikari import errors from hikari import files @@ -88,10 +92,13 @@ _InteractionAutocompleteBuilderT = typing.TypeVar( "_InteractionAutocompleteBuilderT", bound="InteractionAutocompleteBuilder" ) - _ActionRowBuilderT = typing.TypeVar("_ActionRowBuilderT", bound="ActionRowBuilder") + _InteractionModalBuilderT = typing.TypeVar("_InteractionModalBuilderT", bound="InteractionModalBuilder") + _MessageActionRowBuilderT = typing.TypeVar("_MessageActionRowBuilderT", bound="MessageActionRowBuilder") + _ModalActionRowBuilderT = typing.TypeVar("_ModalActionRowBuilderT", bound="ModalActionRowBuilder") _ButtonBuilderT = typing.TypeVar("_ButtonBuilderT", bound="_ButtonBuilder[typing.Any]") _SelectOptionBuilderT = typing.TypeVar("_SelectOptionBuilderT", bound="_SelectOptionBuilder[typing.Any]") _SelectMenuBuilderT = typing.TypeVar("_SelectMenuBuilderT", bound="SelectMenuBuilder[typing.Any]") + _TextInputBuilderT = typing.TypeVar("_TextInputBuilderT", bound="TextInputBuilder[typing.Any]") class _RequestCallSig(typing.Protocol): async def __call__( @@ -1185,6 +1192,55 @@ def build( return {"type": self._type, "data": data}, final_attachments +@attr.define(kw_only=False, weakref_slot=False) +class InteractionModalBuilder(special_endpoints.InteractionModalBuilder): + """Standard implementation of `hikari.api.special_endpoints.InteractionModalBuilder`.""" + + _title: str = attr.field() + _custom_id: str = attr.field() + _components: typing.List[special_endpoints.ComponentBuilder] = attr.field(factory=list) + + @property + def type(self) -> typing.Literal[base_interactions.ResponseType.MODAL]: + return base_interactions.ResponseType.MODAL + + @property + def title(self) -> str: + return self._title + + @property + def custom_id(self) -> str: + return self._custom_id + + @property + def components(self) -> typing.Sequence[special_endpoints.ComponentBuilder]: + return self._components + + def set_title(self: _InteractionModalBuilderT, title: str, /) -> _InteractionModalBuilderT: + self._title = title + return self + + def set_custom_id(self: _InteractionModalBuilderT, custom_id: str, /) -> _InteractionModalBuilderT: + self._custom_id = custom_id + return self + + def add_component( + self: _InteractionModalBuilderT, component: special_endpoints.ComponentBuilder, / + ) -> _InteractionModalBuilderT: + self._components.append(component) + return self + + def build( + self, entity_factory: entity_factory_.EntityFactory, / + ) -> typing.Tuple[typing.MutableMapping[str, typing.Any], typing.Sequence[files.Resource[files.AsyncReader]]]: + data = data_binding.JSONObjectBuilder() + data.put("title", self._title) + data.put("custom_id", self._custom_id) + data.put_array("components", self._components, conversion=lambda component: component.build()) + + return {"type": self.type, "data": data}, () + + @attr.define(kw_only=False, weakref_slot=False) class CommandBuilder(special_endpoints.CommandBuilder): """Standard implementation of `hikari.api.special_endpoints.CommandBuilder`.""" @@ -1406,7 +1462,7 @@ def _build_emoji( @attr.define(kw_only=True, weakref_slot=False) class _ButtonBuilder(special_endpoints.ButtonBuilder[_ContainerProtoT]): _container: _ContainerProtoT = attr.field() - _style: typing.Union[int, messages.ButtonStyle] = attr.field() + _style: typing.Union[int, component_models.ButtonStyle] = attr.field() _custom_id: undefined.UndefinedOr[str] = attr.field(default=undefined.UNDEFINED) _url: undefined.UndefinedOr[str] = attr.field(default=undefined.UNDEFINED) _emoji: typing.Union[snowflakes.Snowflakeish, emojis.Emoji, str, undefined.UndefinedType] = attr.field( @@ -1418,7 +1474,7 @@ class _ButtonBuilder(special_endpoints.ButtonBuilder[_ContainerProtoT]): _is_disabled: bool = attr.field(default=False) @property - def style(self) -> typing.Union[int, messages.ButtonStyle]: + def style(self) -> typing.Union[int, component_models.ButtonStyle]: return self._style @property @@ -1457,7 +1513,7 @@ def add_to_container(self) -> _ContainerProtoT: def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() - data["type"] = messages.ComponentType.BUTTON + data["type"] = component_models.ComponentType.BUTTON data["style"] = self._style data["disabled"] = self._is_disabled data.put("label", self._label) @@ -1646,7 +1702,7 @@ def add_to_container(self) -> _ContainerProtoT: def build(self) -> typing.MutableMapping[str, typing.Any]: data = data_binding.JSONObjectBuilder() - data["type"] = messages.ComponentType.SELECT_MENU + data["type"] = component_models.ComponentType.SELECT_MENU data["custom_id"] = self._custom_id data["options"] = [option.build() for option in self._options] data.put("placeholder", self._placeholder) @@ -1656,18 +1712,120 @@ def build(self) -> typing.MutableMapping[str, typing.Any]: return data +@attr_extensions.with_copy +@attr.define(kw_only=True, weakref_slot=False) +class TextInputBuilder(special_endpoints.TextInputBuilder[_ContainerProtoT]): + """Standard implementation of `hikari.api.special_endpoints.TextInputBuilder`.""" + + _container: _ContainerProtoT = attr.field() + _custom_id: str = attr.field() + _label: str = attr.field() + + _style: component_models.TextInputStyle = attr.field(default=component_models.TextInputStyle.SHORT) + _placeholder: undefined.UndefinedOr[str] = attr.field(default=undefined.UNDEFINED, kw_only=True) + _value: undefined.UndefinedOr[str] = attr.field(default=undefined.UNDEFINED, kw_only=True) + _required: undefined.UndefinedOr[bool] = attr.field(default=undefined.UNDEFINED, kw_only=True) + _min_length: undefined.UndefinedOr[int] = attr.field(default=undefined.UNDEFINED, kw_only=True) + _max_length: undefined.UndefinedOr[int] = attr.field(default=undefined.UNDEFINED, kw_only=True) + + @property + def custom_id(self) -> str: + return self._custom_id + + @property + def label(self) -> str: + return self._label + + @property + def style(self) -> component_models.TextInputStyle: + return self._style + + @property + def placeholder(self) -> undefined.UndefinedOr[str]: + return self._placeholder + + @property + def value(self) -> undefined.UndefinedOr[str]: + return self._value + + @property + def required(self) -> undefined.UndefinedOr[bool]: + return self._required + + @property + def min_length(self) -> undefined.UndefinedOr[int]: + return self._min_length + + @property + def max_length(self) -> undefined.UndefinedOr[int]: + return self._max_length + + def set_style( + self: _TextInputBuilderT, style: typing.Union[component_models.TextInputStyle, int], / + ) -> _TextInputBuilderT: + self._style = component_models.TextInputStyle(style) + return self + + def set_custom_id(self: _TextInputBuilderT, custom_id: str, /) -> _TextInputBuilderT: + self._custom_id = custom_id + return self + + def set_label(self: _TextInputBuilderT, label: str, /) -> _TextInputBuilderT: + self._label = label + return self + + def set_placeholder(self: _TextInputBuilderT, placeholder: str, /) -> _TextInputBuilderT: + self._placeholder = placeholder + return self + + def set_value(self: _TextInputBuilderT, value: str, /) -> _TextInputBuilderT: + self._value = value + return self + + def set_required(self: _TextInputBuilderT, required: bool, /) -> _TextInputBuilderT: + self._required = required + return self + + def set_min_length(self: _TextInputBuilderT, min_length: int, /) -> _TextInputBuilderT: + self._min_length = min_length + return self + + def set_max_length(self: _TextInputBuilderT, max_length: int, /) -> _TextInputBuilderT: + self._max_length = max_length + return self + + def add_to_container(self) -> _ContainerProtoT: + self._container.add_component(self) + return self._container + + def build(self) -> typing.MutableMapping[str, typing.Any]: + data = data_binding.JSONObjectBuilder() + + data["type"] = component_models.ComponentType.TEXT_INPUT + data["style"] = self._style + data["custom_id"] = self._custom_id + data["label"] = self._label + data.put("placeholder", self._placeholder) + data.put("value", self._value) + data.put("required", self._required) + data.put("min_length", self._min_length) + data.put("max_length", self._max_length) + + return data + + @attr.define(kw_only=True, weakref_slot=False) -class ActionRowBuilder(special_endpoints.ActionRowBuilder): +class MessageActionRowBuilder(special_endpoints.MessageActionRowBuilder): """Standard implementation of `hikari.api.special_endpoints.ActionRowBuilder`.""" _components: typing.List[special_endpoints.ComponentBuilder] = attr.field(factory=list) - _stored_type: typing.Optional[messages.ComponentType] = attr.field(default=None) + _stored_type: typing.Optional[component_models.ComponentType] = attr.field(default=None) @property def components(self) -> typing.Sequence[special_endpoints.ComponentBuilder]: return self._components.copy() - def _assert_can_add_type(self, type_: messages.ComponentType, /) -> None: + def _assert_can_add_type(self, type_: component_models.ComponentType, /) -> None: if self._stored_type is not None and self._stored_type != type_: raise ValueError( f"{type_} component type cannot be added to a container which already holds {self._stored_type}" @@ -1675,54 +1833,102 @@ def _assert_can_add_type(self, type_: messages.ComponentType, /) -> None: self._stored_type = type_ - def add_component(self: _ActionRowBuilderT, component: special_endpoints.ComponentBuilder, /) -> _ActionRowBuilderT: + def add_component( + self: _MessageActionRowBuilderT, component: special_endpoints.ComponentBuilder, / + ) -> _MessageActionRowBuilderT: self._components.append(component) return self @typing.overload def add_button( - self: _ActionRowBuilderT, style: messages.InteractiveButtonTypesT, custom_id: str, / - ) -> special_endpoints.InteractiveButtonBuilder[_ActionRowBuilderT]: + self: _MessageActionRowBuilderT, style: component_models.InteractiveButtonTypesT, custom_id: str, / + ) -> special_endpoints.InteractiveButtonBuilder[_MessageActionRowBuilderT]: ... @typing.overload def add_button( - self: _ActionRowBuilderT, - style: typing.Literal[messages.ButtonStyle.LINK, 5], + self: _MessageActionRowBuilderT, + style: typing.Literal[component_models.ButtonStyle.LINK, 5], url: str, /, - ) -> special_endpoints.LinkButtonBuilder[_ActionRowBuilderT]: + ) -> special_endpoints.LinkButtonBuilder[_MessageActionRowBuilderT]: ... @typing.overload def add_button( - self: _ActionRowBuilderT, style: typing.Union[int, messages.ButtonStyle], url_or_custom_id: str, / + self: _MessageActionRowBuilderT, + style: typing.Union[int, component_models.ButtonStyle], + url_or_custom_id: str, + /, ) -> typing.Union[ - special_endpoints.LinkButtonBuilder[_ActionRowBuilderT], - special_endpoints.InteractiveButtonBuilder[_ActionRowBuilderT], + special_endpoints.LinkButtonBuilder[_MessageActionRowBuilderT], + special_endpoints.InteractiveButtonBuilder[_MessageActionRowBuilderT], ]: ... def add_button( - self: _ActionRowBuilderT, style: typing.Union[int, messages.ButtonStyle], url_or_custom_id: str, / + self: _MessageActionRowBuilderT, + style: typing.Union[int, component_models.ButtonStyle], + url_or_custom_id: str, + /, ) -> typing.Union[ - special_endpoints.LinkButtonBuilder[_ActionRowBuilderT], - special_endpoints.InteractiveButtonBuilder[_ActionRowBuilderT], + special_endpoints.LinkButtonBuilder[_MessageActionRowBuilderT], + special_endpoints.InteractiveButtonBuilder[_MessageActionRowBuilderT], ]: - self._assert_can_add_type(messages.ComponentType.BUTTON) - if style in messages.InteractiveButtonTypes: + self._assert_can_add_type(component_models.ComponentType.BUTTON) + if style in component_models.InteractiveButtonTypes: return InteractiveButtonBuilder(container=self, style=style, custom_id=url_or_custom_id) return LinkButtonBuilder(container=self, style=style, url=url_or_custom_id) def add_select_menu( - self: _ActionRowBuilderT, custom_id: str, / - ) -> special_endpoints.SelectMenuBuilder[_ActionRowBuilderT]: - self._assert_can_add_type(messages.ComponentType.SELECT_MENU) + self: _MessageActionRowBuilderT, custom_id: str, / + ) -> special_endpoints.SelectMenuBuilder[_MessageActionRowBuilderT]: + self._assert_can_add_type(component_models.ComponentType.SELECT_MENU) return SelectMenuBuilder(container=self, custom_id=custom_id) def build(self) -> typing.MutableMapping[str, typing.Any]: return { - "type": messages.ComponentType.ACTION_ROW, + "type": component_models.ComponentType.ACTION_ROW, + "components": [component.build() for component in self._components], + } + + +@attr.define(kw_only=True, weakref_slot=False) +class ModalActionRowBuilder(special_endpoints.ModalActionRowBuilder): + """Standard implementation of `hikari.api.special_endpoints.ActionRowBuilder`.""" + + _components: typing.List[special_endpoints.ComponentBuilder] = attr.field(factory=list) + _stored_type: typing.Optional[component_models.ComponentType] = attr.field(default=None) + + @property + def components(self) -> typing.Sequence[special_endpoints.ComponentBuilder]: + return self._components.copy() + + def _assert_can_add_type(self, type_: component_models.ComponentType, /) -> None: + if self._stored_type is not None and self._stored_type != type_: + raise ValueError( + f"{type_} component type cannot be added to a container which already holds {self._stored_type}" + ) + + self._stored_type = type_ + + def add_component( + self: _ModalActionRowBuilderT, component: special_endpoints.ComponentBuilder, / + ) -> _ModalActionRowBuilderT: + self._components.append(component) + return self + + def add_text_input( + self: _ModalActionRowBuilderT, + custom_id: str, + label: str, + ) -> special_endpoints.TextInputBuilder[_ModalActionRowBuilderT]: + self._assert_can_add_type(component_models.ComponentType.TEXT_INPUT) + return TextInputBuilder(container=self, custom_id=custom_id, label=label) + + def build(self) -> typing.MutableMapping[str, typing.Any]: + return { + "type": component_models.ComponentType.ACTION_ROW, "components": [component.build() for component in self._components], } diff --git a/hikari/interactions/__init__.py b/hikari/interactions/__init__.py index 2896af1bc4..214d53a491 100644 --- a/hikari/interactions/__init__.py +++ b/hikari/interactions/__init__.py @@ -26,3 +26,4 @@ from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * from hikari.interactions.component_interactions import * +from hikari.interactions.modal_interactions import * diff --git a/hikari/interactions/__init__.pyi b/hikari/interactions/__init__.pyi index 51d1630723..0bba436d86 100644 --- a/hikari/interactions/__init__.pyi +++ b/hikari/interactions/__init__.pyi @@ -4,3 +4,4 @@ from hikari.interactions.base_interactions import * from hikari.interactions.command_interactions import * from hikari.interactions.component_interactions import * +from hikari.interactions.modal_interactions import * diff --git a/hikari/interactions/base_interactions.py b/hikari/interactions/base_interactions.py index a7aadfe30b..5fa2c37e6e 100644 --- a/hikari/interactions/base_interactions.py +++ b/hikari/interactions/base_interactions.py @@ -32,6 +32,7 @@ "MESSAGE_RESPONSE_TYPES", "MessageResponseTypesT", "PartialInteraction", + "ModalResponseMixin", "ResponseType", ) @@ -74,6 +75,9 @@ class InteractionType(int, enums.Enum): AUTOCOMPLETE = 4 """An interaction triggered by a user typing in a slash command option.""" + MODAL_SUBMIT = 5 + """An interaction triggered by a user submitting a modal.""" + @typing.final class ResponseType(int, enums.Enum): @@ -126,6 +130,14 @@ class ResponseType(int, enums.Enum): * `InteractionType.AUTOCOMPLETE` """ + MODAL = 9 + """An immediate interaction response with instructions to display a modal. + + This is valid for the following interaction types: + + * `InteractionType.MODAL_SUBMIT` + """ + MESSAGE_RESPONSE_TYPES: typing.Final[typing.AbstractSet[MessageResponseTypesT]] = frozenset( [ResponseType.MESSAGE_CREATE, ResponseType.MESSAGE_UPDATE] @@ -563,6 +575,66 @@ async def delete_initial_response(self) -> None: await self.app.rest.delete_interaction_response(self.application_id, self.token) +class ModalResponseMixin(PartialInteraction): + """Mixin' class for all interaction types which can be responded to with a modal.""" + + __slots__: typing.Sequence[str] = () + + async def create_modal_response( + self, + title: str, + custom_id: str, + component: undefined.UndefinedOr[special_endpoints.ComponentBuilder] = undefined.UNDEFINED, + components: undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] = undefined.UNDEFINED, + ) -> None: + """Create a response by sending a modal. + + Parameters + ---------- + title : str + The title that will show up in the modal. + custom_id : str + Developer set custom ID used for identifying interactions with this modal. + + Other Parameters + ---------------- + component : hikari.undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] + A component builders to send in this modal. + components : hikari.undefined.UndefinedOr[typing.Sequence[special_endpoints.ComponentBuilder]] + A sequence of component builders to send in this modal. + + Raises + ------ + ValueError + If both `component` and `components` are specified or if none are specified. + """ + await self.app.rest.create_modal_response( + self.id, + self.token, + title=title, + custom_id=custom_id, + component=component, + components=components, + ) + + def build_modal_response(self, title: str, custom_id: str) -> special_endpoints.InteractionModalBuilder: + """Create a builder for a modal interaction response. + + Parameters + ---------- + title : builtins.str + The title that will show up in the modal. + custom_id : builtins.str + Developer set custom ID used for identifying interactions with this modal. + + Returns + ------- + hikari.api.special_endpoints.InteractionModalBuilder + The interaction modal response builder object. + """ + return self.app.rest.interaction_modal_builder(title=title, custom_id=custom_id) + + @attr.define(hash=True, kw_only=True, weakref_slot=False) class InteractionMember(guilds.Member): """Model of the member who triggered an interaction. diff --git a/hikari/interactions/command_interactions.py b/hikari/interactions/command_interactions.py index 80f7db8352..dd713e6785 100644 --- a/hikari/interactions/command_interactions.py +++ b/hikari/interactions/command_interactions.py @@ -361,7 +361,11 @@ def get_guild(self) -> typing.Optional[guilds.GatewayGuild]: @attr_extensions.with_copy @attr.define(hash=True, kw_only=True, weakref_slot=False) -class CommandInteraction(BaseCommandInteraction, base_interactions.MessageResponseMixin[CommandResponseTypesT]): +class CommandInteraction( + BaseCommandInteraction, + base_interactions.MessageResponseMixin[CommandResponseTypesT], + base_interactions.ModalResponseMixin, +): """Represents a command interaction on Discord.""" app_permissions: typing.Optional[permissions_.Permissions] = attr.field(eq=False, hash=False, repr=False) diff --git a/hikari/interactions/component_interactions.py b/hikari/interactions/component_interactions.py index 656668dd2d..d85d18fd14 100644 --- a/hikari/interactions/component_interactions.py +++ b/hikari/interactions/component_interactions.py @@ -34,6 +34,7 @@ from hikari.interactions import base_interactions if typing.TYPE_CHECKING: + from hikari import components as components_ from hikari import guilds from hikari import locales from hikari import messages @@ -83,13 +84,16 @@ @attr.define(hash=True, weakref_slot=False) -class ComponentInteraction(base_interactions.MessageResponseMixin[ComponentResponseTypesT]): +class ComponentInteraction( + base_interactions.MessageResponseMixin[ComponentResponseTypesT], + base_interactions.ModalResponseMixin, +): """Represents a component interaction on Discord.""" channel_id: snowflakes.Snowflake = attr.field(eq=False) """ID of the channel this interaction was triggered in.""" - component_type: typing.Union[messages.ComponentType, int] = attr.field(eq=False) + component_type: typing.Union[components_.ComponentType, int] = attr.field(eq=False) """The type of component which triggers this interaction. !!! note @@ -319,48 +323,3 @@ def get_guild(self) -> typing.Optional[guilds.GatewayGuild]: return self.app.cache.get_guild(self.guild_id) return None - - async def fetch_parent_message(self) -> messages.Message: - """Fetch the message which this interaction was triggered on. - - Returns - ------- - hikari.messages.Message - The requested message. - - Raises - ------ - builtins.ValueError - If `token` is not available. - hikari.errors.UnauthorizedError - If you are unauthorized to make the request (invalid/missing token). - hikari.errors.NotFoundError - If the webhook is not found or the webhook's message wasn't found. - hikari.errors.RateLimitTooLongError - Raised in the event that a rate limit occurs that is - longer than `max_rate_limit` when making a request. - hikari.errors.RateLimitedError - Usually, Hikari will handle and retry on hitting - rate-limits automatically. This includes most bucket-specific - rate-limits and global rate-limits. In some rare edge cases, - however, Discord implements other undocumented rules for - rate-limiting, such as limits per attribute. These cannot be - detected or handled normally by Hikari due to their undocumented - nature, and will trigger this exception if they occur. - hikari.errors.InternalServerError - If an internal error occurs on Discord while handling the request. - """ - return await self.fetch_message(self.message.id) - - def get_parent_message(self) -> typing.Optional[messages.PartialMessage]: - """Get the message which this interaction was triggered on from the cache. - - Returns - ------- - typing.Optional[hikari.messages.Message] - The object of the message found in the cache or `builtins.None`. - """ - if isinstance(self.app, traits.CacheAware): - return self.app.cache.get_message(self.message.id) - - return None diff --git a/hikari/interactions/modal_interactions.py b/hikari/interactions/modal_interactions.py new file mode 100644 index 0000000000..54735b31de --- /dev/null +++ b/hikari/interactions/modal_interactions.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Models and enums used for Discord's Modals interaction flow.""" + +from __future__ import annotations + +__all__: typing.List[str] = [ + "ModalResponseTypesT", + "ModalInteraction", + "ModalInteraction", +] + +import typing + +import attr + +from hikari import channels +from hikari import guilds +from hikari import messages +from hikari import permissions +from hikari import snowflakes +from hikari import traits +from hikari.interactions import base_interactions +from hikari.internal import attr_extensions + +if typing.TYPE_CHECKING: + from hikari import components as components_ + from hikari import users as _users + from hikari.api import special_endpoints + +ModalResponseTypesT = typing.Literal[ + base_interactions.ResponseType.MESSAGE_CREATE, + 4, + base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE, + 5, + base_interactions.ResponseType.MESSAGE_UPDATE, + 7, + base_interactions.ResponseType.DEFERRED_MESSAGE_UPDATE, + 6, +] +"""Type-hint of the response types which are valid for a modal interaction. + +The following types are valid for this: + +* `hikari.interactions.base_interactions.ResponseType.MESSAGE_CREATE`/`4` +* `hikari.interactions.base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE`/`5` +* `hikari.interactions.base_interactions.ResponseType.MESSAGE_UPDATE`/`7` +* `hikari.interactions.base_interactions.ResponseType.DEFERRED_MESSAGE_UPDATE`/`6` +""" + + +@attr_extensions.with_copy +@attr.define(hash=True, kw_only=True, weakref_slot=False) +class ModalInteraction(base_interactions.MessageResponseMixin[ModalResponseTypesT]): + """Represents a modal interaction on Discord.""" + + channel_id: snowflakes.Snowflake = attr.field(eq=False, hash=False, repr=True) + """ID of the channel this modal interaction event was triggered in.""" + + custom_id: str = attr.field(eq=False, hash=False, repr=True) + """The custom id of the modal.""" + + guild_id: typing.Optional[snowflakes.Snowflake] = attr.field(eq=False, hash=False, repr=True) + """ID of the guild this modal interaction event was triggered in. + + This will be `builtins.None` for modal interactions triggered in DMs. + """ + + guild_locale: typing.Optional[str] = attr.field(eq=False, hash=False, repr=True) + """The preferred language of the guild this modal interaction was triggered in. + + This will be `builtins.None` for modal interactions triggered in DMs. + + !!! note + This value can usually only be changed if `COMMUNITY` is in `hikari.guilds.Guild.features` + for the guild and will otherwise default to `en-US`. + """ + + message: typing.Optional[messages.Message] = attr.field(eq=False, repr=False) + """The message whose component triggered the modal. + + This will be None if the modal was a response to a command. + """ + + member: typing.Optional[base_interactions.InteractionMember] = attr.field(eq=False, hash=False, repr=True) + """The member who triggered this modal interaction. + + This will be `builtins.None` for modal interactions triggered in DMs. + + !!! note + This member object comes with the extra field `permissions` which + contains the member's permissions in the current channel. + """ + + user: _users.User = attr.field(eq=False, hash=False, repr=True) + """The user who triggered this modal interaction.""" + + locale: str = attr.field(eq=False, hash=False, repr=True) + """The selected language of the user who triggered this modal interaction.""" + + app_permissions: typing.Optional[permissions.Permissions] = attr.field(eq=False, hash=False, repr=False) + """Permissions the bot has in this interaction's channel if it's in a guild.""" + + components: typing.Sequence[components_.ModalActionRowComponent] = attr.field(eq=False, hash=False, repr=True) + """Components in the modal.""" + + async def fetch_channel(self) -> channels.TextableChannel: + """Fetch the guild channel this interaction was triggered in. + + Returns + ------- + hikari.channels.TextableChannel + The requested partial channel derived object of the channel this + interaction was triggered in. + + Raises + ------ + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.ForbiddenError + If you are missing the `READ_MESSAGES` permission in the channel. + hikari.errors.NotFoundError + If the channel is not found. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.RateLimitedError + Usually, Hikari will handle and retry on hitting + rate-limits automatically. This includes most bucket-specific + rate-limits and global rate-limits. In some rare edge cases, + however, Discord implements other undocumented rules for + rate-limiting, such as limits per attribute. These cannot be + detected or handled normally by Hikari due to their undocumented + nature, and will trigger this exception if they occur. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + channel = await self.app.rest.fetch_channel(self.channel_id) + assert isinstance(channel, channels.TextableChannel) + return channel + + def get_channel(self) -> typing.Optional[channels.TextableGuildChannel]: + """Get the guild channel this interaction was triggered in from the cache. + + !!! note + This will always return `builtins.None` for interactions triggered + in a DM channel. + + Returns + ------- + typing.Optional[hikari.channels.TextableGuildChannel] + The object of the guild channel that was found in the cache or + `builtins.None`. + """ + if isinstance(self.app, traits.CacheAware): + channel = self.app.cache.get_guild_channel(self.channel_id) + assert channel is None or isinstance(channel, channels.TextableGuildChannel) + return channel + + return None + + async def fetch_guild(self) -> typing.Optional[guilds.RESTGuild]: + """Fetch the guild this interaction happened in. + + Returns + ------- + typing.Optional[hikari.guilds.RESTGuild] + Object of the guild this interaction happened in or `builtins.None` + if this occurred within a DM channel. + + Raises + ------ + hikari.errors.ForbiddenError + If you are not part of the guild. + hikari.errors.NotFoundError + If the guild is not found. + hikari.errors.UnauthorizedError + If you are unauthorized to make the request (invalid/missing token). + hikari.errors.RateLimitTooLongError + Raised in the event that a rate limit occurs that is + longer than `max_rate_limit` when making a request. + hikari.errors.RateLimitedError + Usually, Hikari will handle and retry on hitting + rate-limits automatically. This includes most bucket-specific + rate-limits and global rate-limits. In some rare edge cases, + however, Discord implements other undocumented rules for + rate-limiting, such as limits per attribute. These cannot be + detected or handled normally by Hikari due to their undocumented + nature, and will trigger this exception if they occur. + hikari.errors.InternalServerError + If an internal error occurs on Discord while handling the request. + """ + if not self.guild_id: + return None + + return await self.app.rest.fetch_guild(self.guild_id) + + def get_guild(self) -> typing.Optional[guilds.GatewayGuild]: + """Get the object of the guild this interaction was triggered in from the cache. + + Returns + ------- + typing.Optional[hikari.guilds.GatewayGuild] + The object of the guild if found, else `builtins.None`. + """ + if self.guild_id and isinstance(self.app, traits.CacheAware): + return self.app.cache.get_guild(self.guild_id) + + return None + + def build_response(self) -> special_endpoints.InteractionMessageBuilder: + """Get a message response builder for use in the REST server flow. + + !!! note + For interactions received over the gateway + `ModalInteraction.create_initial_response` should be used to set + the interaction response message. + + Examples + -------- + ```py + async def handle_modal_interaction(interaction: ModalInteraction) -> InteractionMessageBuilder: + return ( + interaction + .build_response() + .add_embed(Embed(description="Hi there")) + .set_content("Konnichiwa") + ) + ``` + + Returns + ------- + hikari.api.special_endpoints.InteractionMessageBuilder + Interaction message response builder object. + """ + return self.app.rest.interaction_message_builder(base_interactions.ResponseType.MESSAGE_CREATE) + + def build_deferred_response(self) -> special_endpoints.InteractionDeferredBuilder: + """Get a deferred message response builder for use in the REST server flow. + + !!! note + For interactions received over the gateway + `ModalInteraction.create_initial_response` should be used to set + the interaction response message. + + !!! note + Unlike `hikari.api.special_endpoints.InteractionMessageBuilder`, + the result of this call can be returned as is without any modifications + being made to it. + + Returns + ------- + hikari.api.special_endpoints.InteractionDeferredBuilder + Deferred interaction message response builder object. + """ + return self.app.rest.interaction_deferred_builder(base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE) diff --git a/hikari/internal/cache.py b/hikari/internal/cache.py index 4b80d3e7c5..c71ccb6f0f 100644 --- a/hikari/internal/cache.py +++ b/hikari/internal/cache.py @@ -69,6 +69,7 @@ if typing.TYPE_CHECKING: from hikari import applications from hikari import channels as channels_ + from hikari import components as components_ from hikari import traits from hikari import users as users_ from hikari.interactions import base_interactions @@ -753,7 +754,7 @@ class MessageData(BaseData[messages.Message]): referenced_message: typing.Optional[RefCell[MessageData]] = attr.field() interaction: typing.Optional[MessageInteractionData] = attr.field() application_id: typing.Optional[snowflakes.Snowflake] = attr.field() - components: typing.Tuple[messages.PartialComponent, ...] = attr.field() + components: typing.Tuple[components_.MessageActionRowComponent, ...] = attr.field() @classmethod def build_from_entity( diff --git a/hikari/messages.py b/hikari/messages.py index 2235572583..09bc534f3d 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -35,21 +35,13 @@ "MessageReference", "PartialMessage", "Message", - "ActionRowComponent", - "ButtonComponent", - "ButtonStyle", - "SelectMenuOption", - "SelectMenuComponent", - "InteractiveButtonTypes", - "InteractiveButtonTypesT", - "ComponentType", - "PartialComponent", ) import typing import attr +from hikari import components as component_models from hikari import files from hikari import guilds from hikari import snowflakes @@ -375,234 +367,6 @@ class MessageInteraction: """Object of the user who invoked this interaction.""" -@typing.final -class ComponentType(int, enums.Enum): - """Types of components found within Discord.""" - - ACTION_ROW = 1 - """A non-interactive container component for other types of components. - - !!! note - As this is a container component it can never be contained within another - component and therefore will always be top-level. - - !!! note - As of writing this can only contain one component type. - """ - - BUTTON = 2 - """A button component. - - !!! note - This cannot be top-level and must be within a container component such - as `ComponentType.ACTION_ROW`. - """ - - SELECT_MENU = 3 - """A select menu component. - - !!! note - This cannot be top-level and must be within a container component such - as `ComponentType.ACTION_ROW`. - """ - - -@typing.final -class ButtonStyle(int, enums.Enum): - """Enum of the available button styles. - - More information, such as how these look, can be found at - https://discord.com/developers/docs/interactions/message-components#buttons-button-styles - """ - - PRIMARY = 1 - """A blurple "call to action" button.""" - - SECONDARY = 2 - """A grey neutral button.""" - - SUCCESS = 3 - """A green button.""" - - DANGER = 4 - """A red button (usually indicates a destructive action).""" - - LINK = 5 - """A grey button which navigates to a URL. - - !!! warning - Unlike the other button styles, clicking this one will not trigger an - interaction and custom_id shouldn't be included for this style. - """ - - -InteractiveButtonTypesT = typing.Union[ - typing.Literal[ButtonStyle.PRIMARY], - typing.Literal[1], - typing.Literal[ButtonStyle.SECONDARY], - typing.Literal[2], - typing.Literal[ButtonStyle.SUCCESS], - typing.Literal[3], - typing.Literal[ButtonStyle.DANGER], - typing.Literal[4], -] -"""Type hints of the `ButtonStyle` values which are valid for interactive buttons. - -The following values are valid for this: - -* `ButtonStyle.PRIMARY`/`1` -* `ButtonStyle.SECONDARY`/`2` -* `ButtonStyle.SUCCESS`/`3` -* `ButtonStyle.DANGER`/`4` -""" - -InteractiveButtonTypes: typing.AbstractSet[InteractiveButtonTypesT] = frozenset( - [ButtonStyle.PRIMARY, ButtonStyle.SECONDARY, ButtonStyle.SUCCESS, ButtonStyle.DANGER] -) -"""Set of the `ButtonType`s which are valid for interactive buttons. - -The following values are included in this: - -* `ButtonStyle.PRIMARY` -* `ButtonStyle.SECONDARY` -* `ButtonStyle.SUCCESS` -* `ButtonStyle.DANGER` -""" - - -@attr.define(kw_only=True, weakref_slot=False) -class PartialComponent: - """Base class for all component entities.""" - - type: typing.Union[ComponentType, int] = attr.field() - """The type of component this is.""" - - -@attr.define(hash=True, kw_only=True, weakref_slot=False) -class ButtonComponent(PartialComponent): - """Represents a message button component. - - !!! note - This is an embedded component and will only ever be found within - top-level container components such as `ActionRowComponent`. - """ - - style: typing.Union[ButtonStyle, int] = attr.field(eq=False) - """The button's style.""" - - label: typing.Optional[str] = attr.field(eq=False) - """Text label which appears on the button.""" - - emoji: typing.Optional[emojis_.Emoji] = attr.field(eq=False) - """Custom or unicode emoji which appears on the button.""" - - custom_id: typing.Optional[str] = attr.field(hash=True) - """Developer defined identifier for this button (will be <= 100 characters). - - !!! note - This is required for the following button styles: - - * `ButtonStyle.PRIMARY` - * `ButtonStyle.SECONDARY` - * `ButtonStyle.SUCCESS` - * `ButtonStyle.DANGER` - """ - - url: typing.Optional[str] = attr.field(eq=False) - """Url for `ButtonStyle.LINK` style buttons.""" - - is_disabled: bool = attr.field(eq=False) - """Whether the button is disabled.""" - - -@attr.define(kw_only=True, weakref_slot=False) -class SelectMenuOption: - """Represents an option for a `SelectMenuComponent`.""" - - label: str = attr.field() - """User-facing name of the option, max 100 characters.""" - - value: str = attr.field() - """Dev-defined value of the option, max 100 characters.""" - - description: typing.Optional[str] = attr.field() - """Optional description of the option, max 100 characters.""" - - emoji: typing.Optional[emojis_.Emoji] = attr.field(eq=False) - """Custom or unicode emoji which appears on the button.""" - - is_default: bool = attr.field() - """Whether this option will be selected by default.""" - - -@attr.define(hash=True, kw_only=True, weakref_slot=False) -class SelectMenuComponent(PartialComponent): - """Represents a message button component. - - !!! note - This is an embedded component and will only ever be found within - top-level container components such as `ActionRowComponent`. - """ - - custom_id: str = attr.field(hash=True) - """Developer defined identifier for this menu (will be <= 100 characters).""" - - options: typing.Sequence[SelectMenuOption] = attr.field(eq=False) - """Sequence of up to 25 of the options set for this menu.""" - - placeholder: typing.Optional[str] = attr.field(eq=False) - """Custom placeholder text shown if nothing is selected, max 100 characters.""" - - min_values: int = attr.field(eq=False) - """The minimum amount of options which must be chosen for this menu. - - This will be greater than or equal to 0 and will be less than or equal to - `SelectMenuComponent.max_values`. - """ - - max_values: int = attr.field(eq=False) - """The minimum amount of options which can be chosen for this menu. - - This will be less than or equal to 25 and will be greater than or equal to - `SelectMenuComponent.min_values`. - """ - - is_disabled: bool = attr.field(eq=False) - """Whether the select menu is disabled.""" - - -@attr.define(weakref_slot=False) -class ActionRowComponent(PartialComponent): - """Represents a row of components attached to a message. - - !!! note - This is a top-level container component and will never be found within - another component. - """ - - components: typing.Sequence[PartialComponent] = attr.field() - """Sequence of the components contained within this row.""" - - @typing.overload - def __getitem__(self, index: int, /) -> PartialComponent: - ... - - @typing.overload - def __getitem__(self, slice_: slice, /) -> typing.Sequence[PartialComponent]: - ... - - def __getitem__( - self, index_or_slice: typing.Union[int, slice], / - ) -> typing.Union[PartialComponent, typing.Sequence[PartialComponent]]: - return self.components[index_or_slice] - - def __iter__(self) -> typing.Iterator[PartialComponent]: - return iter(self.components) - - def __len__(self) -> int: - return len(self.components) - - def _map_cache_maybe_discover( ids: typing.Iterable[snowflakes.Snowflake], cache_call: typing.Callable[[snowflakes.Snowflake], typing.Optional[_T]], @@ -803,7 +567,9 @@ class PartialMessage(snowflakes.Unique): This will only be provided for interaction messages. """ - components: undefined.UndefinedOr[typing.Sequence[PartialComponent]] = attr.field(hash=False, eq=False, repr=False) + components: undefined.UndefinedOr[typing.Sequence[component_models.MessageActionRowComponent]] = attr.field( + hash=False, eq=False, repr=False + ) """Sequence of the components attached to this message.""" @property @@ -1646,5 +1412,7 @@ class Message(PartialMessage): This will only be provided for interaction messages. """ - components: typing.Sequence[PartialComponent] = attr.field(hash=False, eq=False, repr=False) + components: typing.Sequence[component_models.MessageActionRowComponent] = attr.field( + hash=False, eq=False, repr=False + ) """Sequence of the components attached to this message.""" diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index 9a55ace590..ca558ff0e2 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -30,6 +30,7 @@ from hikari import channels as channel_models from hikari import colors as color_models from hikari import commands +from hikari import components as component_models from hikari import embeds as embed_models from hikari import emojis as emoji_models from hikari import errors @@ -53,6 +54,7 @@ from hikari.interactions import base_interactions from hikari.interactions import command_interactions from hikari.interactions import component_interactions +from hikari.interactions import modal_interactions from tests.hikari import hikari_test_helpers @@ -4527,7 +4529,7 @@ def test_deserialize_component_interaction( assert interaction.token == "unique_interaction_token" assert interaction.version == 1 assert interaction.channel_id == 345626669114982999 - assert interaction.component_type is message_models.ComponentType.BUTTON + assert interaction.component_type is component_models.ComponentType.BUTTON assert interaction.custom_id == "click_one" assert interaction.guild_id == 290926798626357999 assert interaction.message == entity_factory_impl.deserialize_message(message_payload) @@ -4569,6 +4571,83 @@ def test_deserialize_component_interaction_with_undefined_fields( assert interaction.app_permissions is None assert isinstance(interaction, component_interactions.ComponentInteraction) + @pytest.fixture() + def modal_interaction_payload(self, interaction_member_payload, message_payload): + return { + "version": 1, + "type": 5, + "token": "unique_interaction_token", + "message": message_payload, + "member": interaction_member_payload, + "id": "846462639134605312", + "guild_id": "290926798626357999", + "data": { + "custom_id": "modaltest", + "components": [ + {"type": 1, "components": [{"value": "Wumpus", "type": 4, "custom_id": "name"}]}, + {"type": 1, "components": [{"value": "Longer Text", "type": 4, "custom_id": "about"}]}, + ], + }, + "channel_id": "345626669114982999", + "application_id": "290926444748734465", + "locale": "en-US", + "guild_locale": "es-ES", + } + + def test_deserialize_modal_interaction( + self, + entity_factory_impl, + mock_app, + modal_interaction_payload, + interaction_member_payload, + message_payload, + ): + interaction = entity_factory_impl.deserialize_modal_interaction(modal_interaction_payload) + assert interaction.app is mock_app + assert interaction.id == 846462639134605312 + assert interaction.application_id == 290926444748734465 + assert interaction.type is base_interactions.InteractionType.MODAL_SUBMIT + assert interaction.token == "unique_interaction_token" + assert interaction.version == 1 + assert interaction.channel_id == 345626669114982999 + assert interaction.guild_id == 290926798626357999 + assert interaction.message == entity_factory_impl.deserialize_message(message_payload) + assert interaction.member == entity_factory_impl._deserialize_interaction_member( + interaction_member_payload, guild_id=290926798626357999 + ) + assert interaction.user is interaction.member.user + assert isinstance(interaction, modal_interactions.ModalInteraction) + + short_action_row = interaction.components[0] + assert isinstance(short_action_row, component_models.ActionRowComponent) + short_text_input = short_action_row.components[0] + assert isinstance(short_text_input, component_models.TextInputComponent) + assert short_text_input.value == "Wumpus" + assert short_text_input.type == component_models.ComponentType.TEXT_INPUT + assert short_text_input.custom_id == "name" + + def test_deserialize_modal_interaction_with_user( + self, + entity_factory_impl, + modal_interaction_payload, + user_payload, + ): + modal_interaction_payload["member"] = None + modal_interaction_payload["user"] = user_payload + + interaction = entity_factory_impl.deserialize_modal_interaction(modal_interaction_payload) + assert interaction.user.id == 115590097100865541 + + def test_deserialize_modal_interaction_with_unrecognized_component( + self, + entity_factory_impl, + modal_interaction_payload, + ): + modal_interaction_payload["data"]["components"] = [{"type": 0}] + + interaction = entity_factory_impl.deserialize_modal_interaction(modal_interaction_payload) + assert len(interaction.components) == 0 + ################## # STICKER MODELS # ################## @@ -5004,27 +5083,14 @@ def test_max_age_when_zero(self, entity_factory_impl, invite_with_metadata_paylo invite_with_metadata_payload["max_age"] = 0 assert entity_factory_impl.deserialize_invite_with_metadata(invite_with_metadata_payload).max_age is None - ################## - # MESSAGE MODELS # - ################## + #################### + # COMPONENT MODELS # + #################### @pytest.fixture() def action_row_payload(self, button_payload): return {"type": 1, "components": [button_payload]} - def test__deserialize_action_row(self, entity_factory_impl, action_row_payload, button_payload): - action_row = entity_factory_impl._deserialize_action_row(action_row_payload) - - assert action_row.type is message_models.ComponentType.ACTION_ROW - assert action_row.components == [entity_factory_impl._deserialize_component(button_payload)] - - def test__deserialize_action_row_handles_unknown_component_type(self, entity_factory_impl): - action_row = entity_factory_impl._deserialize_action_row( - {"type": 1, "components": [{"type": "9494949"}, {"type": "9239292"}]} - ) - - assert action_row.components == [] - @pytest.fixture() def button_payload(self, custom_emoji_payload): return { @@ -5040,8 +5106,8 @@ def button_payload(self, custom_emoji_payload): def test_deserialize__deserialize_button(self, entity_factory_impl, button_payload, custom_emoji_payload): button = entity_factory_impl._deserialize_button(button_payload) - assert button.type is message_models.ComponentType.BUTTON - assert button.style is message_models.ButtonStyle.PRIMARY + assert button.type is component_models.ComponentType.BUTTON + assert button.style is component_models.ButtonStyle.PRIMARY assert button.label == "Click me!" assert button.emoji == entity_factory_impl.deserialize_emoji(custom_emoji_payload) assert button.custom_id == "click_one" @@ -5053,8 +5119,8 @@ def test_deserialize__deserialize_button_with_unset_fields( ): button = entity_factory_impl._deserialize_button({"type": 2, "style": 5}) - assert button.type is message_models.ComponentType.BUTTON - assert button.style is message_models.ButtonStyle.LINK + assert button.type is component_models.ComponentType.BUTTON + assert button.style is component_models.ButtonStyle.LINK assert button.label is None assert button.emoji is None assert button.custom_id is None @@ -5084,7 +5150,7 @@ def select_menu_payload(self, custom_emoji_payload): def test__deserialize_select_menu(self, entity_factory_impl, select_menu_payload, custom_emoji_payload): menu = entity_factory_impl._deserialize_select_menu(select_menu_payload) - assert menu.type is message_models.ComponentType.SELECT_MENU + assert menu.type is component_models.ComponentType.SELECT_MENU assert menu.custom_id == "Not an ID" # SelectMenuOption @@ -5095,7 +5161,7 @@ def test__deserialize_select_menu(self, entity_factory_impl, select_menu_payload assert option.description == "queen" assert option.emoji == entity_factory_impl.deserialize_emoji(custom_emoji_payload) assert option.is_default is True - assert isinstance(option, message_models.SelectMenuOption) + assert isinstance(option, component_models.SelectMenuOption) assert menu.placeholder == "Imagine a place" assert menu.min_values == 69 @@ -5124,28 +5190,52 @@ def test__deserialize_select_menu_partial(self, entity_factory_impl): assert menu.is_disabled is False @pytest.mark.parametrize( - ("type_", "fn"), + ("type_", "fn", "mapping"), [ - (1, "_deserialize_action_row"), - (2, "_deserialize_button"), - (3, "_deserialize_select_menu"), + (2, "_deserialize_button", "_message_component_type_mapping"), + (3, "_deserialize_select_menu", "_message_component_type_mapping"), + (4, "_deserialize_text_input", "_modal_component_type_mapping"), ], ) - def test__deserialize_component(self, mock_app, type_, fn): - payload = {"type": type_} + def test__deserialize_components(self, mock_app, type_, fn, mapping): + component_payload = {"type": type_} + payload = [{"type": 1, "components": [component_payload]}] with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) - assert entity_factory_impl._deserialize_component(payload) is expected_fn.return_value + components = entity_factory_impl._deserialize_components(payload, getattr(entity_factory_impl, mapping)) - expected_fn.assert_called_once_with(payload) + expected_fn.assert_called_once_with(component_payload) + action_row = components[0] + assert isinstance(action_row, component_models.ActionRowComponent) + assert action_row.components[0] is expected_fn.return_value - def test__deserialize_component_handles_unknown_type(self, entity_factory_impl): - with pytest.raises(errors.UnrecognisedEntityError): - entity_factory_impl._deserialize_component({"type": -9434994}) + def test__deserialize_components_handles_unknown_top_component_type(self, entity_factory_impl): + components = entity_factory_impl._deserialize_components( + [ + # Unknown top-level component + {"type": -9434994}, + { + # Known top-level component + "type": 1, + "components": [ + # Unknown components + {"type": 1}, + {"type": 1000000}, + ], + }, + ], + {}, + ) + + assert components == [] + + ################## + # MESSAGE MODELS # + ################## @pytest.fixture() def partial_application_payload(self): @@ -5381,7 +5471,9 @@ def test_deserialize_partial_message( assert partial_message.interaction.user == entity_factory_impl.deserialize_user(user_payload) assert isinstance(partial_message.interaction, message_models.MessageInteraction) - assert partial_message.components == [entity_factory_impl._deserialize_component(action_row_payload)] + assert partial_message.components == entity_factory_impl._deserialize_components( + [action_row_payload], entity_factory_impl._message_component_type_mapping + ) def test_deserialize_partial_message_with_partial_fields(self, entity_factory_impl, message_payload): message_payload["content"] = "" @@ -5561,7 +5653,9 @@ def test_deserialize_message( assert message.interaction.user == entity_factory_impl.deserialize_user(user_payload) assert isinstance(message.interaction, message_models.MessageInteraction) - assert message.components == [entity_factory_impl._deserialize_component(action_row_payload)] + assert message.components == entity_factory_impl._deserialize_components( + [action_row_payload], entity_factory_impl._message_component_type_mapping + ) def test_deserialize_message_with_unset_sub_fields(self, entity_factory_impl, message_payload): del message_payload["application"]["cover_image"] diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 8c1a270836..2af8b2f925 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -1199,11 +1199,23 @@ def test_context_menu_command_command_builder(self, rest_client): assert result.type == commands.CommandType.MESSAGE def test_build_action_row(self, rest_client): - with mock.patch.object(special_endpoints, "ActionRowBuilder") as action_row_builder: + with mock.patch.object(special_endpoints, "MessageActionRowBuilder") as action_row_builder: assert rest_client.build_action_row() is action_row_builder.return_value action_row_builder.assert_called_once_with() + def test_build_message_action_row(self, rest_client): + with mock.patch.object(special_endpoints, "MessageActionRowBuilder") as action_row_builder: + assert rest_client.build_message_action_row() is action_row_builder.return_value + + action_row_builder.assert_called_once_with() + + def test_build_modal_action_row(self, rest_client): + with mock.patch.object(special_endpoints, "ModalActionRowBuilder") as action_row_builder: + assert rest_client.build_modal_action_row() is action_row_builder.return_value + + action_row_builder.assert_called_once_with() + def test__build_message_payload_with_undefined_args(self, rest_client): with mock.patch.object( mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} @@ -1634,6 +1646,21 @@ def test_interaction_message_builder(self, rest_client): assert result.type == 4 assert isinstance(result, special_endpoints.InteractionMessageBuilder) + def test_interaction_modal_builder(self, rest_client): + result = rest_client.interaction_modal_builder("title", "custom") + result.add_component( + special_endpoints.ModalActionRowBuilder().add_text_input("idd", "labell").add_to_container() + ) + + assert result.type == 9 + assert isinstance(result, special_endpoints.InteractionModalBuilder) + + def test_interaction_modal_builder_with_components(self, rest_client): + result = rest_client.interaction_modal_builder("title", "custom") + + assert result.type == 9 + assert isinstance(result, special_endpoints.InteractionModalBuilder) + def test_fetch_scheduled_event_users(self, rest_client: rest.RESTClientImpl): with mock.patch.object(special_endpoints, "ScheduledEventUserIterator") as iterator_cls: iterator = rest_client.fetch_scheduled_event_users( @@ -5879,11 +5906,11 @@ async def test_delete_interaction_response(self, rest_client): rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) async def test_create_autocomplete_response(self, rest_client): - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="dissssnake") + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") rest_client._request = mock.AsyncMock() choices = [commands.CommandChoice(name="a", value="b"), commands.CommandChoice(name="foo", value="bar")] - await rest_client.create_autocomplete_response(StubModel(1235431), "dissssnake", choices) + await rest_client.create_autocomplete_response(StubModel(1235431), "snek", choices) rest_client._request.assert_awaited_once_with( expected_route, @@ -5891,6 +5918,48 @@ async def test_create_autocomplete_response(self, rest_client): no_auth=True, ) + async def test_create_modal_response(self, rest_client): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") + rest_client._request = mock.AsyncMock() + component = mock.Mock() + + await rest_client.create_modal_response( + StubModel(1235431), "snek", title="title", custom_id="idd", component=component + ) + + rest_client._request.assert_awaited_once_with( + expected_route, + json={ + "type": 9, + "data": {"title": "title", "custom_id": "idd", "components": [component.build.return_value]}, + }, + no_auth=True, + ) + + async def test_create_modal_response_with_plural_args(self, rest_client): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") + rest_client._request = mock.AsyncMock() + component = mock.Mock() + + await rest_client.create_modal_response( + StubModel(1235431), "snek", title="title", custom_id="idd", components=[component] + ) + + rest_client._request.assert_awaited_once_with( + expected_route, + json={ + "type": 9, + "data": {"title": "title", "custom_id": "idd", "components": [component.build.return_value]}, + }, + no_auth=True, + ) + + async def test_create_modal_response_when_both_component_and_components_passed(self, rest_client): + with pytest.raises(ValueError, match="Must specify exactly only one of 'component' or 'components'"): + await rest_client.create_modal_response( + StubModel(1235431), "snek", title="title", custom_id="idd", component="not none", components=[] + ) + async def test_fetch_scheduled_event(self, rest_client: rest.RESTClientImpl): expected_route = routes.GET_GUILD_SCHEDULED_EVENT.compile(guild=453123, scheduled_event=222332323) rest_client._request = mock.AsyncMock(return_value={"id": "4949494949"}) diff --git a/tests/hikari/impl/test_special_endpoints.py b/tests/hikari/impl/test_special_endpoints.py index cf010e93be..6454c121a4 100644 --- a/tests/hikari/impl/test_special_endpoints.py +++ b/tests/hikari/impl/test_special_endpoints.py @@ -24,6 +24,7 @@ import pytest from hikari import commands +from hikari import components from hikari import emojis from hikari import files from hikari import locales @@ -912,6 +913,40 @@ def test_build_handles_cleared_attachments(self): assert attachments == [] +class TestInteractionModalBuilder: + def test_type_property(self): + builder = special_endpoints.InteractionModalBuilder("title", "custom_id") + assert builder.type == 9 + + def test_title_property(self): + builder = special_endpoints.InteractionModalBuilder("title", "custom_id").set_title("title2") + assert builder.title == "title2" + + def test_custom_id_property(self): + builder = special_endpoints.InteractionModalBuilder("title", "custom_id").set_custom_id("better_custom_id") + assert builder.custom_id == "better_custom_id" + + def test_components_property(self): + component = mock.Mock() + builder = special_endpoints.InteractionModalBuilder("title", "custom_id").add_component(component) + assert builder.components == [component] + + def test_build(self): + component = mock.Mock() + builder = special_endpoints.InteractionModalBuilder("title", "custom_id").add_component(component) + + result, attachments = builder.build(mock.Mock()) + assert result == { + "type": 9, + "data": { + "title": "title", + "custom_id": "custom_id", + "components": [component.build.return_value], + }, + } + assert attachments == () + + class TestSlashCommandBuilder: def test_description_property(self): builder = special_endpoints.SlashCommandBuilder("ok", "NO") @@ -1166,7 +1201,7 @@ class Test_ButtonBuilder: def button(self): return special_endpoints._ButtonBuilder( container=mock.Mock(), - style=messages.ButtonStyle.DANGER, + style=components.ButtonStyle.DANGER, custom_id="sfdasdasd", url="hi there", emoji=543123, @@ -1177,7 +1212,7 @@ def button(self): ) def test_style_property(self, button): - assert button.style is messages.ButtonStyle.DANGER + assert button.style is components.ButtonStyle.DANGER def test_emoji_property(self, button): assert button.emoji == 543123 @@ -1219,7 +1254,7 @@ def test_set_is_disabled(self, button): def test_build(self): result = special_endpoints._ButtonBuilder( container=object(), - style=messages.ButtonStyle.DANGER, + style=components.ButtonStyle.DANGER, url=undefined.UNDEFINED, emoji_id=undefined.UNDEFINED, emoji_name="emoji_name", @@ -1229,8 +1264,8 @@ def test_build(self): ).build() assert result == { - "type": messages.ComponentType.BUTTON, - "style": messages.ButtonStyle.DANGER, + "type": components.ComponentType.BUTTON, + "style": components.ButtonStyle.DANGER, "emoji": {"name": "emoji_name"}, "label": "no u", "custom_id": "ooga booga", @@ -1240,7 +1275,7 @@ def test_build(self): def test_build_without_optional_fields(self): result = special_endpoints._ButtonBuilder( container=object(), - style=messages.ButtonStyle.LINK, + style=components.ButtonStyle.LINK, url="OK", emoji_id="123321", emoji_name=undefined.UNDEFINED, @@ -1250,8 +1285,8 @@ def test_build_without_optional_fields(self): ).build() assert result == { - "type": messages.ComponentType.BUTTON, - "style": messages.ButtonStyle.LINK, + "type": components.ComponentType.BUTTON, + "style": components.ButtonStyle.LINK, "emoji": {"id": "123321"}, "disabled": False, "url": "OK", @@ -1261,7 +1296,7 @@ def test_add_to_container(self): mock_container = mock.Mock() button = special_endpoints._ButtonBuilder( container=mock_container, - style=messages.ButtonStyle.DANGER, + style=components.ButtonStyle.DANGER, url=undefined.UNDEFINED, emoji_id=undefined.UNDEFINED, emoji_name="emoji_name", @@ -1279,7 +1314,7 @@ class TestLinkButtonBuilder: def test_url_property(self): button = special_endpoints.LinkButtonBuilder( container=object(), - style=messages.ButtonStyle.DANGER, + style=components.ButtonStyle.DANGER, url="hihihihi", emoji_id=undefined.UNDEFINED, emoji_name="emoji_name", @@ -1295,7 +1330,7 @@ class TestInteractiveButtonBuilder: def test_custom_id_property(self): button = special_endpoints.InteractiveButtonBuilder( container=object(), - style=messages.ButtonStyle.DANGER, + style=components.ButtonStyle.DANGER, url="hihihihi", emoji_id=undefined.UNDEFINED, emoji_name="emoji_name", @@ -1442,7 +1477,7 @@ def test_build(self): result = special_endpoints.SelectMenuBuilder(container=object(), custom_id="o2o2o2").build() assert result == { - "type": messages.ComponentType.SELECT_MENU, + "type": components.ComponentType.SELECT_MENU, "custom_id": "o2o2o2", "options": [], "disabled": False, @@ -1462,7 +1497,7 @@ def test_build_partial(self): ) assert result == { - "type": messages.ComponentType.SELECT_MENU, + "type": components.ComponentType.SELECT_MENU, "custom_id": "o2o2o2", "options": [{"hi": "OK"}], "placeholder": "hi", @@ -1472,30 +1507,117 @@ def test_build_partial(self): } -class TestActionRowBuilder: +class TestTextInput: + @pytest.fixture() + def text_input(self): + return special_endpoints.TextInputBuilder( + container=mock.Mock(), + custom_id="o2o2o2", + label="label", + ) + + def test_set_style(self, text_input): + assert text_input.set_style(components.TextInputStyle.PARAGRAPH) is text_input + assert text_input.style == components.TextInputStyle.PARAGRAPH + + def test_set_custom_id(self, text_input): + assert text_input.set_custom_id("custooom") is text_input + assert text_input.custom_id == "custooom" + + def test_set_label(self, text_input): + assert text_input.set_label("labeeeel") is text_input + assert text_input.label == "labeeeel" + + def test_set_placeholder(self, text_input): + assert text_input.set_placeholder("place") is text_input + assert text_input.placeholder == "place" + + def test_set_required(self, text_input): + assert text_input.set_required(True) is text_input + assert text_input.required is True + + def test_set_value(self, text_input): + assert text_input.set_value("valueeeee") is text_input + assert text_input.value == "valueeeee" + + def test_set_min_length_(self, text_input): + assert text_input.set_min_length(10) is text_input + assert text_input.min_length == 10 + + def test_set_max_length(self, text_input): + assert text_input.set_max_length(250) is text_input + assert text_input.max_length == 250 + + def test_add_to_container(self, text_input): + assert text_input.add_to_container() is text_input._container + text_input._container.add_component.assert_called_once_with(text_input) + + def test_build(self): + result = special_endpoints.TextInputBuilder( + container=object(), + custom_id="o2o2o2", + label="label", + ).build() + + assert result == { + "type": components.ComponentType.TEXT_INPUT, + "style": 1, + "custom_id": "o2o2o2", + "label": "label", + } + + def test_build_partial(self): + result = ( + special_endpoints.TextInputBuilder( + container=object(), + custom_id="o2o2o2", + label="label", + ) + .set_placeholder("placeholder") + .set_value("value") + .set_required(False) + .set_min_length(10) + .set_max_length(250) + .build() + ) + + assert result == { + "type": components.ComponentType.TEXT_INPUT, + "style": 1, + "custom_id": "o2o2o2", + "label": "label", + "placeholder": "placeholder", + "value": "value", + "required": False, + "min_length": 10, + "max_length": 250, + } + + +class TestMessageActionRowBuilder: def test_components_property(self): mock_component = object() - row = special_endpoints.ActionRowBuilder().add_component(mock_component) + row = special_endpoints.MessageActionRowBuilder().add_component(mock_component) assert row.components == [mock_component] def test_add_button_for_interactive(self): - row = special_endpoints.ActionRowBuilder() - button = row.add_button(messages.ButtonStyle.DANGER, "go home") + row = special_endpoints.MessageActionRowBuilder() + button = row.add_button(components.ButtonStyle.DANGER, "go home") button.add_to_container() assert row.components == [button] def test_add_button_for_link(self): - row = special_endpoints.ActionRowBuilder() - button = row.add_button(messages.ButtonStyle.LINK, "go home") + row = special_endpoints.MessageActionRowBuilder() + button = row.add_button(components.ButtonStyle.LINK, "go home") button.add_to_container() assert row.components == [button] def test_add_select_menu(self): - row = special_endpoints.ActionRowBuilder() + row = special_endpoints.MessageActionRowBuilder() menu = row.add_select_menu("hihihi") menu.add_to_container() @@ -1506,14 +1628,24 @@ def test_build(self): mock_component_1 = mock.Mock() mock_component_2 = mock.Mock() - row = special_endpoints.ActionRowBuilder() + row = special_endpoints.MessageActionRowBuilder() row._components = [mock_component_1, mock_component_2] result = row.build() assert result == { - "type": messages.ComponentType.ACTION_ROW, + "type": components.ComponentType.ACTION_ROW, "components": [mock_component_1.build.return_value, mock_component_2.build.return_value], } mock_component_1.build.assert_called_once_with() mock_component_2.build.assert_called_once_with() + + +class TestModalActionRow: + def test_add_text_input(self): + row = special_endpoints.ModalActionRowBuilder() + menu = row.add_text_input("hihihi", "label") + + menu.add_to_container() + + assert row.components == [menu] diff --git a/tests/hikari/interactions/test_base_interactions.py b/tests/hikari/interactions/test_base_interactions.py index d1697404c1..6aa2e712d6 100644 --- a/tests/hikari/interactions/test_base_interactions.py +++ b/tests/hikari/interactions/test_base_interactions.py @@ -196,3 +196,36 @@ async def test_delete_initial_response(self, mock_message_response_mixin, mock_a await mock_message_response_mixin.delete_initial_response() mock_app.rest.delete_interaction_response.assert_awaited_once_with(651231, "399393939doodsodso") + + +class TestModalResponseMixin: + @pytest.fixture() + def mock_modal_response_mixin(self, mock_app): + return base_interactions.ModalResponseMixin( + app=mock_app, + id=34123, + application_id=651231, + type=base_interactions.InteractionType.APPLICATION_COMMAND, + token="399393939doodsodso", + version=3122312, + ) + + @pytest.mark.asyncio() + async def test_create_modal_response(self, mock_modal_response_mixin, mock_app): + await mock_modal_response_mixin.create_modal_response("title", "custom_id", None, []) + + mock_app.rest.create_modal_response.assert_awaited_once_with( + 34123, + "399393939doodsodso", + title="title", + custom_id="custom_id", + component=None, + components=[], + ) + + def test_build_response(self, mock_modal_response_mixin, mock_app): + mock_app.rest.interaction_modal_builder = mock.Mock() + builder = mock_modal_response_mixin.build_modal_response("title", "custom_id") + + assert builder is mock_app.rest.interaction_modal_builder.return_value + mock_app.rest.interaction_modal_builder.assert_called_once_with(title="title", custom_id="custom_id") diff --git a/tests/hikari/interactions/test_component_interactions.py b/tests/hikari/interactions/test_component_interactions.py index 160e082121..b4427b92c8 100644 --- a/tests/hikari/interactions/test_component_interactions.py +++ b/tests/hikari/interactions/test_component_interactions.py @@ -27,7 +27,6 @@ from hikari import traits from hikari.interactions import base_interactions from hikari.interactions import component_interactions -from tests.hikari import hikari_test_helpers @pytest.fixture() @@ -144,28 +143,3 @@ def test_get_guild_when_cacheless(self, mock_component_interaction, mock_app): assert mock_component_interaction.get_guild() is None mock_app.cache.get_guild.assert_not_called() - - @pytest.mark.asyncio() - async def test_fetch_parent_message(self): - stub_interaction = hikari_test_helpers.mock_class_namespace( - component_interactions.ComponentInteraction, fetch_message=mock.AsyncMock(), init_=False - )() - stub_interaction.message = mock.Mock(id=3421) - - assert await stub_interaction.fetch_parent_message() is stub_interaction.fetch_message.return_value - - stub_interaction.fetch_message.assert_awaited_once_with(3421) - - def test_get_parent_message(self, mock_component_interaction, mock_app): - mock_component_interaction.message = mock.Mock(id=321655) - - assert mock_component_interaction.get_parent_message() is mock_app.cache.get_message.return_value - - mock_app.cache.get_message.assert_called_once_with(321655) - - def test_get_parent_message_when_cacheless(self, mock_component_interaction, mock_app): - mock_component_interaction.app = mock.Mock(traits.RESTAware) - - assert mock_component_interaction.get_parent_message() is None - - mock_app.cache.get_message.assert_not_called() diff --git a/tests/hikari/interactions/test_modal_interactions.py b/tests/hikari/interactions/test_modal_interactions.py new file mode 100644 index 0000000000..a539792640 --- /dev/null +++ b/tests/hikari/interactions/test_modal_interactions.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import mock +import pytest + +from hikari import channels +from hikari import components +from hikari import snowflakes +from hikari import traits +from hikari.impl import special_endpoints +from hikari.interactions import base_interactions +from hikari.interactions import modal_interactions + + +@pytest.fixture() +def mock_app(): + return mock.Mock(rest=mock.AsyncMock()) + + +class TestModalInteraction: + @pytest.fixture() + def mock_modal_interaction(self, mock_app): + return modal_interactions.ModalInteraction( + app=mock_app, + id=snowflakes.Snowflake(2312312), + type=base_interactions.InteractionType.APPLICATION_COMMAND, + channel_id=snowflakes.Snowflake(3123123), + guild_id=snowflakes.Snowflake(5412231), + member=object(), + user=object(), + token="httptptptptptptptp", + version=1, + application_id=snowflakes.Snowflake(43123), + custom_id="OKOKOK", + message=object(), + locale="es-ES", + guild_locale="en-US", + app_permissions=543123, + components=special_endpoints.ModalActionRowBuilder( + components=[ + components.TextInputComponent( + type=components.ComponentType.TEXT_INPUT, custom_id="le id", value="le value" + ) + ], + ), + ) + + def test_build_response(self, mock_modal_interaction, mock_app): + mock_app.rest.interaction_message_builder = mock.Mock() + response = mock_modal_interaction.build_response() + + assert response is mock_app.rest.interaction_message_builder.return_value + mock_app.rest.interaction_message_builder.assert_called_once() + + def test_build_deferred_response(self, mock_modal_interaction, mock_app): + mock_app.rest.interaction_deferred_builder = mock.Mock() + response = mock_modal_interaction.build_deferred_response() + + assert response is mock_app.rest.interaction_deferred_builder.return_value + mock_app.rest.interaction_deferred_builder.assert_called_once() + + @pytest.mark.asyncio() + async def test_fetch_channel(self, mock_modal_interaction, mock_app): + mock_app.rest.fetch_channel.return_value = mock.Mock(channels.TextableChannel) + + assert await mock_modal_interaction.fetch_channel() is mock_app.rest.fetch_channel.return_value + + mock_app.rest.fetch_channel.assert_awaited_once_with(3123123) + + def test_get_channel(self, mock_modal_interaction, mock_app): + mock_app.cache.get_guild_channel.return_value = mock.Mock(channels.GuildTextChannel) + + assert mock_modal_interaction.get_channel() is mock_app.cache.get_guild_channel.return_value + + mock_app.cache.get_guild_channel.assert_called_once_with(3123123) + + def test_get_channel_without_cache(self, mock_modal_interaction): + mock_modal_interaction.app = mock.Mock(traits.RESTAware) + + assert mock_modal_interaction.get_channel() is None + + @pytest.mark.asyncio() + async def test_fetch_guild(self, mock_modal_interaction, mock_app): + mock_modal_interaction.guild_id = 43123123 + + assert await mock_modal_interaction.fetch_guild() is mock_app.rest.fetch_guild.return_value + + mock_app.rest.fetch_guild.assert_awaited_once_with(43123123) + + @pytest.mark.asyncio() + async def test_fetch_guild_for_dm_interaction(self, mock_modal_interaction, mock_app): + mock_modal_interaction.guild_id = None + + assert await mock_modal_interaction.fetch_guild() is None + + mock_app.rest.fetch_guild.assert_not_called() + + def test_get_guild(self, mock_modal_interaction, mock_app): + mock_modal_interaction.guild_id = 874356 + + assert mock_modal_interaction.get_guild() is mock_app.cache.get_guild.return_value + + mock_app.cache.get_guild.assert_called_once_with(874356) + + def test_get_guild_for_dm_interaction(self, mock_modal_interaction, mock_app): + mock_modal_interaction.guild_id = None + + assert mock_modal_interaction.get_guild() is None + + mock_app.cache.get_guild.assert_not_called() + + def test_get_guild_when_cacheless(self, mock_modal_interaction, mock_app): + mock_modal_interaction.guild_id = 321123 + mock_modal_interaction.app = mock.Mock(traits.RESTAware) + + assert mock_modal_interaction.get_guild() is None + + mock_app.cache.get_guild.assert_not_called() diff --git a/tests/hikari/test_components.py b/tests/hikari/test_components.py new file mode 100644 index 0000000000..4e734a0253 --- /dev/null +++ b/tests/hikari/test_components.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from hikari import components + + +class TestActionRowComponent: + def test_getitem_operator_with_index(self): + mock_component = object() + row = components.ActionRowComponent(type=1, components=[object(), mock_component, object()]) + + assert row[1] is mock_component + + def test_getitem_operator_with_slice(self): + mock_component_1 = object() + mock_component_2 = object() + row = components.ActionRowComponent(type=1, components=[object(), mock_component_1, object(), mock_component_2]) + + assert row[1:4:2] == [mock_component_1, mock_component_2] + + def test_iter_operator(self): + mock_component_1 = object() + mock_component_2 = object() + row = components.ActionRowComponent(type=1, components=[mock_component_1, mock_component_2]) + + assert list(row) == [mock_component_1, mock_component_2] + + def test_len_operator(self): + row = components.ActionRowComponent(type=1, components=[object(), object()]) + + assert len(row) == 2 diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 1aa1a1c3f2..db45b30d88 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -87,33 +87,6 @@ def test_make_cover_image_url_when_hash_is_not_none(self, message_application): ) -class TestActionRowComponent: - def test_getitem_operator_with_index(self): - mock_component = object() - row = messages.ActionRowComponent(type=1, components=[object(), mock_component, object()]) - - assert row[1] is mock_component - - def test_getitem_operator_with_slice(self): - mock_component_1 = object() - mock_component_2 = object() - row = messages.ActionRowComponent(type=1, components=[object(), mock_component_1, object(), mock_component_2]) - - assert row[1:4:2] == [mock_component_1, mock_component_2] - - def test_iter_operator(self): - mock_component_1 = object() - mock_component_2 = object() - row = messages.ActionRowComponent(type=1, components=[mock_component_1, mock_component_2]) - - assert list(row) == [mock_component_1, mock_component_2] - - def test_len_operator(self): - row = messages.ActionRowComponent(type=1, components=[object(), object()]) - - assert len(row) == 2 - - @pytest.fixture() def message(): return messages.Message(