generated from HoloTech/discord-bot-template
Initial commit
This commit is contained in:
149
.gitignore
vendored
Normal file
149
.gitignore
vendored
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
src/modules/test/*
|
||||||
|
|
||||||
|
pending-rewrite/
|
||||||
|
logs/*
|
||||||
|
notes/*
|
||||||
|
tmp/*
|
||||||
|
output/*
|
||||||
|
locales/domains
|
||||||
|
|
||||||
|
.idea/*
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
9
.gitmodules
vendored
Normal file
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[submodule "src/modules/voicefix"]
|
||||||
|
path = src/modules/voicefix
|
||||||
|
url = git@github.com:Intery/StudyLion-voicefix.git
|
||||||
|
[submodule "src/modules/streamalerts"]
|
||||||
|
path = src/modules/streamalerts
|
||||||
|
url = git@github.com:Intery/StudyLion-streamalerts.git
|
||||||
|
[submodule "src/data"]
|
||||||
|
path = src/data
|
||||||
|
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git
|
||||||
1
config/.gitignore
vendored
Normal file
1
config/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
*.conf
|
||||||
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
42
data/schema.sql
Normal file
42
data/schema.sql
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- Metadata {{{
|
||||||
|
CREATE TABLE version_history(
|
||||||
|
component TEXT NOT NULL,
|
||||||
|
from_version INTEGER NOT NULL,
|
||||||
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
|
||||||
|
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW._timestamp = (now() at time zone 'utc');
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ language 'plpgsql';
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- App metadata {{{
|
||||||
|
|
||||||
|
CREATE TABLE app_config(
|
||||||
|
appname TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE bot_config(
|
||||||
|
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
|
||||||
|
sponsor_prompt TEXT,
|
||||||
|
sponsor_message TEXT,
|
||||||
|
default_skin TEXT
|
||||||
|
);
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- TODO: Profile data
|
||||||
|
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
-- vim: set fdm=marker:
|
||||||
9
example-config/emojis.conf
Normal file
9
example-config/emojis.conf
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[EMOJIS]
|
||||||
|
|
||||||
|
tick = :✅:
|
||||||
|
clock = :⏱️:
|
||||||
|
warning = :⚠️:
|
||||||
|
config = :⚙️:
|
||||||
|
stats = :📊:
|
||||||
|
utility = :⏱️:
|
||||||
|
cancel = :❌:
|
||||||
27
example-config/example-bot.conf
Normal file
27
example-config/example-bot.conf
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
[BOT]
|
||||||
|
prefix = t!
|
||||||
|
|
||||||
|
admins = 413668234269818890
|
||||||
|
|
||||||
|
admin_guilds = 1265249490063851571
|
||||||
|
|
||||||
|
shard_count = 1
|
||||||
|
|
||||||
|
ALSO_READ = config/emojis.conf, config/secrets.conf
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
log_file = bot.log
|
||||||
|
|
||||||
|
general_log = https://discord.com/api/webhooks/1409394313552593009/5SB_zbzyPa_ccshoe3ePGjCnbT9s6mPfCfpY8P7bL_Zn6vNkeF4CFFbAFEykHZlZl7e8
|
||||||
|
error_log = %(general_log)s
|
||||||
|
critical_log = %(general_log)s
|
||||||
|
warning_log = %(general_log)s
|
||||||
|
warning_prefix = **WARNING**
|
||||||
|
error_prefix = **ERROR**
|
||||||
|
critical_prefix = ***CRITICAL***
|
||||||
|
|
||||||
|
[LOGGING_LEVELS]
|
||||||
|
root = DEBUG
|
||||||
|
discord = INFO
|
||||||
|
discord.http = INFO
|
||||||
|
discord.gateway = INFO
|
||||||
6
example-config/example-secrets.conf
Normal file
6
example-config/example-secrets.conf
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[BOT]
|
||||||
|
token =
|
||||||
|
|
||||||
|
[DATA]
|
||||||
|
args = dbname=
|
||||||
|
appid =
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
aiohttp
|
||||||
|
cachetools
|
||||||
|
configparser
|
||||||
|
discord.py [voice]
|
||||||
|
iso8601
|
||||||
|
psycopg[pool]
|
||||||
|
pytz
|
||||||
12
scripts/start_bot.py
Executable file
12
scripts/start_bot.py
Executable file
@@ -0,0 +1,12 @@
|
|||||||
|
# !/bin/python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from bot import _main
|
||||||
|
_main()
|
||||||
35
scripts/start_debug.py
Executable file
35
scripts/start_debug.py
Executable file
@@ -0,0 +1,35 @@
|
|||||||
|
# !/bin/python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import tracemalloc
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
tracemalloc.start()
|
||||||
|
|
||||||
|
|
||||||
|
def loop_exception_handler(loop, context):
|
||||||
|
print(context)
|
||||||
|
task: asyncio.Task = context.get('task', None)
|
||||||
|
if task is not None:
|
||||||
|
addendum = f"<Task name='{task.get_name()}' stack='{task.get_stack()}'>"
|
||||||
|
message = context.get('message', '')
|
||||||
|
context['message'] = ' '.join((message, addendum))
|
||||||
|
loop.default_exception_handler(context)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.set_exception_handler(loop_exception_handler)
|
||||||
|
loop.set_debug(enabled=True)
|
||||||
|
|
||||||
|
from bot import _main
|
||||||
|
_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
3
src/babel/__init__.py
Normal file
3
src/babel/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .translator import SOURCE_LOCALE, LeoBabel, LocalBabel, LazyStr, ctx_locale, ctx_translator
|
||||||
|
|
||||||
|
babel = LocalBabel('babel')
|
||||||
81
src/babel/enums.py
Normal file
81
src/babel/enums.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from . import babel
|
||||||
|
|
||||||
|
_p = babel._p
|
||||||
|
|
||||||
|
|
||||||
|
class LocaleMap(Enum):
|
||||||
|
american_english = 'en-US'
|
||||||
|
british_english = 'en-GB'
|
||||||
|
bulgarian = 'bg'
|
||||||
|
chinese = 'zh-CN'
|
||||||
|
taiwan_chinese = 'zh-TW'
|
||||||
|
croatian = 'hr'
|
||||||
|
czech = 'cs'
|
||||||
|
danish = 'da'
|
||||||
|
dutch = 'nl'
|
||||||
|
finnish = 'fi'
|
||||||
|
french = 'fr'
|
||||||
|
german = 'de'
|
||||||
|
greek = 'el'
|
||||||
|
hindi = 'hi'
|
||||||
|
hungarian = 'hu'
|
||||||
|
italian = 'it'
|
||||||
|
japanese = 'ja'
|
||||||
|
korean = 'ko'
|
||||||
|
lithuanian = 'lt'
|
||||||
|
norwegian = 'no'
|
||||||
|
polish = 'pl'
|
||||||
|
brazil_portuguese = 'pt-BR'
|
||||||
|
romanian = 'ro'
|
||||||
|
russian = 'ru'
|
||||||
|
spain_spanish = 'es-ES'
|
||||||
|
swedish = 'sv-SE'
|
||||||
|
thai = 'th'
|
||||||
|
turkish = 'tr'
|
||||||
|
ukrainian = 'uk'
|
||||||
|
vietnamese = 'vi'
|
||||||
|
hebrew = 'he-IL'
|
||||||
|
|
||||||
|
|
||||||
|
# Original Discord names
|
||||||
|
locale_names = {
|
||||||
|
'id': (_p('localenames|locale:id', "Indonesian"), "Bahasa Indonesia"),
|
||||||
|
'da': (_p('localenames|locale:da', "Danish"), "Dansk"),
|
||||||
|
'de': (_p('localenames|locale:de', "German"), "Deutsch"),
|
||||||
|
'en-GB': (_p('localenames|locale:en-GB', "English, UK"), "English, UK"),
|
||||||
|
'en-US': (_p('localenames|locale:en-US', "English, US"), "English, US"),
|
||||||
|
'es-ES': (_p('localenames|locale:es-ES', "Spanish"), "Español"),
|
||||||
|
'fr': (_p('localenames|locale:fr', "French"), "Français"),
|
||||||
|
'hr': (_p('localenames|locale:hr', "Croatian"), "Hrvatski"),
|
||||||
|
'it': (_p('localenames|locale:it', "Italian"), "Italiano"),
|
||||||
|
'lt': (_p('localenames|locale:lt', "Lithuanian"), "Lietuviškai"),
|
||||||
|
'hu': (_p('localenames|locale:hu', "Hungarian"), "Magyar"),
|
||||||
|
'nl': (_p('localenames|locale:nl', "Dutch"), "Nederlands"),
|
||||||
|
'no': (_p('localenames|locale:no', "Norwegian"), "Norsk"),
|
||||||
|
'pl': (_p('localenames|locale:pl', "Polish"), "Polski"),
|
||||||
|
'pt-BR': (_p('localenames|locale:pt-BR', "Portuguese, Brazilian"), "Português do Brasil"),
|
||||||
|
'ro': (_p('localenames|locale:ro', "Romanian, Romania"), "Română"),
|
||||||
|
'fi': (_p('localenames|locale:fi', "Finnish"), "Suomi"),
|
||||||
|
'sv-SE': (_p('localenames|locale:sv-SE', "Swedish"), "Svenska"),
|
||||||
|
'vi': (_p('localenames|locale:vi', "Vietnamese"), "Tiếng Việt"),
|
||||||
|
'tr': (_p('localenames|locale:tr', "Turkish"), "Türkçe"),
|
||||||
|
'cs': (_p('localenames|locale:cs', "Czech"), "Čeština"),
|
||||||
|
'el': (_p('localenames|locale:el', "Greek"), "Ελληνικά"),
|
||||||
|
'bg': (_p('localenames|locale:bg', "Bulgarian"), "български"),
|
||||||
|
'ru': (_p('localenames|locale:ru', "Russian"), "Pусский"),
|
||||||
|
'uk': (_p('localenames|locale:uk', "Ukrainian"), "Українська"),
|
||||||
|
'hi': (_p('localenames|locale:hi', "Hindi"), "हिन्दी"),
|
||||||
|
'th': (_p('localenames|locale:th', "Thai"), "ไทย"),
|
||||||
|
'zh-CN': (_p('localenames|locale:zh-CN', "Chinese, China"), "中文"),
|
||||||
|
'ja': (_p('localenames|locale:ja', "Japanese"), "日本語"),
|
||||||
|
'zh-TW': (_p('localenames|locale:zh-TW', "Chinese, Taiwan"), "繁體中文"),
|
||||||
|
'ko': (_p('localenames|locale:ko', "Korean"), "한국어"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# More names for languages not supported by Discord
|
||||||
|
locale_names |= {
|
||||||
|
'he': (_p('localenames|locale:he', "Hebrew"), "Hebrew"),
|
||||||
|
'he-IL': (_p('localenames|locale:he-IL', "Hebrew"), "Hebrew"),
|
||||||
|
'ceaser': (_p('localenames|locale:test', "Test Language"), "dfbtfs"),
|
||||||
|
}
|
||||||
108
src/babel/translator.py
Normal file
108
src/babel/translator.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from collections import defaultdict
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import gettext
|
||||||
|
|
||||||
|
from discord.app_commands import Translator, locale_str
|
||||||
|
from discord.enums import Locale
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SOURCE_LOCALE = 'en_GB'
|
||||||
|
ctx_locale: ContextVar[str] = ContextVar('locale', default=SOURCE_LOCALE)
|
||||||
|
ctx_translator: ContextVar['LeoBabel'] = ContextVar('translator', default=None) # type: ignore
|
||||||
|
|
||||||
|
null = gettext.NullTranslations()
|
||||||
|
|
||||||
|
|
||||||
|
class LeoBabel(Translator):
|
||||||
|
def __init__(self):
|
||||||
|
self.supported_locales = {loc.name for loc in Locale}
|
||||||
|
self.supported_domains = {}
|
||||||
|
self.translators = defaultdict(dict) # locale -> domain -> GNUTranslator
|
||||||
|
|
||||||
|
async def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def unload(self):
|
||||||
|
self.translators.clear()
|
||||||
|
|
||||||
|
def get_translator(self, locale: Optional[str], domain):
|
||||||
|
return null
|
||||||
|
|
||||||
|
def t(self, lazystr, locale=None):
|
||||||
|
return lazystr._translate_with(null)
|
||||||
|
|
||||||
|
async def translate(self, string: locale_str, locale: Locale, context):
|
||||||
|
if not isinstance(string, LazyStr):
|
||||||
|
return string
|
||||||
|
else:
|
||||||
|
return string.message
|
||||||
|
|
||||||
|
ctx_translator.set(LeoBabel())
|
||||||
|
|
||||||
|
class Method(Enum):
|
||||||
|
GETTEXT = 'gettext'
|
||||||
|
NGETTEXT = 'ngettext'
|
||||||
|
PGETTEXT = 'pgettext'
|
||||||
|
NPGETTEXT = 'npgettext'
|
||||||
|
|
||||||
|
|
||||||
|
class LocalBabel:
|
||||||
|
def __init__(self, domain):
|
||||||
|
self.domain = domain
|
||||||
|
|
||||||
|
@property
|
||||||
|
def methods(self):
|
||||||
|
return (self._, self._n, self._p, self._np)
|
||||||
|
|
||||||
|
def _(self, message):
|
||||||
|
return LazyStr(Method.GETTEXT, message, domain=self.domain)
|
||||||
|
|
||||||
|
def _n(self, singular, plural, n):
|
||||||
|
return LazyStr(Method.NGETTEXT, singular, plural, n, domain=self.domain)
|
||||||
|
|
||||||
|
def _p(self, context, message):
|
||||||
|
return LazyStr(Method.PGETTEXT, context, message, domain=self.domain)
|
||||||
|
|
||||||
|
def _np(self, context, singular, plural, n):
|
||||||
|
return LazyStr(Method.NPGETTEXT, context, singular, plural, n, domain=self.domain)
|
||||||
|
|
||||||
|
|
||||||
|
class LazyStr(locale_str):
|
||||||
|
__slots__ = ('method', 'args', 'domain', 'locale')
|
||||||
|
|
||||||
|
def __init__(self, method, *args, locale=None, domain=None):
|
||||||
|
self.method = method
|
||||||
|
self.args = args
|
||||||
|
self.domain = domain
|
||||||
|
self.locale = locale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self):
|
||||||
|
return self._translate_with(null)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extras(self):
|
||||||
|
return {'locale': self.locale, 'domain': self.domain}
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
def _translate_with(self, translator: gettext.GNUTranslations):
|
||||||
|
method = getattr(translator, self.method.value)
|
||||||
|
return method(*self.args)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'{self.__class__.__name__}({self.method}, {self.args!r}, locale={self.locale}, domain={self.domain})'
|
||||||
|
|
||||||
|
def __eq__(self, obj: object) -> bool:
|
||||||
|
return isinstance(obj, locale_str) and self.message == obj.message
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.args)
|
||||||
20
src/babel/utils.py
Normal file
20
src/babel/utils.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from .translator import ctx_translator
|
||||||
|
from . import babel
|
||||||
|
|
||||||
|
_, _p, _np = babel._, babel._p, babel._np
|
||||||
|
|
||||||
|
|
||||||
|
MONTHS = _p(
|
||||||
|
'utils|months',
|
||||||
|
"January,February,March,April,May,June,July,August,September,October,November,December"
|
||||||
|
)
|
||||||
|
|
||||||
|
SHORT_MONTHS = _p(
|
||||||
|
'utils|short_months',
|
||||||
|
"Jan,Feb,Mar,Apr,May,Jun,Jul,Aug,Sep,Oct,Nov,Dec"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def local_month(month, short=False):
|
||||||
|
string = MONTHS if not short else SHORT_MONTHS
|
||||||
|
return ctx_translator.get().t(string).split(',')[month-1]
|
||||||
107
src/bot.py
Normal file
107
src/bot.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from meta import LionBot, conf, sharding, appname
|
||||||
|
from meta.app import shardname
|
||||||
|
from meta.logger import log_context, log_action_stack, setup_main_logger
|
||||||
|
from meta.context import ctx_bot
|
||||||
|
from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus
|
||||||
|
|
||||||
|
from data import Database
|
||||||
|
|
||||||
|
|
||||||
|
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
||||||
|
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
||||||
|
|
||||||
|
|
||||||
|
logging_queue = setup_main_logger()
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db = Database(conf.data['args'])
|
||||||
|
|
||||||
|
|
||||||
|
async def _data_monitor() -> ComponentStatus:
|
||||||
|
"""
|
||||||
|
Component monitor callback for the database.
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
'stats': str(db.pool.get_stats())
|
||||||
|
}
|
||||||
|
if not db.pool._opened:
|
||||||
|
level = StatusLevel.WAITING
|
||||||
|
info = "(WAITING) Database Pool is not opened."
|
||||||
|
elif db.pool._closed:
|
||||||
|
level = StatusLevel.ERRORED
|
||||||
|
info = "(ERROR) Database Pool is closed."
|
||||||
|
else:
|
||||||
|
level = StatusLevel.OKAY
|
||||||
|
info = "(OK) Database Pool statistics: {stats}"
|
||||||
|
return ComponentStatus(level, info, info, data)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
log_action_stack.set(("Initialising",))
|
||||||
|
logger.info("Initialising StudyLion")
|
||||||
|
|
||||||
|
intents = discord.Intents.all()
|
||||||
|
intents.members = True
|
||||||
|
intents.message_content = True
|
||||||
|
intents.presences = False
|
||||||
|
|
||||||
|
async with db.open():
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with LionBot(
|
||||||
|
command_prefix=conf.bot.get('prefix', '!!'),
|
||||||
|
intents=intents,
|
||||||
|
appname=appname,
|
||||||
|
shardname=shardname,
|
||||||
|
db=db,
|
||||||
|
config=conf,
|
||||||
|
initial_extensions=[
|
||||||
|
'core',
|
||||||
|
'modules',
|
||||||
|
],
|
||||||
|
web_client=session,
|
||||||
|
testing_guilds=conf.bot.getintlist('admin_guilds'),
|
||||||
|
shard_id=sharding.shard_number,
|
||||||
|
shard_count=sharding.shard_count,
|
||||||
|
help_command=None,
|
||||||
|
proxy=conf.bot.get('proxy', None),
|
||||||
|
chunk_guilds_at_startup=True,
|
||||||
|
) as lionbot:
|
||||||
|
ctx_bot.set(lionbot)
|
||||||
|
lionbot.system_monitor.add_component(
|
||||||
|
ComponentMonitor('Database', _data_monitor)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
log_context.set(f"APP: {appname}")
|
||||||
|
logger.info("StudyLion initialised, starting!", extra={'action': 'Starting'})
|
||||||
|
await lionbot.start(conf.bot['TOKEN'])
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
log_context.set(f"APP: {appname}")
|
||||||
|
logger.info("StudyLion closed, shutting down.", extra={'action': "Shutting Down"}, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _main():
|
||||||
|
from signal import SIGINT, SIGTERM
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
main_task = asyncio.ensure_future(main())
|
||||||
|
for signal in [SIGINT, SIGTERM]:
|
||||||
|
loop.add_signal_handler(signal, main_task.cancel)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main_task)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
logging.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_main()
|
||||||
26
src/botdata.py
Normal file
26
src/botdata.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import String, Timestamp, Integer, Bool
|
||||||
|
|
||||||
|
|
||||||
|
class VersionHistory(RowModel):
|
||||||
|
"""
|
||||||
|
CREATE TABLE version_history(
|
||||||
|
component TEXT NOT NULL,
|
||||||
|
from_version INTEGER NOT NULL,
|
||||||
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'version_history'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
component = String()
|
||||||
|
from_version = Integer()
|
||||||
|
to_version = Integer()
|
||||||
|
author = String()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class BotData(Registry):
|
||||||
|
version_history = VersionHistory.table
|
||||||
7
src/constants.py
Normal file
7
src/constants.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
CONFIG_FILE = "config/bot.conf"
|
||||||
|
|
||||||
|
HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
|
||||||
|
|
||||||
|
SCHEMA_VERSIONS = {
|
||||||
|
'ROOT': 1,
|
||||||
|
}
|
||||||
8
src/core/__init__.py
Normal file
8
src/core/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from babel import LocalBabel
|
||||||
|
|
||||||
|
babel = LocalBabel('core')
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
from .cog import CoreCog
|
||||||
|
|
||||||
|
await bot.add_cog(CoreCog(bot))
|
||||||
76
src/core/cog.py
Normal file
76
src/core/cog.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
|
import discord
|
||||||
|
import discord.app_commands as appcmd
|
||||||
|
|
||||||
|
from meta import LionBot, LionCog, LionContext
|
||||||
|
from meta.app import shardname, appname
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
from .data import CoreData
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class keydefaultdict(defaultdict):
|
||||||
|
def __missing__(self, key):
|
||||||
|
if self.default_factory is None:
|
||||||
|
raise KeyError(key)
|
||||||
|
else:
|
||||||
|
ret = self[key] = self.default_factory(key)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class CoreCog(LionCog):
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
self.data = CoreData()
|
||||||
|
bot.db.load_registry(self.data)
|
||||||
|
|
||||||
|
self.app_config: Optional[CoreData.AppConfig] = None
|
||||||
|
self.bot_config: Optional[CoreData.BotConfig] = None
|
||||||
|
|
||||||
|
self.app_cmd_cache: list[discord.app_commands.AppCommand] = []
|
||||||
|
self.cmd_name_cache: dict[str, discord.app_commands.AppCommand] = {}
|
||||||
|
self.mention_cache: dict[str, str] = keydefaultdict(self.mention_cmd)
|
||||||
|
|
||||||
|
async def cog_load(self):
|
||||||
|
# Fetch (and possibly create) core data rows.
|
||||||
|
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||||
|
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||||
|
|
||||||
|
# Load the app command cache
|
||||||
|
await self.reload_appcmd_cache()
|
||||||
|
|
||||||
|
async def reload_appcmd_cache(self):
|
||||||
|
for guildid in self.bot.testing_guilds:
|
||||||
|
self.app_cmd_cache += await self.bot.tree.fetch_commands(guild=discord.Object(guildid))
|
||||||
|
self.app_cmd_cache += await self.bot.tree.fetch_commands()
|
||||||
|
self.cmd_name_cache = {cmd.name: cmd for cmd in self.app_cmd_cache}
|
||||||
|
self.mention_cache = self._mention_cache_from(self.app_cmd_cache)
|
||||||
|
|
||||||
|
def _mention_cache_from(self, cmds: list[appcmd.AppCommand | appcmd.AppCommandGroup]):
|
||||||
|
cache = keydefaultdict(self.mention_cmd)
|
||||||
|
for cmd in cmds:
|
||||||
|
cache[cmd.qualified_name if isinstance(cmd, appcmd.AppCommandGroup) else cmd.name] = cmd.mention
|
||||||
|
subcommands = [option for option in cmd.options if isinstance(option, appcmd.AppCommandGroup)]
|
||||||
|
if subcommands:
|
||||||
|
subcache = self._mention_cache_from(subcommands)
|
||||||
|
cache |= subcache
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def mention_cmd(self, name: str):
|
||||||
|
"""
|
||||||
|
Create an application command mention for the given names.
|
||||||
|
|
||||||
|
If not found in cache, creates a 'fake' mention with an invalid id.
|
||||||
|
"""
|
||||||
|
if name in self.mention_cache:
|
||||||
|
mention = self.mention_cache[name]
|
||||||
|
else:
|
||||||
|
mention = f"</{name}:1110834049204891730>"
|
||||||
|
return mention
|
||||||
45
src/core/data.py
Normal file
45
src/core/data.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from itertools import chain
|
||||||
|
from psycopg import sql
|
||||||
|
from cachetools import TTLCache
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from meta import conf
|
||||||
|
from meta.logger import log_wrap
|
||||||
|
from data import Table, Registry, Column, RowModel, RegisterEnum
|
||||||
|
from data.models import WeakCache
|
||||||
|
from data.columns import Integer, String, Bool, Timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class CoreData(Registry, name="core"):
|
||||||
|
class AppConfig(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE app_config(
|
||||||
|
appname TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'app_config'
|
||||||
|
|
||||||
|
appname = String(primary=True)
|
||||||
|
created_at = Timestamp()
|
||||||
|
|
||||||
|
class BotConfig(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
------
|
||||||
|
CREATE TABLE bot_config(
|
||||||
|
appname TEXT PRIMARY KEY REFERENCES app_config(appname) ON DELETE CASCADE,
|
||||||
|
sponsor_prompt TEXT,
|
||||||
|
sponsor_message TEXT,
|
||||||
|
default_skin TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'bot_config'
|
||||||
|
|
||||||
|
appname = String(primary=True)
|
||||||
|
default_skin = String()
|
||||||
|
sponsor_prompt = String()
|
||||||
|
sponsor_message = String()
|
||||||
227
src/core/setting_types.py
Normal file
227
src/core/setting_types.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""
|
||||||
|
Additional abstract setting types useful for StudyLion settings.
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.enums import TextStyle
|
||||||
|
|
||||||
|
from settings.base import ParentID
|
||||||
|
from settings.setting_types import IntegerSetting, StringSetting
|
||||||
|
from meta import conf
|
||||||
|
from meta.errors import UserInputError
|
||||||
|
from babel.translator import ctx_translator
|
||||||
|
from utils.lib import MessageArgs
|
||||||
|
|
||||||
|
from . import babel
|
||||||
|
|
||||||
|
_p = babel._p
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSetting(StringSetting):
|
||||||
|
"""
|
||||||
|
Typed Setting ABC representing a message sent to Discord.
|
||||||
|
|
||||||
|
Data is a json-formatted string dict with at least one of the fields 'content', 'embed', 'embeds'
|
||||||
|
Value is the corresponding dictionary
|
||||||
|
"""
|
||||||
|
# TODO: Extend to support format keys
|
||||||
|
|
||||||
|
_accepts = _p(
|
||||||
|
'settype:message|accepts',
|
||||||
|
"JSON formatted raw message data"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def download_attachment(attached: discord.Attachment):
|
||||||
|
"""
|
||||||
|
Download a discord.Attachment with some basic filetype and file size validation.
|
||||||
|
"""
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
|
||||||
|
error = None
|
||||||
|
decoded = None
|
||||||
|
if attached.content_type and not ('json' in attached.content_type):
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|download|error:not_json',
|
||||||
|
"The attached message data is not a JSON file!"
|
||||||
|
))
|
||||||
|
elif attached.size > 10000:
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|download|error:size',
|
||||||
|
"The attached message data is too large!"
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
content = await attached.read()
|
||||||
|
try:
|
||||||
|
decoded = content.decode('UTF-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|download|error:decoding',
|
||||||
|
"Could not decode the message data. Please ensure it is saved with the `UTF-8` encoding."
|
||||||
|
))
|
||||||
|
|
||||||
|
if error is not None:
|
||||||
|
raise UserInputError(error)
|
||||||
|
else:
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_to_args(cls, parent_id: ParentID, value: dict, **kwargs) -> MessageArgs:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
args['content'] = value.get('content', "")
|
||||||
|
if 'embed' in value:
|
||||||
|
embed = discord.Embed.from_dict(value['embed'])
|
||||||
|
args['embed'] = embed
|
||||||
|
if 'embeds' in value:
|
||||||
|
embeds = []
|
||||||
|
for embed_data in value['embeds']:
|
||||||
|
embeds.append(discord.Embed.from_dict(embed_data))
|
||||||
|
args['embeds'] = embeds
|
||||||
|
return MessageArgs(**args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _data_from_value(cls, parent_id: ParentID, value: Optional[dict], **kwargs):
|
||||||
|
if value and any(value.get(key, None) for key in ('content', 'embed', 'embeds')):
|
||||||
|
data = json.dumps(value)
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _data_to_value(cls, parent_id: ParentID, data: Optional[str], **kwargs):
|
||||||
|
if data:
|
||||||
|
value = json.loads(data)
|
||||||
|
else:
|
||||||
|
value = None
|
||||||
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Provided user string can be downright random.
|
||||||
|
|
||||||
|
If it isn't json-formatted, treat it as the content of the message.
|
||||||
|
If it is, do basic checking on the length and embeds.
|
||||||
|
"""
|
||||||
|
string = string.strip()
|
||||||
|
if not string or string.lower() == 'none':
|
||||||
|
return None
|
||||||
|
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
|
||||||
|
error_tip = t(_p(
|
||||||
|
'settype:message|error_suffix',
|
||||||
|
"You can view, test, and fix your embed using the online [embed builder]({link})."
|
||||||
|
)).format(
|
||||||
|
link="https://glitchii.github.io/embedbuilder/?editor=json"
|
||||||
|
)
|
||||||
|
|
||||||
|
if string.startswith('{') and string.endswith('}'):
|
||||||
|
# Assume the string is a json-formatted message dict
|
||||||
|
try:
|
||||||
|
value = json.loads(string)
|
||||||
|
except json.JSONDecodeError as err:
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:invalid_json',
|
||||||
|
"The provided message data was not a valid JSON document!\n"
|
||||||
|
"`{error}`"
|
||||||
|
)).format(error=str(err))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
if not isinstance(value, dict) or not any(value.get(key, None) for key in ('content', 'embed', 'embeds')):
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:json_missing_keys',
|
||||||
|
"Message data must be a JSON object with at least one of the following fields: "
|
||||||
|
"`content`, `embed`, `embeds`"
|
||||||
|
))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
embed_data = value.get('embed', None)
|
||||||
|
if not isinstance(embed_data, dict):
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:json_embed_type',
|
||||||
|
"`embed` field must be a valid JSON object."
|
||||||
|
))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
embeds_data = value.get('embeds', [])
|
||||||
|
if not isinstance(embeds_data, list):
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:json_embeds_type',
|
||||||
|
"`embeds` field must be a list."
|
||||||
|
))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
if embed_data and embeds_data:
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:json_embed_embeds',
|
||||||
|
"Message data cannot include both `embed` and `embeds`."
|
||||||
|
))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
content_data = value.get('content', "")
|
||||||
|
if not isinstance(content_data, str):
|
||||||
|
error = t(_p(
|
||||||
|
'settype:message|error:json_content_type',
|
||||||
|
"`content` field must be a string."
|
||||||
|
))
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
# Validate embeds, which is the most likely place for something to go wrong
|
||||||
|
embeds = [embed_data] if embed_data else embeds_data
|
||||||
|
try:
|
||||||
|
for embed in embeds:
|
||||||
|
discord.Embed.from_dict(embed)
|
||||||
|
except Exception as e:
|
||||||
|
# from_dict may raise a range of possible exceptions.
|
||||||
|
raw_error = ''.join(
|
||||||
|
traceback.TracebackException.from_exception(e).format_exception_only()
|
||||||
|
)
|
||||||
|
error = t(_p(
|
||||||
|
'ui:settype:message|error:embed_conversion',
|
||||||
|
"Could not parse the message embed data.\n"
|
||||||
|
"**Error:** `{exception}`"
|
||||||
|
)).format(exception=raw_error)
|
||||||
|
raise UserInputError(error + '\n' + error_tip)
|
||||||
|
|
||||||
|
# At this point, the message will at least successfully convert into MessageArgs
|
||||||
|
# There are numerous ways it could still be invalid, e.g. invalid urls, or too-long fields
|
||||||
|
# or the total message content being too long, or too many fields, etc
|
||||||
|
# This will need to be caught in anything which displays a message parsed from user data.
|
||||||
|
else:
|
||||||
|
# Either the string is not json formatted, or the formatting is broken
|
||||||
|
# Assume the string is a content message
|
||||||
|
value = {
|
||||||
|
'content': string
|
||||||
|
}
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_data(cls, parent_id: ParentID, data: Optional[str], **kwargs):
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = cls._data_to_value(parent_id, data, **kwargs)
|
||||||
|
content = value.get('content', "")
|
||||||
|
if 'embed' in value or 'embeds' in value or len(content) > 100:
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
formatted = t(_p(
|
||||||
|
'settype:message|format:too_long',
|
||||||
|
"Too long to display! See Preview."
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
formatted = content
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_field(self):
|
||||||
|
field = super().input_field
|
||||||
|
field.style = TextStyle.long
|
||||||
|
return field
|
||||||
1
src/data
Submodule
1
src/data
Submodule
Submodule src/data added at cfdfe0eb50
373
src/meta/LionBot.py
Normal file
373
src/meta/LionBot.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
from typing import List, Literal, LiteralString, Optional, TYPE_CHECKING, overload
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
|
from constants import SCHEMA_VERSIONS
|
||||||
|
import discord
|
||||||
|
from discord.utils import MISSING
|
||||||
|
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||||
|
from discord.ext.commands.errors import CommandInvokeError, CheckFailure
|
||||||
|
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
from data import Database, ORDER
|
||||||
|
from utils.lib import tabulate
|
||||||
|
from babel.translator import LeoBabel
|
||||||
|
from botdata import BotData, VersionHistory
|
||||||
|
|
||||||
|
from .config import Conf
|
||||||
|
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
||||||
|
from .context import context
|
||||||
|
from .LionContext import LionContext
|
||||||
|
from .LionTree import LionTree
|
||||||
|
from .errors import HandledException, SafeCancellation
|
||||||
|
from .monitor import SystemMonitor, ComponentMonitor, StatusLevel, ComponentStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.cog import CoreCog
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LionBot(Bot):
|
||||||
|
def __init__(
|
||||||
|
self, *args, appname: str, shardname: str, db: Database, config: Conf,
|
||||||
|
initial_extensions: List[str], web_client: ClientSession,
|
||||||
|
testing_guilds: List[int] = [], **kwargs
|
||||||
|
):
|
||||||
|
kwargs.setdefault('tree_cls', LionTree)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.web_client = web_client
|
||||||
|
self.testing_guilds = testing_guilds
|
||||||
|
self.initial_extensions = initial_extensions
|
||||||
|
self.db = db
|
||||||
|
self.appname = appname
|
||||||
|
self.shardname = shardname
|
||||||
|
# self.appdata = appdata
|
||||||
|
self.data: BotData = db.load_registry(BotData())
|
||||||
|
self.config = config
|
||||||
|
self.translator = LeoBabel()
|
||||||
|
|
||||||
|
self.system_monitor = SystemMonitor()
|
||||||
|
self.monitor = ComponentMonitor('LionBot', self._monitor_status)
|
||||||
|
self.system_monitor.add_component(self.monitor)
|
||||||
|
|
||||||
|
self._locks = WeakValueDictionary()
|
||||||
|
self._running_events = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dbconn(self):
|
||||||
|
return self.db
|
||||||
|
|
||||||
|
@property
|
||||||
|
def core(self):
|
||||||
|
return self.get_cog('CoreCog')
|
||||||
|
|
||||||
|
async def _monitor_status(self):
|
||||||
|
if self.is_closed():
|
||||||
|
level = StatusLevel.ERRORED
|
||||||
|
info = "(ERROR) Websocket is closed"
|
||||||
|
data = {}
|
||||||
|
elif self.is_ws_ratelimited():
|
||||||
|
level = StatusLevel.WAITING
|
||||||
|
info = "(WAITING) Websocket is ratelimited"
|
||||||
|
data = {}
|
||||||
|
elif not self.is_ready():
|
||||||
|
level = StatusLevel.STARTING
|
||||||
|
info = "(STARTING) Not yet ready"
|
||||||
|
data = {}
|
||||||
|
else:
|
||||||
|
level = StatusLevel.OKAY
|
||||||
|
info = (
|
||||||
|
"(OK) "
|
||||||
|
"Logged in with {guild_count} guilds, "
|
||||||
|
", websocket latency {latency}, and {events} running events."
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
'guild_count': len(self.guilds),
|
||||||
|
'latency': self.latency,
|
||||||
|
'events': len(self._running_events),
|
||||||
|
}
|
||||||
|
return ComponentStatus(level, info, info, data)
|
||||||
|
|
||||||
|
async def setup_hook(self) -> None:
|
||||||
|
log_context.set(f"APP: {self.application_id}")
|
||||||
|
|
||||||
|
for extension in self.initial_extensions:
|
||||||
|
await self.load_extension(extension)
|
||||||
|
|
||||||
|
for guildid in self.testing_guilds:
|
||||||
|
guild = discord.Object(guildid)
|
||||||
|
if not self.shard_count or (self.shard_id == ((guildid >> 22) % self.shard_count)):
|
||||||
|
self.tree.copy_global_to(guild=guild)
|
||||||
|
await self.tree.sync(guild=guild)
|
||||||
|
|
||||||
|
# To make the type checker happy about fetching cogs by name
|
||||||
|
# TODO: Move this to stubs at some point
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: Literal['CoreCog']) -> 'CoreCog':
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_cog(self, name: str) -> Optional[Cog]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_cog(self, name: str) -> Optional[Cog]:
|
||||||
|
return super().get_cog(name)
|
||||||
|
|
||||||
|
async def add_cog(self, cog: Cog, **kwargs):
|
||||||
|
sup = super()
|
||||||
|
@log_wrap(action=f"Attach {cog.__cog_name__}")
|
||||||
|
async def wrapper():
|
||||||
|
logger.info(f"Attaching Cog {cog.__cog_name__}")
|
||||||
|
await sup.add_cog(cog, **kwargs)
|
||||||
|
logger.debug(f"Attached Cog {cog.__cog_name__} with no errors.")
|
||||||
|
await wrapper()
|
||||||
|
|
||||||
|
async def load_extension(self, name, *, package=None, **kwargs):
|
||||||
|
sup = super()
|
||||||
|
@log_wrap(action=f"Load {name.strip('.')}")
|
||||||
|
async def wrapper():
|
||||||
|
logger.info(f"Loading extension {name} in package {package}.")
|
||||||
|
await sup.load_extension(name, package=package, **kwargs)
|
||||||
|
logger.debug(f"Loaded extension {name} in package {package}.")
|
||||||
|
await wrapper()
|
||||||
|
|
||||||
|
async def start(self, token: str, *, reconnect: bool = True):
|
||||||
|
await self.data.init()
|
||||||
|
for component, req in SCHEMA_VERSIONS.items():
|
||||||
|
await self.version_check(component, req)
|
||||||
|
|
||||||
|
with logging_context(action="Login"):
|
||||||
|
start_task = asyncio.create_task(self.login(token))
|
||||||
|
await start_task
|
||||||
|
|
||||||
|
with logging_context(stack=("Running",)):
|
||||||
|
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
|
||||||
|
await run_task
|
||||||
|
|
||||||
|
async def version_check(self, component: str, req_version: int):
|
||||||
|
# Query the database to confirm that the given component is listed with the given version.
|
||||||
|
# Typically done upon loading a component
|
||||||
|
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||||
|
|
||||||
|
version = rows[0].to_version if rows else 0
|
||||||
|
|
||||||
|
if version != req_version:
|
||||||
|
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Component %s passed version check with version %s",
|
||||||
|
component,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch(self, event_name: str, *args, **kwargs):
|
||||||
|
with logging_context(action=f"Dispatch {event_name}"):
|
||||||
|
super().dispatch(event_name, *args, **kwargs)
|
||||||
|
|
||||||
|
def _schedule_event(self, coro, event_name, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Extends client._schedule_event to keep a persistent
|
||||||
|
background task store.
|
||||||
|
"""
|
||||||
|
task = super()._schedule_event(coro, event_name, *args, **kwargs)
|
||||||
|
self._running_events.add(task)
|
||||||
|
task.add_done_callback(lambda fut: self._running_events.discard(fut))
|
||||||
|
|
||||||
|
def idlock(self, snowflakeid):
|
||||||
|
lock = self._locks.get(snowflakeid, None)
|
||||||
|
if lock is None:
|
||||||
|
lock = self._locks[snowflakeid] = asyncio.Lock()
|
||||||
|
return lock
|
||||||
|
|
||||||
|
async def on_ready(self):
|
||||||
|
logger.info(
|
||||||
|
f"Logged in as {self.application.name}\n"
|
||||||
|
f"Application id {self.application.id}\n"
|
||||||
|
f"Shard Talk identifier {self.shardname}\n"
|
||||||
|
"------------------------------\n"
|
||||||
|
f"Enabled Modules: {', '.join(self.extensions.keys())}\n"
|
||||||
|
f"Loaded Cogs: {', '.join(self.cogs.keys())}\n"
|
||||||
|
f"Registered Data: {', '.join(self.db.registries.keys())}\n"
|
||||||
|
f"Listening for {sum(1 for _ in self.walk_commands())} commands\n"
|
||||||
|
"------------------------------\n"
|
||||||
|
f"Logged in to {len(self.guilds)} guilds on shard {self.shard_id} of {self.shard_count}\n"
|
||||||
|
"Ready to take commands!\n",
|
||||||
|
extra={'action': 'Ready'}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_context(self, origin, /, *, cls=MISSING):
|
||||||
|
if cls is MISSING:
|
||||||
|
cls = LionContext
|
||||||
|
ctx = await super().get_context(origin, cls=cls)
|
||||||
|
context.set(ctx)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
async def on_command(self, ctx: LionContext):
|
||||||
|
logger.info(
|
||||||
|
f"Executing command '{ctx.command.qualified_name}' "
|
||||||
|
f"(from module '{ctx.cog.qualified_name if ctx.cog else 'None'}') "
|
||||||
|
f"with interaction: {ctx.interaction.data if ctx.interaction else None}",
|
||||||
|
extra={'with_ctx': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_command_error(self, ctx, exception):
|
||||||
|
# TODO: Some of these could have more user-feedback
|
||||||
|
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||||
|
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||||
|
cmd_str = ctx.command.app_command.to_dict(self.tree)
|
||||||
|
else:
|
||||||
|
cmd_str = str(ctx.command)
|
||||||
|
try:
|
||||||
|
raise exception
|
||||||
|
except (HybridCommandError, CommandInvokeError, appCommandInvokeError):
|
||||||
|
try:
|
||||||
|
if isinstance(exception.original, (HybridCommandError, CommandInvokeError, appCommandInvokeError)):
|
||||||
|
original = exception.original.original
|
||||||
|
raise original
|
||||||
|
else:
|
||||||
|
original = exception.original
|
||||||
|
raise original
|
||||||
|
except HandledException:
|
||||||
|
pass
|
||||||
|
except TransformerError as e:
|
||||||
|
msg = str(e)
|
||||||
|
if msg:
|
||||||
|
try:
|
||||||
|
await ctx.error_reply(msg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.debug(
|
||||||
|
f"Caught a transformer error: {repr(e)}",
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
except SafeCancellation:
|
||||||
|
if original.msg:
|
||||||
|
try:
|
||||||
|
await ctx.error_reply(original.msg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.debug(
|
||||||
|
f"Caught a safe cancellation: {original.details}",
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
except discord.Forbidden:
|
||||||
|
# Unknown uncaught Forbidden
|
||||||
|
try:
|
||||||
|
# Attempt a general error reply
|
||||||
|
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||||
|
except Exception:
|
||||||
|
# We can't send anything at all. Exit quietly, but log.
|
||||||
|
logger.warning(
|
||||||
|
f"Caught an unhandled 'Forbidden' while executing: {cmd_str}",
|
||||||
|
exc_info=True,
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
except discord.HTTPException:
|
||||||
|
logger.error(
|
||||||
|
f"Caught an unhandled 'HTTPException' while executing: {cmd_str}",
|
||||||
|
exc_info=True,
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Caught an unknown CommandInvokeError while executing: {cmd_str}",
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
error_embed = discord.Embed(
|
||||||
|
title="Something went wrong!",
|
||||||
|
colour=discord.Colour.dark_red()
|
||||||
|
)
|
||||||
|
error_embed.description = (
|
||||||
|
"An unexpected error occurred while processing your command!\n"
|
||||||
|
"Our development team has been notified, and the issue will be addressed soon.\n"
|
||||||
|
)
|
||||||
|
details = {}
|
||||||
|
details['error'] = f"`{repr(e)}`"
|
||||||
|
if ctx.interaction:
|
||||||
|
details['interactionid'] = f"`{ctx.interaction.id}`"
|
||||||
|
if ctx.command:
|
||||||
|
details['cmd'] = f"`{ctx.command.qualified_name}`"
|
||||||
|
if ctx.author:
|
||||||
|
details['author'] = f"`{ctx.author.id}` -- `{ctx.author}`"
|
||||||
|
if ctx.guild:
|
||||||
|
details['guild'] = f"`{ctx.guild.id}` -- `{ctx.guild.name}`"
|
||||||
|
details['my_guild_perms'] = f"`{ctx.guild.me.guild_permissions.value}`"
|
||||||
|
if ctx.author:
|
||||||
|
ownerstr = ' (owner)' if ctx.author.id == ctx.guild.owner_id else ''
|
||||||
|
details['author_guild_perms'] = f"`{ctx.author.guild_permissions.value}{ownerstr}`"
|
||||||
|
if ctx.channel.type is discord.enums.ChannelType.private:
|
||||||
|
details['channel'] = "`Direct Message`"
|
||||||
|
elif ctx.channel:
|
||||||
|
details['channel'] = f"`{ctx.channel.id}` -- `{ctx.channel.name}`"
|
||||||
|
details['my_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.guild.me).value}`"
|
||||||
|
if ctx.author:
|
||||||
|
details['author_channel_perms'] = f"`{ctx.channel.permissions_for(ctx.author).value}`"
|
||||||
|
details['shard'] = f"`{self.shardname}`"
|
||||||
|
details['log_stack'] = f"`{log_action_stack.get()}`"
|
||||||
|
|
||||||
|
table = '\n'.join(tabulate(*details.items()))
|
||||||
|
error_embed.add_field(name='Details', value=table)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ctx.error_reply(embed=error_embed)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
exception.original = HandledException(exception.original)
|
||||||
|
except CheckFailure as e:
|
||||||
|
logger.debug(
|
||||||
|
f"Command failed check: {e}: {e.args}",
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await ctx.error_reply(str(e))
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
# Completely unknown exception outside of command invocation!
|
||||||
|
# Something is very wrong here, don't attempt user interaction.
|
||||||
|
logger.exception(
|
||||||
|
f"Caught an unknown top-level exception while executing: {cmd_str}",
|
||||||
|
extra={'action': 'BotError', 'with_ctx': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_command(self, command):
|
||||||
|
if not hasattr(command, '_placeholder_group_'):
|
||||||
|
super().add_command(command)
|
||||||
|
|
||||||
|
def request_chunking_for(self, guild):
|
||||||
|
if not guild.chunked:
|
||||||
|
return asyncio.create_task(
|
||||||
|
self._connection.chunk_guild(guild, wait=False, cache=True),
|
||||||
|
name=f"Background chunkreq for {guild.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_interaction(self, interaction: discord.Interaction):
|
||||||
|
"""
|
||||||
|
Adds the interaction author to guild cache if appropriate.
|
||||||
|
|
||||||
|
This gets run a little bit late, so it is possible the interaction gets handled
|
||||||
|
without the author being in case.
|
||||||
|
"""
|
||||||
|
guild = interaction.guild
|
||||||
|
user = interaction.user
|
||||||
|
if guild is not None and user is not None and isinstance(user, discord.Member):
|
||||||
|
if not guild.get_member(user.id):
|
||||||
|
guild._add_member(user)
|
||||||
|
if guild is not None and not guild.chunked:
|
||||||
|
# Getting an interaction in the guild is a good enough reason to request chunking
|
||||||
|
logger.info(
|
||||||
|
f"Unchunked guild <gid: {guild.id}> requesting chunking after interaction."
|
||||||
|
)
|
||||||
|
self.request_chunking_for(guild)
|
||||||
58
src/meta/LionCog.py
Normal file
58
src/meta/LionCog.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from discord.ext.commands import Cog
|
||||||
|
from discord.ext import commands as cmds
|
||||||
|
|
||||||
|
|
||||||
|
class LionCog(Cog):
|
||||||
|
# A set of other cogs that this cog depends on
|
||||||
|
depends_on: set['LionCog'] = set()
|
||||||
|
_placeholder_groups_: set[str]
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
cls._placeholder_groups_ = set()
|
||||||
|
|
||||||
|
for base in reversed(cls.__mro__):
|
||||||
|
for elem, value in base.__dict__.items():
|
||||||
|
if isinstance(value, cmds.HybridGroup) and hasattr(value, '_placeholder_group_'):
|
||||||
|
cls._placeholder_groups_.add(value.name)
|
||||||
|
|
||||||
|
def __new__(cls, *args: Any, **kwargs: Any):
|
||||||
|
# Patch to ensure no placeholder groups are in the command list
|
||||||
|
self = super().__new__(cls)
|
||||||
|
self.__cog_commands__ = [
|
||||||
|
command for command in self.__cog_commands__ if command.name not in cls._placeholder_groups_
|
||||||
|
]
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def _inject(self, bot, *args, **kwargs):
|
||||||
|
if self.depends_on:
|
||||||
|
not_found = {cogname for cogname in self.depends_on if not bot.get_cog(cogname)}
|
||||||
|
raise ValueError(f"Could not load cog '{self.__class__.__name__}', dependencies missing: {not_found}")
|
||||||
|
|
||||||
|
return await super()._inject(bot, *args, *kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def placeholder_group(cls, group: cmds.HybridGroup):
|
||||||
|
group._placeholder_group_ = True
|
||||||
|
return group
|
||||||
|
|
||||||
|
def crossload_group(self, placeholder_group: cmds.HybridGroup, target_group: cmds.HybridGroup):
|
||||||
|
"""
|
||||||
|
Crossload a placeholder group's commands into the target group
|
||||||
|
"""
|
||||||
|
if not isinstance(placeholder_group, cmds.HybridGroup) or not isinstance(target_group, cmds.HybridGroup):
|
||||||
|
raise ValueError("Placeholder and target groups my be HypridGroups.")
|
||||||
|
if placeholder_group.name not in self._placeholder_groups_:
|
||||||
|
raise ValueError("Placeholder group was not registered! Stopping to avoid duplicates.")
|
||||||
|
if target_group.app_command is None:
|
||||||
|
raise ValueError("Target group has no app_command to crossload into.")
|
||||||
|
|
||||||
|
for command in placeholder_group.commands:
|
||||||
|
placeholder_group.remove_command(command.name)
|
||||||
|
target_group.remove_command(command.name)
|
||||||
|
acmd = command.app_command._copy_with(parent=target_group.app_command, binding=self)
|
||||||
|
command.app_command = acmd
|
||||||
|
target_group.add_command(command)
|
||||||
195
src/meta/LionContext.py
Normal file
195
src/meta/LionContext.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
import types
|
||||||
|
import logging
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.enums import ChannelType
|
||||||
|
from discord.ext.commands import Context
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .LionBot import LionBot
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Stuff that might be useful to implement (see cmdClient):
|
||||||
|
sent_messages cache
|
||||||
|
tasks cache
|
||||||
|
error reply
|
||||||
|
usage
|
||||||
|
interaction cache
|
||||||
|
View cache?
|
||||||
|
setting access
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
FlatContext = namedtuple(
|
||||||
|
'FlatContext',
|
||||||
|
('message',
|
||||||
|
'interaction',
|
||||||
|
'guild',
|
||||||
|
'author',
|
||||||
|
'channel',
|
||||||
|
'alias',
|
||||||
|
'prefix',
|
||||||
|
'failed')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LionContext(Context['LionBot']):
|
||||||
|
"""
|
||||||
|
Represents the context a command is invoked under.
|
||||||
|
|
||||||
|
Extends Context to add Lion-specific methods and attributes.
|
||||||
|
Also adds several contextual wrapped utilities for simpler user during command invocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
parts = {}
|
||||||
|
if self.interaction is not None:
|
||||||
|
parts['iid'] = self.interaction.id
|
||||||
|
parts['itype'] = f"\"{self.interaction.type.name}\""
|
||||||
|
if self.message is not None:
|
||||||
|
parts['mid'] = self.message.id
|
||||||
|
if self.author is not None:
|
||||||
|
parts['uid'] = self.author.id
|
||||||
|
parts['uname'] = f"\"{self.author.name}\""
|
||||||
|
if self.channel is not None:
|
||||||
|
parts['cid'] = self.channel.id
|
||||||
|
if self.channel.type is ChannelType.private:
|
||||||
|
parts['cname'] = f"\"{self.channel.recipient}\""
|
||||||
|
else:
|
||||||
|
parts['cname'] = f"\"{self.channel.name}\""
|
||||||
|
if self.guild is not None:
|
||||||
|
parts['gid'] = self.guild.id
|
||||||
|
parts['gname'] = f"\"{self.guild.name}\""
|
||||||
|
if self.command is not None:
|
||||||
|
parts['cmd'] = f"\"{self.command.qualified_name}\""
|
||||||
|
if self.invoked_with is not None:
|
||||||
|
parts['alias'] = f"\"{self.invoked_with}\""
|
||||||
|
if self.command_failed:
|
||||||
|
parts['failed'] = self.command_failed
|
||||||
|
|
||||||
|
return "<LionContext: {}>".format(
|
||||||
|
' '.join(f"{name}={value}" for name, value in parts.items())
|
||||||
|
)
|
||||||
|
|
||||||
|
def flatten(self):
|
||||||
|
"""Flat pure-data context information, for caching and logging."""
|
||||||
|
return FlatContext(
|
||||||
|
self.message.id,
|
||||||
|
self.interaction.id if self.interaction is not None else None,
|
||||||
|
self.guild.id if self.guild is not None else None,
|
||||||
|
self.author.id if self.author is not None else None,
|
||||||
|
self.channel.id if self.channel is not None else None,
|
||||||
|
self.invoked_with,
|
||||||
|
self.prefix,
|
||||||
|
self.command_failed
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def util(cls, util_func):
|
||||||
|
"""
|
||||||
|
Decorator to make a utility function available as a Context instance method.
|
||||||
|
"""
|
||||||
|
setattr(cls, util_func.__name__, util_func)
|
||||||
|
logger.debug(f"Attached context utility function: {util_func.__name__}")
|
||||||
|
return util_func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wrappable_util(cls, util_func):
|
||||||
|
"""
|
||||||
|
Decorator to add a Wrappable utility function as a Context instance method.
|
||||||
|
"""
|
||||||
|
wrapped = Wrappable(util_func)
|
||||||
|
setattr(cls, util_func.__name__, wrapped)
|
||||||
|
logger.debug(f"Attached wrappable context utility function: {util_func.__name__}")
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
async def error_reply(self, content: Optional[str] = None, **kwargs):
|
||||||
|
if content and 'embed' not in kwargs:
|
||||||
|
embed = discord.Embed(
|
||||||
|
colour=discord.Colour.red(),
|
||||||
|
description=content
|
||||||
|
)
|
||||||
|
kwargs['embed'] = embed
|
||||||
|
content = None
|
||||||
|
|
||||||
|
# Expect this may be run in highly unusual circumstances.
|
||||||
|
# This should never error, or at least handle all errors.
|
||||||
|
if self.interaction:
|
||||||
|
kwargs.setdefault('ephemeral', True)
|
||||||
|
try:
|
||||||
|
await self.reply(content=content, **kwargs)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Unknown exception in 'error_reply'.",
|
||||||
|
extra={'action': 'error_reply', 'ctx': repr(self), 'with_ctx': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Wrappable:
|
||||||
|
__slots__ = ('_func', 'wrappers')
|
||||||
|
|
||||||
|
def __init__(self, func):
|
||||||
|
self._func = func
|
||||||
|
self.wrappers = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def __name__(self):
|
||||||
|
return self._func.__name__
|
||||||
|
|
||||||
|
def add_wrapper(self, func, name=None):
|
||||||
|
self.wrappers = self.wrappers or {}
|
||||||
|
name = name or func.__name__
|
||||||
|
self.wrappers[name] = func
|
||||||
|
logger.debug(
|
||||||
|
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
|
||||||
|
extra={'action': "Wrap Util"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_wrapper(self, name):
|
||||||
|
if not self.wrappers or name not in self.wrappers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
|
||||||
|
)
|
||||||
|
self.wrappers.pop(name)
|
||||||
|
logger.debug(
|
||||||
|
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
|
||||||
|
extra={'action': "Unwrap Util"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if self.wrappers:
|
||||||
|
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self._func(*args, **kwargs)
|
||||||
|
|
||||||
|
def _wrapped(self, iter_wraps):
|
||||||
|
next_wrap = next(iter_wraps, None)
|
||||||
|
if next_wrap:
|
||||||
|
def _func(*args, **kwargs):
|
||||||
|
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
|
||||||
|
else:
|
||||||
|
_func = self._func
|
||||||
|
return _func
|
||||||
|
|
||||||
|
def __get__(self, instance, cls=None):
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
return types.MethodType(self, instance)
|
||||||
|
|
||||||
|
|
||||||
|
LionContext.reply = Wrappable(LionContext.reply)
|
||||||
|
|
||||||
|
|
||||||
|
# @LionContext.reply.add_wrapper
|
||||||
|
# async def think(func, ctx, *args, **kwargs):
|
||||||
|
# await ctx.channel.send("thinking")
|
||||||
|
# await func(ctx, *args, **kwargs)
|
||||||
148
src/meta/LionTree.py
Normal file
148
src/meta/LionTree.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import Interaction
|
||||||
|
from discord.app_commands import CommandTree
|
||||||
|
from discord.app_commands.errors import AppCommandError, CommandInvokeError
|
||||||
|
from discord.enums import InteractionType
|
||||||
|
from discord.app_commands.namespace import Namespace
|
||||||
|
|
||||||
|
from utils.lib import tabulate
|
||||||
|
|
||||||
|
from .logger import logging_context, set_logging_context, log_wrap, log_action_stack
|
||||||
|
from .errors import SafeCancellation
|
||||||
|
from .config import conf
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LionTree(CommandTree):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._call_tasks = set()
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error) -> None:
|
||||||
|
try:
|
||||||
|
if isinstance(error, CommandInvokeError):
|
||||||
|
raise error.original
|
||||||
|
else:
|
||||||
|
raise error
|
||||||
|
except SafeCancellation:
|
||||||
|
# Assume this has already been handled
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Unhandled exception in interaction: {interaction}", extra={'action': 'TreeError'})
|
||||||
|
if interaction.type is not InteractionType.autocomplete:
|
||||||
|
embed = self.bugsplat(interaction, error)
|
||||||
|
await self.error_reply(interaction, embed)
|
||||||
|
|
||||||
|
async def error_reply(self, interaction, embed):
|
||||||
|
if not interaction.is_expired():
|
||||||
|
try:
|
||||||
|
if interaction.response.is_done():
|
||||||
|
await interaction.followup.send(embed=embed, ephemeral=True)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def bugsplat(self, interaction, e):
|
||||||
|
error_embed = discord.Embed(title="Something went wrong!", colour=discord.Colour.red())
|
||||||
|
error_embed.description = (
|
||||||
|
"An unexpected error occurred during this interaction!\n"
|
||||||
|
"Our development team has been notified, and the issue will be addressed soon.\n"
|
||||||
|
)
|
||||||
|
details = {}
|
||||||
|
details['error'] = f"`{repr(e)}`"
|
||||||
|
details['interactionid'] = f"`{interaction.id}`"
|
||||||
|
details['interactiontype'] = f"`{interaction.type}`"
|
||||||
|
if interaction.command:
|
||||||
|
details['cmd'] = f"`{interaction.command.qualified_name}`"
|
||||||
|
if interaction.user:
|
||||||
|
details['user'] = f"`{interaction.user.id}` -- `{interaction.user}`"
|
||||||
|
if interaction.guild:
|
||||||
|
details['guild'] = f"`{interaction.guild.id}` -- `{interaction.guild.name}`"
|
||||||
|
details['my_guild_perms'] = f"`{interaction.guild.me.guild_permissions.value}`"
|
||||||
|
if interaction.user:
|
||||||
|
ownerstr = ' (owner)' if interaction.user.id == interaction.guild.owner_id else ''
|
||||||
|
details['user_guild_perms'] = f"`{interaction.user.guild_permissions.value}{ownerstr}`"
|
||||||
|
if interaction.channel.type is discord.enums.ChannelType.private:
|
||||||
|
details['channel'] = "`Direct Message`"
|
||||||
|
elif interaction.channel:
|
||||||
|
details['channel'] = f"`{interaction.channel.id}` -- `{interaction.channel.name}`"
|
||||||
|
details['my_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.guild.me).value}`"
|
||||||
|
if interaction.user:
|
||||||
|
details['user_channel_perms'] = f"`{interaction.channel.permissions_for(interaction.user).value}`"
|
||||||
|
details['shard'] = f"`{interaction.client.shardname}`"
|
||||||
|
details['log_stack'] = f"`{log_action_stack.get()}`"
|
||||||
|
|
||||||
|
table = '\n'.join(tabulate(*details.items()))
|
||||||
|
error_embed.add_field(name='Details', value=table)
|
||||||
|
return error_embed
|
||||||
|
|
||||||
|
def _from_interaction(self, interaction: Interaction) -> None:
|
||||||
|
@log_wrap(context=f"iid: {interaction.id}", isolate=False)
|
||||||
|
async def wrapper():
|
||||||
|
try:
|
||||||
|
await self._call(interaction)
|
||||||
|
except AppCommandError as e:
|
||||||
|
await self._dispatch_error(interaction, e)
|
||||||
|
|
||||||
|
task = self.client.loop.create_task(wrapper(), name='CommandTree-invoker')
|
||||||
|
self._call_tasks.add(task)
|
||||||
|
task.add_done_callback(lambda fut: self._call_tasks.discard(fut))
|
||||||
|
|
||||||
|
async def _call(self, interaction):
|
||||||
|
if not await self.interaction_check(interaction):
|
||||||
|
interaction.command_failed = True
|
||||||
|
return
|
||||||
|
|
||||||
|
data = interaction.data # type: ignore
|
||||||
|
type = data.get('type', 1)
|
||||||
|
if type != 1:
|
||||||
|
# Context menu command...
|
||||||
|
await self._call_context_menu(interaction, data, type)
|
||||||
|
return
|
||||||
|
|
||||||
|
command, options = self._get_app_command_options(data)
|
||||||
|
|
||||||
|
# Pre-fill the cached slot to prevent re-computation
|
||||||
|
interaction._cs_command = command
|
||||||
|
|
||||||
|
# At this point options refers to the arguments of the command
|
||||||
|
# and command refers to the class type we care about
|
||||||
|
namespace = Namespace(interaction, data.get('resolved', {}), options)
|
||||||
|
|
||||||
|
# Same pre-fill as above
|
||||||
|
interaction._cs_namespace = namespace
|
||||||
|
|
||||||
|
# Auto complete handles the namespace differently... so at this point this is where we decide where that is.
|
||||||
|
if interaction.type is InteractionType.autocomplete:
|
||||||
|
set_logging_context(action=f"Acmp {command.qualified_name}")
|
||||||
|
focused = next((opt['name'] for opt in options if opt.get('focused')), None)
|
||||||
|
if focused is None:
|
||||||
|
raise AppCommandError(
|
||||||
|
'This should not happen, but there is no focused element. This is a Discord bug.'
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await command._invoke_autocomplete(interaction, focused, namespace)
|
||||||
|
except Exception as e:
|
||||||
|
await self.on_error(interaction, e)
|
||||||
|
return
|
||||||
|
|
||||||
|
set_logging_context(action=f"Run {command.qualified_name}")
|
||||||
|
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
|
||||||
|
try:
|
||||||
|
await command._invoke_with_namespace(interaction, namespace)
|
||||||
|
except AppCommandError as e:
|
||||||
|
interaction.command_failed = True
|
||||||
|
await command._invoke_error_handlers(interaction, e)
|
||||||
|
await self.on_error(interaction, e)
|
||||||
|
else:
|
||||||
|
if not interaction.command_failed:
|
||||||
|
self.client.dispatch('app_command_completion', interaction, command)
|
||||||
|
finally:
|
||||||
|
if interaction.command_failed:
|
||||||
|
logger.debug("Command completed with errors.")
|
||||||
|
else:
|
||||||
|
logger.debug("Command completed without errors.")
|
||||||
15
src/meta/__init__.py
Normal file
15
src/meta/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from .LionBot import LionBot
|
||||||
|
from .LionCog import LionCog
|
||||||
|
from .LionContext import LionContext
|
||||||
|
from .LionTree import LionTree
|
||||||
|
|
||||||
|
from .logger import logging_context, log_wrap, log_action_stack, log_context, log_app
|
||||||
|
from .config import conf, configEmoji
|
||||||
|
from .args import args
|
||||||
|
from .app import appname, appname_from_shard, shard_from_appname
|
||||||
|
from .errors import HandledException, UserInputError, ResponseTimedOut, SafeCancellation, UserCancelled
|
||||||
|
from .context import context, ctx_bot
|
||||||
|
|
||||||
|
from . import sharding
|
||||||
|
from . import logger
|
||||||
|
from . import app
|
||||||
32
src/meta/app.py
Normal file
32
src/meta/app.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
appname: str
|
||||||
|
The base identifer for this application.
|
||||||
|
This identifies which services the app offers.
|
||||||
|
shardname: str
|
||||||
|
The specific name of the running application.
|
||||||
|
Only one process should be connecteded with a given appname.
|
||||||
|
For the bot apps, usually specifies the shard id and shard number.
|
||||||
|
"""
|
||||||
|
# TODO: Find a better schema for these. We use appname for shard_talk, do we need it for data?
|
||||||
|
|
||||||
|
from . import sharding, conf
|
||||||
|
from .logger import log_app
|
||||||
|
from .args import args
|
||||||
|
|
||||||
|
|
||||||
|
appname = conf.data['appid']
|
||||||
|
appid = appname # backwards compatibility
|
||||||
|
|
||||||
|
|
||||||
|
def appname_from_shard(shardid):
|
||||||
|
appname = f"{conf.data['appid']}_{sharding.shard_count:02}_{shardid:02}"
|
||||||
|
return appname
|
||||||
|
|
||||||
|
|
||||||
|
def shard_from_appname(appname: str):
|
||||||
|
return int(appname.rsplit('_', maxsplit=1)[-1])
|
||||||
|
|
||||||
|
|
||||||
|
shardname = appname_from_shard(sharding.shard_number)
|
||||||
|
|
||||||
|
log_app.set(shardname)
|
||||||
35
src/meta/args.py
Normal file
35
src/meta/args.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from constants import CONFIG_FILE
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Parsed commandline arguments
|
||||||
|
# ------------------------------
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--conf',
|
||||||
|
dest='config',
|
||||||
|
default=CONFIG_FILE,
|
||||||
|
help="Path to configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--shard',
|
||||||
|
dest='shard',
|
||||||
|
default=None,
|
||||||
|
type=int,
|
||||||
|
help="Shard number to run, if applicable."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--host',
|
||||||
|
dest='host',
|
||||||
|
default='127.0.0.1',
|
||||||
|
help="IP address to run the app listener on."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--port',
|
||||||
|
dest='port',
|
||||||
|
default='5001',
|
||||||
|
help="Port to run the app listener on."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
146
src/meta/config.py
Normal file
146
src/meta/config.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
from discord import PartialEmoji
|
||||||
|
import configparser as cfgp
|
||||||
|
|
||||||
|
from .args import args
|
||||||
|
|
||||||
|
shard_number = args.shard
|
||||||
|
|
||||||
|
class configEmoji(PartialEmoji):
|
||||||
|
__slots__ = ('fallback',)
|
||||||
|
|
||||||
|
def __init__(self, *args, fallback=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.fallback = fallback
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_str(cls, emojistr: str):
|
||||||
|
"""
|
||||||
|
Parses emoji strings of one of the following forms
|
||||||
|
`<a:name:id> or fallback`
|
||||||
|
`<:name:id> or fallback`
|
||||||
|
`<a:name:id>`
|
||||||
|
`<:name:id>`
|
||||||
|
"""
|
||||||
|
splits = emojistr.rsplit(' or ', maxsplit=1)
|
||||||
|
|
||||||
|
fallback = splits[1] if len(splits) > 1 else None
|
||||||
|
emojistr = splits[0].strip('<> ')
|
||||||
|
animated, name, id = emojistr.split(':')
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||||
|
animated=bool(animated),
|
||||||
|
id=int(id) if id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MapDotProxy:
|
||||||
|
"""
|
||||||
|
Allows dot access to an underlying Mappable object.
|
||||||
|
"""
|
||||||
|
__slots__ = ("_map", "_converter")
|
||||||
|
|
||||||
|
def __init__(self, mappable, converter=None):
|
||||||
|
self._map = mappable
|
||||||
|
self._converter = converter
|
||||||
|
|
||||||
|
def __getattribute__(self, key):
|
||||||
|
_map = object.__getattribute__(self, '_map')
|
||||||
|
if key == '_map':
|
||||||
|
return _map
|
||||||
|
if key in _map:
|
||||||
|
_converter = object.__getattribute__(self, '_converter')
|
||||||
|
if _converter:
|
||||||
|
return _converter(_map[key])
|
||||||
|
else:
|
||||||
|
return _map[key]
|
||||||
|
else:
|
||||||
|
return object.__getattribute__(_map, key)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._map.__getitem__(key)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParser(cfgp.ConfigParser):
|
||||||
|
"""
|
||||||
|
Extension of base ConfigParser allowing optional
|
||||||
|
section option retrieval without defaults.
|
||||||
|
"""
|
||||||
|
def options(self, section, no_defaults=False, **kwargs):
|
||||||
|
if no_defaults:
|
||||||
|
try:
|
||||||
|
return list(self._sections[section].keys())
|
||||||
|
except KeyError:
|
||||||
|
raise cfgp.NoSectionError(section)
|
||||||
|
else:
|
||||||
|
return super().options(section, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Conf:
|
||||||
|
def __init__(self, configfile, section_name="DEFAULT"):
|
||||||
|
self.configfile = configfile
|
||||||
|
|
||||||
|
self.config = ConfigParser(
|
||||||
|
converters={
|
||||||
|
"intlist": self._getintlist,
|
||||||
|
"list": self._getlist,
|
||||||
|
"emoji": configEmoji.from_str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(configfile) as conff:
|
||||||
|
# Opening with read_file mainly to ensure the file exists
|
||||||
|
self.config.read_file(conff)
|
||||||
|
|
||||||
|
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||||
|
|
||||||
|
self.default = self.config["DEFAULT"]
|
||||||
|
self.section = MapDotProxy(self.config[self.section_name])
|
||||||
|
self.bot = self.section
|
||||||
|
|
||||||
|
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||||
|
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||||
|
read = set()
|
||||||
|
while more_to_read:
|
||||||
|
to_read = more_to_read.pop(0)
|
||||||
|
read.add(to_read)
|
||||||
|
self.config.read(to_read)
|
||||||
|
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||||
|
if path not in read and path not in more_to_read]
|
||||||
|
more_to_read.extend(new_paths)
|
||||||
|
|
||||||
|
self.emojis = MapDotProxy(
|
||||||
|
self.config['EMOJIS'] if 'EMOJIS' in self.config else self.section,
|
||||||
|
converter=configEmoji.from_str
|
||||||
|
)
|
||||||
|
|
||||||
|
global conf
|
||||||
|
conf = self
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.section[key].strip()
|
||||||
|
|
||||||
|
def __getattr__(self, section):
|
||||||
|
name = section.upper()
|
||||||
|
shard_name = f"{name}-{shard_number}"
|
||||||
|
if shard_name in self.config:
|
||||||
|
return self.config[shard_name]
|
||||||
|
else:
|
||||||
|
return self.config[name]
|
||||||
|
|
||||||
|
def get(self, name, fallback=None):
|
||||||
|
result = self.section.get(name, fallback)
|
||||||
|
return result.strip() if result else result
|
||||||
|
|
||||||
|
def _getintlist(self, value):
|
||||||
|
return [int(item.strip()) for item in value.split(',')]
|
||||||
|
|
||||||
|
def _getlist(self, value):
|
||||||
|
return [item.strip() for item in value.split(',')]
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
with open(self.configfile, 'w') as conffile:
|
||||||
|
self.config.write(conffile)
|
||||||
|
|
||||||
|
|
||||||
|
conf = Conf(args.config, 'BOT')
|
||||||
20
src/meta/context.py
Normal file
20
src/meta/context.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Namespace for various global context variables.
|
||||||
|
Allows asyncio callbacks to accurately retrieve information about the current state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .LionBot import LionBot
|
||||||
|
from .LionContext import LionContext
|
||||||
|
|
||||||
|
|
||||||
|
# Contains the current command context, if applicable
|
||||||
|
context: ContextVar[Optional['LionContext']] = ContextVar('context', default=None)
|
||||||
|
|
||||||
|
# Contains the current LionBot instance
|
||||||
|
ctx_bot: ContextVar[Optional['LionBot']] = ContextVar('bot', default=None)
|
||||||
64
src/meta/errors.py
Normal file
64
src/meta/errors.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from string import Template
|
||||||
|
|
||||||
|
|
||||||
|
class SafeCancellation(Exception):
|
||||||
|
"""
|
||||||
|
Raised to safely cancel execution of the current operation.
|
||||||
|
|
||||||
|
If not caught, is expected to be propagated to the Tree and safely ignored there.
|
||||||
|
If a `msg` is provided, a context-aware error handler should catch and send the message to the user.
|
||||||
|
The error handler should then set the `msg` to None, to avoid double handling.
|
||||||
|
Debugging information should go in `details`, to be logged by a top-level error handler.
|
||||||
|
"""
|
||||||
|
default_message = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msg(self):
|
||||||
|
return self._msg if self._msg is not None else self.default_message
|
||||||
|
|
||||||
|
def __init__(self, _msg: Optional[str] = None, details: Optional[str] = None, **kwargs):
|
||||||
|
self._msg: Optional[str] = _msg
|
||||||
|
self.details: str = details if details is not None else self.msg
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UserInputError(SafeCancellation):
|
||||||
|
"""
|
||||||
|
A SafeCancellation induced from unparseable user input.
|
||||||
|
"""
|
||||||
|
default_message = "Could not understand your input."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msg(self):
|
||||||
|
return Template(self._msg).substitute(**self.info) if self._msg is not None else self.default_message
|
||||||
|
|
||||||
|
def __init__(self, _msg: Optional[str] = None, info: dict[str, str] = {}, **kwargs):
|
||||||
|
self.info = info
|
||||||
|
super().__init__(_msg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCancelled(SafeCancellation):
|
||||||
|
"""
|
||||||
|
A SafeCancellation induced from manual user cancellation.
|
||||||
|
|
||||||
|
Usually silent.
|
||||||
|
"""
|
||||||
|
default_msg = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseTimedOut(SafeCancellation):
|
||||||
|
"""
|
||||||
|
A SafeCancellation induced from a user interaction time-out.
|
||||||
|
"""
|
||||||
|
default_msg = "Session timed out waiting for input."
|
||||||
|
|
||||||
|
|
||||||
|
class HandledException(SafeCancellation):
|
||||||
|
"""
|
||||||
|
Sentinel class to indicate to error handlers that this exception has been handled.
|
||||||
|
Required because discord.ext breaks the exception stack, so we can't just catch the error in a lower handler.
|
||||||
|
"""
|
||||||
|
def __init__(self, exc=None, **kwargs):
|
||||||
|
self.exc = exc
|
||||||
|
super().__init__(**kwargs)
|
||||||
468
src/meta/logger.py
Normal file
468
src/meta/logger.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
from logging.handlers import QueueListener, QueueHandler
|
||||||
|
import queue
|
||||||
|
import multiprocessing
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import StringIO
|
||||||
|
from functools import wraps
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import Webhook, File
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .config import conf
|
||||||
|
from . import sharding
|
||||||
|
from .context import context
|
||||||
|
from utils.lib import utc_now
|
||||||
|
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||||
|
|
||||||
|
|
||||||
|
log_logger = logging.getLogger(__name__)
|
||||||
|
log_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||||
|
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||||
|
log_app: ContextVar[str] = ContextVar('logging_shard', default="SHARD {:03}".format(sharding.shard_number))
|
||||||
|
|
||||||
|
def set_logging_context(
|
||||||
|
context: Optional[str] = None,
|
||||||
|
action: Optional[str] = None,
|
||||||
|
stack: Optional[tuple[str, ...]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Statically set the logging context variables to the given values.
|
||||||
|
|
||||||
|
If `action` is given, pushes it onto the `log_action_stack`.
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def logging_context(context=None, action=None, stack=None):
|
||||||
|
"""
|
||||||
|
Context manager for executing a block of code in a given logging context.
|
||||||
|
|
||||||
|
This context manager should only be used around synchronous code.
|
||||||
|
This is because async code *may* get cancelled or externally garbage collected,
|
||||||
|
in which case the finally block will be executed in the wrong context.
|
||||||
|
See https://github.com/python/cpython/issues/93740
|
||||||
|
This can be refactored nicely if this gets merged:
|
||||||
|
https://github.com/python/cpython/pull/99634
|
||||||
|
|
||||||
|
(It will not necessarily break on async code,
|
||||||
|
if the async code can be guaranteed to clean up in its own context.)
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
oldcontext = log_context.get()
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(oldcontext)
|
||||||
|
if stack is not None or action is not None:
|
||||||
|
log_action_stack.set(astack)
|
||||||
|
|
||||||
|
|
||||||
|
def with_log_ctx(isolate=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute a coroutine inside a given logging context.
|
||||||
|
|
||||||
|
If `isolate` is true, ensures that context does not leak
|
||||||
|
outside the coroutine.
|
||||||
|
|
||||||
|
If `isolate` is false, just statically set the context,
|
||||||
|
which will leak unless the coroutine is
|
||||||
|
called in an externally copied context.
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
if isolate:
|
||||||
|
with logging_context(**kwargs):
|
||||||
|
# Task creation will synchronously copy the context
|
||||||
|
# This is gc safe
|
||||||
|
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
|
||||||
|
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
|
||||||
|
return await task
|
||||||
|
else:
|
||||||
|
# This will leak context changes
|
||||||
|
set_logging_context(**kwargs)
|
||||||
|
return await func(*w_args, **w_kwargs)
|
||||||
|
return wrapped
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
|
log_wrap = with_log_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def persist_task(task_collection: set):
|
||||||
|
"""
|
||||||
|
Coroutine decorator that ensures the coroutine is scheduled as a task
|
||||||
|
and added to the given task_collection for strong reference
|
||||||
|
when it is called.
|
||||||
|
|
||||||
|
This is just a hack to handle discord.py events potentially
|
||||||
|
being unexpectedly garbage collected.
|
||||||
|
|
||||||
|
Since this also implicitly schedules the coroutine as a task when it is called,
|
||||||
|
the coroutine will also be run inside an isolated context.
|
||||||
|
"""
|
||||||
|
def decorator(coro):
|
||||||
|
@wraps(coro)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
name = f"persisted-{coro.__name__}"
|
||||||
|
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
|
||||||
|
task_collection.add(task)
|
||||||
|
task.add_done_callback(lambda f: task_collection.discard(f))
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
RESET_SEQ = "\033[0m"
|
||||||
|
COLOR_SEQ = "\033[3%dm"
|
||||||
|
BOLD_SEQ = "\033[1m"
|
||||||
|
"]]]"
|
||||||
|
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||||
|
|
||||||
|
|
||||||
|
def colour_escape(fmt: str) -> str:
|
||||||
|
cmap = {
|
||||||
|
'%(black)': COLOR_SEQ % BLACK,
|
||||||
|
'%(red)': COLOR_SEQ % RED,
|
||||||
|
'%(green)': COLOR_SEQ % GREEN,
|
||||||
|
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||||
|
'%(blue)': COLOR_SEQ % BLUE,
|
||||||
|
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||||
|
'%(cyan)': COLOR_SEQ % CYAN,
|
||||||
|
'%(white)': COLOR_SEQ % WHITE,
|
||||||
|
'%(reset)': RESET_SEQ,
|
||||||
|
'%(bold)': BOLD_SEQ,
|
||||||
|
}
|
||||||
|
for key, value in cmap.items():
|
||||||
|
fmt = fmt.replace(key, value)
|
||||||
|
return fmt
|
||||||
|
|
||||||
|
|
||||||
|
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
|
||||||
|
'%(cyan)%(app)-15s%(reset)|' +
|
||||||
|
'%(cyan)%(context)-24s%(reset)|' +
|
||||||
|
'%(cyan)%(actionstr)-22s%(reset)|' +
|
||||||
|
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||||
|
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||||
|
log_format = colour_escape(log_format)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup the logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_fmt = logging.Formatter(
|
||||||
|
fmt=log_format,
|
||||||
|
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanFilter(logging.Filter):
|
||||||
|
def __init__(self, exclusive_maximum, name=""):
|
||||||
|
super(LessThanFilter, self).__init__(name)
|
||||||
|
self.max_level = exclusive_maximum
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.levelno < self.max_level else 0
|
||||||
|
|
||||||
|
class ExactLevelFilter(logging.Filter):
|
||||||
|
def __init__(self, target_level, name=""):
|
||||||
|
super().__init__(name)
|
||||||
|
self.target_level = target_level
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
return (record.levelno == self.target_level)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadFilter(logging.Filter):
|
||||||
|
def __init__(self, thread_name):
|
||||||
|
super().__init__("")
|
||||||
|
self.thread = thread_name
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.threadName == self.thread else 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInjection(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# These guards are to allow override through _extra
|
||||||
|
# And to ensure the injection is idempotent
|
||||||
|
if not hasattr(record, 'context'):
|
||||||
|
record.context = log_context.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'actionstr'):
|
||||||
|
action_stack = log_action_stack.get()
|
||||||
|
if hasattr(record, 'action'):
|
||||||
|
action_stack = (*action_stack, record.action)
|
||||||
|
if action_stack:
|
||||||
|
record.actionstr = ' ➔ '.join(action_stack)
|
||||||
|
else:
|
||||||
|
record.actionstr = "Unknown Action"
|
||||||
|
|
||||||
|
if not hasattr(record, 'app'):
|
||||||
|
record.app = log_app.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'ctx'):
|
||||||
|
if ctx := context.get():
|
||||||
|
record.ctx = repr(ctx)
|
||||||
|
else:
|
||||||
|
record.ctx = None
|
||||||
|
|
||||||
|
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||||
|
record.ctxstr = '\n' + record.ctx
|
||||||
|
else:
|
||||||
|
record.ctxstr = ""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||||
|
logging_handler_out.setLevel(logging.DEBUG)
|
||||||
|
logging_handler_out.setFormatter(log_fmt)
|
||||||
|
logging_handler_out.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_out)
|
||||||
|
log_logger.addHandler(logging_handler_out)
|
||||||
|
|
||||||
|
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||||
|
logging_handler_err.setLevel(logging.WARNING)
|
||||||
|
logging_handler_err.setFormatter(log_fmt)
|
||||||
|
logging_handler_err.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_err)
|
||||||
|
log_logger.addHandler(logging_handler_err)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalQueueHandler(QueueHandler):
|
||||||
|
def _emit(self, record: logging.LogRecord) -> None:
|
||||||
|
# Removed the call to self.prepare(), handle task cancellation
|
||||||
|
try:
|
||||||
|
self.enqueue(record)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class WebHookHandler(logging.StreamHandler):
|
||||||
|
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
|
||||||
|
super().__init__()
|
||||||
|
self.webhook_url = webhook_url
|
||||||
|
self.prefix = prefix
|
||||||
|
self.batched = ""
|
||||||
|
self.batch = batch
|
||||||
|
self.loop = loop
|
||||||
|
self.batch_delay = 10
|
||||||
|
self.batch_task = None
|
||||||
|
self.last_batched = None
|
||||||
|
self.waiting = []
|
||||||
|
|
||||||
|
self.bucket = Bucket(20, 40)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
self.session = None
|
||||||
|
self.webhook = None
|
||||||
|
|
||||||
|
def get_loop(self):
|
||||||
|
if self.loop is None:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
return self.loop
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.format(record)
|
||||||
|
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||||
|
|
||||||
|
def _post(self, record):
|
||||||
|
if self.session is None:
|
||||||
|
self.setup()
|
||||||
|
asyncio.create_task(self.post(record))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||||
|
|
||||||
|
async def post(self, record):
|
||||||
|
if record.context == 'Webhook Logger':
|
||||||
|
# Don't livelog livelog errors
|
||||||
|
# Otherwise we recurse and Cloudflare hates us
|
||||||
|
return
|
||||||
|
log_context.set("Webhook Logger")
|
||||||
|
log_action_stack.set(("Logging",))
|
||||||
|
log_app.set(record.app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||||
|
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||||
|
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||||
|
message = f"{header}\n{record.msg}{context}"
|
||||||
|
|
||||||
|
if len(message) > 1900:
|
||||||
|
as_file = True
|
||||||
|
else:
|
||||||
|
as_file = False
|
||||||
|
message = "```md\n{}\n```".format(message)
|
||||||
|
|
||||||
|
# Post the log message(s)
|
||||||
|
if self.batch:
|
||||||
|
if len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
else:
|
||||||
|
self.batched += message
|
||||||
|
if len(self.batched) + len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
else:
|
||||||
|
asyncio.create_task(self._schedule_batched())
|
||||||
|
else:
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _schedule_batched(self):
|
||||||
|
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||||
|
# noop, don't reschedule if it is already scheduled
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||||
|
await self.batch_task
|
||||||
|
await self._send_batched()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _send_batched_now(self):
|
||||||
|
if self.batch_task is not None and not self.batch_task.done():
|
||||||
|
self.batch_task.cancel()
|
||||||
|
self.last_batched = None
|
||||||
|
await self._send_batched()
|
||||||
|
|
||||||
|
async def _send_batched(self):
|
||||||
|
if self.batched:
|
||||||
|
batched = self.batched
|
||||||
|
self.batched = ""
|
||||||
|
await self._send(batched)
|
||||||
|
|
||||||
|
async def _send(self, message, as_file=False):
|
||||||
|
try:
|
||||||
|
self.bucket.request()
|
||||||
|
except BucketOverFull:
|
||||||
|
# Silently ignore
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
except BucketFull:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"Ignoring records on live-logger {self.webhook.id}."
|
||||||
|
)
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if self.ignored > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
|
||||||
|
)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if as_file or len(message) > 1900:
|
||||||
|
with StringIO(message) as fp:
|
||||||
|
fp.seek(0)
|
||||||
|
await self.webhook.send(
|
||||||
|
f"{self.prefix}\n`{message.splitlines()[0]}`",
|
||||||
|
file=File(fp, filename="logs.md"),
|
||||||
|
username=log_app.get()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
|
||||||
|
except discord.HTTPException:
|
||||||
|
logger.exception(
|
||||||
|
"Live logger errored. Slowing down live logger."
|
||||||
|
)
|
||||||
|
self.bucket.fill()
|
||||||
|
|
||||||
|
|
||||||
|
handlers = []
|
||||||
|
if webhook := conf.logging['general_log']:
|
||||||
|
handler = WebHookHandler(webhook, batch=True)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['warning_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
|
||||||
|
handler.addFilter(ExactLevelFilter(logging.WARNING))
|
||||||
|
handler.setLevel(logging.WARNING)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['error_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
|
||||||
|
handler.setLevel(logging.ERROR)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['critical_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
|
||||||
|
handler.setLevel(logging.CRITICAL)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def make_queue_handler(queue):
|
||||||
|
qhandler = QueueHandler(queue)
|
||||||
|
qhandler.setLevel(logging.INFO)
|
||||||
|
qhandler.addFilter(ContextInjection())
|
||||||
|
return qhandler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_main_logger(multiprocess=False):
|
||||||
|
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
|
||||||
|
if handlers:
|
||||||
|
# First create a separate loop to run the handlers on
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def run_loop(loop):
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_forever()
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||||
|
loop_thread.daemon = True
|
||||||
|
loop_thread.start()
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
handler.loop = loop
|
||||||
|
|
||||||
|
qhandler = make_queue_handler(q)
|
||||||
|
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||||
|
logger.addHandler(qhandler)
|
||||||
|
|
||||||
|
listener = QueueListener(
|
||||||
|
q, *handlers, respect_handler_level=True
|
||||||
|
)
|
||||||
|
listener.start()
|
||||||
|
return q
|
||||||
139
src/meta/monitor.py
Normal file
139
src/meta/monitor.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from enum import IntEnum
|
||||||
|
from collections import deque, ChainMap
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StatusLevel(IntEnum):
|
||||||
|
ERRORED = -2
|
||||||
|
UNSURE = -1
|
||||||
|
WAITING = 0
|
||||||
|
STARTING = 1
|
||||||
|
OKAY = 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def symbol(self):
|
||||||
|
return symbols[self]
|
||||||
|
|
||||||
|
|
||||||
|
symbols = {
|
||||||
|
StatusLevel.ERRORED: '🟥',
|
||||||
|
StatusLevel.UNSURE: '🟧',
|
||||||
|
StatusLevel.WAITING: '⬜',
|
||||||
|
StatusLevel.STARTING: '🟫',
|
||||||
|
StatusLevel.OKAY: '🟩',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentStatus:
|
||||||
|
def __init__(self, level: StatusLevel, short_formatstr: str, long_formatstr: str, data: dict = {}):
|
||||||
|
self.level = level
|
||||||
|
self.short_formatstr = short_formatstr
|
||||||
|
self.long_formatstr = long_formatstr
|
||||||
|
self.data = data
|
||||||
|
self.created_at = dt.datetime.now(tz=dt.timezone.utc)
|
||||||
|
|
||||||
|
def format_args(self):
|
||||||
|
extra = {
|
||||||
|
'created_at': self.created_at,
|
||||||
|
'level': self.level,
|
||||||
|
'symbol': self.level.symbol,
|
||||||
|
}
|
||||||
|
return ChainMap(extra, self.data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def short(self):
|
||||||
|
return self.short_formatstr.format(**self.format_args())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def long(self):
|
||||||
|
return self.long_formatstr.format(**self.format_args())
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentMonitor:
|
||||||
|
_name = None
|
||||||
|
|
||||||
|
def __init__(self, name=None, callback=None):
|
||||||
|
self._callback = callback
|
||||||
|
self.name = name or self._name
|
||||||
|
if not self.name:
|
||||||
|
raise ValueError("ComponentMonitor must have a name")
|
||||||
|
|
||||||
|
async def _make_status(self, *args, **kwargs):
|
||||||
|
if self._callback is not None:
|
||||||
|
return await self._callback(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def status(self) -> ComponentStatus:
|
||||||
|
try:
|
||||||
|
status = await self._make_status()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Status callback for component '{self.name}' failed. This should not happen."
|
||||||
|
)
|
||||||
|
status = ComponentStatus(
|
||||||
|
level=StatusLevel.UNSURE,
|
||||||
|
short_formatstr="Status callback for '{name}' failed with error '{error}'",
|
||||||
|
long_formatstr="Status callback for '{name}' failed with error '{error}'",
|
||||||
|
data={
|
||||||
|
'name': self.name,
|
||||||
|
'error': repr(e)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMonitor:
|
||||||
|
def __init__(self):
|
||||||
|
self.components = {}
|
||||||
|
self.recent = deque(maxlen=10)
|
||||||
|
|
||||||
|
def add_component(self, component: ComponentMonitor):
|
||||||
|
self.components[component.name] = component
|
||||||
|
return component
|
||||||
|
|
||||||
|
async def request(self):
|
||||||
|
"""
|
||||||
|
Request status from each component.
|
||||||
|
"""
|
||||||
|
tasks = {
|
||||||
|
name: asyncio.create_task(comp.status())
|
||||||
|
for name, comp in self.components.items()
|
||||||
|
}
|
||||||
|
await asyncio.gather(*tasks.values())
|
||||||
|
status = {
|
||||||
|
name: await fut for name, fut in tasks.items()
|
||||||
|
}
|
||||||
|
self.recent.append(status)
|
||||||
|
return status
|
||||||
|
|
||||||
|
async def _format_summary(self, status_dict: dict[str, ComponentStatus]):
|
||||||
|
"""
|
||||||
|
Format a one line summary from a status dict.
|
||||||
|
"""
|
||||||
|
freq = {level: 0 for level in StatusLevel}
|
||||||
|
for status in status_dict.values():
|
||||||
|
freq[status.level] += 1
|
||||||
|
|
||||||
|
summary = '\t'.join(f"{level.symbol} {count}" for level, count in freq.items() if count)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
async def _format_overview(self, status_dict: dict[str, ComponentStatus]):
|
||||||
|
"""
|
||||||
|
Format an overview (one line per component) from a status dict.
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
for name, status in status_dict.items():
|
||||||
|
lines.append(f"{status.level.symbol} {name}: {status.short}")
|
||||||
|
summary = await self._format_summary(status_dict)
|
||||||
|
return '\n'.join((summary, *lines))
|
||||||
|
|
||||||
|
async def get_summary(self):
|
||||||
|
return await self._format_summary(await self.request())
|
||||||
|
|
||||||
|
async def get_overview(self):
|
||||||
|
return await self._format_overview(await self.request())
|
||||||
35
src/meta/sharding.py
Normal file
35
src/meta/sharding.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from .args import args
|
||||||
|
from .config import conf
|
||||||
|
|
||||||
|
from psycopg import sql
|
||||||
|
from data.conditions import Condition, Joiner
|
||||||
|
|
||||||
|
|
||||||
|
shard_number = args.shard or 0
|
||||||
|
|
||||||
|
shard_count = conf.bot.getint('shard_count', 1)
|
||||||
|
|
||||||
|
sharded = (shard_count > 0)
|
||||||
|
|
||||||
|
|
||||||
|
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||||
|
"""
|
||||||
|
Condition constructor for filtering by shard id.
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
-------------
|
||||||
|
Query.where(_shard_condition('guildid', 10, 1))
|
||||||
|
"""
|
||||||
|
return Condition(
|
||||||
|
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||||
|
guildid=sql.Identifier(guild_column),
|
||||||
|
shard_count=sql.Literal(shard_count)
|
||||||
|
),
|
||||||
|
Joiner.EQUALS,
|
||||||
|
sql.Placeholder(),
|
||||||
|
(shard_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-built Condition for filtering by current shard.
|
||||||
|
THIS_SHARD = SHARDID(shard_number)
|
||||||
10
src/modules/__init__.py
Normal file
10
src/modules/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
this_package = 'modules'
|
||||||
|
|
||||||
|
active = [
|
||||||
|
'.sysadmin',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
for ext in active:
|
||||||
|
await bot.load_extension(ext, package=this_package)
|
||||||
5
src/modules/sysadmin/__init__.py
Normal file
5
src/modules/sysadmin/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
from .exec_cog import Exec
|
||||||
|
|
||||||
|
await bot.add_cog(Exec(bot))
|
||||||
336
src/modules/sysadmin/exec_cog.py
Normal file
336
src/modules/sysadmin/exec_cog.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
import io
|
||||||
|
import ast
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import builtins
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
from typing import Callable, Any, Optional
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord.ext.commands.errors import CheckFailure
|
||||||
|
from discord.ui import TextInput, View
|
||||||
|
from discord.ui.button import button
|
||||||
|
import discord.app_commands as appcmd
|
||||||
|
|
||||||
|
from meta.logger import logging_context, log_wrap
|
||||||
|
from meta import conf
|
||||||
|
from meta.context import context, ctx_bot
|
||||||
|
from meta.LionContext import LionContext
|
||||||
|
from meta.LionCog import LionCog
|
||||||
|
from meta.LionBot import LionBot
|
||||||
|
|
||||||
|
from utils.ui import FastModal, input
|
||||||
|
|
||||||
|
from wards import sys_admin
|
||||||
|
|
||||||
|
|
||||||
|
def _(arg): return arg
|
||||||
|
|
||||||
|
def _p(ctx, arg): return arg
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ExecModal(FastModal, title="Execute"):
|
||||||
|
code: TextInput = TextInput(
|
||||||
|
label="Code to execute",
|
||||||
|
style=discord.TextStyle.long,
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecStyle(Enum):
|
||||||
|
EXEC = 'exec'
|
||||||
|
EVAL = 'eval'
|
||||||
|
|
||||||
|
|
||||||
|
class ExecUI(View):
|
||||||
|
def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.ctx: LionContext = ctx
|
||||||
|
self.interaction: Optional[discord.Interaction] = ctx.interaction
|
||||||
|
self.code: Optional[str] = code
|
||||||
|
self.style: ExecStyle = style
|
||||||
|
self.ephemeral: bool = ephemeral
|
||||||
|
|
||||||
|
self._modal: Optional[ExecModal] = None
|
||||||
|
self._msg: Optional[discord.Message] = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction):
|
||||||
|
"""Only allow the original author to use this View"""
|
||||||
|
if interaction.user.id != self.ctx.author.id:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
("You cannot use this interface!"),
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
if self.code is None:
|
||||||
|
if (interaction := self.interaction) is not None:
|
||||||
|
self.interaction = None
|
||||||
|
await interaction.response.send_modal(self.get_modal())
|
||||||
|
await self.wait()
|
||||||
|
else:
|
||||||
|
# Complain
|
||||||
|
# TODO: error_reply
|
||||||
|
await self.ctx.reply("Pls give code.")
|
||||||
|
else:
|
||||||
|
await self.interaction.response.defer(thinking=True, ephemeral=self.ephemeral)
|
||||||
|
await self.compile()
|
||||||
|
await self.wait()
|
||||||
|
|
||||||
|
@button(label="Recompile")
|
||||||
|
async def recompile_button(self, interaction, butt):
|
||||||
|
# Interaction response with modal
|
||||||
|
await interaction.response.send_modal(self.get_modal())
|
||||||
|
|
||||||
|
@button(label="Show Source")
|
||||||
|
async def source_button(self, interaction, butt):
|
||||||
|
if len(self.code) > 1900:
|
||||||
|
# Send as file
|
||||||
|
with StringIO(self.code) as fp:
|
||||||
|
fp.seek(0)
|
||||||
|
file = discord.File(fp, filename="source.py")
|
||||||
|
await interaction.response.send_message(file=file, ephemeral=True)
|
||||||
|
else:
|
||||||
|
# Send as message
|
||||||
|
await interaction.response.send_message(
|
||||||
|
content=f"```py\n{self.code}```",
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_modal(self) -> ExecModal:
|
||||||
|
modal = ExecModal()
|
||||||
|
|
||||||
|
@modal.submit_callback()
|
||||||
|
async def exec_submit(interaction: discord.Interaction):
|
||||||
|
if self.interaction is None:
|
||||||
|
self.interaction = interaction
|
||||||
|
await interaction.response.defer(thinking=True)
|
||||||
|
else:
|
||||||
|
await interaction.response.defer()
|
||||||
|
|
||||||
|
# Set code
|
||||||
|
self.code = modal.code.value
|
||||||
|
|
||||||
|
# Call compile
|
||||||
|
await self.compile()
|
||||||
|
|
||||||
|
return modal
|
||||||
|
|
||||||
|
def get_modal(self):
|
||||||
|
self._modal = self.create_modal()
|
||||||
|
self._modal.code.default = self.code
|
||||||
|
return self._modal
|
||||||
|
|
||||||
|
async def compile(self):
|
||||||
|
# Call _async
|
||||||
|
result = await _async(self.code, style=self.style.value)
|
||||||
|
|
||||||
|
# Display output
|
||||||
|
await self.show_output(result)
|
||||||
|
|
||||||
|
async def show_output(self, output):
|
||||||
|
# Format output
|
||||||
|
# If output message exists and not ephemeral, edit
|
||||||
|
# Otherwise, send message, add buttons
|
||||||
|
if len(output) > 1900:
|
||||||
|
# Send as file
|
||||||
|
with StringIO(output) as fp:
|
||||||
|
fp.seek(0)
|
||||||
|
args = {
|
||||||
|
'content': None,
|
||||||
|
'attachments': [discord.File(fp, filename="output.md")]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
args = {
|
||||||
|
'content': f"```md\n{output}```",
|
||||||
|
'attachments': []
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._msg is None:
|
||||||
|
if self.interaction is not None:
|
||||||
|
msg = await self.interaction.edit_original_response(**args, view=self)
|
||||||
|
else:
|
||||||
|
# Send new message
|
||||||
|
if args['content'] is None:
|
||||||
|
args['file'] = args.pop('attachments')[0]
|
||||||
|
msg = await self.ctx.reply(**args, ephemeral=self.ephemeral, view=self)
|
||||||
|
|
||||||
|
if not self.ephemeral:
|
||||||
|
self._msg = msg
|
||||||
|
else:
|
||||||
|
if self.interaction is not None:
|
||||||
|
await self.interaction.edit_original_response(**args, view=self)
|
||||||
|
else:
|
||||||
|
# Edit message
|
||||||
|
await self._msg.edit(**args)
|
||||||
|
|
||||||
|
|
||||||
|
def mk_print(fp: io.StringIO) -> Callable[..., None]:
|
||||||
|
def _print(*args, file: Any = fp, **kwargs):
|
||||||
|
return print(*args, file=file, **kwargs)
|
||||||
|
return _print
|
||||||
|
|
||||||
|
|
||||||
|
def mk_status_printer(bot, printer):
|
||||||
|
async def _status(details=False):
|
||||||
|
if details:
|
||||||
|
status = await bot.system_monitor.get_overview()
|
||||||
|
else:
|
||||||
|
status = await bot.system_monitor.get_summary()
|
||||||
|
printer(status)
|
||||||
|
return status
|
||||||
|
return _status
|
||||||
|
|
||||||
|
|
||||||
|
@log_wrap(action="Code Exec")
|
||||||
|
async def _async(to_eval: str, style='exec'):
|
||||||
|
newline = '\n' * ('\n' in to_eval)
|
||||||
|
logger.info(
|
||||||
|
f"Exec code with {style}: {newline}{to_eval}"
|
||||||
|
)
|
||||||
|
|
||||||
|
output = io.StringIO()
|
||||||
|
_print = mk_print(output)
|
||||||
|
|
||||||
|
scope: dict[str, Any] = dict(sys.modules)
|
||||||
|
scope['__builtins__'] = builtins
|
||||||
|
scope.update(builtins.__dict__)
|
||||||
|
scope['ctx'] = ctx = context.get()
|
||||||
|
scope['bot'] = ctx_bot.get()
|
||||||
|
scope['print'] = _print # type: ignore
|
||||||
|
scope['print_status'] = mk_status_printer(scope['bot'], _print)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ctx and ctx.message:
|
||||||
|
source_str = f"<msg: {ctx.message.id}>"
|
||||||
|
elif ctx and ctx.interaction:
|
||||||
|
source_str = f"<iid: {ctx.interaction.id}>"
|
||||||
|
else:
|
||||||
|
source_str = "Unknown async"
|
||||||
|
|
||||||
|
code = compile(
|
||||||
|
to_eval,
|
||||||
|
source_str,
|
||||||
|
style,
|
||||||
|
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
|
||||||
|
)
|
||||||
|
func = types.FunctionType(code, scope)
|
||||||
|
|
||||||
|
ret = func()
|
||||||
|
if inspect.iscoroutine(ret):
|
||||||
|
ret = await ret
|
||||||
|
if ret is not None:
|
||||||
|
_print(repr(ret))
|
||||||
|
except Exception:
|
||||||
|
_, exc, tb = sys.exc_info()
|
||||||
|
_print("".join(traceback.format_tb(tb)))
|
||||||
|
_print(f"{type(exc).__name__}: {exc}")
|
||||||
|
|
||||||
|
result = output.getvalue().strip()
|
||||||
|
newline = '\n' * ('\n' in result)
|
||||||
|
logger.info(
|
||||||
|
f"Exec complete, output: {newline}{result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class Exec(LionCog):
|
||||||
|
guild_ids = conf.bot.getintlist('admin_guilds')
|
||||||
|
|
||||||
|
def __init__(self, bot: LionBot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
async def cog_check(self, ctx: LionContext) -> bool: # type: ignore
|
||||||
|
passed = await sys_admin(ctx.bot, ctx.author.id)
|
||||||
|
if passed:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise CheckFailure(
|
||||||
|
"You must be a bot owner to do this!"
|
||||||
|
)
|
||||||
|
|
||||||
|
@commands.hybrid_command(
|
||||||
|
name=_('async'),
|
||||||
|
description=_("Execute arbitrary code with Exec")
|
||||||
|
)
|
||||||
|
@appcmd.describe(
|
||||||
|
string="Code to execute.",
|
||||||
|
)
|
||||||
|
async def async_cmd(self, ctx: LionContext,
|
||||||
|
string: Optional[str] = None,
|
||||||
|
):
|
||||||
|
await ExecUI(ctx, string, ExecStyle.EXEC, ephemeral=False).run()
|
||||||
|
|
||||||
|
@commands.hybrid_command(
|
||||||
|
name=_('reload'),
|
||||||
|
description=_("Reload a given LionBot extension. Launches an ExecUI.")
|
||||||
|
)
|
||||||
|
@appcmd.describe(
|
||||||
|
extension=_("Name of the extension to reload. See autocomplete for options."),
|
||||||
|
force=_("Whether to force an extension reload even if it doesn't exist.")
|
||||||
|
)
|
||||||
|
@appcmd.guilds(*guild_ids)
|
||||||
|
async def reload_cmd(self, ctx: LionContext, extension: str, force: Optional[bool] = False):
|
||||||
|
"""
|
||||||
|
This is essentially just a friendly wrapper to reload an extension.
|
||||||
|
It is equivalent to running "await bot.reload_extension(extension)" in eval,
|
||||||
|
with a slightly nicer interface through the autocomplete and error handling.
|
||||||
|
"""
|
||||||
|
exists = (extension in self.bot.extensions)
|
||||||
|
if not (force or exists):
|
||||||
|
embed = discord.Embed(description=f"Unknown extension {extension}", colour=discord.Colour.red())
|
||||||
|
await ctx.reply(embed=embed)
|
||||||
|
else:
|
||||||
|
# Uses an ExecUI to simplify error handling and re-execution
|
||||||
|
if exists:
|
||||||
|
string = f"await bot.reload_extension('{extension}')"
|
||||||
|
else:
|
||||||
|
string = f"await bot.load_extension('{extension}')"
|
||||||
|
await ExecUI(ctx, string, ExecStyle.EVAL).run()
|
||||||
|
|
||||||
|
@reload_cmd.autocomplete('extension')
|
||||||
|
async def reload_extension_acmpl(self, interaction: discord.Interaction, partial: str):
|
||||||
|
keys = set(self.bot.extensions.keys())
|
||||||
|
results = [
|
||||||
|
appcmd.Choice(name=key, value=key)
|
||||||
|
for key in keys
|
||||||
|
if partial.lower() in key.lower()
|
||||||
|
]
|
||||||
|
if not results:
|
||||||
|
results = [
|
||||||
|
appcmd.Choice(name=f"No extensions found matching {partial}", value="None")
|
||||||
|
]
|
||||||
|
return results[:25]
|
||||||
|
|
||||||
|
@commands.hybrid_command(
|
||||||
|
name=_('shutdown'),
|
||||||
|
description=_("Shutdown (or restart) the client.")
|
||||||
|
)
|
||||||
|
@appcmd.guilds(*guild_ids)
|
||||||
|
async def shutdown_cmd(self, ctx: LionContext):
|
||||||
|
"""
|
||||||
|
Shutdown the client.
|
||||||
|
Maybe do something friendly here?
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down on admin request.")
|
||||||
|
await ctx.reply(
|
||||||
|
embed=discord.Embed(
|
||||||
|
description=f"Understood {ctx.author.mention}, cleaning up and shutting down!",
|
||||||
|
colour=discord.Colour.orange()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await self.bot.close()
|
||||||
7
src/settings/__init__.py
Normal file
7
src/settings/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from babel.translator import LocalBabel
|
||||||
|
babel = LocalBabel('settings_base')
|
||||||
|
|
||||||
|
from .data import ModelData, ListData
|
||||||
|
from .base import BaseSetting
|
||||||
|
from .ui import SettingWidget, InteractiveSetting
|
||||||
|
from .groups import SettingDotDict, SettingGroup, ModelSettings, ModelSetting
|
||||||
166
src/settings/base.py
Normal file
166
src/settings/base.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from typing import Generic, TypeVar, Type, Optional, overload
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Setting metclass?
|
||||||
|
Parse setting docstring to generate default info?
|
||||||
|
Or just put it in the decorator we are already using
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Typing using Generic[parent_id_type, data_type, value_type]
|
||||||
|
# value generic, could be Union[?, UNSET]
|
||||||
|
ParentID = TypeVar('ParentID')
|
||||||
|
SettingData = TypeVar('SettingData')
|
||||||
|
SettingValue = TypeVar('SettingValue')
|
||||||
|
|
||||||
|
T = TypeVar('T', bound='BaseSetting')
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSetting(Generic[ParentID, SettingData, SettingValue]):
|
||||||
|
"""
|
||||||
|
Abstract base class describing a stored configuration setting.
|
||||||
|
A setting consists of logic to load the setting from storage,
|
||||||
|
present it in a readable form, understand user entered values,
|
||||||
|
and write it again in storage.
|
||||||
|
Additionally, the setting has attributes attached describing
|
||||||
|
the setting in a user-friendly manner for display purposes.
|
||||||
|
"""
|
||||||
|
setting_id: str # Unique source identifier for the setting
|
||||||
|
|
||||||
|
_default: Optional[SettingData] = None # Default data value for the setting
|
||||||
|
|
||||||
|
def __init__(self, parent_id: ParentID, data: Optional[SettingData], **kwargs):
|
||||||
|
self.parent_id = parent_id
|
||||||
|
self._data = data
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
# Instance generation
|
||||||
|
@classmethod
|
||||||
|
async def get(cls: Type[T], parent_id: ParentID, **kwargs) -> T:
|
||||||
|
"""
|
||||||
|
Return a setting instance initialised from the stored value, associated with the given parent id.
|
||||||
|
"""
|
||||||
|
data = await cls._reader(parent_id, **kwargs)
|
||||||
|
return cls(parent_id, data, **kwargs)
|
||||||
|
|
||||||
|
# Main interface
|
||||||
|
@property
|
||||||
|
def data(self) -> Optional[SettingData]:
|
||||||
|
"""
|
||||||
|
Retrieves the current internal setting data if it is set, otherwise the default data
|
||||||
|
"""
|
||||||
|
return self._data if self._data is not None else self.default
|
||||||
|
|
||||||
|
@data.setter
|
||||||
|
def data(self, new_data: Optional[SettingData]):
|
||||||
|
"""
|
||||||
|
Sets the internal raw data.
|
||||||
|
Does not write the changes.
|
||||||
|
"""
|
||||||
|
self._data = new_data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default(self) -> Optional[SettingData]:
|
||||||
|
"""
|
||||||
|
Retrieves the default value for this setting.
|
||||||
|
Settings should override this if the default depends on the object id.
|
||||||
|
"""
|
||||||
|
return self._default
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> SettingValue: # Actually optional *if* _default is None
|
||||||
|
"""
|
||||||
|
Context-aware object or objects associated with the setting.
|
||||||
|
"""
|
||||||
|
return self._data_to_value(self.parent_id, self.data) # type: ignore
|
||||||
|
|
||||||
|
@value.setter
|
||||||
|
def value(self, new_value: Optional[SettingValue]):
|
||||||
|
"""
|
||||||
|
Setter which reads the discord-aware object and converts it to data.
|
||||||
|
Does not write the new value.
|
||||||
|
"""
|
||||||
|
self._data = self._data_from_value(self.parent_id, new_value)
|
||||||
|
|
||||||
|
async def write(self, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Write current data to the database.
|
||||||
|
For settings which override this,
|
||||||
|
ensure you handle deletion of values when internal data is None.
|
||||||
|
"""
|
||||||
|
await self._writer(self.parent_id, self._data, **kwargs)
|
||||||
|
|
||||||
|
# Raw converters
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def _data_from_value(cls: Type[T], parent_id: ParentID, value: SettingValue, **kwargs) -> SettingData:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def _data_from_value(cls: Type[T], parent_id: ParentID, value: None, **kwargs) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _data_from_value(
|
||||||
|
cls: Type[T], parent_id: ParentID, value: Optional[SettingValue], **kwargs
|
||||||
|
) -> Optional[SettingData]:
|
||||||
|
"""
|
||||||
|
Convert a high-level setting value to internal data.
|
||||||
|
Must be overridden by the setting.
|
||||||
|
Be aware of UNSET values, these should always pass through as None
|
||||||
|
to provide an unsetting interface.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def _data_to_value(cls: Type[T], parent_id: ParentID, data: SettingData, **kwargs) -> SettingValue:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def _data_to_value(cls: Type[T], parent_id: ParentID, data: None, **kwargs) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _data_to_value(
|
||||||
|
cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs
|
||||||
|
) -> Optional[SettingValue]:
|
||||||
|
"""
|
||||||
|
Convert internal data to high-level setting value.
|
||||||
|
Must be overriden by the setting.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Database access
|
||||||
|
@classmethod
|
||||||
|
async def _reader(cls: Type[T], parent_id: ParentID, **kwargs) -> Optional[SettingData]:
|
||||||
|
"""
|
||||||
|
Retrieve the setting data associated with the given parent_id.
|
||||||
|
May be None if the setting is not set.
|
||||||
|
Must be overridden by the setting.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _writer(cls: Type[T], parent_id: ParentID, data: Optional[SettingData], **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Write provided setting data to storage.
|
||||||
|
Must be overridden by the setting unless the `write` method is overridden.
|
||||||
|
If the data is None, the setting is UNSET and should be deleted.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def setup(cls, bot):
|
||||||
|
"""
|
||||||
|
Initialisation task to be executed during client initialisation.
|
||||||
|
May be used for e.g. populating a cache or required client setup.
|
||||||
|
|
||||||
|
Main application must execute the initialisation task before the setting is used.
|
||||||
|
Further, the task must always be executable, if the setting is loaded.
|
||||||
|
Conditional initialisation should go in the relevant module's init tasks.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
233
src/settings/data.py
Normal file
233
src/settings/data.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
from typing import Type
|
||||||
|
import json
|
||||||
|
|
||||||
|
from data import RowModel, Table, ORDER
|
||||||
|
from meta.logger import log_wrap, set_logging_context
|
||||||
|
|
||||||
|
|
||||||
|
class ModelData:
|
||||||
|
"""
|
||||||
|
Mixin for settings stored in a single row and column of a Model.
|
||||||
|
Assumes that the parent_id is the identity key of the Model.
|
||||||
|
|
||||||
|
This does not create a reference to the Row.
|
||||||
|
"""
|
||||||
|
# Table storing the desired data
|
||||||
|
_model: Type[RowModel]
|
||||||
|
|
||||||
|
# Column with the desired data
|
||||||
|
_column: str
|
||||||
|
|
||||||
|
# Whether to create a row if not found
|
||||||
|
_create_row = False
|
||||||
|
|
||||||
|
# High level data cache to use, leave as None to disable cache.
|
||||||
|
_cache = None # Map[id -> value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _read_from_row(cls, parent_id, row, **kwargs):
|
||||||
|
data = row[cls._column]
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
cls._cache[parent_id] = data
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Read in the requested column associated to the parent id.
|
||||||
|
"""
|
||||||
|
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||||
|
return cls._cache[parent_id]
|
||||||
|
|
||||||
|
model = cls._model
|
||||||
|
if cls._create_row:
|
||||||
|
row = await model.fetch_or_create(parent_id)
|
||||||
|
else:
|
||||||
|
row = await model.fetch(parent_id)
|
||||||
|
data = row[cls._column] if row else None
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
cls._cache[parent_id] = data
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _writer(cls, parent_id, data, **kwargs):
|
||||||
|
"""
|
||||||
|
Write the provided entry to the table.
|
||||||
|
This does *not* create the row if it does not exist.
|
||||||
|
It only updates.
|
||||||
|
"""
|
||||||
|
# TODO: Better way of getting the key?
|
||||||
|
# TODO: Transaction
|
||||||
|
if not isinstance(parent_id, tuple):
|
||||||
|
parent_id = (parent_id, )
|
||||||
|
model = cls._model
|
||||||
|
rows = await model.table.update_where(
|
||||||
|
**model._dict_from_id(parent_id)
|
||||||
|
).set(
|
||||||
|
**{cls._column: data}
|
||||||
|
)
|
||||||
|
# If we didn't update any rows, create a new row
|
||||||
|
if not rows:
|
||||||
|
await model.fetch_or_create(**model._dict_from_id(parent_id), **{cls._column: data})
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
cls._cache[parent_id] = data
|
||||||
|
|
||||||
|
|
||||||
|
class ListData:
|
||||||
|
"""
|
||||||
|
Mixin for list types implemented on a Table.
|
||||||
|
Implements a reader and writer.
|
||||||
|
This assumes the list is the only data stored in the table,
|
||||||
|
and removes list entries by deleting rows.
|
||||||
|
"""
|
||||||
|
setting_id: str
|
||||||
|
|
||||||
|
# Table storing the setting data
|
||||||
|
_table_interface: Table
|
||||||
|
|
||||||
|
# Name of the column storing the id
|
||||||
|
_id_column: str
|
||||||
|
|
||||||
|
# Name of the column storing the data to read
|
||||||
|
_data_column: str
|
||||||
|
|
||||||
|
# Name of column storing the order index to use, if any. Assumed to be Serial on writing.
|
||||||
|
_order_column: str
|
||||||
|
_order_type: ORDER = ORDER.ASC
|
||||||
|
|
||||||
|
# High level data cache to use, set to None to disable cache.
|
||||||
|
_cache = None # Map[id -> value]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@log_wrap(isolate=True)
|
||||||
|
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Read in all entries associated to the given id.
|
||||||
|
"""
|
||||||
|
set_logging_context(action="Read cls.setting_id")
|
||||||
|
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||||
|
return cls._cache[parent_id]
|
||||||
|
|
||||||
|
table = cls._table_interface # type: Table
|
||||||
|
query = table.select_where(**{cls._id_column: parent_id}).select(cls._data_column)
|
||||||
|
if cls._order_column:
|
||||||
|
query.order_by(cls._order_column, direction=cls._order_type)
|
||||||
|
|
||||||
|
rows = await query
|
||||||
|
data = [row[cls._data_column] for row in rows]
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
cls._cache[parent_id] = data
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@log_wrap(isolate=True)
|
||||||
|
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Write the provided list to storage.
|
||||||
|
"""
|
||||||
|
set_logging_context(action="Write cls.setting_id")
|
||||||
|
table = cls._table_interface
|
||||||
|
async with table.connector.connection() as conn:
|
||||||
|
table.connector.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
# Handle None input as an empty list
|
||||||
|
if data is None:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
current = await cls._reader(id, use_cache=False, **kwargs)
|
||||||
|
if not cls._order_column and (add_only or remove_only):
|
||||||
|
to_insert = [item for item in data if item not in current] if not remove_only else []
|
||||||
|
to_remove = data if remove_only else (
|
||||||
|
[item for item in current if item not in data] if not add_only else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle required deletions
|
||||||
|
if to_remove:
|
||||||
|
params = {
|
||||||
|
cls._id_column: id,
|
||||||
|
cls._data_column: to_remove
|
||||||
|
}
|
||||||
|
await table.delete_where(**params)
|
||||||
|
|
||||||
|
# Handle required insertions
|
||||||
|
if to_insert:
|
||||||
|
columns = (cls._id_column, cls._data_column)
|
||||||
|
values = [(id, value) for value in to_insert]
|
||||||
|
await table.insert_many(columns, *values)
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
new_current = [item for item in current + to_insert if item not in to_remove]
|
||||||
|
cls._cache[id] = new_current
|
||||||
|
else:
|
||||||
|
# Remove all and add all to preserve order
|
||||||
|
delete_params = {cls._id_column: id}
|
||||||
|
await table.delete_where(**delete_params)
|
||||||
|
|
||||||
|
if data:
|
||||||
|
columns = (cls._id_column, cls._data_column)
|
||||||
|
values = [(id, value) for value in data]
|
||||||
|
await table.insert_many(columns, *values)
|
||||||
|
|
||||||
|
if cls._cache is not None:
|
||||||
|
cls._cache[id] = data
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueData:
|
||||||
|
"""
|
||||||
|
Mixin for settings implemented in a Key-Value table.
|
||||||
|
The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair.
|
||||||
|
"""
|
||||||
|
_table_interface: Table
|
||||||
|
|
||||||
|
_id_column: str
|
||||||
|
|
||||||
|
_key_column: str
|
||||||
|
|
||||||
|
_value_column: str
|
||||||
|
|
||||||
|
_key: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _reader(cls, id, **kwargs):
|
||||||
|
params = {
|
||||||
|
cls._id_column: id,
|
||||||
|
cls._key_column: cls._key
|
||||||
|
}
|
||||||
|
|
||||||
|
row = await cls._table_interface.select_one_where(**params).select(cls._value_column)
|
||||||
|
data = row[cls._value_column] if row else None
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
data = json.loads(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _writer(cls, id, data, **kwargs):
|
||||||
|
params = {
|
||||||
|
cls._id_column: id,
|
||||||
|
cls._key_column: cls._key
|
||||||
|
}
|
||||||
|
if data is not None:
|
||||||
|
values = {
|
||||||
|
cls._value_column: json.dumps(data)
|
||||||
|
}
|
||||||
|
rows = await cls._table_interface.update_where(**params).set(**values)
|
||||||
|
if not rows:
|
||||||
|
await cls._table_interface.insert_many(
|
||||||
|
(cls._id_column, cls._key_column, cls._value_column),
|
||||||
|
(id, cls._key, json.dumps(data))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await cls._table_interface.delete_where(**params)
|
||||||
|
|
||||||
|
|
||||||
|
# class UserInputError(SafeCancellation):
|
||||||
|
# pass
|
||||||
204
src/settings/groups.py
Normal file
204
src/settings/groups.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
from typing import Generic, Type, TypeVar, Optional, overload
|
||||||
|
|
||||||
|
from data import RowModel
|
||||||
|
|
||||||
|
from .data import ModelData
|
||||||
|
from .ui import InteractiveSetting
|
||||||
|
from .base import BaseSetting
|
||||||
|
|
||||||
|
from utils.lib import tabulate
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T', bound=InteractiveSetting)
|
||||||
|
|
||||||
|
|
||||||
|
class SettingDotDict(Generic[T], dict[str, Type[T]]):
|
||||||
|
"""
|
||||||
|
Dictionary structure allowing simple dot access to items.
|
||||||
|
"""
|
||||||
|
__getattr__ = dict.__getitem__ # type: ignore
|
||||||
|
__setattr__ = dict.__setitem__ # type: ignore
|
||||||
|
__delattr__ = dict.__delitem__ # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class SettingGroup:
|
||||||
|
"""
|
||||||
|
A SettingGroup is a collection of settings under one name.
|
||||||
|
"""
|
||||||
|
__initial_settings__: list[Type[InteractiveSetting]] = []
|
||||||
|
|
||||||
|
_title: Optional[str] = None
|
||||||
|
_description: Optional[str] = None
|
||||||
|
|
||||||
|
def __init_subclass__(cls, title: Optional[str] = None):
|
||||||
|
cls._title = title or cls._title
|
||||||
|
cls._description = cls._description or cls.__doc__
|
||||||
|
|
||||||
|
settings: list[Type[InteractiveSetting]] = []
|
||||||
|
for item in cls.__dict__.values():
|
||||||
|
if isinstance(item, type) and issubclass(item, InteractiveSetting):
|
||||||
|
settings.append(item)
|
||||||
|
cls.__initial_settings__ = settings
|
||||||
|
|
||||||
|
def __init_settings__(self):
|
||||||
|
settings = SettingDotDict()
|
||||||
|
for setting in self.__initial_settings__:
|
||||||
|
settings[setting.__name__] = setting
|
||||||
|
return settings
|
||||||
|
|
||||||
|
def __init__(self, title=None, description=None) -> None:
|
||||||
|
self.title: str = title or self._title or self.__class__.__name__
|
||||||
|
self.description: str = description or self._description or ""
|
||||||
|
self.settings: SettingDotDict[InteractiveSetting] = self.__init_settings__()
|
||||||
|
|
||||||
|
def attach(self, cls: Type[T], name: Optional[str] = None):
|
||||||
|
name = name or cls.setting_id
|
||||||
|
self.settings[name] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def detach(self, cls):
|
||||||
|
return self.settings.pop(cls.__name__, None)
|
||||||
|
|
||||||
|
def update(self, smap):
|
||||||
|
self.settings.update(smap.settings)
|
||||||
|
|
||||||
|
def reduce(self, *keys):
|
||||||
|
for key in keys:
|
||||||
|
self.settings.pop(key, None)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def make_setting_table(self, parent_id, **kwargs):
|
||||||
|
"""
|
||||||
|
Convenience method for generating a rendered setting table.
|
||||||
|
"""
|
||||||
|
rows = []
|
||||||
|
for setting in self.settings.values():
|
||||||
|
if not setting._virtual:
|
||||||
|
set = await setting.get(parent_id, **kwargs)
|
||||||
|
name = set.display_name
|
||||||
|
value = str(set.formatted)
|
||||||
|
rows.append((name, value, set.hover_desc))
|
||||||
|
table_rows = tabulate(
|
||||||
|
*rows,
|
||||||
|
row_format="[`{invis}{key:<{pad}}{colon}`](https://lionbot.org \"{field[2]}\")\t{value}"
|
||||||
|
)
|
||||||
|
return '\n'.join(table_rows)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSetting(ModelData, BaseSetting):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
"""
|
||||||
|
A ModelConfig provides a central point of configuration for any object described by a single Model.
|
||||||
|
|
||||||
|
An instance of a ModelConfig represents configuration for a single object
|
||||||
|
(given by a single row of the corresponding Model).
|
||||||
|
|
||||||
|
The ModelConfig also supports registration of non-model configuration,
|
||||||
|
to support associated settings (e.g. list-settings) for the object.
|
||||||
|
|
||||||
|
This is an ABC, and must be subclassed for each object-type.
|
||||||
|
"""
|
||||||
|
settings: SettingDotDict
|
||||||
|
_model_settings: set
|
||||||
|
model: Type[RowModel]
|
||||||
|
|
||||||
|
def __init__(self, parent_id, row, **kwargs):
|
||||||
|
self.parent_id = parent_id
|
||||||
|
self.row = row
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_setting(cls, setting_cls):
|
||||||
|
"""
|
||||||
|
Decorator to register a non-model setting as part of the object configuration.
|
||||||
|
|
||||||
|
The setting class may be re-accessed through the `settings` class attr.
|
||||||
|
|
||||||
|
Subclasses may provide alternative access pathways to key non-model settings.
|
||||||
|
"""
|
||||||
|
cls.settings[setting_cls.setting_id] = setting_cls
|
||||||
|
return setting_cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_model_setting(cls, model_setting_cls):
|
||||||
|
"""
|
||||||
|
Decorator to register a model setting as part of the object configuration.
|
||||||
|
|
||||||
|
The setting class may be accessed through the `settings` class attr.
|
||||||
|
|
||||||
|
A fresh setting instance may also be retrieved (using cached data)
|
||||||
|
through the `get` instance method.
|
||||||
|
|
||||||
|
Subclasses are recommended to provide model settings as properties
|
||||||
|
for simplified access and type checking.
|
||||||
|
"""
|
||||||
|
cls._model_settings.add(model_setting_cls.setting_id)
|
||||||
|
return cls.register_setting(model_setting_cls)
|
||||||
|
|
||||||
|
def get(self, setting_id):
|
||||||
|
"""
|
||||||
|
Retrieve a freshly initialised copy of the given model-setting.
|
||||||
|
|
||||||
|
The given `setting_id` must have been previously registered through `register_model_setting`.
|
||||||
|
This uses cached data, and so is not guaranteed to be up-to-date.
|
||||||
|
"""
|
||||||
|
if setting_id not in self._model_settings:
|
||||||
|
# TODO: Log
|
||||||
|
raise ValueError
|
||||||
|
setting_cls = self.settings[setting_id]
|
||||||
|
data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs)
|
||||||
|
return setting_cls(self.parent_id, data, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSettings:
|
||||||
|
"""
|
||||||
|
A ModelSettings instance aggregates multiple `ModelSetting` instances
|
||||||
|
bound to the same parent id on a single Model.
|
||||||
|
|
||||||
|
This enables a single point of access
|
||||||
|
for settings of a given Model,
|
||||||
|
with support for caching or deriving as needed.
|
||||||
|
|
||||||
|
This is an abstract base class,
|
||||||
|
and should be subclassed to define the contained settings.
|
||||||
|
"""
|
||||||
|
_settings: SettingDotDict = SettingDotDict()
|
||||||
|
model: Type[RowModel]
|
||||||
|
|
||||||
|
def __init__(self, parent_id, row, **kwargs):
|
||||||
|
self.parent_id = parent_id
|
||||||
|
self.row = row
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls, *parent_id, **kwargs):
|
||||||
|
"""
|
||||||
|
Load an instance of this ModelSetting with the given parent_id
|
||||||
|
and setting keyword arguments.
|
||||||
|
"""
|
||||||
|
row = await cls.model.fetch_or_create(*parent_id)
|
||||||
|
return cls(parent_id, row, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def attach(self, setting_cls):
|
||||||
|
"""
|
||||||
|
Decorator to attach the given setting class to this modelsetting.
|
||||||
|
"""
|
||||||
|
# This violates the interface principle, use structured typing instead?
|
||||||
|
if not (issubclass(setting_cls, BaseSetting) and issubclass(setting_cls, ModelData)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The provided setting class must be `ModelSetting`, not {setting_cls.__class__.__name__}."
|
||||||
|
)
|
||||||
|
self._settings[setting_cls.setting_id] = setting_cls
|
||||||
|
return setting_cls
|
||||||
|
|
||||||
|
def get(self, setting_id):
|
||||||
|
setting_cls = self._settings.get(setting_id)
|
||||||
|
data = setting_cls._read_from_row(self.parent_id, self.row, **self.kwargs)
|
||||||
|
return setting_cls(self.parent_id, data, **self.kwargs)
|
||||||
|
|
||||||
|
def __getitem__(self, setting_id):
|
||||||
|
return self.get(setting_id)
|
||||||
13
src/settings/mock.py
Normal file
13
src/settings/mock.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
|
|
||||||
|
|
||||||
|
class LocalString:
|
||||||
|
def __init__(self, string):
|
||||||
|
self.string = string
|
||||||
|
|
||||||
|
def as_string(self):
|
||||||
|
return self.string
|
||||||
|
|
||||||
|
|
||||||
|
_ = LocalString
|
||||||
1393
src/settings/setting_types.py
Normal file
1393
src/settings/setting_types.py
Normal file
File diff suppressed because it is too large
Load Diff
512
src/settings/ui.py
Normal file
512
src/settings/ui.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
from typing import Optional, Callable, Any, Dict, Coroutine, Generic, TypeVar, List
|
||||||
|
import asyncio
|
||||||
|
from contextvars import copy_context
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import ui
|
||||||
|
from discord.ui.button import ButtonStyle, Button, button
|
||||||
|
from discord.ui.modal import Modal
|
||||||
|
from discord.ui.text_input import TextInput
|
||||||
|
from meta.errors import UserInputError
|
||||||
|
|
||||||
|
from utils.lib import tabulate, recover_context
|
||||||
|
from utils.ui import FastModal
|
||||||
|
from meta.config import conf
|
||||||
|
from meta.context import ctx_bot
|
||||||
|
from babel.translator import ctx_translator, LazyStr
|
||||||
|
|
||||||
|
from .base import BaseSetting, ParentID, SettingData, SettingValue
|
||||||
|
from . import babel
|
||||||
|
|
||||||
|
_p = babel._p
|
||||||
|
|
||||||
|
|
||||||
|
ST = TypeVar('ST', bound='InteractiveSetting')
|
||||||
|
|
||||||
|
|
||||||
|
class SettingModal(FastModal):
|
||||||
|
input_field: TextInput = TextInput(label="Edit Setting")
|
||||||
|
|
||||||
|
def update_field(self, new_field):
|
||||||
|
self.remove_item(self.input_field)
|
||||||
|
self.add_item(new_field)
|
||||||
|
self.input_field = new_field
|
||||||
|
|
||||||
|
|
||||||
|
class SettingWidget(Generic[ST], ui.View):
|
||||||
|
# TODO: Permission restrictions and callback!
|
||||||
|
# Context variables for permitted user(s)? Subclass ui.View with PermittedView?
|
||||||
|
# Don't need to descend permissions to Modal
|
||||||
|
# Maybe combine with timeout manager
|
||||||
|
|
||||||
|
def __init__(self, setting: ST, auto_write=True, **kwargs):
|
||||||
|
self.setting = setting
|
||||||
|
self.update_children()
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.auto_write = auto_write
|
||||||
|
|
||||||
|
self._interaction: Optional[discord.Interaction] = None
|
||||||
|
self._modal: Optional[SettingModal] = None
|
||||||
|
self._exports: List[ui.Item] = self.make_exports()
|
||||||
|
|
||||||
|
self._context = copy_context()
|
||||||
|
|
||||||
|
def update_children(self):
|
||||||
|
"""
|
||||||
|
Method called before base View initialisation.
|
||||||
|
Allows updating the children components (usually explicitly defined callbacks),
|
||||||
|
before Item instantiation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def order_children(self, *children):
|
||||||
|
"""
|
||||||
|
Helper method to set and order the children using bound methods.
|
||||||
|
"""
|
||||||
|
child_map = {child.__name__: child for child in self.__view_children_items__}
|
||||||
|
self.__view_children_items__ = [child_map[child.__name__] for child in children]
|
||||||
|
|
||||||
|
def update_child(self, child, new_args):
|
||||||
|
args = getattr(child, '__discord_ui_model_kwargs__')
|
||||||
|
args |= new_args
|
||||||
|
|
||||||
|
def make_exports(self):
|
||||||
|
"""
|
||||||
|
Called post-instantiation to populate self._exports.
|
||||||
|
"""
|
||||||
|
return self.children
|
||||||
|
|
||||||
|
def refresh(self):
|
||||||
|
"""
|
||||||
|
Update widget components from current setting data, if applicable.
|
||||||
|
E.g. to update the default entry in a select list after a choice has been made,
|
||||||
|
or update button colours.
|
||||||
|
This does not trigger a discord ui update,
|
||||||
|
that is the responsibility of the interaction handler.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def show(self, interaction: discord.Interaction, key: Any = None, override=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Complete standard setting widget UI flow for this setting.
|
||||||
|
The SettingWidget components may be attached to other messages as needed,
|
||||||
|
and they may be triggered individually,
|
||||||
|
but this coroutine defines the standard interface.
|
||||||
|
Intended for use by any interaction which wants to "open the setting".
|
||||||
|
|
||||||
|
Extra keyword arguments are passed directly to the interaction reply (for e.g. ephemeral).
|
||||||
|
"""
|
||||||
|
if key is None:
|
||||||
|
# By default, only have one widget listener per interaction.
|
||||||
|
key = ('widget', interaction.id)
|
||||||
|
|
||||||
|
# If there is already a widget listening on this key, respect override
|
||||||
|
if self.setting.get_listener(key) and not override:
|
||||||
|
# Refuse to spawn another widget
|
||||||
|
return
|
||||||
|
|
||||||
|
async def update_callback(new_data):
|
||||||
|
self.setting.data = new_data
|
||||||
|
await interaction.edit_original_response(embed=self.setting.embed, view=self, **kwargs)
|
||||||
|
|
||||||
|
self.setting.register_callback(key)(update_callback)
|
||||||
|
await interaction.response.send_message(embed=self.setting.embed, view=self, **kwargs)
|
||||||
|
await self.wait()
|
||||||
|
try:
|
||||||
|
# Try and detach the view, since we aren't handling events anymore.
|
||||||
|
await interaction.edit_original_response(view=None)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
self.setting.deregister_callback(key)
|
||||||
|
|
||||||
|
def attach(self, group_view: ui.View):
|
||||||
|
"""
|
||||||
|
Attach this setting widget to a view representing several settings.
|
||||||
|
"""
|
||||||
|
for item in self._exports:
|
||||||
|
group_view.add_item(item)
|
||||||
|
|
||||||
|
@button(style=ButtonStyle.secondary, label="Edit", row=4)
|
||||||
|
async def edit_button(self, interaction: discord.Interaction, button: ui.Button):
|
||||||
|
"""
|
||||||
|
Spawn a simple edit modal,
|
||||||
|
populated with `setting.input_field`.
|
||||||
|
"""
|
||||||
|
recover_context(self._context)
|
||||||
|
# Spawn the setting modal
|
||||||
|
await interaction.response.send_modal(self.modal)
|
||||||
|
|
||||||
|
@button(style=ButtonStyle.danger, label="Reset", row=4)
|
||||||
|
async def reset_button(self, interaction: discord.Interaction, button: Button):
|
||||||
|
recover_context(self._context)
|
||||||
|
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
await self.setting.interactive_set(None, interaction)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def modal(self) -> Modal:
|
||||||
|
"""
|
||||||
|
Build a Modal dialogue for updating the setting.
|
||||||
|
Refreshes (and re-attaches) the input field each time this is called.
|
||||||
|
"""
|
||||||
|
if self._modal is not None:
|
||||||
|
self._modal.update_field(self.setting.input_field)
|
||||||
|
return self._modal
|
||||||
|
|
||||||
|
# TODO: Attach shared timeouts to the modal
|
||||||
|
self._modal = modal = SettingModal(
|
||||||
|
title=f"Edit {self.setting.display_name}",
|
||||||
|
)
|
||||||
|
modal.update_field(self.setting.input_field)
|
||||||
|
|
||||||
|
@modal.submit_callback()
|
||||||
|
async def edit_submit(interaction: discord.Interaction):
|
||||||
|
# TODO: Catch and handle UserInputError
|
||||||
|
await interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
data = await self.setting._parse_string(self.setting.parent_id, modal.input_field.value)
|
||||||
|
await self.setting.interactive_set(data, interaction)
|
||||||
|
|
||||||
|
return modal
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]):
|
||||||
|
__slots__ = ('_widget',)
|
||||||
|
|
||||||
|
# Configuration interface descriptions
|
||||||
|
_display_name: LazyStr # User readable name of the setting
|
||||||
|
_desc: LazyStr # User readable brief description of the setting
|
||||||
|
_long_desc: LazyStr # User readable long description of the setting
|
||||||
|
_accepts: LazyStr # User readable description of the acceptable values
|
||||||
|
_set_cmd: str = None
|
||||||
|
_notset_str: LazyStr = _p('setting|formatted|notset', "Not Set")
|
||||||
|
_virtual: bool = False # Whether the setting should be hidden from tables and dashboards
|
||||||
|
_required: bool = False
|
||||||
|
|
||||||
|
Widget = SettingWidget
|
||||||
|
|
||||||
|
# A list of callback coroutines to call when the setting updates
|
||||||
|
# This can be used globally to refresh state when the setting updates,
|
||||||
|
# Or locallly to e.g. refresh an active widget.
|
||||||
|
# The callbacks are called on write, so they may be bypassed by direct use of _writer!
|
||||||
|
_listeners_: Dict[Any, Callable[[Optional[SettingData]], Coroutine[Any, Any, None]]] = {}
|
||||||
|
|
||||||
|
# Optional client event to dispatch when theis setting has been written
|
||||||
|
# Event handlers should be of the form Callable[ParentID, SettingData]
|
||||||
|
_event: Optional[str] = None
|
||||||
|
|
||||||
|
# Interaction ward that should be validated via interaction_check
|
||||||
|
_write_ward: Optional[Callable[[discord.Interaction], Coroutine[Any, Any, bool]]] = None
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._widget: Optional[SettingWidget] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def long_desc(self):
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
bot = ctx_bot.get()
|
||||||
|
return t(self._long_desc).format(
|
||||||
|
bot=bot,
|
||||||
|
cmds=bot.core.mention_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self):
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
return t(self._display_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def desc(self):
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
return t(self._desc)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accepts(self):
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
return t(self._accepts)
|
||||||
|
|
||||||
|
async def write(self, **kwargs) -> None:
|
||||||
|
await super().write(**kwargs)
|
||||||
|
self.dispatch_update()
|
||||||
|
for listener in self._listeners_.values():
|
||||||
|
asyncio.create_task(listener(self.data))
|
||||||
|
|
||||||
|
def dispatch_update(self):
|
||||||
|
"""
|
||||||
|
Dispatch a client event along `self._event`, if set.
|
||||||
|
|
||||||
|
Override to modify the target event handler arguments.
|
||||||
|
By default, event handlers should be of the form:
|
||||||
|
Callable[[ParentID, SettingData], Coroutine[Any, Any, None]]
|
||||||
|
"""
|
||||||
|
if self._event is not None and (bot := ctx_bot.get()) is not None:
|
||||||
|
bot.dispatch(self._event, self.parent_id, self)
|
||||||
|
|
||||||
|
def get_listener(self, key):
|
||||||
|
return self._listeners_.get(key, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_callback(cls, name=None):
|
||||||
|
def wrapped(coro):
|
||||||
|
cls._listeners_[name or coro.__name__] = coro
|
||||||
|
return coro
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deregister_callback(cls, name):
|
||||||
|
cls._listeners_.pop(name, None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def update_message(self):
|
||||||
|
"""
|
||||||
|
Response message sent when the setting has successfully been updated.
|
||||||
|
Should generally be one line.
|
||||||
|
"""
|
||||||
|
if self.data is None:
|
||||||
|
return "Setting reset!"
|
||||||
|
else:
|
||||||
|
return f"Setting Updated! New value: {self.formatted}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hover_desc(self):
|
||||||
|
"""
|
||||||
|
This no longer works since Discord changed the hover rules.
|
||||||
|
|
||||||
|
return '\n'.join((
|
||||||
|
self.display_name,
|
||||||
|
'=' * len(self.display_name),
|
||||||
|
self.desc,
|
||||||
|
f"\nAccepts: {self.accepts}"
|
||||||
|
))
|
||||||
|
"""
|
||||||
|
return self.desc
|
||||||
|
|
||||||
|
async def update_response(self, interaction: discord.Interaction, message: Optional[str] = None, **kwargs):
|
||||||
|
"""
|
||||||
|
Respond to an interaction which triggered a setting update.
|
||||||
|
Usually just wraps `update_message` in an embed and sends it back.
|
||||||
|
Passes any extra `kwargs` to the message creation method.
|
||||||
|
"""
|
||||||
|
embed = discord.Embed(
|
||||||
|
description=f"{str(conf.emojis.tick)} {message or self.update_message}",
|
||||||
|
colour=discord.Color.green()
|
||||||
|
)
|
||||||
|
if interaction.response.is_done():
|
||||||
|
await interaction.edit_original_response(embed=embed, **kwargs)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(embed=embed, **kwargs)
|
||||||
|
|
||||||
|
async def interactive_set(self, new_data: Optional[SettingData], interaction: discord.Interaction, **kwargs):
|
||||||
|
self.data = new_data
|
||||||
|
await self.write()
|
||||||
|
await self.update_response(interaction, **kwargs)
|
||||||
|
|
||||||
|
async def format_in(self, bot, **kwargs):
|
||||||
|
"""
|
||||||
|
Formatted version of the setting given an asynchronous context with client.
|
||||||
|
"""
|
||||||
|
return self.formatted
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed_field(self):
|
||||||
|
"""
|
||||||
|
Returns a {name, value} pair for use in an Embed field.
|
||||||
|
"""
|
||||||
|
name = self.display_name
|
||||||
|
value = f"{self.long_desc}\n{self.desc_table}"
|
||||||
|
if len(value) > 1024:
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
desc_table = '\n'.join(
|
||||||
|
tabulate(
|
||||||
|
*self._desc_table(
|
||||||
|
show_value=t(_p(
|
||||||
|
'setting|embed_field|too_long',
|
||||||
|
"Too long to display here!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
value = f"{self.long_desc}\n{desc_table}"
|
||||||
|
if len(value) > 1024:
|
||||||
|
# Forcibly trim
|
||||||
|
value = value[:1020] + '...'
|
||||||
|
return {'name': name, 'value': value}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def set_str(self):
|
||||||
|
if self._set_cmd is not None:
|
||||||
|
bot = ctx_bot.get()
|
||||||
|
if bot:
|
||||||
|
return bot.core.mention_cmd(self._set_cmd)
|
||||||
|
else:
|
||||||
|
return f"`/{self._set_cmd}`"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def notset_str(self):
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
return t(self._notset_str)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self):
|
||||||
|
"""
|
||||||
|
Returns a full embed describing this setting.
|
||||||
|
"""
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=t(_p(
|
||||||
|
'setting|summary_embed|title',
|
||||||
|
"Configuration options for `{name}`"
|
||||||
|
)).format(name=self.display_name),
|
||||||
|
)
|
||||||
|
embed.description = "{}\n{}".format(self.long_desc.format(self=self), self.desc_table)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# Currently line
|
||||||
|
lines.append((
|
||||||
|
t(_p('setting|summary_table|field:currently|key', "Currently")),
|
||||||
|
show_value or (self.formatted or self.notset_str)
|
||||||
|
))
|
||||||
|
|
||||||
|
# Default line
|
||||||
|
if (default := self.default) is not None:
|
||||||
|
lines.append((
|
||||||
|
t(_p('setting|summary_table|field:default|key', "By Default")),
|
||||||
|
self._format_data(self.parent_id, default) or 'None'
|
||||||
|
))
|
||||||
|
|
||||||
|
# Set using line
|
||||||
|
if (set_str := self.set_str) is not None:
|
||||||
|
lines.append((
|
||||||
|
t(_p('setting|summary_table|field:set|key', "Set Using")),
|
||||||
|
set_str
|
||||||
|
))
|
||||||
|
return lines
|
||||||
|
|
||||||
|
@property
|
||||||
|
def desc_table(self) -> str:
|
||||||
|
return '\n'.join(tabulate(*self._desc_table()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_field(self) -> TextInput:
|
||||||
|
"""
|
||||||
|
TextInput field used for string-based setting modification.
|
||||||
|
May be added to external modal for grouped setting editing.
|
||||||
|
This property is not persistent, and creates a new field each time.
|
||||||
|
"""
|
||||||
|
return TextInput(
|
||||||
|
label=self.display_name,
|
||||||
|
placeholder=self.accepts,
|
||||||
|
default=self.input_formatted[:4000] if self.input_formatted else None,
|
||||||
|
required=self._required
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def widget(self):
|
||||||
|
"""
|
||||||
|
Returns the Discord UI View associated with the current setting.
|
||||||
|
"""
|
||||||
|
if self._widget is None:
|
||||||
|
self._widget = self.Widget(self)
|
||||||
|
return self._widget
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_widget(cls, WidgetCls):
|
||||||
|
"""
|
||||||
|
Convenience decorator to create the widget class for this setting.
|
||||||
|
"""
|
||||||
|
cls.Widget = WidgetCls
|
||||||
|
return WidgetCls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def formatted(self):
|
||||||
|
"""
|
||||||
|
Default user-readable form of the setting.
|
||||||
|
Should be a short single line.
|
||||||
|
"""
|
||||||
|
return self._format_data(self.parent_id, self.data, **self.kwargs) or self.notset_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_formatted(self) -> str:
|
||||||
|
"""
|
||||||
|
Format the current value as a default value for an input field.
|
||||||
|
Returned string must be acceptable through parse_string.
|
||||||
|
Does not take into account defaults.
|
||||||
|
"""
|
||||||
|
if self._data is not None:
|
||||||
|
return str(self._data)
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def summary(self):
|
||||||
|
"""
|
||||||
|
Formatted summary of the data.
|
||||||
|
May be implemented in `_format_data(..., summary=True, ...)` or overidden.
|
||||||
|
"""
|
||||||
|
return self._format_data(self.parent_id, self.data, summary=True, **self.kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_string(cls, parent_id, userstr: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Return a setting instance initialised from a parsed user string.
|
||||||
|
"""
|
||||||
|
data = await cls._parse_string(parent_id, userstr, **kwargs)
|
||||||
|
return cls(parent_id, data, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def from_value(cls, parent_id, value, **kwargs):
|
||||||
|
await cls._check_value(parent_id, value, **kwargs)
|
||||||
|
data = cls._data_from_value(parent_id, value, **kwargs)
|
||||||
|
return cls(parent_id, data, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _parse_string(cls, parent_id, string: str, **kwargs) -> Optional[SettingData]:
|
||||||
|
"""
|
||||||
|
Parse user provided string (usually from a TextInput) into raw setting data.
|
||||||
|
Must be overriden by the setting if the setting is user-configurable.
|
||||||
|
Returns None if the setting was unset.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_data(cls, parent_id, data, **kwargs):
|
||||||
|
"""
|
||||||
|
Convert raw setting data into a formatted user-readable string,
|
||||||
|
representing the current value.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _check_value(cls, parent_id, value, **kwargs):
|
||||||
|
"""
|
||||||
|
Check the provided value is valid.
|
||||||
|
|
||||||
|
Many setting update methods now provide Discord objects instead of raw data or user strings.
|
||||||
|
This method may be used for value-checking such a value.
|
||||||
|
|
||||||
|
Raises UserInputError if the value fails validation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def interaction_check(cls, parent_id, interaction: discord.Interaction, **kwargs):
|
||||||
|
if cls._write_ward is not None and not await cls._write_ward(interaction):
|
||||||
|
# TODO: Combine the check system so we can do customised errors here
|
||||||
|
t = ctx_translator.get().t
|
||||||
|
raise UserInputError(t(_p(
|
||||||
|
'setting|interaction_check|error',
|
||||||
|
"You do not have sufficient permissions to do this!"
|
||||||
|
)))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
command callback for set command?
|
||||||
|
autocomplete for set command?
|
||||||
|
|
||||||
|
Might be better in a ConfigSetting subclass.
|
||||||
|
But also mix into the base setting types.
|
||||||
|
"""
|
||||||
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
97
src/utils/ansi.py
Normal file
97
src/utils/ansi.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Minimal library for making Discord Ansi colour codes.
|
||||||
|
"""
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
PREFIX = u'\u001b'
|
||||||
|
|
||||||
|
|
||||||
|
class TextColour(StrEnum):
|
||||||
|
Gray = '30'
|
||||||
|
Red = '31'
|
||||||
|
Green = '32'
|
||||||
|
Yellow = '33'
|
||||||
|
Blue = '34'
|
||||||
|
Pink = '35'
|
||||||
|
Cyan = '36'
|
||||||
|
White = '37'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return AnsiColour(fg=self).as_str()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return AnsiColour(fg=self)
|
||||||
|
|
||||||
|
|
||||||
|
class BgColour(StrEnum):
|
||||||
|
FireflyDarkBlue = '40'
|
||||||
|
Orange = '41'
|
||||||
|
MarbleBlue = '42'
|
||||||
|
GrayTurq = '43'
|
||||||
|
Gray = '44'
|
||||||
|
Indigo = '45'
|
||||||
|
LightGray = '46'
|
||||||
|
White = '47'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return AnsiColour(bg=self).as_str()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return AnsiColour(bg=self)
|
||||||
|
|
||||||
|
|
||||||
|
class Format(StrEnum):
|
||||||
|
NORMAL = '0'
|
||||||
|
BOLD = '1'
|
||||||
|
UNDERLINE = '4'
|
||||||
|
NOOP = '9'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return AnsiColour(self).as_str()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return AnsiColour(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AnsiColour:
|
||||||
|
def __init__(self, *flags, fg=None, bg=None):
|
||||||
|
self.text_colour = fg
|
||||||
|
self.background_colour = bg
|
||||||
|
self.reset = (Format.NORMAL in flags)
|
||||||
|
self._flags = set(flags)
|
||||||
|
self._flags.discard(Format.NORMAL)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flags(self):
|
||||||
|
return (*((Format.NORMAL,) if self.reset else ()), *self._flags)
|
||||||
|
|
||||||
|
def as_str(self):
|
||||||
|
parts = []
|
||||||
|
if self.reset:
|
||||||
|
parts.append(Format.NORMAL)
|
||||||
|
elif not self.flags:
|
||||||
|
parts.append(Format.NOOP)
|
||||||
|
|
||||||
|
parts.extend(self._flags)
|
||||||
|
|
||||||
|
for c in (self.text_colour, self.background_colour):
|
||||||
|
if c is not None:
|
||||||
|
parts.append(c)
|
||||||
|
|
||||||
|
partstr = ';'.join(part.value for part in parts)
|
||||||
|
return f"{PREFIX}[{partstr}m" # ]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.as_str()
|
||||||
|
|
||||||
|
def __add__(self, obj: 'AnsiColour'):
|
||||||
|
text_colour = obj.text_colour or self.text_colour
|
||||||
|
background_colour = obj.background_colour or self.background_colour
|
||||||
|
flags = (*self.flags, *obj.flags)
|
||||||
|
return AnsiColour(*flags, fg=text_colour, bg=background_colour)
|
||||||
|
|
||||||
|
|
||||||
|
RESET = AnsiColour(Format.NORMAL)
|
||||||
|
BOLD = AnsiColour(Format.BOLD)
|
||||||
|
UNDERLINE = AnsiColour(Format.UNDERLINE)
|
||||||
165
src/utils/data.py
Normal file
165
src/utils/data.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Some useful pre-built Conditions for data queries.
|
||||||
|
"""
|
||||||
|
from typing import Optional, Any
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from psycopg import sql
|
||||||
|
from data.conditions import Condition, Joiner
|
||||||
|
from data.columns import ColumnExpr
|
||||||
|
from data.base import Expression
|
||||||
|
from constants import MAX_COINS
|
||||||
|
|
||||||
|
|
||||||
|
def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[Any, ...]) -> Condition:
|
||||||
|
"""
|
||||||
|
Condition constructor for filtering by multiple column equalities.
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
-------------
|
||||||
|
Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4)))
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
raise ValueError("Cannot create empty multivalue condition.")
|
||||||
|
left = sql.SQL("({})").format(
|
||||||
|
sql.SQL(', ').join(
|
||||||
|
sql.Identifier(key)
|
||||||
|
for key in columns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
right_item = sql.SQL('({})').format(
|
||||||
|
sql.SQL(', ').join(
|
||||||
|
sql.Placeholder()
|
||||||
|
for _ in columns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
right = sql.SQL("({})").format(
|
||||||
|
sql.SQL(', ').join(
|
||||||
|
right_item
|
||||||
|
for _ in data
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return Condition(
|
||||||
|
left,
|
||||||
|
Joiner.IN,
|
||||||
|
right,
|
||||||
|
chain(*data)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition:
|
||||||
|
"""
|
||||||
|
Condition constructor for filtering member tables by guild and user id simultaneously.
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
-------------
|
||||||
|
Query.where(MEMBERS((1234,12), (5678,34)))
|
||||||
|
"""
|
||||||
|
if not memberids:
|
||||||
|
raise ValueError("Cannot create a condition with no members")
|
||||||
|
return Condition(
|
||||||
|
sql.SQL("({guildid}, {userid})").format(
|
||||||
|
guildid=sql.Identifier(guild_column),
|
||||||
|
userid=sql.Identifier(user_column)
|
||||||
|
),
|
||||||
|
Joiner.IN,
|
||||||
|
sql.SQL("({})").format(
|
||||||
|
sql.SQL(', ').join(
|
||||||
|
sql.SQL("({}, {})").format(
|
||||||
|
sql.Placeholder(),
|
||||||
|
sql.Placeholder()
|
||||||
|
) for _ in memberids
|
||||||
|
)
|
||||||
|
),
|
||||||
|
chain(*memberids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def as_duration(expr: Expression) -> ColumnExpr:
|
||||||
|
"""
|
||||||
|
Convert an integer expression into a duration expression.
|
||||||
|
"""
|
||||||
|
expr_expr, expr_values = expr.as_tuple()
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} * interval '1 second')").format(expr_expr),
|
||||||
|
expr_values
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TemporaryTable(Expression):
|
||||||
|
"""
|
||||||
|
Create a temporary table expression to be used in From or With clauses.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
```
|
||||||
|
tmp_table = TemporaryTable('_col1', '_col2', name='data')
|
||||||
|
tmp_table.values((1, 2), (3, 4))
|
||||||
|
|
||||||
|
real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str, ...]] = None):
|
||||||
|
self.name = name
|
||||||
|
self.columns = columns
|
||||||
|
self.types = types
|
||||||
|
if types and len(types) != len(columns):
|
||||||
|
raise ValueError("Number of types does not much number of columns!")
|
||||||
|
|
||||||
|
self._table_columns = {
|
||||||
|
col: ColumnExpr(sql.Identifier(name, col))
|
||||||
|
for col in columns
|
||||||
|
}
|
||||||
|
|
||||||
|
self.values = []
|
||||||
|
|
||||||
|
def __getitem__(self, key) -> sql.Identifier:
|
||||||
|
return self._table_columns[key]
|
||||||
|
|
||||||
|
def as_tuple(self):
|
||||||
|
"""
|
||||||
|
(VALUES {})
|
||||||
|
AS
|
||||||
|
name (col1, col2)
|
||||||
|
"""
|
||||||
|
if not self.values:
|
||||||
|
raise ValueError("Cannot flatten CTE with no values.")
|
||||||
|
|
||||||
|
single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns))
|
||||||
|
if self.types:
|
||||||
|
first_value = sql.SQL("({})").format(
|
||||||
|
sql.SQL(", ").join(
|
||||||
|
sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast))
|
||||||
|
for cast in self.types
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
first_value = single_value
|
||||||
|
|
||||||
|
value_placeholder = sql.SQL("(VALUES {})").format(
|
||||||
|
sql.SQL(", ").join(
|
||||||
|
(first_value, *(single_value for _ in self.values[1:]))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
expr = sql.SQL("{values} AS {name} ({columns})").format(
|
||||||
|
values=value_placeholder,
|
||||||
|
name=sql.Identifier(self.name),
|
||||||
|
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns)
|
||||||
|
)
|
||||||
|
values = chain(*self.values)
|
||||||
|
return (expr, values)
|
||||||
|
|
||||||
|
def set_values(self, *data):
|
||||||
|
self.values = data
|
||||||
|
|
||||||
|
|
||||||
|
def SAFECOINS(expr: Expression) -> Expression:
|
||||||
|
expr_expr, expr_values = expr.as_tuple()
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("LEAST({}, {})").format(
|
||||||
|
expr_expr,
|
||||||
|
sql.Literal(MAX_COINS)
|
||||||
|
),
|
||||||
|
expr_values
|
||||||
|
)
|
||||||
879
src/utils/lib.py
Normal file
879
src/utils/lib.py
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
from io import StringIO
|
||||||
|
from typing import NamedTuple, Optional, Sequence, Union, overload, List, Any
|
||||||
|
import collections
|
||||||
|
import datetime
|
||||||
|
import datetime as dt
|
||||||
|
import iso8601 # type: ignore
|
||||||
|
import pytz
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from contextvars import Context
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.partial_emoji import _EmojiTag
|
||||||
|
from discord import Embed, File, GuildSticker, StickerItem, AllowedMentions, Message, MessageReference, PartialMessage
|
||||||
|
from discord.ui import View
|
||||||
|
|
||||||
|
from meta.errors import UserInputError
|
||||||
|
|
||||||
|
|
||||||
|
multiselect_regex = re.compile(
|
||||||
|
r"^([0-9, -]+)$",
|
||||||
|
re.DOTALL | re.IGNORECASE | re.VERBOSE
|
||||||
|
)
|
||||||
|
tick = '✅'
|
||||||
|
cross = '❌'
|
||||||
|
|
||||||
|
MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
class MessageArgs:
|
||||||
|
"""
|
||||||
|
Utility class for storing message creation and editing arguments.
|
||||||
|
"""
|
||||||
|
# TODO: Overrides for mutually exclusive arguments, see Messageable.send
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = ...,
|
||||||
|
*,
|
||||||
|
tts: bool = ...,
|
||||||
|
embed: Embed = ...,
|
||||||
|
file: File = ...,
|
||||||
|
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||||
|
delete_after: float = ...,
|
||||||
|
nonce: Union[str, int] = ...,
|
||||||
|
allowed_mentions: AllowedMentions = ...,
|
||||||
|
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||||
|
mention_author: bool = ...,
|
||||||
|
view: View = ...,
|
||||||
|
suppress_embeds: bool = ...,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = ...,
|
||||||
|
*,
|
||||||
|
tts: bool = ...,
|
||||||
|
embed: Embed = ...,
|
||||||
|
files: Sequence[File] = ...,
|
||||||
|
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||||
|
delete_after: float = ...,
|
||||||
|
nonce: Union[str, int] = ...,
|
||||||
|
allowed_mentions: AllowedMentions = ...,
|
||||||
|
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||||
|
mention_author: bool = ...,
|
||||||
|
view: View = ...,
|
||||||
|
suppress_embeds: bool = ...,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = ...,
|
||||||
|
*,
|
||||||
|
tts: bool = ...,
|
||||||
|
embeds: Sequence[Embed] = ...,
|
||||||
|
file: File = ...,
|
||||||
|
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||||
|
delete_after: float = ...,
|
||||||
|
nonce: Union[str, int] = ...,
|
||||||
|
allowed_mentions: AllowedMentions = ...,
|
||||||
|
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||||
|
mention_author: bool = ...,
|
||||||
|
view: View = ...,
|
||||||
|
suppress_embeds: bool = ...,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Optional[str] = ...,
|
||||||
|
*,
|
||||||
|
tts: bool = ...,
|
||||||
|
embeds: Sequence[Embed] = ...,
|
||||||
|
files: Sequence[File] = ...,
|
||||||
|
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
|
||||||
|
delete_after: float = ...,
|
||||||
|
nonce: Union[str, int] = ...,
|
||||||
|
allowed_mentions: AllowedMentions = ...,
|
||||||
|
reference: Union[Message, MessageReference, PartialMessage] = ...,
|
||||||
|
mention_author: bool = ...,
|
||||||
|
view: View = ...,
|
||||||
|
suppress_embeds: bool = ...,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def send_args(self) -> dict:
|
||||||
|
if self.kwargs.get('view', MISSING) is None:
|
||||||
|
kwargs = self.kwargs.copy()
|
||||||
|
kwargs.pop('view')
|
||||||
|
else:
|
||||||
|
kwargs = self.kwargs
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edit_args(self) -> dict:
|
||||||
|
args = {}
|
||||||
|
kept = (
|
||||||
|
'content', 'embed', 'embeds', 'delete_after', 'allowed_mentions', 'view'
|
||||||
|
)
|
||||||
|
for k in kept:
|
||||||
|
if k in self.kwargs:
|
||||||
|
args[k] = self.kwargs[k]
|
||||||
|
|
||||||
|
if 'file' in self.kwargs:
|
||||||
|
args['attachments'] = [self.kwargs['file']]
|
||||||
|
|
||||||
|
if 'files' in self.kwargs:
|
||||||
|
args['attachments'] = self.kwargs['files']
|
||||||
|
|
||||||
|
if 'suppress_embeds' in self.kwargs:
|
||||||
|
args['suppress'] = self.kwargs['suppress_embeds']
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def tabulate(
|
||||||
|
*fields: tuple[str, str],
|
||||||
|
row_format: str = "`{invis}{key:<{pad}}{colon}`\t{value}",
|
||||||
|
sub_format: str = "`{invis:<{pad}}{colon}`\t{value}",
|
||||||
|
colon: str = ':',
|
||||||
|
invis: str = "",
|
||||||
|
**args
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Turns a list of (property, value) pairs into
|
||||||
|
a pretty string with one `prop: value` pair each line,
|
||||||
|
padded so that the colons in each line are lined up.
|
||||||
|
Use `\\r\\n` in a value to break the line with padding.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fields: List[tuple[str, str]]
|
||||||
|
List of (key, value) pairs.
|
||||||
|
row_format: str
|
||||||
|
The format string used to format each row.
|
||||||
|
sub_format: str
|
||||||
|
The format string used to format each subline in a row.
|
||||||
|
colon: str
|
||||||
|
The colon character used.
|
||||||
|
invis: str
|
||||||
|
The invisible character used (to avoid Discord stripping the string).
|
||||||
|
|
||||||
|
Returns: List[str]
|
||||||
|
The list of resulting table rows.
|
||||||
|
Each row corresponds to one (key, value) pair from fields.
|
||||||
|
"""
|
||||||
|
max_len = max(len(field[0]) for field in fields)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for field in fields:
|
||||||
|
key = field[0]
|
||||||
|
value = field[1]
|
||||||
|
lines = value.split('\r\n')
|
||||||
|
|
||||||
|
row_line = row_format.format(
|
||||||
|
invis=invis,
|
||||||
|
key=key,
|
||||||
|
pad=max_len,
|
||||||
|
colon=colon,
|
||||||
|
value=lines[0],
|
||||||
|
field=field,
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
if len(lines) > 1:
|
||||||
|
row_lines = [row_line]
|
||||||
|
for line in lines[1:]:
|
||||||
|
sub_line = sub_format.format(
|
||||||
|
invis=invis,
|
||||||
|
pad=max_len + len(colon),
|
||||||
|
colon=colon,
|
||||||
|
value=line,
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
row_lines.append(sub_line)
|
||||||
|
row_line = '\n'.join(row_lines)
|
||||||
|
rows.append(row_line)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def paginate_list(item_list: list[str], block_length=20, style="markdown", title=None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Create pretty codeblock pages from a list of strings.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
item_list: List[str]
|
||||||
|
List of strings to paginate.
|
||||||
|
block_length: int
|
||||||
|
Maximum number of strings per page.
|
||||||
|
style: str
|
||||||
|
Codeblock style to use.
|
||||||
|
Title formatting assumes the `markdown` style, and numbered lists work well with this.
|
||||||
|
However, `markdown` sometimes messes up formatting in the list.
|
||||||
|
title: str
|
||||||
|
Optional title to add to the top of each page.
|
||||||
|
|
||||||
|
Returns: List[str]
|
||||||
|
List of pages, each formatted into a codeblock,
|
||||||
|
and containing at most `block_length` of the provided strings.
|
||||||
|
"""
|
||||||
|
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
|
||||||
|
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
|
||||||
|
pages = []
|
||||||
|
for i, block in enumerate(page_blocks):
|
||||||
|
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
|
||||||
|
if title:
|
||||||
|
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
|
||||||
|
else:
|
||||||
|
header = pagenum
|
||||||
|
header_line = "=" * len(header)
|
||||||
|
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
|
||||||
|
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
|
||||||
|
return pages
|
||||||
|
|
||||||
|
|
||||||
|
def split_text(text: str, blocksize=2000, code=True, syntax="", maxheight=50) -> list[str]:
|
||||||
|
"""
|
||||||
|
Break the text into blocks of maximum length blocksize
|
||||||
|
If possible, break across nearby newlines. Otherwise just break at blocksize chars
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text: str
|
||||||
|
Text to break into blocks.
|
||||||
|
blocksize: int
|
||||||
|
Maximum character length for each block.
|
||||||
|
code: bool
|
||||||
|
Whether to wrap each block in codeblocks (these are counted in the blocksize).
|
||||||
|
syntax: str
|
||||||
|
The markdown formatting language to use for the codeblocks, if applicable.
|
||||||
|
maxheight: int
|
||||||
|
The maximum number of lines in each block
|
||||||
|
|
||||||
|
Returns: List[str]
|
||||||
|
List of blocks,
|
||||||
|
each containing at most `block_size` characters,
|
||||||
|
of height at most `maxheight`.
|
||||||
|
"""
|
||||||
|
# Adjust blocksize to account for the codeblocks if required
|
||||||
|
blocksize = blocksize - 8 - len(syntax) if code else blocksize
|
||||||
|
|
||||||
|
# Build the blocks
|
||||||
|
blocks = []
|
||||||
|
while True:
|
||||||
|
# If the remaining text is already small enough, append it
|
||||||
|
if len(text) <= blocksize:
|
||||||
|
blocks.append(text)
|
||||||
|
break
|
||||||
|
text = text.strip('\n')
|
||||||
|
|
||||||
|
# Find the last newline in the prototype block
|
||||||
|
split_on = text[0:blocksize].rfind('\n')
|
||||||
|
split_on = blocksize if split_on < blocksize // 5 else split_on
|
||||||
|
|
||||||
|
# Add the block and truncate the text
|
||||||
|
blocks.append(text[0:split_on])
|
||||||
|
text = text[split_on:]
|
||||||
|
|
||||||
|
# Add the codeblock ticks and the code syntax header, if required
|
||||||
|
if code:
|
||||||
|
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def strfdelta(delta: datetime.timedelta, sec=False, minutes=True, short=False) -> str:
|
||||||
|
"""
|
||||||
|
Convert a datetime.timedelta object into an easily readable duration string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
delta: datetime.timedelta
|
||||||
|
The timedelta object to convert into a readable string.
|
||||||
|
sec: bool
|
||||||
|
Whether to include the seconds from the timedelta object in the string.
|
||||||
|
minutes: bool
|
||||||
|
Whether to include the minutes from the timedelta object in the string.
|
||||||
|
short: bool
|
||||||
|
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||||
|
|
||||||
|
Returns: str
|
||||||
|
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||||
|
Time units will be abbreviated if short was set to True.
|
||||||
|
"""
|
||||||
|
output = [[delta.days, 'd' if short else ' day'],
|
||||||
|
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||||
|
if minutes:
|
||||||
|
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||||
|
if sec:
|
||||||
|
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||||
|
for i in range(len(output)):
|
||||||
|
if output[i][0] != 1 and not short:
|
||||||
|
output[i][1] += 's' # type: ignore
|
||||||
|
reply_msg = []
|
||||||
|
if output[0][0] != 0:
|
||||||
|
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||||
|
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||||
|
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||||
|
for i in range(2, len(output) - 1):
|
||||||
|
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||||
|
if not short and reply_msg:
|
||||||
|
reply_msg.append("and ")
|
||||||
|
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||||
|
return "".join(reply_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_dur(time_str: str) -> int:
|
||||||
|
"""
|
||||||
|
Parses a user provided time duration string into a timedelta object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
time_str: str
|
||||||
|
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||||
|
|
||||||
|
Returns: int
|
||||||
|
The number of seconds the duration represents.
|
||||||
|
"""
|
||||||
|
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||||
|
'h': lambda x: x * 60 * 60,
|
||||||
|
'm': lambda x: x * 60,
|
||||||
|
's': lambda x: x}
|
||||||
|
time_str = time_str.strip(" ,")
|
||||||
|
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||||
|
seconds = 0
|
||||||
|
for bit in found:
|
||||||
|
if bit[1] in funcs:
|
||||||
|
seconds += funcs[bit[1]](int(bit[0]))
|
||||||
|
return seconds
|
||||||
|
|
||||||
|
|
||||||
|
def strfdur(duration: int, short=True, show_days=False) -> str:
|
||||||
|
"""
|
||||||
|
Convert a duration given in seconds to a number of hours, minutes, and seconds.
|
||||||
|
"""
|
||||||
|
days = duration // (3600 * 24) if show_days else 0
|
||||||
|
hours = duration // 3600
|
||||||
|
if days:
|
||||||
|
hours %= 24
|
||||||
|
minutes = duration // 60 % 60
|
||||||
|
seconds = duration % 60
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if days:
|
||||||
|
unit = 'd' if short else (' days' if days != 1 else ' day')
|
||||||
|
parts.append('{}{}'.format(days, unit))
|
||||||
|
if hours:
|
||||||
|
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
|
||||||
|
parts.append('{}{}'.format(hours, unit))
|
||||||
|
if minutes:
|
||||||
|
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
|
||||||
|
parts.append('{}{}'.format(minutes, unit))
|
||||||
|
if seconds or duration == 0:
|
||||||
|
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
|
||||||
|
parts.append('{}{}'.format(seconds, unit))
|
||||||
|
|
||||||
|
if short:
|
||||||
|
return ' '.join(parts)
|
||||||
|
else:
|
||||||
|
return ', '.join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def substitute_ranges(ranges_str: str, max_match=20, max_range=1000, separator=',') -> str:
|
||||||
|
"""
|
||||||
|
Substitutes a user provided list of numbers and ranges,
|
||||||
|
and replaces the ranges by the corresponding list of numbers.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ranges_str: str
|
||||||
|
The string to ranges in.
|
||||||
|
max_match: int
|
||||||
|
The maximum number of ranges to replace.
|
||||||
|
Any ranges exceeding this will be ignored.
|
||||||
|
max_range: int
|
||||||
|
The maximum length of range to replace.
|
||||||
|
Attempting to replace a range longer than this will raise a `ValueError`.
|
||||||
|
"""
|
||||||
|
def _repl(match):
|
||||||
|
n1 = int(match.group(1))
|
||||||
|
n2 = int(match.group(2))
|
||||||
|
if n2 - n1 > max_range:
|
||||||
|
# TODO: Upgrade to SafeCancellation
|
||||||
|
raise ValueError("Provided range is too large!")
|
||||||
|
return separator.join(str(i) for i in range(n1, n2 + 1))
|
||||||
|
|
||||||
|
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ranges(ranges_str: str, ignore_errors=False, separator=',', **kwargs) -> list[int]:
|
||||||
|
"""
|
||||||
|
Parses a user provided range string into a list of numbers.
|
||||||
|
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
|
||||||
|
"""
|
||||||
|
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
|
||||||
|
_numbers = (item.strip() for item in substituted.split(','))
|
||||||
|
numbers = [item for item in _numbers if item]
|
||||||
|
integers = [int(item) for item in numbers if item.isdigit()]
|
||||||
|
|
||||||
|
if not ignore_errors and len(integers) != len(numbers):
|
||||||
|
# TODO: Upgrade to SafeCancellation
|
||||||
|
raise ValueError(
|
||||||
|
"Couldn't parse the provided selection!\n"
|
||||||
|
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
|
||||||
|
)
|
||||||
|
|
||||||
|
return integers
|
||||||
|
|
||||||
|
|
||||||
|
def msg_string(msg: discord.Message, mask_link=False, line_break=False, tz=None, clean=True) -> str:
|
||||||
|
"""
|
||||||
|
Format a message into a string with various information, such as:
|
||||||
|
the timestamp of the message, author, message content, and attachments.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg: Message
|
||||||
|
The message to format.
|
||||||
|
mask_link: bool
|
||||||
|
Whether to mask the URLs of any attachments.
|
||||||
|
line_break: bool
|
||||||
|
Whether a line break should be used in the string.
|
||||||
|
tz: Timezone
|
||||||
|
The timezone to use in the formatted message.
|
||||||
|
clean: bool
|
||||||
|
Whether to use the clean content of the original message.
|
||||||
|
|
||||||
|
Returns: str
|
||||||
|
A formatted string containing various information:
|
||||||
|
User timezone, message author, message content, attachments
|
||||||
|
"""
|
||||||
|
timestr = "%I:%M %p, %d/%m/%Y"
|
||||||
|
if tz:
|
||||||
|
time = iso8601.parse_date(msg.created_at.isoformat()).astimezone(tz).strftime(timestr)
|
||||||
|
else:
|
||||||
|
time = msg.created_at.strftime(timestr)
|
||||||
|
user = str(msg.author)
|
||||||
|
attach_list = [attach.proxy_url for attach in msg.attachments if attach.proxy_url]
|
||||||
|
if mask_link:
|
||||||
|
attach_list = ["[Link]({})".format(url) for url in attach_list]
|
||||||
|
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
|
||||||
|
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
|
||||||
|
time=time,
|
||||||
|
user=user,
|
||||||
|
line_break="\n" if line_break else "",
|
||||||
|
message=msg.clean_content if clean else msg.content,
|
||||||
|
attachments=attachments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convdatestring(datestring: str) -> datetime.timedelta:
|
||||||
|
"""
|
||||||
|
Convert a date string into a datetime.timedelta object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
datestring: str
|
||||||
|
The string to convert to a datetime.timedelta object.
|
||||||
|
|
||||||
|
Returns: datetime.timedelta
|
||||||
|
A datetime.timedelta object formed from the string provided.
|
||||||
|
"""
|
||||||
|
datestring = datestring.strip(' ,')
|
||||||
|
datearray = []
|
||||||
|
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||||
|
'h': lambda x: x * 60 * 60,
|
||||||
|
'm': lambda x: x * 60,
|
||||||
|
's': lambda x: x}
|
||||||
|
currentnumber = ''
|
||||||
|
for char in datestring:
|
||||||
|
if char.isdigit():
|
||||||
|
currentnumber += char
|
||||||
|
else:
|
||||||
|
if currentnumber == '':
|
||||||
|
continue
|
||||||
|
datearray.append((int(currentnumber), char))
|
||||||
|
currentnumber = ''
|
||||||
|
seconds = 0
|
||||||
|
if currentnumber:
|
||||||
|
seconds += int(currentnumber)
|
||||||
|
for i in datearray:
|
||||||
|
if i[1] in funcs:
|
||||||
|
seconds += funcs[i[1]](i[0])
|
||||||
|
return datetime.timedelta(seconds=seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class _rawChannel(discord.abc.Messageable):
|
||||||
|
"""
|
||||||
|
Raw messageable class representing an arbitrary channel,
|
||||||
|
not necessarially seen by the gateway.
|
||||||
|
"""
|
||||||
|
def __init__(self, state, id):
|
||||||
|
self._state = state
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
async def _get_channel(self):
|
||||||
|
return discord.Object(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def mail(client: discord.Client, channelid: int, **msg_args) -> discord.Message:
|
||||||
|
"""
|
||||||
|
Mails a message to a channelid which may be invisible to the gateway.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
client: discord.Client
|
||||||
|
The client to use for mailing.
|
||||||
|
Must at least have static authentication and have a valid `_connection`.
|
||||||
|
channelid: int
|
||||||
|
The channel id to mail to.
|
||||||
|
msg_args: Any
|
||||||
|
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
|
||||||
|
"""
|
||||||
|
# Create the raw channel
|
||||||
|
channel = _rawChannel(client._connection, channelid)
|
||||||
|
return await channel.send(**msg_args)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedField(NamedTuple):
|
||||||
|
name: str
|
||||||
|
value: str
|
||||||
|
inline: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
def emb_add_fields(embed: discord.Embed, emb_fields: list[tuple[str, str, bool]]):
|
||||||
|
"""
|
||||||
|
Append embed fields to an embed.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
embed: discord.Embed
|
||||||
|
The embed to add the field to.
|
||||||
|
emb_fields: tuple
|
||||||
|
The values to add to a field.
|
||||||
|
name: str
|
||||||
|
The name of the field.
|
||||||
|
value: str
|
||||||
|
The value of the field.
|
||||||
|
inline: bool
|
||||||
|
Whether the embed field should be inline or not.
|
||||||
|
"""
|
||||||
|
for field in emb_fields:
|
||||||
|
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
|
||||||
|
|
||||||
|
|
||||||
|
def join_list(string: list[str], nfs=False) -> str:
|
||||||
|
"""
|
||||||
|
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
string: list
|
||||||
|
The list to join together.
|
||||||
|
nfs: bool
|
||||||
|
(no fullstops)
|
||||||
|
Whether to exclude fullstops/periods from the output messages.
|
||||||
|
If not provided, fullstops will be appended to the output.
|
||||||
|
"""
|
||||||
|
# TODO: Probably not useful with localisation
|
||||||
|
if len(string) > 1:
|
||||||
|
return "{}{} and {}{}".format((", ").join(string[:-1]),
|
||||||
|
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
|
||||||
|
else:
|
||||||
|
return "{}{}".format("".join(string), "" if nfs else ".")
|
||||||
|
|
||||||
|
|
||||||
|
def shard_of(shard_count: int, guildid: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the shard number of a given guild.
|
||||||
|
"""
|
||||||
|
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
|
def jumpto(guildid: int, channeldid: int, messageid: int) -> str:
|
||||||
|
"""
|
||||||
|
Build a jump link for a message given its location.
|
||||||
|
"""
|
||||||
|
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||||
|
guildid,
|
||||||
|
channeldid,
|
||||||
|
messageid
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def utc_now() -> datetime.datetime:
|
||||||
|
"""
|
||||||
|
Return the current timezone-aware utc timestamp.
|
||||||
|
"""
|
||||||
|
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def multiple_replace(string: str, rep_dict: dict[str, str]) -> str:
|
||||||
|
if rep_dict:
|
||||||
|
pattern = re.compile(
|
||||||
|
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
|
||||||
|
flags=re.DOTALL
|
||||||
|
)
|
||||||
|
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def recover_context(context: Context):
|
||||||
|
for var in context:
|
||||||
|
var.set(context[var])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ids(idstr: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
Parse a provided comma separated string of maybe-mentions, maybe-ids, into a list of integer ids.
|
||||||
|
|
||||||
|
Object agnostic, so all mention tokens are stripped.
|
||||||
|
Raises UserInputError if an id is invalid,
|
||||||
|
setting `orig` and `item` info fields.
|
||||||
|
"""
|
||||||
|
from meta.errors import UserInputError
|
||||||
|
|
||||||
|
# Extract ids from string
|
||||||
|
splititer = (split.strip('<@!#&>, ') for split in idstr.split(','))
|
||||||
|
splits = [split for split in splititer if split]
|
||||||
|
|
||||||
|
# Check they are integers
|
||||||
|
if (not_id := next((split for split in splits if not split.isdigit()), None)) is not None:
|
||||||
|
raise UserInputError("Could not extract an id from `$item`!", {'orig': idstr, 'item': not_id})
|
||||||
|
|
||||||
|
# Cast to integer and return
|
||||||
|
return list(map(int, splits))
|
||||||
|
|
||||||
|
|
||||||
|
def error_embed(error, **kwargs) -> discord.Embed:
|
||||||
|
embed = discord.Embed(
|
||||||
|
colour=discord.Colour.brand_red(),
|
||||||
|
description=error,
|
||||||
|
timestamp=utc_now()
|
||||||
|
)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
class Timezoned:
|
||||||
|
"""
|
||||||
|
ABC mixin for objects with a set timezone.
|
||||||
|
|
||||||
|
Provides several useful localised properties.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def timezone(self) -> pytz.timezone:
|
||||||
|
"""
|
||||||
|
Must be implemented by the deriving class!
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def now(self):
|
||||||
|
"""
|
||||||
|
Return the current time localised to the object's timezone.
|
||||||
|
"""
|
||||||
|
return datetime.datetime.now(tz=self.timezone)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def today(self):
|
||||||
|
"""
|
||||||
|
Return the start of the day localised to the object's timezone.
|
||||||
|
"""
|
||||||
|
now = self.now
|
||||||
|
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def week_start(self):
|
||||||
|
"""
|
||||||
|
Return the start of the week in the object's timezone
|
||||||
|
"""
|
||||||
|
today = self.today
|
||||||
|
return today - datetime.timedelta(days=today.weekday())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def month_start(self):
|
||||||
|
"""
|
||||||
|
Return the start of the current month in the object's timezone
|
||||||
|
"""
|
||||||
|
today = self.today
|
||||||
|
return today.replace(day=1)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_multiple(format_string, mapping):
|
||||||
|
"""
|
||||||
|
Subsistutes the keys from the format_dict with their corresponding values.
|
||||||
|
|
||||||
|
Substitution is non-chained, and done in a single pass via regex.
|
||||||
|
"""
|
||||||
|
if not mapping:
|
||||||
|
raise ValueError("Empty mapping passed.")
|
||||||
|
|
||||||
|
keys = list(mapping.keys())
|
||||||
|
pattern = '|'.join(f"({key})" for key in keys)
|
||||||
|
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def emojikey(emoji: discord.Emoji | discord.PartialEmoji | str):
|
||||||
|
"""
|
||||||
|
Produces a distinguishing key for an Emoji or PartialEmoji.
|
||||||
|
|
||||||
|
Equality checks using this key should act as expected.
|
||||||
|
"""
|
||||||
|
if isinstance(emoji, _EmojiTag):
|
||||||
|
if emoji.id:
|
||||||
|
key = str(emoji.id)
|
||||||
|
else:
|
||||||
|
key = str(emoji.name)
|
||||||
|
else:
|
||||||
|
key = str(emoji)
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
def recurse_map(func, obj, loc=[]):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
for k, v in obj.items():
|
||||||
|
loc.append(k)
|
||||||
|
obj[k] = recurse_map(func, v, loc)
|
||||||
|
loc.pop()
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
for i, item in enumerate(obj):
|
||||||
|
loc.append(i)
|
||||||
|
obj[i] = recurse_map(func, item)
|
||||||
|
loc.pop()
|
||||||
|
else:
|
||||||
|
obj = func(loc, obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def check_dm(user: discord.User | discord.Member) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether we can direct message the given user.
|
||||||
|
|
||||||
|
Assumes the client is initialised.
|
||||||
|
This uses an always-failing HTTP request,
|
||||||
|
so we need to be very very very careful that this is not used frequently.
|
||||||
|
Optimally only at the explicit behest of the user
|
||||||
|
(i.e. during a user instigated interaction).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await user.send('')
|
||||||
|
except discord.Forbidden:
|
||||||
|
return False
|
||||||
|
except discord.HTTPException:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def command_lengths(tree) -> dict[str, int]:
|
||||||
|
cmds = tree.get_commands()
|
||||||
|
payloads = [
|
||||||
|
await cmd.get_translated_payload(tree.translator)
|
||||||
|
for cmd in cmds
|
||||||
|
]
|
||||||
|
lens = {}
|
||||||
|
for command in payloads:
|
||||||
|
name = command['name']
|
||||||
|
crumbs = {}
|
||||||
|
cmd_len = lens[name] = _recurse_length(command, crumbs, (name,))
|
||||||
|
if name == 'configure' or cmd_len > 4000:
|
||||||
|
print(f"'{name}' over 4000. Breadcrumb Trail follows:")
|
||||||
|
lines = []
|
||||||
|
for loc, val in crumbs.items():
|
||||||
|
locstr = '.'.join(loc)
|
||||||
|
lines.append(f"{locstr}: {val}")
|
||||||
|
print('\n'.join(lines))
|
||||||
|
print(json.dumps(command, indent=2))
|
||||||
|
return lens
|
||||||
|
|
||||||
|
def _recurse_length(payload, breadcrumbs={}, header=()) -> int:
|
||||||
|
total = 0
|
||||||
|
total_header = (*header, '')
|
||||||
|
breadcrumbs[total_header] = 0
|
||||||
|
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
# Read strings that count towards command length
|
||||||
|
# String length is length of longest localisation, including default.
|
||||||
|
for key in ('name', 'description', 'value'):
|
||||||
|
if key in payload:
|
||||||
|
value = payload[key]
|
||||||
|
if isinstance(value, str):
|
||||||
|
values = (value, *payload.get(key + '_localizations', {}).values())
|
||||||
|
maxlen = max(map(len, values))
|
||||||
|
total += maxlen
|
||||||
|
breadcrumbs[(*header, key)] = maxlen
|
||||||
|
|
||||||
|
for key, value in payload.items():
|
||||||
|
loc = (*header, key)
|
||||||
|
total += _recurse_length(value, breadcrumbs, loc)
|
||||||
|
elif isinstance(payload, list):
|
||||||
|
for i, item in enumerate(payload):
|
||||||
|
if isinstance(item, dict) and 'name' in item:
|
||||||
|
loc = (*header, f"{i}<{item['name']}>")
|
||||||
|
else:
|
||||||
|
loc = (*header, str(i))
|
||||||
|
total += _recurse_length(item, breadcrumbs, loc)
|
||||||
|
|
||||||
|
if total:
|
||||||
|
breadcrumbs[total_header] = total
|
||||||
|
else:
|
||||||
|
breadcrumbs.pop(total_header)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
def write_records(records: list[dict[str, Any]], stream: StringIO):
|
||||||
|
if records:
|
||||||
|
keys = records[0].keys()
|
||||||
|
stream.write(','.join(keys))
|
||||||
|
stream.write('\n')
|
||||||
|
for record in records:
|
||||||
|
stream.write(','.join(map(str, record.values())))
|
||||||
|
stream.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
parse_dur_exps = [
|
||||||
|
(
|
||||||
|
r"(?P<value>\d+)\s*(?:(d)|(day))",
|
||||||
|
60 * 60 * 24,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"(?P<value>\d+)\s*(?:(h)|(hour))",
|
||||||
|
60 * 60
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"(?P<value>\d+)\s*(?:(m)|(min))",
|
||||||
|
60
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"(?P<value>\d+)\s*(?:(s)|(sec))",
|
||||||
|
1
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_duration(string: str) -> Optional[int]:
|
||||||
|
seconds = 0
|
||||||
|
found = False
|
||||||
|
for expr, multiplier in parse_dur_exps:
|
||||||
|
match = re.search(expr, string, flags=re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
found = True
|
||||||
|
seconds += int(match.group('value')) * multiplier
|
||||||
|
|
||||||
|
return seconds if found else None
|
||||||
191
src/utils/monitor.py
Normal file
191
src/utils/monitor.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
import asyncio
|
||||||
|
import bisect
|
||||||
|
import logging
|
||||||
|
from typing import TypeVar, Generic, Optional, Callable, Coroutine, Any
|
||||||
|
|
||||||
|
from .lib import utc_now
|
||||||
|
from .ratelimits import Bucket
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Taskid = TypeVar('Taskid')
|
||||||
|
|
||||||
|
|
||||||
|
class TaskMonitor(Generic[Taskid]):
|
||||||
|
"""
|
||||||
|
Base class for a task monitor.
|
||||||
|
|
||||||
|
Stores tasks as a time-sorted list of taskids.
|
||||||
|
Subclasses may override `run_task` to implement an executor.
|
||||||
|
|
||||||
|
Adding or removing a single task has O(n) performance.
|
||||||
|
To bulk update tasks, instead use `schedule_tasks`.
|
||||||
|
|
||||||
|
Each taskid must be unique and hashable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, executor=None, bucket: Optional[Bucket] = None):
|
||||||
|
# Ratelimit bucket to enforce maximum execution rate
|
||||||
|
self._bucket = bucket
|
||||||
|
|
||||||
|
self.executor: Optional[Callable[[Taskid], Coroutine[Any, Any, None]]] = executor
|
||||||
|
|
||||||
|
self._wakeup: asyncio.Event = asyncio.Event()
|
||||||
|
self._monitor_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# Task data
|
||||||
|
self._tasklist: list[Taskid] = []
|
||||||
|
self._taskmap: dict[Taskid, int] = {} # taskid -> timestamp
|
||||||
|
|
||||||
|
# Running map ensures we keep a reference to the running task
|
||||||
|
# And allows simpler external cancellation if required
|
||||||
|
self._running: dict[Taskid, asyncio.Future] = {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
"<"
|
||||||
|
f"{self.__class__.__name__}"
|
||||||
|
f" tasklist={len(self._tasklist)}"
|
||||||
|
f" taskmap={len(self._taskmap)}"
|
||||||
|
f" wakeup={self._wakeup.is_set()}"
|
||||||
|
f" bucket={self._bucket}"
|
||||||
|
f" running={len(self._running)}"
|
||||||
|
f" task={self._monitor_task}"
|
||||||
|
f">"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||||
|
"""
|
||||||
|
Similar to `schedule_tasks`, but wipe and reset the tasklist.
|
||||||
|
"""
|
||||||
|
self._taskmap = {tid: time for tid, time in tasks}
|
||||||
|
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
|
||||||
|
self._wakeup.set()
|
||||||
|
|
||||||
|
def schedule_tasks(self, *tasks: tuple[Taskid, int]) -> None:
|
||||||
|
"""
|
||||||
|
Schedule the given tasks.
|
||||||
|
|
||||||
|
Rather than repeatedly inserting tasks,
|
||||||
|
where the O(log n) insort is dominated by the O(n) list insertion,
|
||||||
|
we build an entirely new list, and always wake up the loop.
|
||||||
|
"""
|
||||||
|
self._taskmap |= {tid: time for tid, time in tasks}
|
||||||
|
self._tasklist = list(sorted(self._taskmap.keys(), key=lambda tid: -1 * self._taskmap[tid]))
|
||||||
|
self._wakeup.set()
|
||||||
|
|
||||||
|
def schedule_task(self, taskid: Taskid, timestamp: int) -> None:
|
||||||
|
"""
|
||||||
|
Insert the provided task into the tasklist.
|
||||||
|
If the new task has a lower timestamp than the next task, wakes up the monitor loop.
|
||||||
|
"""
|
||||||
|
if self._tasklist:
|
||||||
|
nextid = self._tasklist[-1]
|
||||||
|
wake = self._taskmap[nextid] >= timestamp
|
||||||
|
wake = wake or taskid == nextid
|
||||||
|
else:
|
||||||
|
wake = True
|
||||||
|
if taskid in self._taskmap:
|
||||||
|
self._tasklist.remove(taskid)
|
||||||
|
self._taskmap[taskid] = timestamp
|
||||||
|
bisect.insort_left(self._tasklist, taskid, key=lambda t: -1 * self._taskmap[t])
|
||||||
|
if wake:
|
||||||
|
self._wakeup.set()
|
||||||
|
|
||||||
|
def cancel_tasks(self, *taskids: Taskid) -> None:
|
||||||
|
"""
|
||||||
|
Remove all tasks with the given taskids from the tasklist.
|
||||||
|
If the next task has this taskid, wake up the monitor loop.
|
||||||
|
"""
|
||||||
|
taskids = set(taskids)
|
||||||
|
wake = (self._tasklist and self._tasklist[-1] in taskids)
|
||||||
|
self._tasklist = [tid for tid in self._tasklist if tid not in taskids]
|
||||||
|
for tid in taskids:
|
||||||
|
self._taskmap.pop(tid, None)
|
||||||
|
if wake:
|
||||||
|
self._wakeup.set()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self._monitor_task and not self._monitor_task.done():
|
||||||
|
self._monitor_task.cancel()
|
||||||
|
# Start the monitor
|
||||||
|
self._monitor_task = asyncio.create_task(self.monitor())
|
||||||
|
return self._monitor_task
|
||||||
|
|
||||||
|
async def monitor(self):
|
||||||
|
"""
|
||||||
|
Start the monitor.
|
||||||
|
Executes the tasks in `self.tasks` at the specified time.
|
||||||
|
|
||||||
|
This will shield task execution from cancellation
|
||||||
|
to avoid partial states.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self._wakeup.clear()
|
||||||
|
if not self._tasklist:
|
||||||
|
# No tasks left, just sleep until wakeup
|
||||||
|
await self._wakeup.wait()
|
||||||
|
else:
|
||||||
|
# Get the next task, sleep until wakeup or it is ready to run
|
||||||
|
nextid = self._tasklist[-1]
|
||||||
|
nexttime = self._taskmap[nextid]
|
||||||
|
sleep_for = nexttime - utc_now().timestamp()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._wakeup.wait(), timeout=sleep_for)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Ready to run the task
|
||||||
|
self._tasklist.pop()
|
||||||
|
self._taskmap.pop(nextid, None)
|
||||||
|
self._running[nextid] = asyncio.ensure_future(self._run(nextid))
|
||||||
|
else:
|
||||||
|
# Wakeup task fired, loop again
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Log closure and wait for remaining tasks
|
||||||
|
# A second cancellation will also cancel the tasks
|
||||||
|
logger.debug(
|
||||||
|
f"Task Monitor {self.__class__.__name__} cancelled with {len(self._tasklist)} tasks remaining. "
|
||||||
|
f"Waiting for {len(self._running)} running tasks to complete."
|
||||||
|
)
|
||||||
|
await asyncio.gather(*self._running.values(), return_exceptions=True)
|
||||||
|
|
||||||
|
async def _run(self, taskid: Taskid) -> None:
|
||||||
|
# Execute the task, respecting the ratelimit bucket
|
||||||
|
if self._bucket is not None:
|
||||||
|
# IMPLEMENTATION NOTE:
|
||||||
|
# Bucket.wait() should guarantee not more than n tasks/second are run
|
||||||
|
# and that a request directly afterwards will _not_ raise BucketFull
|
||||||
|
# Make sure that only one waiter is actually waiting on its sleep task
|
||||||
|
# The other waiters should be waiting on a lock around the sleep task
|
||||||
|
# Waiters are executed in wait-order, so if we only let a single waiter in
|
||||||
|
# we shouldn't get collisions.
|
||||||
|
# Furthermore, make sure we do _not_ pass back to the event loop after waiting
|
||||||
|
# Or we will lose thread-safety for BucketFull
|
||||||
|
await self._bucket.wait()
|
||||||
|
fut = asyncio.create_task(self.run_task(taskid))
|
||||||
|
try:
|
||||||
|
await asyncio.shield(fut)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# Protect the monitor loop from any other exceptions
|
||||||
|
logger.exception(
|
||||||
|
f"Ignoring exception in task monitor {self.__class__.__name__} while "
|
||||||
|
f"executing <taskid: {taskid}>"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._running.pop(taskid)
|
||||||
|
|
||||||
|
async def run_task(self, taskid: Taskid):
|
||||||
|
"""
|
||||||
|
Execute the task with the given taskid.
|
||||||
|
|
||||||
|
Default implementation executes `self.executor` if it exists,
|
||||||
|
otherwise raises NotImplementedError.
|
||||||
|
"""
|
||||||
|
if self.executor is not None:
|
||||||
|
await self.executor(taskid)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
173
src/utils/ratelimits.py
Normal file
173
src/utils/ratelimits.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFull(Exception):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is already full
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BucketOverFull(BucketFull):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is overfull
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||||
|
|
||||||
|
def __init__(self, max_level, empty_time):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
self.leak_rate = max_level / empty_time
|
||||||
|
|
||||||
|
self._level = 0
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
self._last_full = False
|
||||||
|
self._wait_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether the bucket is 'full',
|
||||||
|
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
self._leak()
|
||||||
|
return self._level + 1 > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overfull(self):
|
||||||
|
self._leak()
|
||||||
|
return self._level > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def delay(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level + 1 > self.max_level:
|
||||||
|
delay = (self._level + 1 - self.max_level) * self.leak_rate
|
||||||
|
else:
|
||||||
|
delay = 0
|
||||||
|
return delay
|
||||||
|
|
||||||
|
def _leak(self):
|
||||||
|
if self._level:
|
||||||
|
elapsed = time.monotonic() - self._last_checked
|
||||||
|
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||||
|
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
def request(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level > self.max_level:
|
||||||
|
raise BucketOverFull
|
||||||
|
elif self._level == self.max_level:
|
||||||
|
self._level += 1
|
||||||
|
if self._last_full:
|
||||||
|
raise BucketOverFull
|
||||||
|
else:
|
||||||
|
self._last_full = True
|
||||||
|
raise BucketFull
|
||||||
|
else:
|
||||||
|
self._last_full = False
|
||||||
|
self._level += 1
|
||||||
|
|
||||||
|
def fill(self):
|
||||||
|
self._leak()
|
||||||
|
self._level = max(self._level, self.max_level + 1)
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""
|
||||||
|
Wait until the bucket has room.
|
||||||
|
|
||||||
|
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||||
|
# Otherwise multiple waiters will have the same delay,
|
||||||
|
# and race for the wakeup after sleep.
|
||||||
|
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||||
|
async with self._wait_lock:
|
||||||
|
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||||
|
# or a synchronous request overflows the bucket while we are waiting.
|
||||||
|
while self.full:
|
||||||
|
await asyncio.sleep(self.delay)
|
||||||
|
|
||||||
|
async def wrapped(self, coro):
|
||||||
|
await self.wait()
|
||||||
|
self.request()
|
||||||
|
await coro
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
|
||||||
|
self.error = error or "Too many requests, please slow down!"
|
||||||
|
self.buckets = cache
|
||||||
|
|
||||||
|
def request_for(self, key):
|
||||||
|
if not (bucket := self.buckets.get(key, None)):
|
||||||
|
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bucket.request()
|
||||||
|
except BucketOverFull:
|
||||||
|
raise SafeCancellation(details="Bucket overflow")
|
||||||
|
except BucketFull:
|
||||||
|
raise SafeCancellation(self.error, details="Bucket full")
|
||||||
|
|
||||||
|
def ward(self, member=True, key=None):
|
||||||
|
"""
|
||||||
|
Command ratelimit decorator.
|
||||||
|
"""
|
||||||
|
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(ctx, *args, **kwargs):
|
||||||
|
self.request_for(key(ctx))
|
||||||
|
return await func(ctx, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def limit_concurrency(aws, limit):
|
||||||
|
"""
|
||||||
|
Run provided awaitables concurrently,
|
||||||
|
ensuring that no more than `limit` are running at once.
|
||||||
|
"""
|
||||||
|
aws = iter(aws)
|
||||||
|
aws_ended = False
|
||||||
|
pending = set()
|
||||||
|
count = 0
|
||||||
|
logger.debug("Starting limited concurrency executor")
|
||||||
|
|
||||||
|
while pending or not aws_ended:
|
||||||
|
while len(pending) < limit and not aws_ended:
|
||||||
|
aw = next(aws, None)
|
||||||
|
if aw is None:
|
||||||
|
aws_ended = True
|
||||||
|
else:
|
||||||
|
pending.add(asyncio.create_task(aw))
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
break
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
pending, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
while done:
|
||||||
|
yield done.pop()
|
||||||
|
logger.debug(f"Completed {count} tasks")
|
||||||
12
src/utils/ui/__init__.py
Normal file
12
src/utils/ui/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from babel.translator import LocalBabel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
util_babel = LocalBabel('utils')
|
||||||
|
|
||||||
|
from .hooked import *
|
||||||
|
from .leo import *
|
||||||
|
from .micros import *
|
||||||
|
from .msgeditor import *
|
||||||
59
src/utils/ui/hooked.py
Normal file
59
src/utils/ui/hooked.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ui.item import Item
|
||||||
|
from discord.ui.button import Button
|
||||||
|
|
||||||
|
from .leo import LeoUI
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'HookedItem',
|
||||||
|
'AButton',
|
||||||
|
'AsComponents'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HookedItem:
|
||||||
|
"""
|
||||||
|
Mixin for Item classes allowing an instance to be used as a callback decorator.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, pass_kwargs={}, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.pass_kwargs = pass_kwargs
|
||||||
|
|
||||||
|
def __call__(self, coro):
|
||||||
|
async def wrapped(interaction, **kwargs):
|
||||||
|
return await coro(interaction, self, **(self.pass_kwargs | kwargs))
|
||||||
|
self.callback = wrapped
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class AButton(HookedItem, Button):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AsComponents(LeoUI):
|
||||||
|
"""
|
||||||
|
Simple container class to accept a number of Items and turn them into an attachable View.
|
||||||
|
"""
|
||||||
|
def __init__(self, *items, pass_kwargs={}, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.pass_kwargs = pass_kwargs
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
self.add_item(item)
|
||||||
|
|
||||||
|
async def _scheduled_task(self, item: Item, interaction: discord.Interaction):
|
||||||
|
try:
|
||||||
|
item._refresh_state(interaction, interaction.data) # type: ignore
|
||||||
|
|
||||||
|
allow = await self.interaction_check(interaction)
|
||||||
|
if not allow:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.timeout:
|
||||||
|
self.__timeout_expiry = time.monotonic() + self.timeout
|
||||||
|
|
||||||
|
await item.callback(interaction, **self.pass_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return await self.on_error(interaction, e, item)
|
||||||
485
src/utils/ui/leo.py
Normal file
485
src/utils/ui/leo.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
from typing import List, Optional, Any, Dict
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextvars import copy_context, Context
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ui import Modal, View, Item
|
||||||
|
|
||||||
|
from meta.logger import log_action_stack, logging_context
|
||||||
|
from meta.errors import SafeCancellation
|
||||||
|
|
||||||
|
from . import logger
|
||||||
|
from ..lib import MessageArgs, error_embed
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'LeoUI',
|
||||||
|
'MessageUI',
|
||||||
|
'LeoModal',
|
||||||
|
'error_handler_for'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LeoUI(View):
|
||||||
|
"""
|
||||||
|
View subclass for small-scale user interfaces.
|
||||||
|
|
||||||
|
While a 'View' provides an interface for managing a collection of components,
|
||||||
|
a `LeoUI` may also manage a message, and potentially slave Views or UIs.
|
||||||
|
The `LeoUI` also exposes more advanced cleanup and timeout methods,
|
||||||
|
and preserves the context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, ui_name=None, context=None, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
self._context = copy_context()
|
||||||
|
else:
|
||||||
|
self._context = context
|
||||||
|
|
||||||
|
self._name = ui_name or self.__class__.__name__
|
||||||
|
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self._name])
|
||||||
|
|
||||||
|
# List of slaved views to stop when this view stops
|
||||||
|
self._slaves: List[View] = []
|
||||||
|
|
||||||
|
# TODO: Replace this with a substitutable ViewLayout class
|
||||||
|
self._layout: Optional[tuple[tuple[Item, ...], ...]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _stopped(self) -> asyncio.Future:
|
||||||
|
"""
|
||||||
|
Return an future indicating whether the View has finished interacting.
|
||||||
|
|
||||||
|
Currently exposes a hidden attribute of the underlying View.
|
||||||
|
May be reimplemented in future.
|
||||||
|
"""
|
||||||
|
return self._View__stopped
|
||||||
|
|
||||||
|
def to_components(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extending component generator to apply the set _layout, if it exists.
|
||||||
|
"""
|
||||||
|
if self._layout is not None:
|
||||||
|
# Alternative rendering using layout
|
||||||
|
components = []
|
||||||
|
for i, row in enumerate(self._layout):
|
||||||
|
# Skip empty rows
|
||||||
|
if not row:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Since we aren't relying on ViewWeights, manually check width here
|
||||||
|
if sum(item.width for item in row) > 5:
|
||||||
|
raise ValueError(f"Row {i} of custom {self.__class__.__name__} is too wide!")
|
||||||
|
|
||||||
|
# Create the component dict for this row
|
||||||
|
components.append({
|
||||||
|
'type': 1,
|
||||||
|
'components': [item.to_component_dict() for item in row]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
components = super().to_components()
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
def set_layout(self, *rows: tuple[Item, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Set the layout of the rendered View as a matrix of items,
|
||||||
|
or more precisely, a list of action rows.
|
||||||
|
|
||||||
|
This acts independently of the existing sorting with `_ViewWeights`,
|
||||||
|
and overrides the sorting if applied.
|
||||||
|
"""
|
||||||
|
self._layout = rows
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""
|
||||||
|
Coroutine to run when timeing out, stopping, or cancelling.
|
||||||
|
Generally cleans up any open resources, and removes any leftover components.
|
||||||
|
"""
|
||||||
|
logging.debug(f"{self!r} running default cleanup.", extra={'action': 'cleanup'})
|
||||||
|
return None
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
Extends View.stop() to also stop all the slave views.
|
||||||
|
Note that stopping is idempotent, so it is okay if close() also calls stop().
|
||||||
|
"""
|
||||||
|
for slave in self._slaves:
|
||||||
|
slave.stop()
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
async def close(self, msg=None):
|
||||||
|
self.stop()
|
||||||
|
await self.cleanup()
|
||||||
|
|
||||||
|
async def pre_timeout(self):
|
||||||
|
"""
|
||||||
|
Task to execute before actually timing out.
|
||||||
|
This may cancel the timeout by refreshing or rescheduling it.
|
||||||
|
(E.g. to ask the user whether they want to keep going.)
|
||||||
|
|
||||||
|
Default implementation does nothing.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
"""
|
||||||
|
Task to execute after timeout is complete.
|
||||||
|
Default implementation calls cleanup.
|
||||||
|
"""
|
||||||
|
await self.cleanup()
|
||||||
|
|
||||||
|
async def __dispatch_timeout(self):
|
||||||
|
"""
|
||||||
|
This essentially extends View._dispatch_timeout,
|
||||||
|
to include a pre_timeout task
|
||||||
|
which may optionally refresh and hence cancel the timeout.
|
||||||
|
"""
|
||||||
|
if self._View__stopped.done():
|
||||||
|
# We are already stopped, nothing to do
|
||||||
|
return
|
||||||
|
|
||||||
|
with logging_context(action='Timeout'):
|
||||||
|
try:
|
||||||
|
await self.pre_timeout()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Unhandled error caught while dispatching timeout for {self!r}.",
|
||||||
|
extra={'with_ctx': True, 'action': 'Error'}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we still need to timeout
|
||||||
|
if self.timeout is None:
|
||||||
|
# The timeout was removed entirely, silently walk away
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._View__stopped.done():
|
||||||
|
# We stopped while waiting for the pre timeout.
|
||||||
|
# Or maybe another thread timed us out
|
||||||
|
# Either way, we are done here
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if self._View__timeout_expiry is not None and now < self._View__timeout_expiry:
|
||||||
|
# The timeout was extended, make sure the timeout task is running then fade away
|
||||||
|
if self._View__timeout_task is None or self._View__timeout_task.done():
|
||||||
|
self._View__timeout_task = asyncio.create_task(self._View__timeout_task_impl())
|
||||||
|
else:
|
||||||
|
# Actually timeout, and call the post-timeout task for cleanup.
|
||||||
|
self._really_timeout()
|
||||||
|
await self.on_timeout()
|
||||||
|
|
||||||
|
def _dispatch_timeout(self):
|
||||||
|
"""
|
||||||
|
Overriding timeout method completely, to support interactive flow during timeout,
|
||||||
|
and optional refreshing of the timeout.
|
||||||
|
"""
|
||||||
|
return self._context.run(asyncio.create_task, self.__dispatch_timeout())
|
||||||
|
|
||||||
|
def _really_timeout(self):
|
||||||
|
"""
|
||||||
|
Actuallly times out the View.
|
||||||
|
This copies View._dispatch_timeout, apart from the `on_timeout` dispatch,
|
||||||
|
which is now handled by `__dispatch_timeout`.
|
||||||
|
"""
|
||||||
|
if self._View__stopped.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._View__cancel_callback:
|
||||||
|
self._View__cancel_callback(self)
|
||||||
|
self._View__cancel_callback = None
|
||||||
|
|
||||||
|
self._View__stopped.set_result(True)
|
||||||
|
|
||||||
|
def _dispatch_item(self, *args, **kwargs):
|
||||||
|
"""Extending event dispatch to run in the instantiation context."""
|
||||||
|
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception, item: Item):
|
||||||
|
"""
|
||||||
|
Default LeoUI error handle.
|
||||||
|
This may be tail extended by subclasses to preserve the exception stack.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raise error
|
||||||
|
except SafeCancellation as e:
|
||||||
|
if e.msg and not interaction.is_expired():
|
||||||
|
try:
|
||||||
|
if interaction.response.is_done():
|
||||||
|
await interaction.followup.send(
|
||||||
|
embed=error_embed(e.msg),
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
embed=error_embed(e.msg),
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
logger.debug(
|
||||||
|
f"Caught a safe cancellation from LeoUI: {e.details}",
|
||||||
|
extra={'action': 'Cancel'}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unhandled interaction exception occurred in item {item!r} of LeoUI {self!r} from interaction: "
|
||||||
|
f"{interaction.data}",
|
||||||
|
extra={'with_ctx': True, 'action': 'UIError'}
|
||||||
|
)
|
||||||
|
# Explicitly handle the bugsplat ourselves
|
||||||
|
splat = interaction.client.tree.bugsplat(interaction, error)
|
||||||
|
await interaction.client.tree.error_reply(interaction, splat)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageUI(LeoUI):
|
||||||
|
"""
|
||||||
|
Simple single-message LeoUI, intended as a framework for UIs
|
||||||
|
attached to a single interaction response.
|
||||||
|
|
||||||
|
UIs may also be sent as regular messages by using `send(channel)` instead of `run(interaction)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, callerid: Optional[int] = None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# ----- UI state -----
|
||||||
|
# User ID of the original caller (e.g. command author).
|
||||||
|
# Mainly used for interaction usage checks and logging
|
||||||
|
self._callerid = callerid
|
||||||
|
|
||||||
|
# Original interaction, if this UI is sent as an interaction response
|
||||||
|
self._original: discord.Interaction = None
|
||||||
|
|
||||||
|
# Message holding the UI, when the UI is sent attached to a followup
|
||||||
|
self._message: discord.Message = None
|
||||||
|
|
||||||
|
# Refresh lock, to avoid cache collisions on refresh
|
||||||
|
self._refresh_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channel(self):
|
||||||
|
if self._original is not None:
|
||||||
|
return self._original.channel
|
||||||
|
else:
|
||||||
|
return self._message.channel
|
||||||
|
|
||||||
|
# ----- UI API -----
|
||||||
|
async def run(self, interaction: discord.Interaction, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the UI as a response or followup to the given interaction.
|
||||||
|
|
||||||
|
Should be extended if more complex run mechanics are needed
|
||||||
|
(e.g. registering listeners or setting up caches).
|
||||||
|
"""
|
||||||
|
await self.draw(interaction, **kwargs)
|
||||||
|
|
||||||
|
async def refresh(self, *args, thinking: Optional[discord.Interaction] = None, **kwargs):
|
||||||
|
"""
|
||||||
|
Reload and redraw this UI.
|
||||||
|
|
||||||
|
Primarily a hook-method for use by parents and other controllers.
|
||||||
|
Performs a full data and reload and refresh (maintaining UI state, e.g. page n).
|
||||||
|
"""
|
||||||
|
async with self._refresh_lock:
|
||||||
|
# Reload data
|
||||||
|
await self.reload()
|
||||||
|
# Redraw UI message
|
||||||
|
await self.redraw(thinking=thinking)
|
||||||
|
|
||||||
|
async def quit(self):
|
||||||
|
"""
|
||||||
|
Quit the UI.
|
||||||
|
|
||||||
|
This usually involves removing the original message,
|
||||||
|
and stopping or closing the underlying View.
|
||||||
|
"""
|
||||||
|
for child in self._slaves:
|
||||||
|
# TODO: Better to use duck typing or interface typing
|
||||||
|
if isinstance(child, MessageUI) and not child.is_finished():
|
||||||
|
asyncio.create_task(child.quit())
|
||||||
|
try:
|
||||||
|
if self._original is not None and not self._original.is_expired():
|
||||||
|
await self._original.delete_original_response()
|
||||||
|
self._original = None
|
||||||
|
if self._message is not None:
|
||||||
|
await self._message.delete()
|
||||||
|
self._message = None
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Note close() also runs cleanup and stop
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
# ----- UI Flow -----
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction):
|
||||||
|
"""
|
||||||
|
Check the given interaction is authorised to use this UI.
|
||||||
|
|
||||||
|
Default implementation simply checks that the interaction is
|
||||||
|
from the original caller.
|
||||||
|
Extend for more complex logic.
|
||||||
|
"""
|
||||||
|
return interaction.user.id == self._callerid
|
||||||
|
|
||||||
|
async def make_message(self) -> MessageArgs:
|
||||||
|
"""
|
||||||
|
Create the UI message body, depening on the current state.
|
||||||
|
|
||||||
|
Called upon each redraw.
|
||||||
|
Should handle caching if message construction is for some reason intensive.
|
||||||
|
|
||||||
|
Must be implemented by concrete UI subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def refresh_layout(self):
|
||||||
|
"""
|
||||||
|
Asynchronously refresh the message components,
|
||||||
|
and explicitly set the message component layout.
|
||||||
|
|
||||||
|
Called just before redrawing, before `make_message`.
|
||||||
|
|
||||||
|
Must be implemented by concrete UI subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reload(self):
|
||||||
|
"""
|
||||||
|
Reload and recompute the underlying data for this UI.
|
||||||
|
|
||||||
|
Must be implemented by concrete UI subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def draw(self, interaction, force_followup=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Send the UI as a response or followup to the given interaction.
|
||||||
|
|
||||||
|
If the interaction has been responded to, or `force_followup` is set,
|
||||||
|
creates a followup message instead of a response to the interaction.
|
||||||
|
"""
|
||||||
|
# Initial data loading
|
||||||
|
await self.reload()
|
||||||
|
# Set the UI layout
|
||||||
|
await self.refresh_layout()
|
||||||
|
# Fetch message arguments
|
||||||
|
args = await self.make_message()
|
||||||
|
|
||||||
|
as_followup = force_followup or interaction.response.is_done()
|
||||||
|
if as_followup:
|
||||||
|
self._message = await interaction.followup.send(**args.send_args, **kwargs, view=self)
|
||||||
|
else:
|
||||||
|
self._original = interaction
|
||||||
|
await interaction.response.send_message(**args.send_args, **kwargs, view=self)
|
||||||
|
|
||||||
|
async def send(self, channel: discord.abc.Messageable, **kwargs):
|
||||||
|
"""
|
||||||
|
Alternative to draw() which uses a discord.abc.Messageable.
|
||||||
|
"""
|
||||||
|
await self.reload()
|
||||||
|
await self.refresh_layout()
|
||||||
|
args = await self.make_message()
|
||||||
|
self._message = await channel.send(**args.send_args, view=self)
|
||||||
|
|
||||||
|
async def _redraw(self, args):
|
||||||
|
if self._original and not self._original.is_expired():
|
||||||
|
await self._original.edit_original_response(**args.edit_args, view=self)
|
||||||
|
elif self._message:
|
||||||
|
await self._message.edit(**args.edit_args, view=self)
|
||||||
|
else:
|
||||||
|
# Interaction expired or already closed. Quietly cleanup.
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def redraw(self, thinking: Optional[discord.Interaction] = None):
|
||||||
|
"""
|
||||||
|
Update the output message for this UI.
|
||||||
|
|
||||||
|
If a thinking interaction is provided, deletes the response while redrawing.
|
||||||
|
"""
|
||||||
|
await self.refresh_layout()
|
||||||
|
args = await self.make_message()
|
||||||
|
|
||||||
|
if thinking is not None and not thinking.is_expired() and thinking.response.is_done():
|
||||||
|
asyncio.create_task(thinking.delete_original_response())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redraw(args)
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
# Unknown communication error, nothing we can reliably do. Exit quietly.
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected UI redraw failure occurred in {self}: {repr(e)}",
|
||||||
|
)
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""
|
||||||
|
Remove message components from interaction response, if possible.
|
||||||
|
|
||||||
|
Extend to remove listeners or clean up caches.
|
||||||
|
`cleanup` is always called when the UI is exiting,
|
||||||
|
through timeout or user-driven closure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._original is not None and not self._original.is_expired():
|
||||||
|
await self._original.edit_original_response(view=None)
|
||||||
|
self._original = None
|
||||||
|
if self._message is not None:
|
||||||
|
await self._message.edit(view=None)
|
||||||
|
self._message = None
|
||||||
|
except discord.HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LeoModal(Modal):
|
||||||
|
"""
|
||||||
|
Context-aware Modal class.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, context: Optional[Context] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if context is None:
|
||||||
|
self._context = copy_context()
|
||||||
|
else:
|
||||||
|
self._context = context
|
||||||
|
self._context.run(log_action_stack.set, [*self._context[log_action_stack], self.__class__.__name__])
|
||||||
|
|
||||||
|
def _dispatch_submit(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Extending event dispatch to run in the instantiation context.
|
||||||
|
"""
|
||||||
|
return self._context.run(super()._dispatch_submit, *args, **kwargs)
|
||||||
|
|
||||||
|
def _dispatch_item(self, *args, **kwargs):
|
||||||
|
"""Extending event dispatch to run in the instantiation context."""
|
||||||
|
return self._context.run(super()._dispatch_item, *args, **kwargs)
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||||
|
"""
|
||||||
|
Default LeoModal error handle.
|
||||||
|
This may be tail extended by subclasses to preserve the exception stack.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raise error
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"Unhandled interaction exception occurred in {self!r}. Interaction: {interaction.data}",
|
||||||
|
extra={'with_ctx': True, 'action': 'ModalError'}
|
||||||
|
)
|
||||||
|
# Explicitly handle the bugsplat ourselves
|
||||||
|
splat = interaction.client.tree.bugsplat(interaction, error)
|
||||||
|
await interaction.client.tree.error_reply(interaction, splat)
|
||||||
|
|
||||||
|
|
||||||
|
def error_handler_for(exc):
|
||||||
|
def wrapper(coro):
|
||||||
|
coro._ui_error_handler_for_ = exc
|
||||||
|
return coro
|
||||||
|
return wrapper
|
||||||
329
src/utils/ui/micros.py
Normal file
329
src/utils/ui/micros.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
from typing import List, Coroutine, Optional, Any, Type, TypeVar, Callable, Dict
|
||||||
|
import functools
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ui import TextInput
|
||||||
|
from discord.ui.button import button
|
||||||
|
|
||||||
|
from meta.logger import logging_context
|
||||||
|
from meta.errors import ResponseTimedOut
|
||||||
|
|
||||||
|
from .leo import LeoModal, LeoUI
|
||||||
|
|
||||||
|
__all__ = (
|
||||||
|
'FastModal',
|
||||||
|
'ModalRetryUI',
|
||||||
|
'Confirm',
|
||||||
|
'input',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FastModal(LeoModal):
|
||||||
|
__class_error_handlers__ = []
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
error_handlers = {}
|
||||||
|
for base in reversed(cls.__mro__):
|
||||||
|
for name, member in base.__dict__.items():
|
||||||
|
if hasattr(member, '_ui_error_handler_for_'):
|
||||||
|
error_handlers[name] = member
|
||||||
|
|
||||||
|
cls.__class_error_handlers__ = list(error_handlers.values())
|
||||||
|
|
||||||
|
def __init__error_handlers__(self):
|
||||||
|
handlers = {}
|
||||||
|
for handler in self.__class_error_handlers__:
|
||||||
|
handlers[handler._ui_error_handler_for_] = functools.partial(handler, self)
|
||||||
|
return handlers
|
||||||
|
|
||||||
|
def __init__(self, *items: TextInput, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
for item in items:
|
||||||
|
self.add_item(item)
|
||||||
|
self._result: asyncio.Future[discord.Interaction] = asyncio.get_event_loop().create_future()
|
||||||
|
self._waiters: List[Callable[[discord.Interaction], Coroutine]] = []
|
||||||
|
self._error_handlers = self.__init__error_handlers__()
|
||||||
|
|
||||||
|
def error_handler(self, exception):
|
||||||
|
def wrapper(coro):
|
||||||
|
self._error_handlers[exception] = coro
|
||||||
|
return coro
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
async def wait_for(self, check=None, timeout=None):
|
||||||
|
# Wait for _result or timeout
|
||||||
|
# If we timeout, or the view times out, raise TimeoutError
|
||||||
|
# Otherwise, return the Interaction
|
||||||
|
# This allows multiple listeners and callbacks to wait on
|
||||||
|
while True:
|
||||||
|
result = await asyncio.wait_for(asyncio.shield(self._result), timeout=timeout)
|
||||||
|
if check is not None:
|
||||||
|
if not check(result):
|
||||||
|
continue
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
self._result.set_exception(asyncio.TimeoutError)
|
||||||
|
|
||||||
|
def submit_callback(self, timeout=None, check=None, once=False, pass_args=(), pass_kwargs={}):
|
||||||
|
def wrapper(coro):
|
||||||
|
async def wrapped_callback(interaction):
|
||||||
|
with logging_context(action=coro.__name__):
|
||||||
|
if check is not None:
|
||||||
|
if not check(interaction):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await coro(interaction, *pass_args, **pass_kwargs)
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if once:
|
||||||
|
self._waiters.remove(wrapped_callback)
|
||||||
|
self._waiters.append(wrapped_callback)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception, *args):
|
||||||
|
try:
|
||||||
|
# First let our error handlers have a go
|
||||||
|
# If there is no handler for this error, or the handlers themselves error,
|
||||||
|
# drop to the superclass error handler implementation.
|
||||||
|
try:
|
||||||
|
raise error
|
||||||
|
except tuple(self._error_handlers.keys()) as e:
|
||||||
|
# If an error handler is registered for this exception, run it.
|
||||||
|
for cls, handler in self._error_handlers.items():
|
||||||
|
if isinstance(e, cls):
|
||||||
|
await handler(interaction, e)
|
||||||
|
except Exception as error:
|
||||||
|
await super().on_error(interaction, error)
|
||||||
|
|
||||||
|
async def on_submit(self, interaction):
|
||||||
|
print("On submit")
|
||||||
|
old_result = self._result
|
||||||
|
self._result = asyncio.get_event_loop().create_future()
|
||||||
|
old_result.set_result(interaction)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for waiter in self._waiters:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
waiter(interaction),
|
||||||
|
name=f"leo-ui-fastmodal-{self.id}-callback-{waiter.__name__}"
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
async def input(
|
||||||
|
interaction: discord.Interaction,
|
||||||
|
title: str,
|
||||||
|
question: Optional[str] = None,
|
||||||
|
field: Optional[TextInput] = None,
|
||||||
|
timeout=180,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[discord.Interaction, str]:
|
||||||
|
"""
|
||||||
|
Spawn a modal to accept input.
|
||||||
|
Returns an (interaction, value) pair, with interaction not yet responded to.
|
||||||
|
May raise asyncio.TimeoutError if the view times out.
|
||||||
|
"""
|
||||||
|
if field is None:
|
||||||
|
field = TextInput(
|
||||||
|
label=kwargs.get('label', question),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
modal = FastModal(
|
||||||
|
field,
|
||||||
|
title=title,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
interaction = await modal.wait_for()
|
||||||
|
return (interaction, field.value)
|
||||||
|
|
||||||
|
|
||||||
|
class ModalRetryUI(LeoUI):
|
||||||
|
def __init__(self, modal: FastModal, message, label: Optional[str] = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.modal = modal
|
||||||
|
self.item_values = {item: item.value for item in modal.children if isinstance(item, TextInput)}
|
||||||
|
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
self._interaction = None
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
self.retry_button.label = label
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self):
|
||||||
|
return discord.Embed(
|
||||||
|
title="Uh-Oh!",
|
||||||
|
description=self.message,
|
||||||
|
colour=discord.Colour.red()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def respond_to(self, interaction):
|
||||||
|
self._interaction = interaction
|
||||||
|
if interaction.response.is_done():
|
||||||
|
await interaction.followup.send(embed=self.embed, ephemeral=True, view=self)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(embed=self.embed, ephemeral=True, view=self)
|
||||||
|
|
||||||
|
@button(label="Retry")
|
||||||
|
async def retry_button(self, interaction, butt):
|
||||||
|
# Setting these here so they don't update in the meantime
|
||||||
|
for item, value in self.item_values.items():
|
||||||
|
item.default = value
|
||||||
|
if self._interaction is not None:
|
||||||
|
await self._interaction.delete_original_response()
|
||||||
|
self._interaction = None
|
||||||
|
await interaction.response.send_modal(self.modal)
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Confirm(LeoUI):
|
||||||
|
"""
|
||||||
|
Micro UI class implementing a confirmation question.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
confirm_msg: str
|
||||||
|
The confirmation question to ask from the user.
|
||||||
|
This is set as the description of the `embed` property.
|
||||||
|
The `embed` may be further modified if required.
|
||||||
|
permitted_id: Optional[int]
|
||||||
|
The user id allowed to access this interaction.
|
||||||
|
Other users will recieve an access denied error message.
|
||||||
|
defer: bool
|
||||||
|
Whether to defer the interaction response while handling the button.
|
||||||
|
It may be useful to set this to `False` to obtain manual control
|
||||||
|
over the interaction response flow (e.g. to send a modal or ephemeral message).
|
||||||
|
The button press interaction may be accessed through `Confirm.interaction`.
|
||||||
|
Default: True
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
```
|
||||||
|
confirm = Confirm("Are you sure?", ctx.author.id)
|
||||||
|
confirm.embed.colour = discord.Colour.red()
|
||||||
|
confirm.confirm_button.label = "Yes I am sure"
|
||||||
|
confirm.cancel_button.label = "No I am not sure"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await confirm.ask(ctx.interaction, ephemeral=True)
|
||||||
|
except ResultTimedOut:
|
||||||
|
return
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
confirm_msg: str,
|
||||||
|
permitted_id: Optional[int] = None,
|
||||||
|
defer: bool = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.confirm_msg = confirm_msg
|
||||||
|
self.permitted_id = permitted_id
|
||||||
|
self.defer = defer
|
||||||
|
|
||||||
|
self._embed: Optional[discord.Embed] = None
|
||||||
|
self._result: asyncio.Future[bool] = asyncio.Future()
|
||||||
|
|
||||||
|
# Indicates whether we should delete the message or the interaction response
|
||||||
|
self._is_followup: bool = False
|
||||||
|
self._original: Optional[discord.Interaction] = None
|
||||||
|
self._message: Optional[discord.Message] = None
|
||||||
|
|
||||||
|
async def interaction_check(self, interaction: discord.Interaction):
|
||||||
|
return (self.permitted_id is None) or interaction.user.id == self.permitted_id
|
||||||
|
|
||||||
|
async def on_timeout(self):
|
||||||
|
# Propagate timeout to result Future
|
||||||
|
self._result.set_exception(ResponseTimedOut)
|
||||||
|
await self.cleanup()
|
||||||
|
|
||||||
|
async def cleanup(self):
|
||||||
|
"""
|
||||||
|
Cleanup the confirmation prompt by deleting it, if possible.
|
||||||
|
Ignores any Discord errors that occur during the process.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._is_followup and self._message:
|
||||||
|
await self._message.delete()
|
||||||
|
elif not self._is_followup and self._original and not self._original.is_expired():
|
||||||
|
await self._original.delete_original_response()
|
||||||
|
except discord.HTTPException:
|
||||||
|
# A user probably already deleted the message
|
||||||
|
# Anything could have happened, just ignore.
|
||||||
|
pass
|
||||||
|
|
||||||
|
@button(label="Confirm")
|
||||||
|
async def confirm_button(self, interaction: discord.Interaction, press):
|
||||||
|
if self.defer:
|
||||||
|
await interaction.response.defer()
|
||||||
|
self._result.set_result(True)
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@button(label="Cancel")
|
||||||
|
async def cancel_button(self, interaction: discord.Interaction, press):
|
||||||
|
if self.defer:
|
||||||
|
await interaction.response.defer()
|
||||||
|
self._result.set_result(False)
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed(self):
|
||||||
|
"""
|
||||||
|
Confirmation embed shown to the user.
|
||||||
|
This is cached, and may be modifed directly through the usual EmbedProxy API,
|
||||||
|
or explicitly overwritten.
|
||||||
|
"""
|
||||||
|
if self._embed is None:
|
||||||
|
self._embed = discord.Embed(
|
||||||
|
colour=discord.Colour.orange(),
|
||||||
|
description=self.confirm_msg
|
||||||
|
)
|
||||||
|
return self._embed
|
||||||
|
|
||||||
|
@embed.setter
|
||||||
|
def embed(self, value):
|
||||||
|
self._embed = value
|
||||||
|
|
||||||
|
async def ask(self, interaction: discord.Interaction, ephemeral=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Send this confirmation prompt in response to the provided interaction.
|
||||||
|
Extra keyword arguments are passed to `Interaction.response.send_message`
|
||||||
|
or `Interaction.send_followup`, depending on whether
|
||||||
|
the provided interaction has already been responded to.
|
||||||
|
|
||||||
|
The `epehemeral` argument is handled specially,
|
||||||
|
since the question message can only be deleted through `Interaction.delete_original_response`.
|
||||||
|
|
||||||
|
Waits on and returns the internal `result` Future.
|
||||||
|
|
||||||
|
Returns: bool
|
||||||
|
True if the user pressed the confirm button.
|
||||||
|
False if the user pressed the cancel button.
|
||||||
|
Raises:
|
||||||
|
ResponseTimedOut:
|
||||||
|
If the user does not respond before the UI times out.
|
||||||
|
"""
|
||||||
|
self._original = interaction
|
||||||
|
if interaction.response.is_done():
|
||||||
|
# Interaction already responded to, send a follow up
|
||||||
|
if ephemeral:
|
||||||
|
raise ValueError("Cannot send an ephemeral response to a used interaction.")
|
||||||
|
self._message = await interaction.followup.send(embed=self.embed, **kwargs, view=self)
|
||||||
|
self._is_followup = True
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(
|
||||||
|
embed=self.embed, ephemeral=ephemeral, **kwargs, view=self
|
||||||
|
)
|
||||||
|
self._is_followup = False
|
||||||
|
return await self._result
|
||||||
|
|
||||||
|
# TODO: Selector MicroUI for displaying options (<= 25)
|
||||||
1070
src/utils/ui/msgeditor.py
Normal file
1070
src/utils/ui/msgeditor.py
Normal file
File diff suppressed because it is too large
Load Diff
9
src/wards.py
Normal file
9
src/wards.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from meta import LionBot
|
||||||
|
|
||||||
|
# Raw checks, return True/False depending on whether they pass
|
||||||
|
async def sys_admin(bot: LionBot, userid: int):
|
||||||
|
"""
|
||||||
|
Checks whether the context author is listed in the configuration file as a bot admin.
|
||||||
|
"""
|
||||||
|
admins = bot.config.bot.getintlist('admins')
|
||||||
|
return userid in admins
|
||||||
Reference in New Issue
Block a user