Compare commits
15 Commits
11381f8e80
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 407938bf43 | |||
| 5a435b6a5d | |||
| b74c7cda48 | |||
| daa9eb671b | |||
| cf363fd738 | |||
| 9a0d4090f5 | |||
| d05dc81667 | |||
| a4dd540f44 | |||
| 63e5dd1796 | |||
| daa370e09f | |||
| 58c0873987 | |||
| 5de3fd77bf | |||
| 873def8456 | |||
| c3c8baa4b2 | |||
| 850c5d7abb |
12
.gitmodules
vendored
Normal file
12
.gitmodules
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[submodule "src/modules/voicefix"]
|
||||
path = src/modules/voicefix
|
||||
url = https://github.com/Intery/StudyLion-voicefix.git
|
||||
[submodule "src/modules/streamalerts"]
|
||||
path = src/modules/streamalerts
|
||||
url = https://github.com/Intery/StudyLion-streamalerts.git
|
||||
[submodule "src/modules/messagelogger"]
|
||||
path = src/modules/messagelogger
|
||||
url = https://git.thewisewolf.dev/HoloTech/discord-messagelogger-plugin.git
|
||||
[submodule "src/modules/voicelog"]
|
||||
path = src/modules/voicelog
|
||||
url = https://git.thewisewolf.dev/HoloTech/voicelog-plugin.git
|
||||
@@ -50,6 +50,59 @@ CREATE TABLE channel_links(
|
||||
PRIMARY KEY (linkid, channelid)
|
||||
);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Message Logging {{{
|
||||
|
||||
CREATE TABLE logging_guilds(
|
||||
guildid BIGINT PRIMARY KEY,
|
||||
webhook_url TEXT,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TRIGGER logging_guilds_timestamp BEFORE UPDATE ON logging_guilds
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
|
||||
CREATE TABLE logged_messages(
|
||||
messageid BIGINT PRIMARY KEY,
|
||||
guildid BIGINT NOT NULL REFERENCES logging_guilds ON DELETE CASCADE,
|
||||
channelid BIGINT NOT NULL,
|
||||
userid BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL,
|
||||
deleted_at TIMESTAMPTZ,
|
||||
_timestamp TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TRIGGER logged_messages_timestamp BEFORE UPDATE ON logged_messages
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
CREATE TABLE message_states(
|
||||
stateid INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
|
||||
messageid BIGINT NOT NULL REFERENCES logged_messages ON DELETE CASCADE,
|
||||
content TEXT NOT NULL,
|
||||
embeds_raw TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX message_states_messageid ON message_states (messageid);
|
||||
|
||||
CREATE TABLE logged_attachments(
|
||||
attachment_id BIGINT PRIMARY KEY,
|
||||
proxy_url TEXT NOT NULL,
|
||||
url TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
filesize INTEGER NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
permalink TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE logged_messages_attachments(
|
||||
stateid INTEGER NOT NULL REFERENCES message_states(stateid) ON DELETE CASCADE,
|
||||
attachment_id BIGINT NOT NULL REFERENCES logged_attachments(attachment_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX logged_messages_attachments_stateid ON logged_messages_attachments (stateid);
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
aiohttp==3.7.4.post0
|
||||
cachetools==4.2.2
|
||||
configparser==5.0.2
|
||||
aiohttp
|
||||
cachetools
|
||||
configparser
|
||||
discord.py [voice]
|
||||
iso8601==0.1.16
|
||||
iso8601
|
||||
psycopg[pool]
|
||||
pytz==2021.1
|
||||
pytz
|
||||
twitchAPI
|
||||
|
||||
3
src/babel/__init__.py
Normal file
3
src/babel/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator
|
||||
|
||||
babel = LocalBabel('babel')
|
||||
81
src/babel/enums.py
Normal file
81
src/babel/enums.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from enum import Enum
|
||||
from . import babel
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class LocaleMap(Enum):
|
||||
american_english = 'en-US'
|
||||
british_english = 'en-GB'
|
||||
bulgarian = 'bg'
|
||||
chinese = 'zh-CN'
|
||||
taiwan_chinese = 'zh-TW'
|
||||
croatian = 'hr'
|
||||
czech = 'cs'
|
||||
danish = 'da'
|
||||
dutch = 'nl'
|
||||
finnish = 'fi'
|
||||
french = 'fr'
|
||||
german = 'de'
|
||||
greek = 'el'
|
||||
hindi = 'hi'
|
||||
hungarian = 'hu'
|
||||
italian = 'it'
|
||||
japanese = 'ja'
|
||||
korean = 'ko'
|
||||
lithuanian = 'lt'
|
||||
norwegian = 'no'
|
||||
polish = 'pl'
|
||||
brazil_portuguese = 'pt-BR'
|
||||
romanian = 'ro'
|
||||
russian = 'ru'
|
||||
spain_spanish = 'es-ES'
|
||||
swedish = 'sv-SE'
|
||||
thai = 'th'
|
||||
turkish = 'tr'
|
||||
ukrainian = 'uk'
|
||||
vietnamese = 'vi'
|
||||
hebrew = 'he-IL'
|
||||
|
||||
|
||||
# Original Discord names
|
||||
locale_names = {
|
||||
'id': (_p('localenames|locale:id', "Indonesian"), "Bahasa Indonesia"),
|
||||
'da': (_p('localenames|locale:da', "Danish"), "Dansk"),
|
||||
'de': (_p('localenames|locale:de', "German"), "Deutsch"),
|
||||
'en-GB': (_p('localenames|locale:en-GB', "English, UK"), "English, UK"),
|
||||
'en-US': (_p('localenames|locale:en-US', "English, US"), "English, US"),
|
||||
'es-ES': (_p('localenames|locale:es-ES', "Spanish"), "Español"),
|
||||
'fr': (_p('localenames|locale:fr', "French"), "Français"),
|
||||
'hr': (_p('localenames|locale:hr', "Croatian"), "Hrvatski"),
|
||||
'it': (_p('localenames|locale:it', "Italian"), "Italiano"),
|
||||
'lt': (_p('localenames|locale:lt', "Lithuanian"), "Lietuviškai"),
|
||||
'hu': (_p('localenames|locale:hu', "Hungarian"), "Magyar"),
|
||||
'nl': (_p('localenames|locale:nl', "Dutch"), "Nederlands"),
|
||||
'no': (_p('localenames|locale:no', "Norwegian"), "Norsk"),
|
||||
'pl': (_p('localenames|locale:pl', "Polish"), "Polski"),
|
||||
'pt-BR': (_p('localenames|locale:pt-BR', "Portuguese, Brazilian"), "Português do Brasil"),
|
||||
'ro': (_p('localenames|locale:ro', "Romanian, Romania"), "Română"),
|
||||
'fi': (_p('localenames|locale:fi', "Finnish"), "Suomi"),
|
||||
'sv-SE': (_p('localenames|locale:sv-SE', "Swedish"), "Svenska"),
|
||||
'vi': (_p('localenames|locale:vi', "Vietnamese"), "Tiếng Việt"),
|
||||
'tr': (_p('localenames|locale:tr', "Turkish"), "Türkçe"),
|
||||
'cs': (_p('localenames|locale:cs', "Czech"), "Čeština"),
|
||||
'el': (_p('localenames|locale:el', "Greek"), "Ελληνικά"),
|
||||
'bg': (_p('localenames|locale:bg', "Bulgarian"), "български"),
|
||||
'ru': (_p('localenames|locale:ru', "Russian"), "Pусский"),
|
||||
'uk': (_p('localenames|locale:uk', "Ukrainian"), "Українська"),
|
||||
'hi': (_p('localenames|locale:hi', "Hindi"), "हिन्दी"),
|
||||
'th': (_p('localenames|locale:th', "Thai"), "ไทย"),
|
||||
'zh-CN': (_p('localenames|locale:zh-CN', "Chinese, China"), "中文"),
|
||||
'ja': (_p('localenames|locale:ja', "Japanese"), "日本語"),
|
||||
'zh-TW': (_p('localenames|locale:zh-TW', "Chinese, Taiwan"), "繁體中文"),
|
||||
'ko': (_p('localenames|locale:ko', "Korean"), "한국어"),
|
||||
}
|
||||
|
||||
# More names for languages not supported by Discord
|
||||
locale_names |= {
|
||||
'he': (_p('localenames|locale:he', "Hebrew"), "Hebrew"),
|
||||
'he-IL': (_p('localenames|locale:he-IL', "Hebrew"), "Hebrew"),
|
||||
'ceaser': (_p('localenames|locale:test', "Test Language"), "dfbtfs"),
|
||||
}
|
||||
108
src/babel/translator.py
Normal file
108
src/babel/translator.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
|
||||
import gettext
|
||||
|
||||
from discord.app_commands import Translator, locale_str
|
||||
from discord.enums import Locale
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SOURCE_LOCALE = 'en_GB'
|
||||
ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE)
|
||||
ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore
|
||||
|
||||
null = gettext.NullTranslations()
|
||||
|
||||
|
||||
class LeoBabel(Translator):
|
||||
def __init__(self):
|
||||
self.supported_locales = {loc.name for loc in Locale}
|
||||
self.supported_domains = {}
|
||||
self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator
|
||||
|
||||
async def load(self):
|
||||
pass
|
||||
|
||||
async def unload(self):
|
||||
self.translators.clear()
|
||||
|
||||
def get_translator(self, locale: Optional[str], domain):
|
||||
return null
|
||||
|
||||
def t(self, lazystr, locale=None):
|
||||
return lazystr._translate_with(null)
|
||||
|
||||
async def translate(self, string: locale_str, locale: Locale, context):
|
||||
if not isinstance(string, LazyStr):
|
||||
return string
|
||||
else:
|
||||
return string.message
|
||||
|
||||
ctx_translator.set(LeoBabel())
|
||||
|
||||
class Method(Enum):
|
||||
GETTEXT = 'gettext'
|
||||
NGETTEXT = 'ngettext'
|
||||
PGETTEXT = 'pgettext'
|
||||
NPGETTEXT = 'npgettext'
|
||||
|
||||
|
||||
class LocalBabel:
|
||||
def __init__(self, domain):
|
||||
self.domain = domain
|
||||
|
||||
@property
|
||||
def methods(self):
|
||||
return (self._, self._n, self._p, self._np)
|
||||
|
||||
def _(self, message):
|
||||
return LazyStr(Method.GETTEXT, message, domain=self.domain)
|
||||
|
||||
def _n(self, singular, plural, n):
|
||||
return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain)
|
||||
|
||||
def _p(self, context, message):
|
||||
return LazyStr(Method.PGETTEXT, context, message, domain=self.domain)
|
||||
|
||||
def _np(self, context, singular, plural, n):
|
||||
return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain)
|
||||
|
||||
|
||||
class LazyStr(locale_str):
|
||||
__slots__ = ('method', 'args', 'domain', 'locale')
|
||||
|
||||
def __init__(self, method, *args, locale=None, domain=None):
|
||||
self.method = method
|
||||
self.args = args
|
||||
self.domain = domain
|
||||
self.locale = locale
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return self._translate_with(null)
|
||||
|
||||
@property
|
||||
def extras(self):
|
||||
return {'locale': self.locale, 'domain': self.domain}
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def _translate_with(self, translator: gettext.GNUTranslations):
|
||||
method = getattr(translator, self.method.value)
|
||||
return method(*self.args)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})'
|
||||
|
||||
def __eq__(self, obj: object) -> bool:
|
||||
return isinstance(obj, locale_str) and self.message == obj.message
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.args)
|
||||
20
src/babel/utils.py
Normal file
20
src/babel/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .translator import ctx_translator
|
||||
from . import babel
|
||||
|
||||
_, _p, _np = babel._, babel._p, babel._np
|
||||
|
||||
|
||||
MONTHS = _p(
|
||||
'utils|months',
|
||||
"January,February,March,April,May,June,July,August,September,October,November,December"
|
||||
)
|
||||
|
||||
SHORT_MONTHS = _p(
|
||||
'utils|short_months',
|
||||
"Jan,Feb,Mar,Apr,May,Jun,Jul,Aug,Sep,Oct,Nov,Dec"
|
||||
)
|
||||
|
||||
|
||||
def local_month(month, short=False):
|
||||
string = MONTHS if not short else SHORT_MONTHS
|
||||
return ctx_translator.get().t(string).split(',')[month-1]
|
||||
@@ -54,7 +54,7 @@ async def main():
|
||||
intents = discord.Intents.all()
|
||||
intents.members = True
|
||||
intents.message_content = True
|
||||
intents.presences = False
|
||||
intents.presences = True
|
||||
|
||||
async with db.open():
|
||||
version = await db.version()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from babel import LocalBabel
|
||||
|
||||
babel = LocalBabel('core')
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import CoreCog
|
||||
|
||||
227
src/core/setting_types.py
Normal file
227
src/core/setting_types.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Additional abstract setting types useful for StudyLion settings.
|
||||
"""
|
||||
from typing import Optional
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import discord
|
||||
from discord.enums import TextStyle
|
||||
|
||||
from settings.base import ParentID
|
||||
from settings.setting_types import IntegerSetting, StringSetting
|
||||
from meta import conf
|
||||
from meta.errors import UserInputError
|
||||
from babel.translator import ctx_translator
|
||||
from utils.lib import MessageArgs
|
||||
|
||||
from . import babel
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class MessageSetting(StringSetting):
|
||||
"""
|
||||
Typed Setting ABC representing a message sent to Discord.
|
||||
|
||||
Data is a json-formatted string dict with at least one of the fields 'content', 'embed', 'embeds'
|
||||
Value is the corresponding dictionary
|
||||
"""
|
||||
# TODO: Extend to support format keys
|
||||
|
||||
_accepts = _p(
|
||||
'settype:message|accepts',
|
||||
"JSON formatted raw message data"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def download_attachment(attached: discord.Attachment):
|
||||
"""
|
||||
Download a discord.Attachment with some basic filetype and file size validation.
|
||||
"""
|
||||
t = ctx_translator.get().t
|
||||
|
||||
error = None
|
||||
decoded = None
|
||||
if attached.content_type and not ('json' in attached.content_type):
|
||||
error = t(_p(
|
||||
'settype:message|download|error:not_json',
|
||||
"The attached message data is not a JSON file!"
|
||||
))
|
||||
elif attached.size > 10000:
|
||||
error = t(_p(
|
||||
'settype:message|download|error:size',
|
||||
"The attached message data is too large!"
|
||||
))
|
||||
else:
|
||||
content = await attached.read()
|
||||
try:
|
||||
decoded = content.decode('UTF-8')
|
||||
except UnicodeDecodeError:
|
||||
error = t(_p(
|
||||
'settype:message|download|error:decoding',
|
||||
"Could not decode the message data. Please ensure it is saved with the `UTF-8` encoding."
|
||||
))
|
||||
|
||||
if error is not None:
|
||||
raise UserInputError(error)
|
||||
else:
|
||||
return decoded
|
||||
|
||||
@classmethod
|
||||
def value_to_args(cls, parent_id: ParentID, value: dict, **kwargs) -> MessageArgs:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
args = {}
|
||||
args['content'] = value.get('content', "")
|
||||
if 'embed' in value:
|
||||
embed = discord.Embed.from_dict(value['embed'])
|
||||
args['embed'] = embed
|
||||
if 'embeds' in value:
|
||||
embeds = []
|
||||
for embed_data in value['embeds']:
|
||||
embeds.append(discord.Embed.from_dict(embed_data))
|
||||
args['embeds'] = embeds
|
||||
return MessageArgs(**args)
|
||||
|
||||
@classmethod
|
||||
def _data_from_value(cls, parent_id: ParentID, value: Optional[dict], **kwargs):
|
||||
if value and any(value.get(key, None) for key in ('content', 'embed', 'embeds')):
|
||||
data = json.dumps(value)
|
||||
else:
|
||||
data = None
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _data_to_value(cls, parent_id: ParentID, data: Optional[str], **kwargs):
|
||||
if data:
|
||||
value = json.loads(data)
|
||||
else:
|
||||
value = None
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs):
|
||||
"""
|
||||
Provided user string can be downright random.
|
||||
|
||||
If it isn't json-formatted, treat it as the content of the message.
|
||||
If it is, do basic checking on the length and embeds.
|
||||
"""
|
||||
string = string.strip()
|
||||
if not string or string.lower() == 'none':
|
||||
return None
|
||||
|
||||
t = ctx_translator.get().t
|
||||
|
||||
error_tip = t(_p(
|
||||
'settype:message|error_suffix',
|
||||
"You can view, test, and fix your embed using the online [embed builder]({link})."
|
||||
)).format(
|
||||
link="https://glitchii.github.io/embedbuilder/?editor=json"
|
||||
)
|
||||
|
||||
if string.startswith('{') and string.endswith('}'):
|
||||
# Assume the string is a json-formatted message dict
|
||||
try:
|
||||
value = json.loads(string)
|
||||
except json.JSONDecodeError as err:
|
||||
error = t(_p(
|
||||
'settype:message|error:invalid_json',
|
||||
"The provided message data was not a valid JSON document!\n"
|
||||
"`{error}`"
|
||||
)).format(error=str(err))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
if not isinstance(value, dict) or not any(value.get(key, None) for key in ('content', 'embed', 'embeds')):
|
||||
error = t(_p(
|
||||
'settype:message|error:json_missing_keys',
|
||||
"Message data must be a JSON object with at least one of the following fields: "
|
||||
"`content`, `embed`, `embeds`"
|
||||
))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
embed_data = value.get('embed', None)
|
||||
if not isinstance(embed_data, dict):
|
||||
error = t(_p(
|
||||
'settype:message|error:json_embed_type',
|
||||
"`embed` field must be a valid JSON object."
|
||||
))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
embeds_data = value.get('embeds', [])
|
||||
if not isinstance(embeds_data, list):
|
||||
error = t(_p(
|
||||
'settype:message|error:json_embeds_type',
|
||||
"`embeds` field must be a list."
|
||||
))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
if embed_data and embeds_data:
|
||||
error = t(_p(
|
||||
'settype:message|error:json_embed_embeds',
|
||||
"Message data cannot include both `embed` and `embeds`."
|
||||
))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
content_data = value.get('content', "")
|
||||
if not isinstance(content_data, str):
|
||||
error = t(_p(
|
||||
'settype:message|error:json_content_type',
|
||||
"`content` field must be a string."
|
||||
))
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
# Validate embeds, which is the most likely place for something to go wrong
|
||||
embeds = [embed_data] if embed_data else embeds_data
|
||||
try:
|
||||
for embed in embeds:
|
||||
discord.Embed.from_dict(embed)
|
||||
except Exception as e:
|
||||
# from_dict may raise a range of possible exceptions.
|
||||
raw_error = ''.join(
|
||||
traceback.TracebackException.from_exception(e).format_exception_only()
|
||||
)
|
||||
error = t(_p(
|
||||
'ui:settype:message|error:embed_conversion',
|
||||
"Could not parse the message embed data.\n"
|
||||
"**Error:** `{exception}`"
|
||||
)).format(exception=raw_error)
|
||||
raise UserInputError(error + '\n' + error_tip)
|
||||
|
||||
# At this point, the message will at least successfully convert into MessageArgs
|
||||
# There are numerous ways it could still be invalid, e.g. invalid urls, or too-long fields
|
||||
# or the total message content being too long, or too many fields, etc
|
||||
# This will need to be caught in anything which displays a message parsed from user data.
|
||||
else:
|
||||
# Either the string is not json formatted, or the formatting is broken
|
||||
# Assume the string is a content message
|
||||
value = {
|
||||
'content': string
|
||||
}
|
||||
return json.dumps(value)
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, parent_id: ParentID, data: Optional[str], **kwargs):
|
||||
if not data:
|
||||
return None
|
||||
|
||||
value = cls._data_to_value(parent_id, data, **kwargs)
|
||||
content = value.get('content', "")
|
||||
if 'embed' in value or 'embeds' in value or len(content) > 100:
|
||||
t = ctx_translator.get().t
|
||||
formatted = t(_p(
|
||||
'settype:message|format:too_long',
|
||||
"Too long to display! See Preview."
|
||||
))
|
||||
else:
|
||||
formatted = content
|
||||
|
||||
return formatted
|
||||
|
||||
@property
|
||||
def input_field(self):
|
||||
field = super().input_field
|
||||
field.style = TextStyle.long
|
||||
return field
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
from psycopg import AsyncCursor, sql
|
||||
from psycopg.abc import Query, Params
|
||||
from psycopg._encodings import pgconn_encoding
|
||||
from psycopg._encodings import conn_encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AsyncLoggingCursor(AsyncCursor):
|
||||
elif isinstance(query, (sql.SQL, sql.Composed)):
|
||||
msg = query.as_string(self)
|
||||
elif isinstance(query, bytes):
|
||||
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
|
||||
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
|
||||
else:
|
||||
msg = repr(query)
|
||||
return msg
|
||||
|
||||
@@ -12,6 +12,7 @@ from aiohttp import ClientSession
|
||||
|
||||
from data import Database
|
||||
from utils.lib import tabulate
|
||||
from babel.translator import LeoBabel
|
||||
|
||||
from .config import Conf
|
||||
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
||||
@@ -43,6 +44,7 @@ class LionBot(Bot):
|
||||
self.shardname = shardname
|
||||
# self.appdata = appdata
|
||||
self.config = config
|
||||
self.translator = LeoBabel()
|
||||
|
||||
self.system_monitor = SystemMonitor()
|
||||
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||
@@ -189,7 +191,7 @@ class LionBot(Bot):
|
||||
# TODO: Some of these could have more user-feedback
|
||||
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||
cmd_str = ctx.command.app_command.to_dict()
|
||||
cmd_str = ctx.command.app_command.to_dict(self.tree)
|
||||
else:
|
||||
cmd_str = str(ctx.command)
|
||||
try:
|
||||
|
||||
@@ -131,7 +131,7 @@ class LionTree(CommandTree):
|
||||
return
|
||||
|
||||
set_logging_context(action=f"Run {command.qualified_name}")
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
|
||||
try:
|
||||
await command._invoke_with_namespace(interaction, namespace)
|
||||
except AppCommandError as e:
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
this_package = 'modules'
|
||||
this_package = "modules"
|
||||
|
||||
active = [
|
||||
'.sysadmin',
|
||||
'.voicefix',
|
||||
]
|
||||
active = [".sysadmin", ".voicefix", ".messagelogger", ".voicelog"]
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
||||
1
src/modules/messagelogger
Submodule
1
src/modules/messagelogger
Submodule
Submodule src/modules/messagelogger added at 166e310f96
1
src/modules/voicefix
Submodule
1
src/modules/voicefix
Submodule
Submodule src/modules/voicefix added at 70d089f5de
@@ -1,7 +0,0 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import VoiceFixCog
|
||||
await bot.add_cog(VoiceFixCog(bot))
|
||||
@@ -1,449 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
from cachetools import FIFOCache
|
||||
|
||||
import discord
|
||||
from discord.abc import GuildChannel
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
||||
from utils.ui import Confirm
|
||||
|
||||
from . import logger
|
||||
from .data import LinkData
|
||||
|
||||
|
||||
async def prepare_attachments(attachments: list[discord.Attachment]):
|
||||
results = []
|
||||
for attach in attachments:
|
||||
try:
|
||||
as_file = await attach.to_file(spoiler=attach.is_spoiler())
|
||||
results.append(as_file)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
async def prepare_embeds(message: discord.Message):
|
||||
embeds = [embed for embed in message.embeds if embed.type == 'rich']
|
||||
if message.reference:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.dark_gray(),
|
||||
description=f"Reply to {message.reference.jump_url}"
|
||||
)
|
||||
embeds.append(embed)
|
||||
return embeds
|
||||
|
||||
|
||||
|
||||
class VoiceFixCog(LionCog):
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(LinkData())
|
||||
|
||||
# Map of linkids to list of channelids
|
||||
self.link_channels = {}
|
||||
|
||||
# Map of channelids to linkids
|
||||
self.channel_links = {}
|
||||
|
||||
# Map of channelids to initialised discord.Webhook
|
||||
self.hooks = {}
|
||||
|
||||
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
|
||||
self.message_cache = FIFOCache(maxsize=200)
|
||||
# webhook msgid -> orig msgid
|
||||
self.wmessages = FIFOCache(maxsize=600)
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
await self.reload_links()
|
||||
|
||||
async def reload_links(self):
|
||||
records = await self.data.channel_links.select_where()
|
||||
channel_links = defaultdict(set)
|
||||
link_channels = defaultdict(set)
|
||||
|
||||
for record in records:
|
||||
linkid = record['linkid']
|
||||
channelid = record['channelid']
|
||||
|
||||
channel_links[channelid].add(linkid)
|
||||
link_channels[linkid].add(channelid)
|
||||
|
||||
channelids = list(channel_links.keys())
|
||||
if channelids:
|
||||
await self.data.LinkHook.fetch_where(channelid=channelids)
|
||||
for channelid in channelids:
|
||||
# Will hit cache, so don't need any more data queries
|
||||
await self.fetch_webhook_for(channelid)
|
||||
|
||||
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
|
||||
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
|
||||
|
||||
logger.info(
|
||||
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
|
||||
)
|
||||
|
||||
@LionCog.listener('on_message')
|
||||
async def on_message(self, message: discord.Message):
|
||||
# Don't need this because everything except explicit messages are webhooks now
|
||||
# if self.bot.user and (message.author.id == self.bot.user.id):
|
||||
# return
|
||||
if message.webhook_id:
|
||||
return
|
||||
|
||||
async with self.lock:
|
||||
sent = []
|
||||
linkids = self.channel_links.get(message.channel.id, ())
|
||||
if linkids:
|
||||
for linkid in linkids:
|
||||
for channelid in self.link_channels[linkid]:
|
||||
if channelid != message.channel.id:
|
||||
if message.attachments:
|
||||
files = await prepare_attachments(message.attachments)
|
||||
else:
|
||||
files = []
|
||||
|
||||
hook = self.hooks[channelid]
|
||||
avatar = message.author.avatar or message.author.default_avatar
|
||||
msg = await hook.send(
|
||||
content=message.content,
|
||||
wait=True,
|
||||
username=message.author.display_name,
|
||||
avatar_url=avatar.url,
|
||||
embeds=await prepare_embeds(message),
|
||||
files=files,
|
||||
allowed_mentions=discord.AllowedMentions.none()
|
||||
)
|
||||
sent.append((channelid, msg))
|
||||
self.wmessages[msg.id] = message.id
|
||||
if sent:
|
||||
# For easier lookup
|
||||
self.wmessages[message.id] = message.id
|
||||
sent.append((message.channel.id, message))
|
||||
|
||||
self.message_cache[message.id] = sent
|
||||
logger.info(f"Forwarded message {message.id}")
|
||||
|
||||
|
||||
@LionCog.listener('on_message_edit')
|
||||
async def on_message_edit(self, before, after):
|
||||
async with self.lock:
|
||||
cached_sent = self.message_cache.pop(before.id, ())
|
||||
new_sent = []
|
||||
for cid, msg in cached_sent:
|
||||
try:
|
||||
if msg.id != before.id:
|
||||
msg = await msg.edit(
|
||||
content=after.content,
|
||||
embeds=await prepare_embeds(after),
|
||||
)
|
||||
new_sent.append((cid, msg))
|
||||
except discord.NotFound:
|
||||
pass
|
||||
if new_sent:
|
||||
self.message_cache[after.id] = new_sent
|
||||
|
||||
@LionCog.listener('on_message_delete')
|
||||
async def on_message_delete(self, message):
|
||||
async with self.lock:
|
||||
origid = self.wmessages.get(message.id, None)
|
||||
if origid:
|
||||
cached_sent = self.message_cache.pop(origid, ())
|
||||
for _, msg in cached_sent:
|
||||
try:
|
||||
if msg.id != message.id:
|
||||
await msg.delete()
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
@LionCog.listener('on_reaction_add')
|
||||
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
|
||||
async with self.lock:
|
||||
message = reaction.message
|
||||
emoji = reaction.emoji
|
||||
origid = self.wmessages.get(message.id, None)
|
||||
if origid and reaction.count == 1:
|
||||
cached_sent = self.message_cache.get(origid, ())
|
||||
for _, msg in cached_sent:
|
||||
# TODO: Would be better to have a Message and check the reactions
|
||||
try:
|
||||
if msg.id != message.id:
|
||||
await msg.add_reaction(emoji)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
|
||||
hook = self.hooks.get(channelid, None)
|
||||
if hook is None:
|
||||
row = await self.data.LinkHook.fetch(channelid)
|
||||
if row is None:
|
||||
channel = self.bot.get_channel(channelid)
|
||||
if channel is None:
|
||||
raise ValueError("Cannot find channel to create hook.")
|
||||
hook = await channel.create_webhook(name="LabRat Channel Link")
|
||||
await self.data.LinkHook.create(
|
||||
channelid=channelid,
|
||||
webhookid=hook.id,
|
||||
token=hook.token,
|
||||
)
|
||||
else:
|
||||
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
|
||||
self.hooks[channelid] = hook
|
||||
return hook
|
||||
|
||||
@cmds.hybrid_group(
|
||||
name='linker',
|
||||
description="Base command group for the channel linker"
|
||||
)
|
||||
@appcmds.default_permissions(manage_channels=True)
|
||||
async def linker_group(self, ctx: LionContext):
|
||||
...
|
||||
|
||||
@linker_group.command(
|
||||
name='link',
|
||||
description="Create a new link, or add a channel to an existing link."
|
||||
)
|
||||
@appcmds.describe(
|
||||
name="Name of the new or existing channel link.",
|
||||
channel1="First channel to add to the link.",
|
||||
channel2="Second channel to add to the link.",
|
||||
channel3="Third channel to add to the link.",
|
||||
channel4="Fourth channel to add to the link.",
|
||||
channel5="Fifth channel to add to the link.",
|
||||
channelid="Optionally add a channel by id (for e.g. cross-server links).",
|
||||
)
|
||||
async def linker_link(self, ctx: LionContext,
|
||||
name: str,
|
||||
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
||||
channelid: Optional[str] = None,
|
||||
):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
await ctx.interaction.response.defer(thinking=True)
|
||||
|
||||
# Check if link 'name' already exists, create if not
|
||||
existing = await self.data.Link.fetch_where()
|
||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
||||
if link_row is None:
|
||||
# Create
|
||||
link_row = await self.data.Link.create(name=name)
|
||||
link_channels = set()
|
||||
created = True
|
||||
else:
|
||||
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
|
||||
link_channels = {record['channelid'] for record in records}
|
||||
created = False
|
||||
|
||||
# Create webhooks and webhook rows on channels if required
|
||||
maybe_channels = [
|
||||
channel1, channel2, channel3, channel4, channel5,
|
||||
]
|
||||
if channelid and channelid.isdigit():
|
||||
channel = self.bot.get_channel(int(channelid))
|
||||
maybe_channels.append(channel)
|
||||
|
||||
channels = [channel for channel in maybe_channels if channel]
|
||||
for channel in channels:
|
||||
await self.fetch_webhook_for(channel.id)
|
||||
|
||||
# Insert or update the links
|
||||
for channel in channels:
|
||||
if channel.id not in link_channels:
|
||||
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
|
||||
|
||||
await self.reload_links()
|
||||
|
||||
if created:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Link Created",
|
||||
description=(
|
||||
"Created the link **{name}** and linked channels:\n{channels}"
|
||||
).format(name=name, channels=', '.join(channel.mention for channel in channels))
|
||||
)
|
||||
else:
|
||||
channelids = self.link_channels[link_row.linkid]
|
||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Channels Linked",
|
||||
description=(
|
||||
"Updated the link **{name}** to link the following channels:\n{channelstr}"
|
||||
).format(name=link_row.name, channelstr=channelstr)
|
||||
)
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_group.command(
|
||||
name='unlink',
|
||||
description="Destroy a link, or remove a channel from a link."
|
||||
)
|
||||
@appcmds.describe(
|
||||
name="Name of the link to destroy",
|
||||
channel="Channel to remove from the link.",
|
||||
)
|
||||
async def linker_unlink(self, ctx: LionContext,
|
||||
name: str, channel: Optional[GuildChannel] = None):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
# Get the link, error if it doesn't exist
|
||||
existing = await self.data.Link.fetch_where()
|
||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
||||
if link_row is None:
|
||||
raise UserInputError(
|
||||
f"Link **{name}** doesn't exist!"
|
||||
)
|
||||
|
||||
link_channelids = self.link_channels.get(link_row.linkid, ())
|
||||
|
||||
if channel is not None:
|
||||
# If channel was given, remove channel from link and ack
|
||||
if channel.id not in link_channelids:
|
||||
raise UserInputError(
|
||||
f"{channel.mention} is not linked in **{link_row.name}**!"
|
||||
)
|
||||
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Channel Unlinked",
|
||||
description=f"{channel.mention} has been removed from **{link_row.name}**."
|
||||
)
|
||||
else:
|
||||
# Otherwise, confirm link destroy, delete link row, and ack
|
||||
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
|
||||
confirm = Confirm(
|
||||
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
|
||||
ctx.author.id,
|
||||
)
|
||||
confirm.embed.colour = discord.Colour.red()
|
||||
try:
|
||||
result = await confirm.ask(ctx.interaction)
|
||||
except ResponseTimedOut:
|
||||
result = False
|
||||
if not result:
|
||||
raise SafeCancellation
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title="Link removed",
|
||||
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
|
||||
)
|
||||
await link_row.delete()
|
||||
|
||||
await self.reload_links()
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_link.autocomplete('name')
|
||||
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Autocomplete an existing link.
|
||||
"""
|
||||
existing = await self.data.Link.fetch_where()
|
||||
names = [row.name for row in existing]
|
||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
||||
if not matching:
|
||||
choice = appcmds.Choice(
|
||||
name=f"Create a new link '{partial}'",
|
||||
value=partial
|
||||
)
|
||||
choices = [choice]
|
||||
else:
|
||||
choices = [
|
||||
appcmds.Choice(
|
||||
name=f"Link {name}",
|
||||
value=name
|
||||
)
|
||||
for name in matching
|
||||
]
|
||||
return choices
|
||||
|
||||
@linker_unlink.autocomplete('name')
|
||||
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
|
||||
"""
|
||||
Autocomplete an existing link.
|
||||
"""
|
||||
existing = await self.data.Link.fetch_where()
|
||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
||||
if not matching:
|
||||
choice = appcmds.Choice(
|
||||
name=f"No existing links matching '{partial}'",
|
||||
value=partial
|
||||
)
|
||||
choices = [choice]
|
||||
else:
|
||||
choices = [
|
||||
appcmds.Choice(
|
||||
name=f"Link {name}",
|
||||
value=name
|
||||
)
|
||||
for name in matching
|
||||
]
|
||||
return choices
|
||||
|
||||
@linker_group.command(
|
||||
name='links',
|
||||
description="Display the existing channel links."
|
||||
)
|
||||
async def linker_links(self, ctx: LionContext):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
await ctx.interaction.response.defer(thinking=True)
|
||||
|
||||
links = await self.data.Link.fetch_where()
|
||||
|
||||
if not links:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.light_grey(),
|
||||
title="No channel links have been set up!",
|
||||
description="Create a new link and add channels with {linker}".format(
|
||||
linker=self.bot.core.mention_cmd('linker link')
|
||||
)
|
||||
)
|
||||
else:
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=f"Channel Links in {ctx.guild.name}",
|
||||
)
|
||||
for link in links:
|
||||
channelids = self.link_channels.get(link.linkid, ())
|
||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
||||
embed.add_field(
|
||||
name=f"Link **{link.name}**",
|
||||
value=channelstr,
|
||||
inline=False
|
||||
)
|
||||
# TODO: May want paging if over 25 links....
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
@linker_group.command(
|
||||
name="webhook",
|
||||
description='Manually configure the webhook for a given channel.'
|
||||
)
|
||||
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
hook = discord.Webhook.from_url(webhook, client=self.bot)
|
||||
existing = await self.data.LionHook.fetch(channel.id)
|
||||
if existing:
|
||||
await existing.update(webhookid=hook.id, token=hook.token)
|
||||
else:
|
||||
await self.data.LinkHook.create(
|
||||
channelid=channel.id,
|
||||
webhookid=hook.id,
|
||||
token=hook.token,
|
||||
)
|
||||
self.hooks[channel.id] = hook
|
||||
await ctx.reply(f"Webhook for {channel.mention} updated!")
|
||||
@@ -1,39 +0,0 @@
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import Integer, Bool, Timestamp, String
|
||||
|
||||
|
||||
class LinkData(Registry):
|
||||
class Link(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE links(
|
||||
linkid SERIAL PRIMARY KEY,
|
||||
name TEXT
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'links'
|
||||
_cache_ = {}
|
||||
|
||||
linkid = Integer(primary=True)
|
||||
name = String()
|
||||
|
||||
|
||||
channel_links = Table('channel_links')
|
||||
|
||||
class LinkHook(RowModel):
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE channel_webhooks(
|
||||
channelid BIGINT PRIMARY KEY,
|
||||
webhookid BIGINT NOT NULL,
|
||||
token TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'channel_webhooks'
|
||||
_cache_ = {}
|
||||
|
||||
channelid = Integer(primary=True)
|
||||
webhookid = Integer()
|
||||
token = String()
|
||||
1
src/modules/voicelog
Submodule
1
src/modules/voicelog
Submodule
Submodule src/modules/voicelog added at 175e96c7e2
7
src/settings/__init__.py
Normal file
7
src/settings/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from babel.translator import LocalBabel
|
||||
babel = LocalBabel('settings_base')
|
||||
|
||||
from .data import ModelData, ListData
|
||||
from .base import BaseSetting
|
||||
from .ui import SettingWidget, InteractiveSetting
|
||||
from .groups import SettingDotDict, SettingGroup, ModelSettings, ModelSetting
|
||||
166
src/settings/base.py
Normal file
166
src/settings/base.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import Generic, TypeVar, Type, Optional, overload
|
||||
|
||||
|
||||
"""
|
||||
Setting metclass?
|
||||
Parse setting docstring to generate default info?
|
||||
Or just put it in the decorator we are already using
|
||||
"""
|
||||
|
||||
|
||||
# Typing using Generic[parent_id_type, data_type, value_type]
|
||||
# value generic, could be Union[?, UNSET]
|
||||
ParentID = TypeVar('ParentID')
|
||||
SettingData = TypeVar('SettingData')
|
||||
SettingValue = TypeVar('SettingValue')
|
||||
|
||||
T = TypeVar('T', bound='BaseSetting')
|
||||
|
||||
|
||||
class BaseSetting(Generic[ParentID, SettingData, SettingValue]):
|
||||
"""
|
||||
Abstract base class describing a stored configuration setting.
|
||||
A setting consists of logic to load the setting from storage,
|
||||
present it in a readable form, understand user entered values,
|
||||
and write it again in storage.
|
||||
Additionally, the setting has attributes attached describing
|
||||
the setting in a user-friendly manner for display purposes.
|
||||
"""
|
||||
setting_id: str # Unique source identifier for the setting
|
||||
|
||||
_default: Optional[SettingData] = None # Default data value for the setting
|
||||
|
||||
def __init__(self, parent_id: ParentID, data: Optional[SettingData], **kwargs):
|
||||
self.parent_id = parent_id
|
||||
self._data = data
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Instance generation
|
||||
@classmethod
|
||||
async def get(cls: Type[T], parent_id: ParentID, **kwargs) -> T:
|
||||
"""
|
||||
Return a setting instance initialised from the stored value, associated with the given parent id.
|
||||
"""
|
||||
data = await cls._reader(parent_id, **kwargs)
|
||||
return cls(parent_id, data, **kwargs)
|
||||
|
||||
# Main interface
|
||||
@property
|
||||
def data(self) -> Optional[SettingData]:
|
||||
"""
|
||||
Retrieves the current internal setting data if it is set, otherwise the default data
|
||||
"""
|
||||
return self._data if self._data is not None else self.default
|
||||
|
||||
@data.setter
|
||||
def data(self, new_data: Optional[SettingData]):
|
||||
"""
|
||||
Sets the internal raw data.
|
||||
Does not write the changes.
|
||||
"""
|
||||
self._data = new_data
|
||||
|
||||
@property
|
||||
def default(self) -> Optional[SettingData]:
|
||||
"""
|
||||
Retrieves the default value for this setting.
|
||||
Settings should override this if the default depends on the object id.
|
||||
"""
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def value(self) -> SettingValue: # Actually optional *if* _default is None
|
||||
"""
|
||||
Context-aware object or objects associated with the setting.
|
||||
"""
|
||||
return self._data_to_value(self.parent_id, self.data) # type: ignore
|
||||
|
||||
@value.setter
|
||||
def value(self, new_value: Optional[SettingValue]):
|
||||
"""
|
||||
Setter which reads the discord-aware object and converts it to data.
|
||||
Does not write the new value.
|
||||
"""
|
||||
self._data = self._data_from_value(self.parent_id, new_value)
|
||||
|
||||
async def write(self, **kwargs) -> None:
|
||||
"""
|
||||
Write current data to the database.
|
||||
For settings which override this,
|
||||
ensure you handle deletion of values when internal data is None.
|
||||
"""
|
||||
await self._writer(self.parent_id, self._data, **kwargs)
|
||||
|
||||
# Raw converters
|
||||
@overload
|
||||
@classmethod
|
||||
def _data_from_value(cls: Type[T], parent_id: ParentID, value: SettingValue, **kwargs) -> SettingData:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def _data_from_value(cls: Type[T], parent_id: ParentID, value: None, **kwargs) -> None:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def _data_from_value(
|
||||
cls: Type[T], parent_id: ParentID, value: Optional[SettingValue], **kwargs
|
||||
) -> Optional[SettingData]:
|
||||
"""
|
||||
Convert a high-level setting value to internal data.
|
||||
Must be overridden by the setting.
|
||||
Be aware of UNSET values, these should always pass through as None
|
||||
to provide an unsetting interface.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def _data_to_value(cls: Type[T], parent_id: ParentID, data: SettingData, **kwargs) -> SettingValue:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def _data_to_value(cls: Type[T], parent_id: ParentID, data: None, **kwargs) -> None:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def _data_to_value(
|
||||
cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs
|
||||
) -> Optional[SettingValue]:
|
||||
"""
|
||||
Convert internal data to high-level setting value.
|
||||
Must be overriden by the setting.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Database access
|
||||
@classmethod
|
||||
async def _reader(cls: Type[T], parent_id: ParentID, **kwargs) -> Optional[SettingData]:
|
||||
"""
|
||||
Retrieve the setting data associated with the given parent_id.
|
||||
May be None if the setting is not set.
|
||||
Must be overridden by the setting.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def _writer(cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs) -> None:
|
||||
"""
|
||||
Write provided setting data to storage.
|
||||
Must be overridden by the setting unless the `write` method is overridden.
|
||||
If the data is None, the setting is UNSET and should be deleted.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def setup(cls, bot):
|
||||
"""
|
||||
Initialisation task to be executed during client initialisation.
|
||||
May be used for e.g. populating a cache or required client setup.
|
||||
|
||||
Main application must execute the initialisation task before the setting is used.
|
||||
Further, the task must always be executable, if the setting is loaded.
|
||||
Conditional initialisation should go in the relevant module's init tasks.
|
||||
"""
|
||||
return None
|
||||
233
src/settings/data.py
Normal file
233
src/settings/data.py
Normal file
@@ -0,0 +1,233 @@
|
||||
from typing import Type
|
||||
import json
|
||||
|
||||
from data import RowModel, Table, ORDER
|
||||
from meta.logger import log_wrap, set_logging_context
|
||||
|
||||
|
||||
class ModelData:
|
||||
"""
|
||||
Mixin for settings stored in a single row and column of a Model.
|
||||
Assumes that the parent_id is the identity key of the Model.
|
||||
|
||||
This does not create a reference to the Row.
|
||||
"""
|
||||
# Table storing the desired data
|
||||
_model: Type[RowModel]
|
||||
|
||||
# Column with the desired data
|
||||
_column: str
|
||||
|
||||
# Whether to create a row if not found
|
||||
_create_row = False
|
||||
|
||||
# High level data cache to use, leave as None to disable cache.
|
||||
_cache = None # Map[id -> value]
|
||||
|
||||
@classmethod
|
||||
def _read_from_row(cls, parent_id, row, **kwargs):
|
||||
data = row[cls._column]
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[parent_id] = data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||
"""
|
||||
Read in the requested column associated to the parent id.
|
||||
"""
|
||||
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||
return cls._cache[parent_id]
|
||||
|
||||
model = cls._model
|
||||
if cls._create_row:
|
||||
row = await model.fetch_or_create(parent_id)
|
||||
else:
|
||||
row = await model.fetch(parent_id)
|
||||
data = row[cls._column] if row else None
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[parent_id] = data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def _writer(cls, parent_id, data, **kwargs):
|
||||
"""
|
||||
Write the provided entry to the table.
|
||||
This does *not* create the row if it does not exist.
|
||||
It only updates.
|
||||
"""
|
||||
# TODO: Better way of getting the key?
|
||||
# TODO: Transaction
|
||||
if not isinstance(parent_id, tuple):
|
||||
parent_id = (parent_id, )
|
||||
model = cls._model
|
||||
rows = await model.table.update_where(
|
||||
**model._dict_from_id(parent_id)
|
||||
).set(
|
||||
**{cls._column: data}
|
||||
)
|
||||
# If we didn't update any rows, create a new row
|
||||
if not rows:
|
||||
await model.fetch_or_create(**model._dict_from_id(parent_id), **{cls._column: data})
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[parent_id] = data
|
||||
|
||||
|
||||
class ListData:
|
||||
"""
|
||||
Mixin for list types implemented on a Table.
|
||||
Implements a reader and writer.
|
||||
This assumes the list is the only data stored in the table,
|
||||
and removes list entries by deleting rows.
|
||||
"""
|
||||
setting_id: str
|
||||
|
||||
# Table storing the setting data
|
||||
_table_interface: Table
|
||||
|
||||
# Name of the column storing the id
|
||||
_id_column: str
|
||||
|
||||
# Name of the column storing the data to read
|
||||
_data_column: str
|
||||
|
||||
# Name of column storing the order index to use, if any. Assumed to be Serial on writing.
|
||||
_order_column: str
|
||||
_order_type: ORDER = ORDER.ASC
|
||||
|
||||
# High level data cache to use, set to None to disable cache.
|
||||
_cache = None # Map[id -> value]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(isolate=True)
|
||||
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||
"""
|
||||
Read in all entries associated to the given id.
|
||||
"""
|
||||
set_logging_context(action="Read cls.setting_id")
|
||||
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||
return cls._cache[parent_id]
|
||||
|
||||
table = cls._table_interface # type: Table
|
||||
query = table.select_where(**{cls._id_column: parent_id}).select(cls._data_column)
|
||||
if cls._order_column:
|
||||
query.order_by(cls._order_column, direction=cls._order_type)
|
||||
|
||||
rows = await query
|
||||
data = [row[cls._data_column] for row in rows]
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[parent_id] = data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@log_wrap(isolate=True)
|
||||
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
|
||||
"""
|
||||
Write the provided list to storage.
|
||||
"""
|
||||
set_logging_context(action="Write cls.setting_id")
|
||||
table = cls._table_interface
|
||||
async with table.connector.connection() as conn:
|
||||
table.connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# Handle None input as an empty list
|
||||
if data is None:
|
||||
data = []
|
||||
|
||||
current = await cls._reader(id, use_cache=False, **kwargs)
|
||||
if not cls._order_column and (add_only or remove_only):
|
||||
to_insert = [item for item in data if item not in current] if not remove_only else []
|
||||
to_remove = data if remove_only else (
|
||||
[item for item in current if item not in data] if not add_only else []
|
||||
)
|
||||
|
||||
# Handle required deletions
|
||||
if to_remove:
|
||||
params = {
|
||||
cls._id_column: id,
|
||||
cls._data_column: to_remove
|
||||
}
|
||||
await table.delete_where(**params)
|
||||
|
||||
# Handle required insertions
|
||||
if to_insert:
|
||||
columns = (cls._id_column, cls._data_column)
|
||||
values = [(id, value) for value in to_insert]
|
||||
await table.insert_many(columns, *values)
|
||||
|
||||
if cls._cache is not None:
|
||||
new_current = [item for item in current + to_insert if item not in to_remove]
|
||||
cls._cache[id] = new_current
|
||||
else:
|
||||
# Remove all and add all to preserve order
|
||||
delete_params = {cls._id_column: id}
|
||||
await table.delete_where(**delete_params)
|
||||
|
||||
if data:
|
||||
columns = (cls._id_column, cls._data_column)
|
||||
values = [(id, value) for value in data]
|
||||
await table.insert_many(columns, *values)
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[id] = data
|
||||
|
||||
|
||||
class KeyValueData:
|
||||
"""
|
||||
Mixin for settings implemented in a Key-Value table.
|
||||
The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair.
|
||||
"""
|
||||
_table_interface: Table
|
||||
|
||||
_id_column: str
|
||||
|
||||
_key_column: str
|
||||
|
||||
_value_column: str
|
||||
|
||||
_key: str
|
||||
|
||||
@classmethod
|
||||
async def _reader(cls, id, **kwargs):
|
||||
params = {
|
||||
cls._id_column: id,
|
||||
cls._key_column: cls._key
|
||||
}
|
||||
|
||||
row = await cls._table_interface.select_one_where(**params).select(cls._value_column)
|
||||
data = row[cls._value_column] if row else None
|
||||
|
||||
if data is not None:
|
||||
data = json.loads(data)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def _writer(cls, id, data, **kwargs):
|
||||
params = {
|
||||
cls._id_column: id,
|
||||
cls._key_column: cls._key
|
||||
}
|
||||
if data is not None:
|
||||
values = {
|
||||
cls._value_column: json.dumps(data)
|
||||
}
|
||||
rows = await cls._table_interface.update_where(**params).set(**values)
|
||||
if not rows:
|
||||
await cls._table_interface.insert_many(
|
||||
(cls._id_column, cls._key_column, cls._value_column),
|
||||
(id, cls._key, json.dumps(data))
|
||||
)
|
||||
else:
|
||||
await cls._table_interface.delete_where(**params)
|
||||
|
||||
|
||||
# class UserInputError(SafeCancellation):
|
||||
# pass
|
||||
204
src/settings/groups.py
Normal file
204
src/settings/groups.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from typing import Generic, Type, TypeVar, Optional, overload
|
||||
|
||||
from data import RowModel
|
||||
|
||||
from .data import ModelData
|
||||
from .ui import InteractiveSetting
|
||||
from .base import BaseSetting
|
||||
|
||||
from utils.lib import tabulate
|
||||
|
||||
|
||||
T = TypeVar('T', bound=InteractiveSetting)
|
||||
|
||||
|
||||
class SettingDotDict(Generic[T], dict[str, Type[T]]):
|
||||
"""
|
||||
Dictionary structure allowing simple dot access to items.
|
||||
"""
|
||||
__getattr__ = dict.__getitem__ # type: ignore
|
||||
__setattr__ = dict.__setitem__ # type: ignore
|
||||
__delattr__ = dict.__delitem__ # type: ignore
|
||||
|
||||
|
||||
class SettingGroup:
|
||||
"""
|
||||
A SettingGroup is a collection of settings under one name.
|
||||
"""
|
||||
__initial_settings__: list[Type[InteractiveSetting]] = []
|
||||
|
||||
_title: Optional[str] = None
|
||||
_description: Optional[str] = None
|
||||
|
||||
def __init_subclass__(cls, title: Optional[str] = None):
|
||||
cls._title = title or cls._title
|
||||
cls._description = cls._description or cls.__doc__
|
||||
|
||||
settings: list[Type[InteractiveSetting]] = []
|
||||
for item in cls.__dict__.values():
|
||||
if isinstance(item, type) and issubclass(item, InteractiveSetting):
|
||||
settings.append(item)
|
||||
cls.__initial_settings__ = settings
|
||||
|
||||
def __init_settings__(self):
|
||||
settings = SettingDotDict()
|
||||
for setting in self.__initial_settings__:
|
||||
settings[setting.__name__] = setting
|
||||
return settings
|
||||
|
||||
def __init__(self, title=None, description=None) -> None:
|
||||
self.title: str = title or self._title or self.__class__.__name__
|
||||
self.description: str = description or self._description or ""
|
||||
self.settings: SettingDotDict[InteractiveSetting] = self.__init_settings__()
|
||||
|
||||
def attach(self, cls: Type[T], name: Optional[str] = None):
|
||||
name = name or cls.setting_id
|
||||
self.settings[name] = cls
|
||||
return cls
|
||||
|
||||
def detach(self, cls):
|
||||
return self.settings.pop(cls.__name__, None)
|
||||
|
||||
def update(self, smap):
|
||||
self.settings.update(smap.settings)
|
||||
|
||||
def reduce(self, *keys):
|
||||
for key in keys:
|
||||
self.settings.pop(key, None)
|
||||
return
|
||||
|
||||
async def make_setting_table(self, parent_id, **kwargs):
|
||||
"""
|
||||
Convenience method for generating a rendered setting table.
|
||||
"""
|
||||
rows = []
|
||||
for setting in self.settings.values():
|
||||
if not setting._virtual:
|
||||
set = await setting.get(parent_id, **kwargs)
|
||||
name = set.display_name
|
||||
value = str(set.formatted)
|
||||
rows.append((name, value, set.hover_desc))
|
||||
table_rows = tabulate(
|
||||
*rows,
|
||||
row_format="[`{invis}{key:<{pad}}{colon}`](https://lionbot.org \"{field[2]}\")\t{value}"
|
||||
)
|
||||
return '\n'.join(table_rows)
|
||||
|
||||
|
||||
class ModelSetting(ModelData, BaseSetting):
|
||||
...
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
A ModelConfig provides a central point of configuration for any object described by a single Model.
|
||||
|
||||
An instance of a ModelConfig represents configuration for a single object
|
||||
(given by a single row of the corresponding Model).
|
||||
|
||||
The ModelConfig also supports registration of non-model configuration,
|
||||
to support associated settings (e.g. list-settings) for the object.
|
||||
|
||||
This is an ABC, and must be subclassed for each object-type.
|
||||
"""
|
||||
settings: SettingDotDict
|
||||
_model_settings: set
|
||||
model: Type[RowModel]
|
||||
|
||||
def __init__(self, parent_id, row, **kwargs):
|
||||
self.parent_id = parent_id
|
||||
self.row = row
|
||||
self.kwargs = kwargs
|
||||
|
||||
@classmethod
|
||||
def register_setting(cls, setting_cls):
|
||||
"""
|
||||
Decorator to register a non-model setting as part of the object configuration.
|
||||
|
||||
The setting class may be re-accessed through the `settings` class attr.
|
||||
|
||||
Subclasses may provide alternative access pathways to key non-model settings.
|
||||
"""
|
||||
cls.settings[setting_cls.setting_id] = setting_cls
|
||||
return setting_cls
|
||||
|
||||
@classmethod
|
||||
def register_model_setting(cls, model_setting_cls):
|
||||
"""
|
||||
Decorator to register a model setting as part of the object configuration.
|
||||
|
||||
The setting class may be accessed through the `settings` class attr.
|
||||
|
||||
A fresh setting instance may also be retrieved (using cached data)
|
||||
through the `get` instance method.
|
||||
|
||||
Subclasses are recommended to provide model settings as properties
|
||||
for simplified access and type checking.
|
||||
"""
|
||||
cls._model_settings.add(model_setting_cls.setting_id)
|
||||
return cls.register_setting(model_setting_cls)
|
||||
|
||||
def get(self, setting_id):
|
||||
"""
|
||||
Retrieve a freshly initialised copy of the given model-setting.
|
||||
|
||||
The given `setting_id` must have been previously registered through `register_model_setting`.
|
||||
This uses cached data, and so is not guaranteed to be up-to-date.
|
||||
"""
|
||||
if setting_id not in self._model_settings:
|
||||
# TODO: Log
|
||||
raise ValueError
|
||||
setting_cls = self.settings[setting_id]
|
||||
data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs)
|
||||
return setting_cls(self.parent_id, data, **self.kwargs)
|
||||
|
||||
|
||||
class ModelSettings:
|
||||
"""
|
||||
A ModelSettings instance aggregates multiple `ModelSetting` instances
|
||||
bound to the same parent id on a single Model.
|
||||
|
||||
This enables a single point of access
|
||||
for settings of a given Model,
|
||||
with support for caching or deriving as needed.
|
||||
|
||||
This is an abstract base class,
|
||||
and should be subclassed to define the contained settings.
|
||||
"""
|
||||
_settings: SettingDotDict = SettingDotDict()
|
||||
model: Type[RowModel]
|
||||
|
||||
def __init__(self, parent_id, row, **kwargs):
|
||||
self.parent_id = parent_id
|
||||
self.row = row
|
||||
self.kwargs = kwargs
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls, *parent_id, **kwargs):
|
||||
"""
|
||||
Load an instance of this ModelSetting with the given parent_id
|
||||
and setting keyword arguments.
|
||||
"""
|
||||
row = await cls.model.fetch_or_create(*parent_id)
|
||||
return cls(parent_id, row, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def attach(self, setting_cls):
|
||||
"""
|
||||
Decorator to attach the given setting class to this modelsetting.
|
||||
"""
|
||||
# This violates the interface principle, use structured typing instead?
|
||||
if not (issubclass(setting_cls, BaseSetting) and issubclass(setting_cls, ModelData)):
|
||||
raise ValueError(
|
||||
f"The provided setting class must be `ModelSetting`, not {setting_cls.__class__.__name__}."
|
||||
)
|
||||
self._settings[setting_cls.setting_id] = setting_cls
|
||||
return setting_cls
|
||||
|
||||
def get(self, setting_id):
|
||||
setting_cls = self._settings.get(setting_id)
|
||||
data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs)
|
||||
return setting_cls(self.parent_id, data, **self.kwargs)
|
||||
|
||||
def __getitem__(self, setting_id):
|
||||
return self.get(setting_id)
|
||||
13
src/settings/mock.py
Normal file
13
src/settings/mock.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
||||
|
||||
class LocalString:
|
||||
def __init__(self, string):
|
||||
self.string = string
|
||||
|
||||
def as_string(self):
|
||||
return self.string
|
||||
|
||||
|
||||
_ = LocalString
|
||||
1393
src/settings/setting_types.py
Normal file
1393
src/settings/setting_types.py
Normal file
File diff suppressed because it is too large
Load Diff
512
src/settings/ui.py
Normal file
512
src/settings/ui.py
Normal file
@@ -0,0 +1,512 @@
|
||||
from typing import Optional, Callable, Any, Dict, Coroutine, Generic, TypeVar, List
|
||||
import asyncio
|
||||
from contextvars import copy_context
|
||||
|
||||
import discord
|
||||
from discord import ui
|
||||
from discord.ui.button import ButtonStyle, Button, button
|
||||
from discord.ui.modal import Modal
|
||||
from discord.ui.text_input import TextInput
|
||||
from meta.errors import UserInputError
|
||||
|
||||
from utils.lib import tabulate, recover_context
|
||||
from utils.ui import FastModal
|
||||
from meta.config import conf
|
||||
from meta.context import ctx_bot
|
||||
from babel.translator import ctx_translator, LazyStr
|
||||
|
||||
from .base import BaseSetting, ParentID, SettingData, SettingValue
|
||||
from . import babel
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
ST = TypeVar('ST', bound='InteractiveSetting')
|
||||
|
||||
|
||||
class SettingModal(FastModal):
|
||||
input_field: TextInput = TextInput(label="Edit Setting")
|
||||
|
||||
def update_field(self, new_field):
|
||||
self.remove_item(self.input_field)
|
||||
self.add_item(new_field)
|
||||
self.input_field = new_field
|
||||
|
||||
|
||||
class SettingWidget(Generic[ST], ui.View):
|
||||
# TODO: Permission restrictions and callback!
|
||||
# Context variables for permitted user(s)? Subclass ui.View with PermittedView?
|
||||
# Don't need to descend permissions to Modal
|
||||
# Maybe combine with timeout manager
|
||||
|
||||
def __init__(self, setting: ST, auto_write=True, **kwargs):
|
||||
self.setting = setting
|
||||
self.update_children()
|
||||
super().__init__(**kwargs)
|
||||
self.auto_write = auto_write
|
||||
|
||||
self._interaction: Optional[discord.Interaction] = None
|
||||
self._modal: Optional[SettingModal] = None
|
||||
self._exports: List[ui.Item] = self.make_exports()
|
||||
|
||||
self._context = copy_context()
|
||||
|
||||
def update_children(self):
|
||||
"""
|
||||
Method called before base View initialisation.
|
||||
Allows updating the children components (usually explicitly defined callbacks),
|
||||
before Item instantiation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def order_children(self, *children):
|
||||
"""
|
||||
Helper method to set and order the children using bound methods.
|
||||
"""
|
||||
child_map = {child.__name__: child for child in self.__view_children_items__}
|
||||
self.__view_children_items__ = [child_map[child.__name__] for child in children]
|
||||
|
||||
def update_child(self, child, new_args):
|
||||
args = getattr(child, '__discord_ui_model_kwargs__')
|
||||
args |= new_args
|
||||
|
||||
def make_exports(self):
|
||||
"""
|
||||
Called post-instantiation to populate self._exports.
|
||||
"""
|
||||
return self.children
|
||||
|
||||
def refresh(self):
|
||||
"""
|
||||
Update widget components from current setting data, if applicable.
|
||||
E.g. to update the default entry in a select list after a choice has been made,
|
||||
or update button colours.
|
||||
This does not trigger a discord ui update,
|
||||
that is the responsibility of the interaction handler.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def show(self, interaction: discord.Interaction, key: Any = None, override=False, **kwargs):
|
||||
"""
|
||||
Complete standard setting widget UI flow for this setting.
|
||||
The SettingWidget components may be attached to other messages as needed,
|
||||
and they may be triggered individually,
|
||||
but this coroutine defines the standard interface.
|
||||
Intended for use by any interaction which wants to "open the setting".
|
||||
|
||||
Extra keyword arguments are passed directly to the interaction reply (for e.g. ephemeral).
|
||||
"""
|
||||
if key is None:
|
||||
# By default, only have one widget listener per interaction.
|
||||
key = ('widget', interaction.id)
|
||||
|
||||
# If there is already a widget listening on this key, respect override
|
||||
if self.setting.get_listener(key) and not override:
|
||||
# Refuse to spawn another widget
|
||||
return
|
||||
|
||||
async def update_callback(new_data):
|
||||
self.setting.data = new_data
|
||||
await interaction.edit_original_response(embed=self.setting.embed, view=self, **kwargs)
|
||||
|
||||
self.setting.register_callback(key)(update_callback)
|
||||
await interaction.response.send_message(embed=self.setting.embed, view=self, **kwargs)
|
||||
await self.wait()
|
||||
try:
|
||||
# Try and detach the view, since we aren't handling events anymore.
|
||||
await interaction.edit_original_response(view=None)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
self.setting.deregister_callback(key)
|
||||
|
||||
def attach(self, group_view: ui.View):
|
||||
"""
|
||||
Attach this setting widget to a view representing several settings.
|
||||
"""
|
||||
for item in self._exports:
|
||||
group_view.add_item(item)
|
||||
|
||||
@button(style=ButtonStyle.secondary, label="Edit", row=4)
|
||||
async def edit_button(self, interaction: discord.Interaction, button: ui.Button):
|
||||
"""
|
||||
Spawn a simple edit modal,
|
||||
populated with `setting.input_field`.
|
||||
"""
|
||||
recover_context(self._context)
|
||||
# Spawn the setting modal
|
||||
await interaction.response.send_modal(self.modal)
|
||||
|
||||
@button(style=ButtonStyle.danger, label="Reset", row=4)
|
||||
async def reset_button(self, interaction: discord.Interaction, button: Button):
|
||||
recover_context(self._context)
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
await self.setting.interactive_set(None, interaction)
|
||||
|
||||
@property
|
||||
def modal(self) -> Modal:
|
||||
"""
|
||||
Build a Modal dialogue for updating the setting.
|
||||
Refreshes (and re-attaches) the input field each time this is called.
|
||||
"""
|
||||
if self._modal is not None:
|
||||
self._modal.update_field(self.setting.input_field)
|
||||
return self._modal
|
||||
|
||||
# TODO: Attach shared timeouts to the modal
|
||||
self._modal = modal = SettingModal(
|
||||
title=f"Edit {self.setting.display_name}",
|
||||
)
|
||||
modal.update_field(self.setting.input_field)
|
||||
|
||||
@modal.submit_callback()
|
||||
async def edit_submit(interaction: discord.Interaction):
|
||||
# TODO: Catch and handle UserInputError
|
||||
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||
data = await self.setting._parse_string(self.setting.parent_id, modal.input_field.value)
|
||||
await self.setting.interactive_set(data, interaction)
|
||||
|
||||
return modal
|
||||
|
||||
|
||||
class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]):
|
||||
__slots__ = ('_widget',)
|
||||
|
||||
# Configuration interface descriptions
|
||||
_display_name: LazyStr # User readable name of the setting
|
||||
_desc: LazyStr # User readable brief description of the setting
|
||||
_long_desc: LazyStr # User readable long description of the setting
|
||||
_accepts: LazyStr # User readable description of the acceptable values
|
||||
_set_cmd: str = None
|
||||
_notset_str: LazyStr = _p('setting|formatted|notset', "Not Set")
|
||||
_virtual: bool = False # Whether the setting should be hidden from tables and dashboards
|
||||
_required: bool = False
|
||||
|
||||
Widget = SettingWidget
|
||||
|
||||
# A list of callback coroutines to call when the setting updates
|
||||
# This can be used globally to refresh state when the setting updates,
|
||||
# Or locallly to e.g. refresh an active widget.
|
||||
# The callbacks are called on write, so they may be bypassed by direct use of _writer!
|
||||
_listeners_: Dict[Any, Callable[[Optional[SettingData]], Coroutine[Any, Any, None]]] = {}
|
||||
|
||||
# Optional client event to dispatch when theis setting has been written
|
||||
# Event handlers should be of the form Callable[ParentID, SettingData]
|
||||
_event: Optional[str] = None
|
||||
|
||||
# Interaction ward that should be validated via interaction_check
|
||||
_write_ward: Optional[Callable[[discord.Interaction], Coroutine[Any, Any, bool]]] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._widget: Optional[SettingWidget] = None
|
||||
|
||||
@property
|
||||
def long_desc(self):
|
||||
t = ctx_translator.get().t
|
||||
bot = ctx_bot.get()
|
||||
return t(self._long_desc).format(
|
||||
bot=bot,
|
||||
cmds=bot.core.mention_cache
|
||||
)
|
||||
|
||||
@property
|
||||
def display_name(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(self._display_name)
|
||||
|
||||
@property
|
||||
def desc(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(self._desc)
|
||||
|
||||
@property
|
||||
def accepts(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(self._accepts)
|
||||
|
||||
async def write(self, **kwargs) -> None:
|
||||
await super().write(**kwargs)
|
||||
self.dispatch_update()
|
||||
for listener in self._listeners_.values():
|
||||
asyncio.create_task(listener(self.data))
|
||||
|
||||
def dispatch_update(self):
|
||||
"""
|
||||
Dispatch a client event along `self._event`, if set.
|
||||
|
||||
Override to modify the target event handler arguments.
|
||||
By default, event handlers should be of the form:
|
||||
Callable[[ParentID, SettingData], Coroutine[Any, Any, None]]
|
||||
"""
|
||||
if self._event is not None and (bot := ctx_bot.get()) is not None:
|
||||
bot.dispatch(self._event, self.parent_id, self)
|
||||
|
||||
def get_listener(self, key):
|
||||
return self._listeners_.get(key, None)
|
||||
|
||||
@classmethod
|
||||
def register_callback(cls, name=None):
|
||||
def wrapped(coro):
|
||||
cls._listeners_[name or coro.__name__] = coro
|
||||
return coro
|
||||
return wrapped
|
||||
|
||||
@classmethod
|
||||
def deregister_callback(cls, name):
|
||||
cls._listeners_.pop(name, None)
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
"""
|
||||
Response message sent when the setting has successfully been updated.
|
||||
Should generally be one line.
|
||||
"""
|
||||
if self.data is None:
|
||||
return "Setting reset!"
|
||||
else:
|
||||
return f"Setting Updated! New value: {self.formatted}"
|
||||
|
||||
@property
|
||||
def hover_desc(self):
|
||||
"""
|
||||
This no longer works since Discord changed the hover rules.
|
||||
|
||||
return '\n'.join((
|
||||
self.display_name,
|
||||
'=' * len(self.display_name),
|
||||
self.desc,
|
||||
f"\nAccepts: {self.accepts}"
|
||||
))
|
||||
"""
|
||||
return self.desc
|
||||
|
||||
async def update_response(self, interaction: discord.Interaction, message: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Respond to an interaction which triggered a setting update.
|
||||
Usually just wraps `update_message` in an embed and sends it back.
|
||||
Passes any extra `kwargs` to the message creation method.
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
description=f"{str(conf.emojis.tick)} {message or self.update_message}",
|
||||
colour=discord.Color.green()
|
||||
)
|
||||
if interaction.response.is_done():
|
||||
await interaction.edit_original_response(embed=embed, **kwargs)
|
||||
else:
|
||||
await interaction.response.send_message(embed=embed, **kwargs)
|
||||
|
||||
async def interactive_set(self, new_data: Optional[SettingData], interaction: discord.Interaction, **kwargs):
|
||||
self.data = new_data
|
||||
await self.write()
|
||||
await self.update_response(interaction, **kwargs)
|
||||
|
||||
async def format_in(self, bot, **kwargs):
|
||||
"""
|
||||
Formatted version of the setting given an asynchronous context with client.
|
||||
"""
|
||||
return self.formatted
|
||||
|
||||
@property
|
||||
def embed_field(self):
|
||||
"""
|
||||
Returns a {name, value} pair for use in an Embed field.
|
||||
"""
|
||||
name = self.display_name
|
||||
value = f"{self.long_desc}\n{self.desc_table}"
|
||||
if len(value) > 1024:
|
||||
t = ctx_translator.get().t
|
||||
desc_table = '\n'.join(
|
||||
tabulate(
|
||||
*self._desc_table(
|
||||
show_value=t(_p(
|
||||
'setting|embed_field|too_long',
|
||||
"Too long to display here!"
|
||||
))
|
||||
)
|
||||
)
|
||||
)
|
||||
value = f"{self.long_desc}\n{desc_table}"
|
||||
if len(value) > 1024:
|
||||
# Forcibly trim
|
||||
value = value[:1020] + '...'
|
||||
return {'name': name, 'value': value}
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
if self._set_cmd is not None:
|
||||
bot = ctx_bot.get()
|
||||
if bot:
|
||||
return bot.core.mention_cmd(self._set_cmd)
|
||||
else:
|
||||
return f"`/{self._set_cmd}`"
|
||||
|
||||
@property
|
||||
def notset_str(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(self._notset_str)
|
||||
|
||||
@property
|
||||
def embed(self):
|
||||
"""
|
||||
Returns a full embed describing this setting.
|
||||
"""
|
||||
t = ctx_translator.get().t
|
||||
embed = discord.Embed(
|
||||
title=t(_p(
|
||||
'setting|summary_embed|title',
|
||||
"Configuration options for `{name}`"
|
||||
)).format(name=self.display_name),
|
||||
)
|
||||
embed.description = "{}\n{}".format(self.long_desc.format(self=self), self.desc_table)
|
||||
return embed
|
||||
|
||||
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
|
||||
t = ctx_translator.get().t
|
||||
lines = []
|
||||
|
||||
# Currently line
|
||||
lines.append((
|
||||
t(_p('setting|summary_table|field:currently|key', "Currently")),
|
||||
show_value or (self.formatted or self.notset_str)
|
||||
))
|
||||
|
||||
# Default line
|
||||
if (default := self.default) is not None:
|
||||
lines.append((
|
||||
t(_p('setting|summary_table|field:default|key', "By Default")),
|
||||
self._format_data(self.parent_id, default) or 'None'
|
||||
))
|
||||
|
||||
# Set using line
|
||||
if (set_str := self.set_str) is not None:
|
||||
lines.append((
|
||||
t(_p('setting|summary_table|field:set|key', "Set Using")),
|
||||
set_str
|
||||
))
|
||||
return lines
|
||||
|
||||
@property
|
||||
def desc_table(self) -> str:
|
||||
return '\n'.join(tabulate(*self._desc_table()))
|
||||
|
||||
@property
|
||||
def input_field(self) -> TextInput:
|
||||
"""
|
||||
TextInput field used for string-based setting modification.
|
||||
May be added to external modal for grouped setting editing.
|
||||
This property is not persistent, and creates a new field each time.
|
||||
"""
|
||||
return TextInput(
|
||||
label=self.display_name,
|
||||
placeholder=self.accepts,
|
||||
default=self.input_formatted[:4000] if self.input_formatted else None,
|
||||
required=self._required
|
||||
)
|
||||
|
||||
@property
|
||||
def widget(self):
|
||||
"""
|
||||
Returns the Discord UI View associated with the current setting.
|
||||
"""
|
||||
if self._widget is None:
|
||||
self._widget = self.Widget(self)
|
||||
return self._widget
|
||||
|
||||
@classmethod
|
||||
def set_widget(cls, WidgetCls):
|
||||
"""
|
||||
Convenience decorator to create the widget class for this setting.
|
||||
"""
|
||||
cls.Widget = WidgetCls
|
||||
return WidgetCls
|
||||
|
||||
@property
|
||||
def formatted(self):
|
||||
"""
|
||||
Default user-readable form of the setting.
|
||||
Should be a short single line.
|
||||
"""
|
||||
return self._format_data(self.parent_id, self.data, **self.kwargs) or self.notset_str
|
||||
|
||||
@property
|
||||
def input_formatted(self) -> str:
|
||||
"""
|
||||
Format the current value as a default value for an input field.
|
||||
Returned string must be acceptable through parse_string.
|
||||
Does not take into account defaults.
|
||||
"""
|
||||
if self._data is not None:
|
||||
return str(self._data)
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
"""
|
||||
Formatted summary of the data.
|
||||
May be implemented in `_format_data(..., summary=True, ...)` or overidden.
|
||||
"""
|
||||
return self._format_data(self.parent_id, self.data, summary=True, **self.kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_string(cls, parent_id, userstr: str, **kwargs):
|
||||
"""
|
||||
Return a setting instance initialised from a parsed user string.
|
||||
"""
|
||||
data = await cls._parse_string(parent_id, userstr, **kwargs)
|
||||
return cls(parent_id, data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def from_value(cls, parent_id, value, **kwargs):
|
||||
await cls._check_value(parent_id, value, **kwargs)
|
||||
data = cls._data_from_value(parent_id, value, **kwargs)
|
||||
return cls(parent_id, data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _parse_string(cls, parent_id, string: str, **kwargs) -> Optional[SettingData]:
|
||||
"""
|
||||
Parse user provided string (usually from a TextInput) into raw setting data.
|
||||
Must be overriden by the setting if the setting is user-configurable.
|
||||
Returns None if the setting was unset.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, parent_id, data, **kwargs):
|
||||
"""
|
||||
Convert raw setting data into a formatted user-readable string,
|
||||
representing the current value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def _check_value(cls, parent_id, value, **kwargs):
|
||||
"""
|
||||
Check the provided value is valid.
|
||||
|
||||
Many setting update methods now provide Discord objects instead of raw data or user strings.
|
||||
This method may be used for value-checking such a value.
|
||||
|
||||
Raises UserInputError if the value fails validation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def interaction_check(cls, parent_id, interaction: discord.Interaction, **kwargs):
|
||||
if cls._write_ward is not None and not await cls._write_ward(interaction):
|
||||
# TODO: Combine the check system so we can do customised errors here
|
||||
t = ctx_translator.get().t
|
||||
raise UserInputError(t(_p(
|
||||
'setting|interaction_check|error',
|
||||
"You do not have sufficient permissions to do this!"
|
||||
)))
|
||||
|
||||
|
||||
"""
|
||||
command callback for set command?
|
||||
autocomplete for set command?
|
||||
|
||||
Might be better in a ConfigSetting subclass.
|
||||
But also mix into the base setting types.
|
||||
"""
|
||||
@@ -845,3 +845,35 @@ def write_records(records: list[dict[str, Any]], stream: StringIO):
|
||||
for record in records:
|
||||
stream.write(','.join(map(str, record.values())))
|
||||
stream.write('\n')
|
||||
|
||||
|
||||
parse_dur_exps = [
|
||||
(
|
||||
r"(?P<value>\d+)\s*(?:(d)|(day))",
|
||||
60 * 60 * 24,
|
||||
),
|
||||
(
|
||||
r"(?P<value>\d+)\s*(?:(h)|(hour))",
|
||||
60 * 60
|
||||
),
|
||||
(
|
||||
r"(?P<value>\d+)\s*(?:(m)|(min))",
|
||||
60
|
||||
),
|
||||
(
|
||||
r"(?P<value>\d+)\s*(?:(s)|(sec))",
|
||||
1
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def parse_duration(string: str) -> Optional[int]:
|
||||
seconds = 0
|
||||
found = False
|
||||
for expr, multiplier in parse_dur_exps:
|
||||
match = re.search(expr, string, flags=re.IGNORECASE)
|
||||
if match:
|
||||
found = True
|
||||
seconds += int(match.group('value')) * multiplier
|
||||
|
||||
return seconds if found else None
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from babel.translator import LocalBabel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
util_babel = LocalBabel('utils')
|
||||
|
||||
from .hooked import *
|
||||
from .leo import *
|
||||
from .micros import *
|
||||
from .msgeditor import *
|
||||
|
||||
1070
src/utils/ui/msgeditor.py
Normal file
1070
src/utils/ui/msgeditor.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user