feature (interactions): Basic button support.

This commit is contained in:
2022-04-18 12:53:17 +03:00
parent f753271403
commit cc7c988007
8 changed files with 299 additions and 7 deletions

View File

@@ -1,7 +1,5 @@
from . import data # noqa
from . import patches
from .module import module
from .lion import Lion
from . import blacklists

View File

@@ -1,4 +1,8 @@
from .logger import log, logger
from . import interactions
from . import patches
from .client import client
from .config import conf
from .args import args

View File

@@ -1,6 +1,7 @@
from discord import Intents
from cmdClient.cmdClient import cmdClient
from . import patches
from .config import conf
from .sharding import shard_number, shard_count
from LionContext import LionContext

View File

@@ -0,0 +1,3 @@
from . import enums
from .interactions import _component_interaction_factory, Interaction, ComponentInteraction
from .components import *

View File

@@ -0,0 +1,125 @@
import asyncio
import uuid
from .enums import ButtonStyle, InteractionType
"""
Notes:
When interaction is sent, add message info
Add wait_for to Button and SelectMenu
wait_for_interaction for generic
listen=True for the listenables, register with a listener
Need a deregister then as well
send(..., components=[ActionRow(Button(...))])
Automatically ack interaction? DEFERRED_UPDATE_MESSAGE
async def Button.wait_for(timeout=None, ack=False)
Blocks until the button is pressed. Returns a ButtonPress (Interaction).
def MessageComponent.add_callback(timeout)
Adds an async callback function to the Component.
Construct the response independent of the original component.
Original component has a convenience wait_for that runs wait_for_interaction(custom_id=self.custom_id)...
The callback? Just add a wait_for
"""
class MessageComponent:
_type = None
def __init_(self, *args, **kwargs):
self.message = None
def listen(self):
...
def close(self):
...
class ActionRow(MessageComponent):
_type = 1
def __init__(self, *components):
self.components = components
def to_dict(self):
data = {
"type": self._type,
"components": [comp.to_dict() for comp in self.components]
}
return data
class Button(MessageComponent):
_type = 2
def __init__(self, label, style=ButtonStyle.PRIMARY, custom_id=None, url=None, emoji=None, disabled=False):
if style == ButtonStyle.LINK:
if url is None:
raise ValueError("Link buttons must have a url")
custom_id = None
elif custom_id is None:
custom_id = uuid.uuid4()
self.label = label
self.style = style
self.custom_id = custom_id
self.url = url
self.emoji = emoji
self.disabled = disabled
def to_dict(self):
data = {
"type": self._type,
"label": self.label,
"style": int(self.style)
}
if self.style == ButtonStyle.LINK:
data['url'] = self.url
else:
data['custom_id'] = self.custom_id
if self.emoji is not None:
# TODO: This only supports PartialEmoji, not Emoji
data['emoji'] = self.emoji.to_dict()
return data
async def wait_for_press(self, timeout=None, check=None):
from meta import client
def _check(interaction):
valid = True
print(interaction.custom_id)
valid = valid and interaction.interaction_type == InteractionType.MESSAGE_COMPONENT
valid = valid and interaction.custom_id == self.custom_id
valid = valid and (check is None or check(interaction))
return valid
return await client.wait_for('interaction_create', timeout=timeout, check=_check)
def on_press(self, timeout=None, repeat=True, pass_args=(), pass_kwargs={}):
def wrapper(func):
async def wrapped():
while True:
try:
button_press = await self.wait_for_press(timeout=timeout)
except asyncio.TimeoutError:
break
asyncio.create_task(func(button_press, *pass_args, **pass_kwargs))
if not repeat:
break
future = asyncio.create_task(wrapped())
return future
return wrapper
class SelectMenu(MessageComponent):
_type = 3
# MessageComponent listener
live_components = {}

View File

@@ -0,0 +1,28 @@
from enum import IntEnum
class InteractionType(IntEnum):
PING = 1
APPLICATION_COMMAND = 2
MESSAGE_COMPONENT = 3
APPLICATION_COMMAND_AUTOCOMPLETE = 4
MODAL_SUBMIT = 5
class ComponentType(IntEnum):
ACTIONROW = 1
BUTTON = 2
SELECTMENU = 3
TEXTINPUT = 4
class ButtonStyle(IntEnum):
PRIMARY = 1
SECONDARY = 2
SUCCESS = 3
DANGER = 4
LINK = 5
class InteractionCallback(IntEnum):
DEFERRED_UPDATE_MESSAGE = 6

View File

@@ -0,0 +1,52 @@
import asyncio
from .enums import ComponentType, InteractionType, InteractionCallback
class Interaction:
__slots__ = (
'id',
'token',
'_state'
)
async def callback_deferred(self):
return await self._state.http.interaction_callback(self.id, self.token, InteractionCallback.DEFERRED_UPDATE_MESSAGE)
def ack(self):
asyncio.create_task(self.callback_deferred())
class ComponentInteraction(Interaction):
interaction_type = InteractionType.MESSAGE_COMPONENT
# TODO: Slots
def __init__(self, message, user, data, state):
self.message = message
self.user = user
self._state = state
self._from_data(data)
def _from_data(self, data):
self.id = data['id']
self.token = data['token']
self.application_id = data['application_id']
component_data = data['data']
self.component_type = ComponentType(component_data['component_type'])
self.custom_id = component_data.get('custom_id', None)
class ButtonPress(ComponentInteraction):
__slots__ = ()
def _component_interaction_factory(data):
component_type = data['data']['component_type']
if component_type == ComponentType.BUTTON:
return ButtonPress
else:
return None

View File

@@ -1,14 +1,23 @@
"""
Temporary patches for the discord.py library to support new features of the discord API.
"""
import logging
from discord.state import ConnectionState
from discord.http import Route, HTTPClient
from discord.abc import Messageable
from discord.utils import InvalidArgument
from discord import File, AllowedMentions
from discord.utils import InvalidArgument, _get_as_snowflake
from discord import File, AllowedMentions, Member, User, Message
from .interactions import _component_interaction_factory
from .interactions.enums import InteractionType
log = logging.getLogger(__name__)
def send_message(self, channel_id, content, *, tts=False, embeds=None,
nonce=None, allowed_mentions=None, message_reference=None):
nonce=None, allowed_mentions=None, message_reference=None, components=None):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
payload = {}
@@ -30,16 +39,37 @@ def send_message(self, channel_id, content, *, tts=False, embeds=None,
if message_reference:
payload['message_reference'] = message_reference
if components is not None:
payload['components'] = components
return self.request(r, json=payload)
def interaction_callback(self, interaction_id, interaction_token, callback_type, callback_data=None):
r = Route(
'POST',
'/interactions/{interaction_id}/{interaction_token}/callback',
interaction_id=interaction_id,
interaction_token=interaction_token
)
payload = {}
payload['type'] = int(callback_type)
if callback_data:
payload['data'] = callback_data
return self.request(r, json=payload)
HTTPClient.send_message = send_message
HTTPClient.interaction_callback = interaction_callback
async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=None,
files=None, delete_after=None, nonce=None,
allowed_mentions=None, reference=None,
mention_author=None):
mention_author=None, components=None):
channel = await self._get_channel()
state = self._state
@@ -53,6 +83,9 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
if embeds is not None:
embeds = [embed.to_dict() for embed in embeds]
if components is not None:
components = [comp.to_dict() for comp in components]
if allowed_mentions is not None:
if state.allowed_mentions is not None:
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
@@ -101,7 +134,7 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
else:
data = await state.http.send_message(channel.id, content, tts=tts, embeds=embeds,
nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference)
message_reference=reference, components=components)
ret = state.create_message(channel=channel, data=data)
if delete_after is not None:
@@ -109,3 +142,51 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
return ret
Messageable.send = send
def parse_interaction_create(self, data):
self.dispatch('raw_interaction_create', data)
if (guild_id := data.get('guild_id', None)):
guild = self._get_guild(int(guild_id))
if guild is None:
log.debug('INTERACTION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return
else:
guild = None
if (member_data := data.get('member', None)) is not None:
# Construct member
# TODO: Theoretical reliance on cached guild
user = Member(data=member_data, guild=guild, state=self)
else:
# Assume user
user = self.get_user(_get_as_snowflake(data['user'], 'id')) or User(data=data['user'], state=self)
message = self._get_message(_get_as_snowflake(data['message'], 'id'))
if not message:
message_data = data['message']
channel, _ = self._get_guild_channel(message_data)
message = Message(data=message_data, channel=channel, state=self)
if self._messages is not None:
self._messages.append(message)
interaction = None
if data['type'] == InteractionType.MESSAGE_COMPONENT:
interaction_class = _component_interaction_factory(data)
if interaction_class:
interaction = interaction_class(message, user, data, self)
else:
log.debug(
'INTERACTION_CREATE recieved unhandled message component interaction type: %s',
data['data']['component_type']
)
else:
log.debug('INTERACTION_CREATE recieved unhandled interaction type: %s', data['type'])
interaction = None
if interaction:
self.dispatch('interaction_create', interaction)
ConnectionState.parse_interaction_create = parse_interaction_create