(interactions): Basic support for modals.

This commit is contained in:
2022-04-19 11:34:34 +03:00
parent e73302d21f
commit 035a295962
6 changed files with 207 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
from . import enums from . import enums
from .interactions import _component_interaction_factory, Interaction, ComponentInteraction from .interactions import _component_interaction_factory, Interaction, ComponentInteraction, ModalResponse
from .components import * from .components import *
from .modals import *

View File

@@ -1,15 +1,23 @@
import asyncio import asyncio
import uuid import uuid
import json
from .enums import ButtonStyle, InteractionType from .enums import ButtonStyle, InteractionType
class MessageComponent: class MessageComponent:
_type = None _type = None
interaction_type = InteractionType.MESSAGE_COMPONENT
def __init_(self, *args, **kwargs): def __init_(self, *args, **kwargs):
self.message = None self.message = None
def to_dict(self):
raise NotImplementedError
def to_json(self):
return json.dumps(self.to_dict())
class ActionRow(MessageComponent): class ActionRow(MessageComponent):
_type = 1 _type = 1
@@ -26,13 +34,15 @@ class ActionRow(MessageComponent):
class AwaitableComponent: class AwaitableComponent:
interaction_type: InteractionType = None
async def wait_for(self, timeout=None, check=None): async def wait_for(self, timeout=None, check=None):
from meta import client from meta import client
def _check(interaction): def _check(interaction):
valid = True valid = True
print(interaction.custom_id) print(interaction.custom_id)
valid = valid and interaction.interaction_type == InteractionType.MESSAGE_COMPONENT valid = valid and interaction.interaction_type == self.interaction_type
valid = valid and interaction.custom_id == self.custom_id valid = valid and interaction.custom_id == self.custom_id
valid = valid and (check is None or check(interaction)) valid = valid and (check is None or check(interaction))
return valid return valid

View File

@@ -24,5 +24,17 @@ class ButtonStyle(IntEnum):
LINK = 5 LINK = 5
class TextInputStyle(IntEnum):
SHORT = 1
PARAGRAPH = 2
class InteractionCallback(IntEnum): class InteractionCallback(IntEnum):
PONG = 1
CHANNEL_MESSAGE_WITH_SOURCE = 4
DEFERRED_CHANNEL_MESSAGE_WITH_SOURCE = 5
DEFERRED_UPDATE_MESSAGE = 6 DEFERRED_UPDATE_MESSAGE = 6
UPDATE_MESSAGE = 7
APPLICATION_COMMAND_AUTOCOMPLETE_RESULT = 8
MODAL = 9

View File

@@ -9,11 +9,23 @@ class Interaction:
'_state' '_state'
) )
async def callback_deferred(self): async def response_deferred(self):
return await self._state.http.interaction_callback(self.id, self.token, InteractionCallback.DEFERRED_UPDATE_MESSAGE) return await self._state.http.interaction_callback(
self.id,
self.token,
InteractionCallback.DEFERRED_UPDATE_MESSAGE
)
async def response_modal(self, modal):
return await self._state.http.interaction_callback(
self.id,
self.token,
InteractionCallback.MODAL,
modal.to_dict()
)
def ack(self): def ack(self):
asyncio.create_task(self.callback_deferred()) asyncio.create_task(self.response_deferred())
class ComponentInteraction(Interaction): class ComponentInteraction(Interaction):
@@ -51,6 +63,44 @@ class Selection(ComponentInteraction):
self.values = data['data']['values'] self.values = data['data']['values']
class ModalResponse(Interaction):
__slots__ = (
'message',
'user',
'_state'
'id',
'token',
'application_id',
'custom_id',
'values'
)
interaction_type = InteractionType.MODAL_SUBMIT
def __init__(self, message, user, data, state):
self.message = message
self.user = user
self._state = state
self._from_data(data)
def _from_data(self, data):
self.id = data['id']
self.token = data['token']
self.application_id = data['application_id']
component_data = data['data']
self.custom_id = component_data.get('custom_id', None)
values = {}
for row in component_data['components']:
for component in row['components']:
values[component['custom_id']] = component['value']
self.values = values
def _component_interaction_factory(data): def _component_interaction_factory(data):
component_type = data['data']['component_type'] component_type = data['data']['component_type']

View File

@@ -0,0 +1,54 @@
import uuid
from .enums import TextInputStyle, InteractionType
from .components import AwaitableComponent
class Modal(AwaitableComponent):
interaction_type = InteractionType.MODAL_SUBMIT
def __init__(self, title, *components, custom_id=None):
self.custom_id = custom_id or str(uuid.uuid4())
self.title = title
self.components = components
def to_dict(self):
data = {
'title': self.title,
'custom_id': self.custom_id,
'components': [comp.to_dict() for comp in self.components]
}
return data
class TextInput:
_type = 4
def __init__(
self,
label, placeholder=None, value=None, required=False,
style=TextInputStyle.SHORT, min_length=None, max_length=None,
custom_id=None
):
self.custom_id = custom_id or str(uuid.uuid4())
self.label = label
self.placeholder = placeholder
self.value = value
self.required = required
self.style = style
self.min_length = min_length
self.max_length = max_length
def to_dict(self):
data = {
'type': self._type,
'custom_id': self.custom_id,
'style': int(self.style),
'label': self.label,
}
for key in ('min_length', 'max_length', 'required', 'value', 'placeholder'):
if (value := getattr(self, key)) is not None:
data[key] = value
return data

View File

@@ -3,19 +3,29 @@ Temporary patches for the discord.py library to support new features of the disc
""" """
import logging import logging
from json import JSONEncoder
from discord.state import ConnectionState from discord.state import ConnectionState
from discord.http import Route, HTTPClient from discord.http import Route, HTTPClient
from discord.abc import Messageable from discord.abc import Messageable
from discord.utils import InvalidArgument, _get_as_snowflake from discord.utils import InvalidArgument, _get_as_snowflake, to_json
from discord import File, AllowedMentions, Member, User, Message from discord import File, AllowedMentions, Member, User, Message
from .interactions import _component_interaction_factory from .interactions import _component_interaction_factory, ModalResponse
from .interactions.enums import InteractionType from .interactions.enums import InteractionType
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _default(self, obj):
return getattr(obj.__class__, "to_json", _default.default)(obj)
_default.default = JSONEncoder().default
JSONEncoder.default = _default
def send_message(self, channel_id, content, *, tts=False, embeds=None, def send_message(self, channel_id, content, *, tts=False, embeds=None,
nonce=None, allowed_mentions=None, message_reference=None, components=None): nonce=None, allowed_mentions=None, message_reference=None, components=None):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
@@ -39,12 +49,59 @@ def send_message(self, channel_id, content, *, tts=False, embeds=None,
if message_reference: if message_reference:
payload['message_reference'] = message_reference payload['message_reference'] = message_reference
if components is not None: if components:
payload['components'] = components payload['components'] = components
return self.request(r, json=payload) return self.request(r, json=payload)
def send_files(
self,
channel_id, *,
files,
content=None, tts=False, embed=None, embeds=None, nonce=None, allowed_mentions=None, message_reference=None,
components=None
):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
form = []
payload = {'tts': tts}
if content:
payload['content'] = content
if embed:
payload['embed'] = embed
if embeds:
payload['embeds'] = embeds
if nonce:
payload['nonce'] = nonce
if allowed_mentions:
payload['allowed_mentions'] = allowed_mentions
if message_reference:
payload['message_reference'] = message_reference
if components:
payload['components'] = components
form.append({'name': 'payload_json', 'value': to_json(payload)})
if len(files) == 1:
file = files[0]
form.append({
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream'
})
else:
for index, file in enumerate(files):
form.append({
'name': 'file%s' % index,
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream'
})
return self.request(r, form=form, files=files)
def interaction_callback(self, interaction_id, interaction_token, callback_type, callback_data=None): def interaction_callback(self, interaction_id, interaction_token, callback_type, callback_data=None):
r = Route( r = Route(
'POST', 'POST',
@@ -62,7 +119,16 @@ def interaction_callback(self, interaction_id, interaction_token, callback_type,
return self.request(r, json=payload) return self.request(r, json=payload)
def edit_message(self, channel_id, message_id, components=None, **fields):
r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
if components is not None:
fields['components'] = [comp.to_dict() for comp in components]
return self.request(r, json=fields)
HTTPClient.send_files = send_files
HTTPClient.send_message = send_message HTTPClient.send_message = send_message
HTTPClient.edit_message = edit_message
HTTPClient.interaction_callback = interaction_callback HTTPClient.interaction_callback = interaction_callback
@@ -114,7 +180,7 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
try: try:
data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions,
content=content, tts=tts, embed=embed, nonce=nonce, content=content, tts=tts, embed=embed, nonce=nonce,
message_reference=reference) message_reference=reference, components=components)
finally: finally:
file.close() file.close()
@@ -126,8 +192,8 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
try: try:
data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, data = await state.http.send_files(channel.id, files=files, content=content, tts=tts,
embed=embed, nonce=nonce, allowed_mentions=allowed_mentions, embeds=embeds, nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference) message_reference=reference, components=components)
finally: finally:
for f in files: for f in files:
f.close() f.close()
@@ -181,8 +247,11 @@ def parse_interaction_create(self, data):
'INTERACTION_CREATE recieved unhandled message component interaction type: %s', 'INTERACTION_CREATE recieved unhandled message component interaction type: %s',
data['data']['component_type'] data['data']['component_type']
) )
elif data['type'] == InteractionType.MODAL_SUBMIT:
interaction = ModalResponse(message, user, data, self)
else: else:
log.debug('INTERACTION_CREATE recieved unhandled interaction type: %s', data['type']) log.debug('INTERACTION_CREATE recieved unhandled interaction type: %s', data['type'])
log.debug(data)
interaction = None interaction = None
if interaction: if interaction: