feature (interactions): Basic button support.
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
from . import data # noqa
|
from . import data # noqa
|
||||||
|
|
||||||
from . import patches
|
|
||||||
|
|
||||||
from .module import module
|
from .module import module
|
||||||
from .lion import Lion
|
from .lion import Lion
|
||||||
from . import blacklists
|
from . import blacklists
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
from .logger import log, logger
|
from .logger import log, logger
|
||||||
|
|
||||||
|
from . import interactions
|
||||||
|
from . import patches
|
||||||
|
|
||||||
from .client import client
|
from .client import client
|
||||||
from .config import conf
|
from .config import conf
|
||||||
from .args import args
|
from .args import args
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from discord import Intents
|
from discord import Intents
|
||||||
from cmdClient.cmdClient import cmdClient
|
from cmdClient.cmdClient import cmdClient
|
||||||
|
|
||||||
|
from . import patches
|
||||||
from .config import conf
|
from .config import conf
|
||||||
from .sharding import shard_number, shard_count
|
from .sharding import shard_number, shard_count
|
||||||
from LionContext import LionContext
|
from LionContext import LionContext
|
||||||
|
|||||||
3
bot/meta/interactions/__init__.py
Normal file
3
bot/meta/interactions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from . import enums
|
||||||
|
from .interactions import _component_interaction_factory, Interaction, ComponentInteraction
|
||||||
|
from .components import *
|
||||||
125
bot/meta/interactions/components.py
Normal file
125
bot/meta/interactions/components.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .enums import ButtonStyle, InteractionType
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Notes:
|
||||||
|
When interaction is sent, add message info
|
||||||
|
Add wait_for to Button and SelectMenu
|
||||||
|
wait_for_interaction for generic
|
||||||
|
listen=True for the listenables, register with a listener
|
||||||
|
Need a deregister then as well
|
||||||
|
|
||||||
|
send(..., components=[ActionRow(Button(...))])
|
||||||
|
|
||||||
|
Automatically ack interaction? DEFERRED_UPDATE_MESSAGE
|
||||||
|
|
||||||
|
async def Button.wait_for(timeout=None, ack=False)
|
||||||
|
Blocks until the button is pressed. Returns a ButtonPress (Interaction).
|
||||||
|
def MessageComponent.add_callback(timeout)
|
||||||
|
Adds an async callback function to the Component.
|
||||||
|
|
||||||
|
Construct the response independent of the original component.
|
||||||
|
Original component has a convenience wait_for that runs wait_for_interaction(custom_id=self.custom_id)...
|
||||||
|
The callback? Just add a wait_for
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MessageComponent:
|
||||||
|
_type = None
|
||||||
|
|
||||||
|
def __init_(self, *args, **kwargs):
|
||||||
|
self.message = None
|
||||||
|
|
||||||
|
def listen(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ActionRow(MessageComponent):
|
||||||
|
_type = 1
|
||||||
|
|
||||||
|
def __init__(self, *components):
|
||||||
|
self.components = components
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
data = {
|
||||||
|
"type": self._type,
|
||||||
|
"components": [comp.to_dict() for comp in self.components]
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class Button(MessageComponent):
|
||||||
|
_type = 2
|
||||||
|
|
||||||
|
def __init__(self, label, style=ButtonStyle.PRIMARY, custom_id=None, url=None, emoji=None, disabled=False):
|
||||||
|
if style == ButtonStyle.LINK:
|
||||||
|
if url is None:
|
||||||
|
raise ValueError("Link buttons must have a url")
|
||||||
|
custom_id = None
|
||||||
|
elif custom_id is None:
|
||||||
|
custom_id = uuid.uuid4()
|
||||||
|
|
||||||
|
self.label = label
|
||||||
|
self.style = style
|
||||||
|
self.custom_id = custom_id
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
self.emoji = emoji
|
||||||
|
self.disabled = disabled
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
data = {
|
||||||
|
"type": self._type,
|
||||||
|
"label": self.label,
|
||||||
|
"style": int(self.style)
|
||||||
|
}
|
||||||
|
if self.style == ButtonStyle.LINK:
|
||||||
|
data['url'] = self.url
|
||||||
|
else:
|
||||||
|
data['custom_id'] = self.custom_id
|
||||||
|
if self.emoji is not None:
|
||||||
|
# TODO: This only supports PartialEmoji, not Emoji
|
||||||
|
data['emoji'] = self.emoji.to_dict()
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def wait_for_press(self, timeout=None, check=None):
|
||||||
|
from meta import client
|
||||||
|
|
||||||
|
def _check(interaction):
|
||||||
|
valid = True
|
||||||
|
print(interaction.custom_id)
|
||||||
|
valid = valid and interaction.interaction_type == InteractionType.MESSAGE_COMPONENT
|
||||||
|
valid = valid and interaction.custom_id == self.custom_id
|
||||||
|
valid = valid and (check is None or check(interaction))
|
||||||
|
return valid
|
||||||
|
|
||||||
|
return await client.wait_for('interaction_create', timeout=timeout, check=_check)
|
||||||
|
|
||||||
|
def on_press(self, timeout=None, repeat=True, pass_args=(), pass_kwargs={}):
|
||||||
|
def wrapper(func):
|
||||||
|
async def wrapped():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
button_press = await self.wait_for_press(timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
asyncio.create_task(func(button_press, *pass_args, **pass_kwargs))
|
||||||
|
if not repeat:
|
||||||
|
break
|
||||||
|
future = asyncio.create_task(wrapped())
|
||||||
|
return future
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class SelectMenu(MessageComponent):
|
||||||
|
_type = 3
|
||||||
|
|
||||||
|
|
||||||
|
# MessageComponent listener
|
||||||
|
live_components = {}
|
||||||
28
bot/meta/interactions/enums.py
Normal file
28
bot/meta/interactions/enums.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionType(IntEnum):
|
||||||
|
PING = 1
|
||||||
|
APPLICATION_COMMAND = 2
|
||||||
|
MESSAGE_COMPONENT = 3
|
||||||
|
APPLICATION_COMMAND_AUTOCOMPLETE = 4
|
||||||
|
MODAL_SUBMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentType(IntEnum):
|
||||||
|
ACTIONROW = 1
|
||||||
|
BUTTON = 2
|
||||||
|
SELECTMENU = 3
|
||||||
|
TEXTINPUT = 4
|
||||||
|
|
||||||
|
|
||||||
|
class ButtonStyle(IntEnum):
|
||||||
|
PRIMARY = 1
|
||||||
|
SECONDARY = 2
|
||||||
|
SUCCESS = 3
|
||||||
|
DANGER = 4
|
||||||
|
LINK = 5
|
||||||
|
|
||||||
|
|
||||||
|
class InteractionCallback(IntEnum):
|
||||||
|
DEFERRED_UPDATE_MESSAGE = 6
|
||||||
52
bot/meta/interactions/interactions.py
Normal file
52
bot/meta/interactions/interactions.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import asyncio
|
||||||
|
from .enums import ComponentType, InteractionType, InteractionCallback
|
||||||
|
|
||||||
|
|
||||||
|
class Interaction:
|
||||||
|
__slots__ = (
|
||||||
|
'id',
|
||||||
|
'token',
|
||||||
|
'_state'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def callback_deferred(self):
|
||||||
|
return await self._state.http.interaction_callback(self.id, self.token, InteractionCallback.DEFERRED_UPDATE_MESSAGE)
|
||||||
|
|
||||||
|
def ack(self):
|
||||||
|
asyncio.create_task(self.callback_deferred())
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentInteraction(Interaction):
|
||||||
|
interaction_type = InteractionType.MESSAGE_COMPONENT
|
||||||
|
# TODO: Slots
|
||||||
|
|
||||||
|
def __init__(self, message, user, data, state):
|
||||||
|
self.message = message
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
self._state = state
|
||||||
|
|
||||||
|
self._from_data(data)
|
||||||
|
|
||||||
|
def _from_data(self, data):
|
||||||
|
self.id = data['id']
|
||||||
|
self.token = data['token']
|
||||||
|
self.application_id = data['application_id']
|
||||||
|
|
||||||
|
component_data = data['data']
|
||||||
|
|
||||||
|
self.component_type = ComponentType(component_data['component_type'])
|
||||||
|
self.custom_id = component_data.get('custom_id', None)
|
||||||
|
|
||||||
|
|
||||||
|
class ButtonPress(ComponentInteraction):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
||||||
|
def _component_interaction_factory(data):
|
||||||
|
component_type = data['data']['component_type']
|
||||||
|
|
||||||
|
if component_type == ComponentType.BUTTON:
|
||||||
|
return ButtonPress
|
||||||
|
else:
|
||||||
|
return None
|
||||||
@@ -1,14 +1,23 @@
|
|||||||
"""
|
"""
|
||||||
Temporary patches for the discord.py library to support new features of the discord API.
|
Temporary patches for the discord.py library to support new features of the discord API.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from discord.state import ConnectionState
|
||||||
from discord.http import Route, HTTPClient
|
from discord.http import Route, HTTPClient
|
||||||
from discord.abc import Messageable
|
from discord.abc import Messageable
|
||||||
from discord.utils import InvalidArgument
|
from discord.utils import InvalidArgument, _get_as_snowflake
|
||||||
from discord import File, AllowedMentions
|
from discord import File, AllowedMentions, Member, User, Message
|
||||||
|
|
||||||
|
from .interactions import _component_interaction_factory
|
||||||
|
from .interactions.enums import InteractionType
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def send_message(self, channel_id, content, *, tts=False, embeds=None,
|
def send_message(self, channel_id, content, *, tts=False, embeds=None,
|
||||||
nonce=None, allowed_mentions=None, message_reference=None):
|
nonce=None, allowed_mentions=None, message_reference=None, components=None):
|
||||||
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
|
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
|
||||||
payload = {}
|
payload = {}
|
||||||
|
|
||||||
@@ -30,16 +39,37 @@ def send_message(self, channel_id, content, *, tts=False, embeds=None,
|
|||||||
if message_reference:
|
if message_reference:
|
||||||
payload['message_reference'] = message_reference
|
payload['message_reference'] = message_reference
|
||||||
|
|
||||||
|
if components is not None:
|
||||||
|
payload['components'] = components
|
||||||
|
|
||||||
|
return self.request(r, json=payload)
|
||||||
|
|
||||||
|
|
||||||
|
def interaction_callback(self, interaction_id, interaction_token, callback_type, callback_data=None):
|
||||||
|
r = Route(
|
||||||
|
'POST',
|
||||||
|
'/interactions/{interaction_id}/{interaction_token}/callback',
|
||||||
|
interaction_id=interaction_id,
|
||||||
|
interaction_token=interaction_token
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {}
|
||||||
|
|
||||||
|
payload['type'] = int(callback_type)
|
||||||
|
if callback_data:
|
||||||
|
payload['data'] = callback_data
|
||||||
|
|
||||||
return self.request(r, json=payload)
|
return self.request(r, json=payload)
|
||||||
|
|
||||||
|
|
||||||
HTTPClient.send_message = send_message
|
HTTPClient.send_message = send_message
|
||||||
|
HTTPClient.interaction_callback = interaction_callback
|
||||||
|
|
||||||
|
|
||||||
async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=None,
|
async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=None,
|
||||||
files=None, delete_after=None, nonce=None,
|
files=None, delete_after=None, nonce=None,
|
||||||
allowed_mentions=None, reference=None,
|
allowed_mentions=None, reference=None,
|
||||||
mention_author=None):
|
mention_author=None, components=None):
|
||||||
|
|
||||||
channel = await self._get_channel()
|
channel = await self._get_channel()
|
||||||
state = self._state
|
state = self._state
|
||||||
@@ -53,6 +83,9 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
|
|||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
embeds = [embed.to_dict() for embed in embeds]
|
embeds = [embed.to_dict() for embed in embeds]
|
||||||
|
|
||||||
|
if components is not None:
|
||||||
|
components = [comp.to_dict() for comp in components]
|
||||||
|
|
||||||
if allowed_mentions is not None:
|
if allowed_mentions is not None:
|
||||||
if state.allowed_mentions is not None:
|
if state.allowed_mentions is not None:
|
||||||
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
|
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
|
||||||
@@ -101,7 +134,7 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
|
|||||||
else:
|
else:
|
||||||
data = await state.http.send_message(channel.id, content, tts=tts, embeds=embeds,
|
data = await state.http.send_message(channel.id, content, tts=tts, embeds=embeds,
|
||||||
nonce=nonce, allowed_mentions=allowed_mentions,
|
nonce=nonce, allowed_mentions=allowed_mentions,
|
||||||
message_reference=reference)
|
message_reference=reference, components=components)
|
||||||
|
|
||||||
ret = state.create_message(channel=channel, data=data)
|
ret = state.create_message(channel=channel, data=data)
|
||||||
if delete_after is not None:
|
if delete_after is not None:
|
||||||
@@ -109,3 +142,51 @@ async def send(self, content=None, *, tts=False, embed=None, embeds=None, file=N
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
Messageable.send = send
|
Messageable.send = send
|
||||||
|
|
||||||
|
|
||||||
|
def parse_interaction_create(self, data):
|
||||||
|
self.dispatch('raw_interaction_create', data)
|
||||||
|
|
||||||
|
if (guild_id := data.get('guild_id', None)):
|
||||||
|
guild = self._get_guild(int(guild_id))
|
||||||
|
if guild is None:
|
||||||
|
log.debug('INTERACTION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
guild = None
|
||||||
|
|
||||||
|
if (member_data := data.get('member', None)) is not None:
|
||||||
|
# Construct member
|
||||||
|
# TODO: Theoretical reliance on cached guild
|
||||||
|
user = Member(data=member_data, guild=guild, state=self)
|
||||||
|
else:
|
||||||
|
# Assume user
|
||||||
|
user = self.get_user(_get_as_snowflake(data['user'], 'id')) or User(data=data['user'], state=self)
|
||||||
|
|
||||||
|
message = self._get_message(_get_as_snowflake(data['message'], 'id'))
|
||||||
|
if not message:
|
||||||
|
message_data = data['message']
|
||||||
|
channel, _ = self._get_guild_channel(message_data)
|
||||||
|
message = Message(data=message_data, channel=channel, state=self)
|
||||||
|
if self._messages is not None:
|
||||||
|
self._messages.append(message)
|
||||||
|
|
||||||
|
interaction = None
|
||||||
|
if data['type'] == InteractionType.MESSAGE_COMPONENT:
|
||||||
|
interaction_class = _component_interaction_factory(data)
|
||||||
|
if interaction_class:
|
||||||
|
interaction = interaction_class(message, user, data, self)
|
||||||
|
else:
|
||||||
|
log.debug(
|
||||||
|
'INTERACTION_CREATE recieved unhandled message component interaction type: %s',
|
||||||
|
data['data']['component_type']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.debug('INTERACTION_CREATE recieved unhandled interaction type: %s', data['type'])
|
||||||
|
interaction = None
|
||||||
|
|
||||||
|
if interaction:
|
||||||
|
self.dispatch('interaction_create', interaction)
|
||||||
|
|
||||||
|
|
||||||
|
ConnectionState.parse_interaction_create = parse_interaction_create
|
||||||
Reference in New Issue
Block a user