diff --git a/bot/meta/interactions/__init__.py b/bot/meta/interactions/__init__.py index e7e68854..660b5a93 100644 --- a/bot/meta/interactions/__init__.py +++ b/bot/meta/interactions/__init__.py @@ -1,3 +1,4 @@ from . import enums -from .interactions import _component_interaction_factory, Interaction, ComponentInteraction +from .interactions import _component_interaction_factory, Interaction, ComponentInteraction, ModalResponse from .components import * +from .modals import * diff --git a/bot/meta/interactions/components.py b/bot/meta/interactions/components.py index a651e241..c7788546 100644 --- a/bot/meta/interactions/components.py +++ b/bot/meta/interactions/components.py @@ -1,15 +1,23 @@ import asyncio import uuid +import json from .enums import ButtonStyle, InteractionType class MessageComponent: _type = None + interaction_type = InteractionType.MESSAGE_COMPONENT def __init_(self, *args, **kwargs): self.message = None + def to_dict(self): + raise NotImplementedError + + def to_json(self): + return json.dumps(self.to_dict()) + class ActionRow(MessageComponent): _type = 1 @@ -26,13 +34,15 @@ class ActionRow(MessageComponent): class AwaitableComponent: + interaction_type: InteractionType = None + async def wait_for(self, timeout=None, check=None): from meta import client def _check(interaction): valid = True print(interaction.custom_id) - valid = valid and interaction.interaction_type == InteractionType.MESSAGE_COMPONENT + valid = valid and interaction.interaction_type == self.interaction_type valid = valid and interaction.custom_id == self.custom_id valid = valid and (check is None or check(interaction)) return valid diff --git a/bot/meta/interactions/enums.py b/bot/meta/interactions/enums.py index c08974be..3d1c8c6c 100644 --- a/bot/meta/interactions/enums.py +++ b/bot/meta/interactions/enums.py @@ -24,5 +24,17 @@ class ButtonStyle(IntEnum): LINK = 5 +class TextInputStyle(IntEnum): + SHORT = 1 + PARAGRAPH = 2 + + class InteractionCallback(IntEnum): + PONG = 1 + CHANNEL_MESSAGE_WITH_SOURCE = 4 + DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5 DEFERRED_UPDATE_MESSAGE = 6 + UPDATE_MESSAGE = 7 + APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8 + MODAL = 9 + diff --git a/bot/meta/interactions/interactions.py b/bot/meta/interactions/interactions.py index 01f5d993..f51df5ed 100644 --- a/bot/meta/interactions/interactions.py +++ b/bot/meta/interactions/interactions.py @@ -9,11 +9,23 @@ class Interaction: '_state' ) - async def callback_deferred(self): - return await self._state.http.interaction_callback(self.id, self.token, InteractionCallback.DEFERRED_UPDATE_MESSAGE) + async def response_deferred(self): + return await self._state.http.interaction_callback( + self.id, + self.token, + InteractionCallback.DEFERRED_UPDATE_MESSAGE + ) + + async def response_modal(self, modal): + return await self._state.http.interaction_callback( + self.id, + self.token, + InteractionCallback.MODAL, + modal.to_dict() + ) def ack(self): - asyncio.create_task(self.callback_deferred()) + asyncio.create_task(self.response_deferred()) class ComponentInteraction(Interaction): @@ -51,6 +63,44 @@ class Selection(ComponentInteraction): self.values = data['data']['values'] +class ModalResponse(Interaction): + __slots__ = ( + 'message', + 'user', + '_state' + 'id', + 'token', + 'application_id', + 'custom_id', + 'values' + ) + + interaction_type = InteractionType.MODAL_SUBMIT + + def __init__(self, message, user, data, state): + self.message = message + self.user = user + + self._state = state + + self._from_data(data) + + def _from_data(self, data): + self.id = data['id'] + self.token = data['token'] + self.application_id = data['application_id'] + + component_data = data['data'] + + self.custom_id = component_data.get('custom_id', None) + + values = {} + for row in component_data['components']: + for component in row['components']: + values[component['custom_id']] = component['value'] + self.values = values + + def _component_interaction_factory(data): component_type = data['data']['component_type'] diff --git a/bot/meta/interactions/modals.py b/bot/meta/interactions/modals.py new file mode 100644 index 00000000..56bcfd8e --- /dev/null +++ b/bot/meta/interactions/modals.py @@ -0,0 +1,54 @@ +import uuid + +from .enums import TextInputStyle, InteractionType +from .components import AwaitableComponent + + +class Modal(AwaitableComponent): + interaction_type = InteractionType.MODAL_SUBMIT + + def __init__(self, title, *components, custom_id=None): + self.custom_id = custom_id or str(uuid.uuid4()) + + self.title = title + self.components = components + + def to_dict(self): + data = { + 'title': self.title, + 'custom_id': self.custom_id, + 'components': [comp.to_dict() for comp in self.components] + } + return data + + +class TextInput: + _type = 4 + + def __init__( + self, + label, placeholder=None, value=None, required=False, + style=TextInputStyle.SHORT, min_length=None, max_length=None, + custom_id=None + ): + self.custom_id = custom_id or str(uuid.uuid4()) + + self.label = label + self.placeholder = placeholder + self.value = value + self.required = required + self.style = style + self.min_length = min_length + self.max_length = max_length + + def to_dict(self): + data = { + 'type': self._type, + 'custom_id': self.custom_id, + 'style': int(self.style), + 'label': self.label, + } + for key in ('min_length', 'max_length', 'required', 'value', 'placeholder'): + if (value := getattr(self, key)) is not None: + data[key] = value + return data diff --git a/bot/meta/patches.py b/bot/meta/patches.py index 14a5741e..f692c4da 100644 --- a/bot/meta/patches.py +++ b/bot/meta/patches.py @@ -3,19 +3,29 @@ Temporary patches for the discord.py library to support new features of the disc """ import logging +from json import JSONEncoder + from discord.state import ConnectionState from discord.http import Route, HTTPClient from discord.abc import Messageable -from discord.utils import InvalidArgument, _get_as_snowflake +from discord.utils import InvalidArgument, _get_as_snowflake, to_json from discord import File, AllowedMentions, Member, User, Message -from .interactions import _component_interaction_factory +from .interactions import _component_interaction_factory, ModalResponse from .interactions.enums import InteractionType log = logging.getLogger(__name__) +def _default(self, obj): + return getattr(obj.__class__, "to_json", _default.default)(obj) + + +_default.default = JSONEncoder().default +JSONEncoder.default = _default + + def send_message(self, channel_id, content, *, tts=False, embeds=None, nonce=None, allowed_mentions=None, message_reference=None, components=None): r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) @@ -39,12 +49,59 @@ def send_message(self, channel_id, content, *, tts=False, embeds=None, if message_reference: payload['message_reference'] = message_reference - if components is not None: + if components: payload['components'] = components return self.request(r, json=payload) +def send_files( + self, + channel_id, *, + files, + content=None, tts=False, embed=None, embeds=None, nonce=None, allowed_mentions=None, message_reference=None, + components=None +): + r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + form = [] + + payload = {'tts': tts} + if content: + payload['content'] = content + if embed: + payload['embed'] = embed + if embeds: + payload['embeds'] = embeds + if nonce: + payload['nonce'] = nonce + if allowed_mentions: + payload['allowed_mentions'] = allowed_mentions + if message_reference: + payload['message_reference'] = message_reference + if components: + payload['components'] = components + + form.append({'name': 'payload_json', 'value': to_json(payload)}) + if len(files) == 1: + file = files[0] + form.append({ + 'name': 'file', + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream' + }) + else: + for index, file in enumerate(files): + form.append({ + 'name': 'file%s' % index, + 'value': file.fp, + 'filename': file.filename, + 'content_type': 'application/octet-stream' + }) + + return self.request(r, form=form, files=files) + + def interaction_callback(self, interaction_id, interaction_token, callback_type, callback_data=None): r = Route( 'POST', @@ -62,7 +119,16 @@ def interaction_callback(self, interaction_id, interaction_token, callback_type, return self.request(r, json=payload) +def edit_message(self, channel_id, message_id, components=None, **fields): + r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + if components is not None: + fields['components'] = [comp.to_dict() for comp in components] + return self.request(r, json=fields) + + +HTTPClient.send_files = send_files HTTPClient.send_message = send_message +HTTPClient.edit_message = edit_message HTTPClient.interaction_callback = interaction_callback @@ -114,7 +180,7 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N try: data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, content=content, tts=tts, embed=embed, nonce=nonce, - message_reference=reference) + message_reference=reference, components=components) finally: file.close() @@ -126,8 +192,8 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N try: data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, - embed=embed, nonce=nonce, allowed_mentions=allowed_mentions, - message_reference=reference) + embeds=embeds, nonce=nonce, allowed_mentions=allowed_mentions, + message_reference=reference, components=components) finally: for f in files: f.close() @@ -181,8 +247,11 @@ def parse_interaction_create(self, data): 'INTERACTION_CREATE recieved unhandled message component interaction type: %s', data['data']['component_type'] ) + elif data['type'] == InteractionType.MODAL_SUBMIT: + interaction = ModalResponse(message, user, data, self) else: log.debug('INTERACTION_CREATE recieved unhandled interaction type: %s', data['type']) + log.debug(data) interaction = None if interaction: