Compare commits
10 Commits
c3c8baa4b2
...
230ea88d23
| Author | SHA1 | Date | |
|---|---|---|---|
| 230ea88d23 | |||
| 42af454864 | |||
| 9a0d4090f5 | |||
| d05dc81667 | |||
| a4dd540f44 | |||
| 63e5dd1796 | |||
| daa370e09f | |||
| 58c0873987 | |||
| 5de3fd77bf | |||
| 873def8456 |
9
.gitmodules
vendored
Normal file
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[submodule "src/modules/voicefix"]
|
||||||
|
path = src/modules/voicefix
|
||||||
|
url = git@github.com:Intery/StudyLion-voicefix.git
|
||||||
|
[submodule "src/modules/streamalerts"]
|
||||||
|
path = src/modules/streamalerts
|
||||||
|
url = git@github.com:Intery/StudyLion-streamalerts.git
|
||||||
|
[submodule "src/data"]
|
||||||
|
path = src/data
|
||||||
|
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git
|
||||||
9
config/emojis.conf
Normal file
9
config/emojis.conf
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
[EMOJIS]
|
||||||
|
|
||||||
|
tick = :✅:
|
||||||
|
clock = :⏱️:
|
||||||
|
warning = :⚠️:
|
||||||
|
config = :⚙️:
|
||||||
|
stats = :📊:
|
||||||
|
utility = :⏱️:
|
||||||
|
cancel = :❌:
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
-- Metadata {{{
|
-- Metadata {{{
|
||||||
CREATE TABLE VersionHistory(
|
CREATE TABLE version_history(
|
||||||
version INTEGER NOT NULL,
|
component TEXT NOT NULL,
|
||||||
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
from_version INTEGER NOT NULL,
|
||||||
author TEXT
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
);
|
);
|
||||||
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
|
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
|
||||||
|
|
||||||
|
|
||||||
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||||
@@ -31,76 +33,6 @@ CREATE TABLE bot_config(
|
|||||||
);
|
);
|
||||||
-- }}}
|
-- }}}
|
||||||
|
|
||||||
-- Channel Linker {{{
|
-- TODO: Profile data
|
||||||
|
|
||||||
CREATE TABLE links(
|
|
||||||
linkid SERIAL PRIMARY KEY,
|
|
||||||
name TEXT
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE channel_webhooks(
|
|
||||||
channelid BIGINT PRIMARY KEY,
|
|
||||||
webhookid BIGINT NOT NULL,
|
|
||||||
token TEXT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE channel_links(
|
|
||||||
linkid INTEGER NOT NULL REFERENCES links (linkid) ON DELETE CASCADE,
|
|
||||||
channelid BIGINT NOT NULL REFERENCES channel_webhooks (channelid) ON DELETE CASCADE,
|
|
||||||
PRIMARY KEY (linkid, channelid)
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
-- }}}
|
|
||||||
|
|
||||||
-- Stream Alerts {{{
|
|
||||||
|
|
||||||
-- DROP TABLE IF EXISTS stream_alerts;
|
|
||||||
-- DROP TABLE IF EXISTS streams;
|
|
||||||
-- DROP TABLE IF EXISTS alert_channels;
|
|
||||||
-- DROP TABLE IF EXISTS streamers;
|
|
||||||
|
|
||||||
CREATE TABLE streamers(
|
|
||||||
userid BIGINT PRIMARY KEY,
|
|
||||||
login_name TEXT NOT NULL,
|
|
||||||
display_name TEXT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE alert_channels(
|
|
||||||
subscriptionid SERIAL PRIMARY KEY,
|
|
||||||
guildid BIGINT NOT NULL,
|
|
||||||
channelid BIGINT NOT NULL,
|
|
||||||
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
|
|
||||||
created_by BIGINT NOT NULL,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
paused BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
end_delete BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
live_message TEXT,
|
|
||||||
end_message TEXT
|
|
||||||
);
|
|
||||||
CREATE INDEX alert_channels_guilds ON alert_channels (guildid);
|
|
||||||
CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid);
|
|
||||||
|
|
||||||
CREATE TABLE streams(
|
|
||||||
streamid SERIAL PRIMARY KEY,
|
|
||||||
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
|
|
||||||
start_at TIMESTAMPTZ NOT NULL,
|
|
||||||
twitch_stream_id BIGINT,
|
|
||||||
game_name TEXT,
|
|
||||||
title TEXT,
|
|
||||||
end_at TIMESTAMPTZ
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE stream_alerts(
|
|
||||||
alertid SERIAL PRIMARY KEY,
|
|
||||||
streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE,
|
|
||||||
subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE,
|
|
||||||
sent_at TIMESTAMPTZ NOT NULL,
|
|
||||||
messageid BIGINT NOT NULL,
|
|
||||||
resolved_at TIMESTAMPTZ
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
-- }}}
|
|
||||||
|
|
||||||
-- vim: set fdm=marker:
|
-- vim: set fdm=marker:
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
aiohttp==3.7.4.post0
|
aiohttp
|
||||||
cachetools==4.2.2
|
cachetools
|
||||||
configparser==5.0.2
|
configparser
|
||||||
discord.py [voice]
|
discord.py [voice]
|
||||||
iso8601==0.1.16
|
iso8601
|
||||||
psycopg[pool]
|
psycopg[pool]
|
||||||
pytz==2021.1
|
pytz
|
||||||
twitchAPI
|
|
||||||
|
|||||||
11
src/bot.py
11
src/bot.py
@@ -13,8 +13,6 @@ from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus
|
|||||||
|
|
||||||
from data import Database
|
from data import Database
|
||||||
|
|
||||||
from constants import DATA_VERSION
|
|
||||||
|
|
||||||
|
|
||||||
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
||||||
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
||||||
@@ -57,15 +55,10 @@ async def main():
|
|||||||
intents.presences = False
|
intents.presences = False
|
||||||
|
|
||||||
async with db.open():
|
async with db.open():
|
||||||
version = await db.version()
|
|
||||||
if version.version != DATA_VERSION:
|
|
||||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
|
||||||
logger.critical(error)
|
|
||||||
raise RuntimeError(error)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with LionBot(
|
async with LionBot(
|
||||||
command_prefix='!leo!',
|
command_prefix=conf.bot.get('prefix', '!!'),
|
||||||
intents=intents,
|
intents=intents,
|
||||||
appname=appname,
|
appname=appname,
|
||||||
shardname=shardname,
|
shardname=shardname,
|
||||||
@@ -81,7 +74,7 @@ async def main():
|
|||||||
shard_count=sharding.shard_count,
|
shard_count=sharding.shard_count,
|
||||||
help_command=None,
|
help_command=None,
|
||||||
proxy=conf.bot.get('proxy', None),
|
proxy=conf.bot.get('proxy', None),
|
||||||
chunk_guilds_at_startup=False,
|
chunk_guilds_at_startup=True,
|
||||||
) as lionbot:
|
) as lionbot:
|
||||||
ctx_bot.set(lionbot)
|
ctx_bot.set(lionbot)
|
||||||
lionbot.system_monitor.add_component(
|
lionbot.system_monitor.add_component(
|
||||||
|
|||||||
26
src/botdata.py
Normal file
26
src/botdata.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import String, Timestamp, Integer, Bool
|
||||||
|
|
||||||
|
|
||||||
|
class VersionHistory(RowModel):
|
||||||
|
"""
|
||||||
|
CREATE TABLE version_history(
|
||||||
|
component TEXT NOT NULL,
|
||||||
|
from_version INTEGER NOT NULL,
|
||||||
|
to_version INTEGER NOT NULL,
|
||||||
|
author TEXT NOT NULL,
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'version_history'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
component = String()
|
||||||
|
from_version = Integer()
|
||||||
|
to_version = Integer()
|
||||||
|
author = String()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class BotData(Registry):
|
||||||
|
version_history = VersionHistory.table
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
CONFIG_FILE = "config/bot.conf"
|
CONFIG_FILE = "config/bot.conf"
|
||||||
DATA_VERSION = 1
|
|
||||||
|
|
||||||
MAX_COINS = 2147483647 - 1
|
|
||||||
|
|
||||||
HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
|
HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
|
||||||
|
|
||||||
|
SCHEMA_VERSIONS = {
|
||||||
|
'ROOT': 1,
|
||||||
|
}
|
||||||
|
|||||||
1
src/data
Submodule
1
src/data
Submodule
Submodule src/data added at cfdfe0eb50
@@ -1,9 +0,0 @@
|
|||||||
from .conditions import Condition, condition, NULL
|
|
||||||
from .database import Database
|
|
||||||
from .models import RowModel, RowTable, WeakCache
|
|
||||||
from .table import Table
|
|
||||||
from .base import Expression, RawExpr
|
|
||||||
from .columns import ColumnExpr, Column, Integer, String
|
|
||||||
from .registry import Registry, AttachableClass, Attachable
|
|
||||||
from .adapted import RegisterEnum
|
|
||||||
from .queries import ORDER, NULLS, JOINTYPE
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
from psycopg.types.enum import register_enum, EnumInfo
|
|
||||||
from psycopg import AsyncConnection
|
|
||||||
from .registry import Attachable, Registry
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterEnum(Attachable):
|
|
||||||
def __init__(self, enum, name: Optional[str] = None, mapper=None):
|
|
||||||
super().__init__()
|
|
||||||
self.enum = enum
|
|
||||||
self.name = name or enum.__name__
|
|
||||||
self.mapping = mapper(enum) if mapper is not None else self._mapper()
|
|
||||||
|
|
||||||
def _mapper(self):
|
|
||||||
return {m: m.value[0] for m in self.enum}
|
|
||||||
|
|
||||||
def attach_to(self, registry: Registry):
|
|
||||||
self._registry = registry
|
|
||||||
registry.init_task(self.on_init)
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def on_init(self, registry: Registry):
|
|
||||||
connector = registry._conn
|
|
||||||
if connector is None:
|
|
||||||
raise ValueError("Cannot initialise without connector!")
|
|
||||||
connector.connect_hook(self.connection_hook)
|
|
||||||
# await connector.refresh_pool()
|
|
||||||
# The below may be somewhat dangerous
|
|
||||||
# But adaption should never write to the database
|
|
||||||
await connector.map_over_pool(self.connection_hook)
|
|
||||||
# if conn := connector.conn:
|
|
||||||
# # Ensure the adaption is run in the current context as well
|
|
||||||
# await self.connection_hook(conn)
|
|
||||||
|
|
||||||
async def connection_hook(self, conn: AsyncConnection):
|
|
||||||
info = await EnumInfo.fetch(conn, self.name)
|
|
||||||
if info is None:
|
|
||||||
raise ValueError(f"Enum {self.name} not found in database.")
|
|
||||||
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
from itertools import chain
|
|
||||||
from psycopg import sql
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Expression(Protocol):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class RawExpr(Expression):
|
|
||||||
__slots__ = ('expr', 'values')
|
|
||||||
|
|
||||||
expr: sql.Composable
|
|
||||||
values: tuple[Any, ...]
|
|
||||||
|
|
||||||
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
|
|
||||||
self.expr = expr
|
|
||||||
self.values = values
|
|
||||||
|
|
||||||
def as_tuple(self):
|
|
||||||
return (self.expr, self.values)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
|
|
||||||
"""
|
|
||||||
Join a sequence of Expressions into a single RawExpr.
|
|
||||||
"""
|
|
||||||
tups = (
|
|
||||||
expression.as_tuple()
|
|
||||||
for expression in expressions
|
|
||||||
)
|
|
||||||
return cls.join_tuples(*tups, joiner=joiner)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
|
|
||||||
exprs, values = zip(*tuples)
|
|
||||||
expr = joiner.join(exprs)
|
|
||||||
value = tuple(chain(*values))
|
|
||||||
return cls(expr, value)
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
|
|
||||||
from psycopg import sql
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from .base import RawExpr, Expression
|
|
||||||
from .conditions import Condition, Joiner
|
|
||||||
from .table import Table
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnExpr(RawExpr):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __lt__(self, obj) -> Condition:
|
|
||||||
expr, values = self.as_tuple()
|
|
||||||
|
|
||||||
if isinstance(obj, Expression):
|
|
||||||
# column < Expression
|
|
||||||
obj_expr, obj_values = obj.as_tuple()
|
|
||||||
cond_exprs = (expr, Joiner.LT, obj_expr)
|
|
||||||
cond_values = (*values, *obj_values)
|
|
||||||
else:
|
|
||||||
# column < Literal
|
|
||||||
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
|
|
||||||
cond_values = (*values, obj)
|
|
||||||
|
|
||||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
|
||||||
|
|
||||||
def __le__(self, obj) -> Condition:
|
|
||||||
expr, values = self.as_tuple()
|
|
||||||
|
|
||||||
if isinstance(obj, Expression):
|
|
||||||
# column <= Expression
|
|
||||||
obj_expr, obj_values = obj.as_tuple()
|
|
||||||
cond_exprs = (expr, Joiner.LE, obj_expr)
|
|
||||||
cond_values = (*values, *obj_values)
|
|
||||||
else:
|
|
||||||
# column <= Literal
|
|
||||||
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
|
|
||||||
cond_values = (*values, obj)
|
|
||||||
|
|
||||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
|
||||||
|
|
||||||
def __eq__(self, obj) -> Condition: # type: ignore[override]
|
|
||||||
return Condition._expression_equality(self, obj)
|
|
||||||
|
|
||||||
def __ne__(self, obj) -> Condition: # type: ignore[override]
|
|
||||||
return ~(self.__eq__(obj))
|
|
||||||
|
|
||||||
def __gt__(self, obj) -> Condition:
|
|
||||||
return ~(self.__le__(obj))
|
|
||||||
|
|
||||||
def __ge__(self, obj) -> Condition:
|
|
||||||
return ~(self.__lt__(obj))
|
|
||||||
|
|
||||||
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
|
|
||||||
if isinstance(obj, Expression):
|
|
||||||
obj_expr, obj_values = obj.as_tuple()
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} + {})").format(self.expr, obj_expr),
|
|
||||||
(*self.values, *obj_values)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
|
|
||||||
(*self.values, obj)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __sub__(self, obj) -> 'ColumnExpr':
|
|
||||||
if isinstance(obj, Expression):
|
|
||||||
obj_expr, obj_values = obj.as_tuple()
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} - {})").format(self.expr, obj_expr),
|
|
||||||
(*self.values, *obj_values)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
|
|
||||||
(*self.values, obj)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __mul__(self, obj) -> 'ColumnExpr':
|
|
||||||
if isinstance(obj, Expression):
|
|
||||||
obj_expr, obj_values = obj.as_tuple()
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} * {})").format(self.expr, obj_expr),
|
|
||||||
(*self.values, *obj_values)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
|
|
||||||
(*self.values, obj)
|
|
||||||
)
|
|
||||||
|
|
||||||
def CAST(self, target_type: sql.Composable):
|
|
||||||
return ColumnExpr(
|
|
||||||
sql.SQL("({}::{})").format(self.expr, target_type),
|
|
||||||
self.values
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .models import RowModel
|
|
||||||
|
|
||||||
|
|
||||||
class Column(ColumnExpr, Generic[T]):
|
|
||||||
def __init__(self, name: Optional[str] = None,
|
|
||||||
primary: bool = False, references: Optional['Column'] = None,
|
|
||||||
type: Optional[Type[T]] = None):
|
|
||||||
self.primary = primary
|
|
||||||
self.references = references
|
|
||||||
self.name: str = name # type: ignore
|
|
||||||
self.owner: Optional['RowModel'] = None
|
|
||||||
self._type = type
|
|
||||||
|
|
||||||
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
|
||||||
self.values = ()
|
|
||||||
|
|
||||||
def __set_name__(self, owner, name):
|
|
||||||
# Only allow setting the owner once
|
|
||||||
self.name = self.name or name
|
|
||||||
self.owner = owner
|
|
||||||
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
|
|
||||||
...
|
|
||||||
|
|
||||||
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
|
|
||||||
# Get value from row data or session
|
|
||||||
if obj is None:
|
|
||||||
return self
|
|
||||||
else:
|
|
||||||
return obj.data[self.name]
|
|
||||||
|
|
||||||
|
|
||||||
class Integer(Column[int]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class String(Column[str]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Bool(Column[bool]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Timestamp(Column[datetime]):
|
|
||||||
pass
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
# from meta import sharding
|
|
||||||
from typing import Any, Union
|
|
||||||
from enum import Enum
|
|
||||||
from itertools import chain
|
|
||||||
from psycopg import sql
|
|
||||||
|
|
||||||
from .base import Expression, RawExpr
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
A Condition is a "logical" database expression, intended for use in Where statements.
|
|
||||||
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
|
|
||||||
"""
|
|
||||||
|
|
||||||
NULL = None
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(Enum):
|
|
||||||
EQUALS = ('=', '!=')
|
|
||||||
IS = ('IS', 'IS NOT')
|
|
||||||
LIKE = ('LIKE', 'NOT LIKE')
|
|
||||||
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
|
|
||||||
IN = ('IN', 'NOT IN')
|
|
||||||
LT = ('<', '>=')
|
|
||||||
LE = ('<=', '>')
|
|
||||||
NONE = ('', '')
|
|
||||||
|
|
||||||
|
|
||||||
class Condition(Expression):
|
|
||||||
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
|
|
||||||
values: tuple[Any, ...] = (), negated=False
|
|
||||||
):
|
|
||||||
self.expr1 = expr1
|
|
||||||
self.joiner = joiner
|
|
||||||
self.negated = negated
|
|
||||||
self.expr2 = expr2
|
|
||||||
self.values = values
|
|
||||||
|
|
||||||
def as_tuple(self):
|
|
||||||
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
|
|
||||||
if self.negated and self.joiner is Joiner.NONE:
|
|
||||||
expr = sql.SQL("NOT ({})").format(expr)
|
|
||||||
return (expr, self.values)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
|
|
||||||
"""
|
|
||||||
Construct a Condition from a sequence of Conditions,
|
|
||||||
together with some explicit column conditions.
|
|
||||||
"""
|
|
||||||
# TODO: Consider adding a _table identifier here so we can identify implicit columns
|
|
||||||
# Or just require subquery type conditions to always come from modelled tables.
|
|
||||||
implicit_conditions = (
|
|
||||||
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
|
|
||||||
)
|
|
||||||
return cls._and(*conditions, *implicit_conditions)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _and(cls, *conditions: 'Condition'):
|
|
||||||
if not len(conditions):
|
|
||||||
raise ValueError("Cannot combine 0 Conditions")
|
|
||||||
if len(conditions) == 1:
|
|
||||||
return conditions[0]
|
|
||||||
|
|
||||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
|
||||||
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
|
||||||
cond_values = tuple(chain(*values))
|
|
||||||
|
|
||||||
return Condition(cond_expr, values=cond_values)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _or(cls, *conditions: 'Condition'):
|
|
||||||
if not len(conditions):
|
|
||||||
raise ValueError("Cannot combine 0 Conditions")
|
|
||||||
if len(conditions) == 1:
|
|
||||||
return conditions[0]
|
|
||||||
|
|
||||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
|
||||||
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
|
||||||
cond_values = tuple(chain(*values))
|
|
||||||
|
|
||||||
return Condition(cond_expr, values=cond_values)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _not(cls, condition: 'Condition'):
|
|
||||||
condition.negated = not condition.negated
|
|
||||||
return condition
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
|
|
||||||
# TODO: Check if this supports sbqueries
|
|
||||||
col_expr, col_values = column.as_tuple()
|
|
||||||
|
|
||||||
# TODO: Also support sql.SQL? For joins?
|
|
||||||
if isinstance(value, Expression):
|
|
||||||
# column = Expression
|
|
||||||
value_expr, value_values = value.as_tuple()
|
|
||||||
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
|
|
||||||
cond_values = (*col_values, *value_values)
|
|
||||||
elif isinstance(value, (tuple, list)):
|
|
||||||
# column in (...)
|
|
||||||
# TODO: Support expressions in value tuple?
|
|
||||||
if not value:
|
|
||||||
raise ValueError("Cannot create Condition from empty iterable!")
|
|
||||||
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
|
|
||||||
cond_exprs = (col_expr, Joiner.IN, value_expr)
|
|
||||||
cond_values = (*col_values, *value)
|
|
||||||
elif value is None:
|
|
||||||
# column IS NULL
|
|
||||||
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
|
|
||||||
cond_values = col_values
|
|
||||||
else:
|
|
||||||
# column = Literal
|
|
||||||
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
|
|
||||||
cond_values = (*col_values, value)
|
|
||||||
|
|
||||||
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
|
||||||
|
|
||||||
def __invert__(self) -> 'Condition':
|
|
||||||
self.negated = not self.negated
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __and__(self, condition: 'Condition') -> 'Condition':
|
|
||||||
return self._and(self, condition)
|
|
||||||
|
|
||||||
def __or__(self, condition: 'Condition') -> 'Condition':
|
|
||||||
return self._or(self, condition)
|
|
||||||
|
|
||||||
|
|
||||||
# Helper method to simply condition construction
|
|
||||||
def condition(*args, **kwargs) -> Condition:
|
|
||||||
return Condition.construct(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# class NOT(Condition):
|
|
||||||
# __slots__ = ('value',)
|
|
||||||
#
|
|
||||||
# def __init__(self, value):
|
|
||||||
# self.value = value
|
|
||||||
#
|
|
||||||
# def apply(self, key, values, conditions):
|
|
||||||
# item = self.value
|
|
||||||
# if isinstance(item, (list, tuple)):
|
|
||||||
# if item:
|
|
||||||
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
|
||||||
# values.extend(item)
|
|
||||||
# else:
|
|
||||||
# raise ValueError("Cannot check an empty iterable!")
|
|
||||||
# else:
|
|
||||||
# conditions.append("{}!={}".format(key, _replace_char))
|
|
||||||
# values.append(item)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class GEQ(Condition):
|
|
||||||
# __slots__ = ('value',)
|
|
||||||
#
|
|
||||||
# def __init__(self, value):
|
|
||||||
# self.value = value
|
|
||||||
#
|
|
||||||
# def apply(self, key, values, conditions):
|
|
||||||
# item = self.value
|
|
||||||
# if isinstance(item, (list, tuple)):
|
|
||||||
# raise ValueError("Cannot apply GEQ condition to a list!")
|
|
||||||
# else:
|
|
||||||
# conditions.append("{} >= {}".format(key, _replace_char))
|
|
||||||
# values.append(item)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class LEQ(Condition):
|
|
||||||
# __slots__ = ('value',)
|
|
||||||
#
|
|
||||||
# def __init__(self, value):
|
|
||||||
# self.value = value
|
|
||||||
#
|
|
||||||
# def apply(self, key, values, conditions):
|
|
||||||
# item = self.value
|
|
||||||
# if isinstance(item, (list, tuple)):
|
|
||||||
# raise ValueError("Cannot apply LEQ condition to a list!")
|
|
||||||
# else:
|
|
||||||
# conditions.append("{} <= {}".format(key, _replace_char))
|
|
||||||
# values.append(item)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class Constant(Condition):
|
|
||||||
# __slots__ = ('value',)
|
|
||||||
#
|
|
||||||
# def __init__(self, value):
|
|
||||||
# self.value = value
|
|
||||||
#
|
|
||||||
# def apply(self, key, values, conditions):
|
|
||||||
# conditions.append("{} {}".format(key, self.value))
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# class SHARDID(Condition):
|
|
||||||
# __slots__ = ('shardid', 'shard_count')
|
|
||||||
#
|
|
||||||
# def __init__(self, shardid, shard_count):
|
|
||||||
# self.shardid = shardid
|
|
||||||
# self.shard_count = shard_count
|
|
||||||
#
|
|
||||||
# def apply(self, key, values, conditions):
|
|
||||||
# if self.shard_count > 1:
|
|
||||||
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
|
|
||||||
# values.append(self.shardid)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# NULL = Constant('IS NULL')
|
|
||||||
# NOTNULL = Constant('IS NOT NULL')
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import psycopg as psq
|
|
||||||
from psycopg_pool import AsyncConnectionPool
|
|
||||||
from psycopg.pq import TransactionStatus
|
|
||||||
|
|
||||||
from .cursor import AsyncLoggingCursor
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
row_factory = psq.rows.dict_row
|
|
||||||
|
|
||||||
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class Connector:
|
|
||||||
cursor_factory = AsyncLoggingCursor
|
|
||||||
|
|
||||||
def __init__(self, conn_args):
|
|
||||||
self._conn_args = conn_args
|
|
||||||
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
|
||||||
|
|
||||||
self.pool = self.make_pool()
|
|
||||||
|
|
||||||
self.conn_hooks = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def conn(self) -> Optional[psq.AsyncConnection]:
|
|
||||||
"""
|
|
||||||
Convenience property for the current context connection.
|
|
||||||
"""
|
|
||||||
return ctx_connection.get()
|
|
||||||
|
|
||||||
@conn.setter
|
|
||||||
def conn(self, conn: psq.AsyncConnection):
|
|
||||||
"""
|
|
||||||
Set the contextual connection in the current context.
|
|
||||||
Always do this in an isolated context!
|
|
||||||
"""
|
|
||||||
ctx_connection.set(conn)
|
|
||||||
|
|
||||||
def make_pool(self) -> AsyncConnectionPool:
|
|
||||||
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
|
||||||
return AsyncConnectionPool(
|
|
||||||
self._conn_args,
|
|
||||||
open=False,
|
|
||||||
min_size=4,
|
|
||||||
max_size=8,
|
|
||||||
configure=self._setup_connection,
|
|
||||||
kwargs=self._conn_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def refresh_pool(self):
|
|
||||||
"""
|
|
||||||
Refresh the pool.
|
|
||||||
|
|
||||||
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
|
||||||
Better ways should be sought (a way to
|
|
||||||
"""
|
|
||||||
logger.info("Pool refresh requested, closing and reopening.")
|
|
||||||
old_pool = self.pool
|
|
||||||
self.pool = self.make_pool()
|
|
||||||
await self.pool.open()
|
|
||||||
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
|
||||||
await old_pool.close()
|
|
||||||
logger.info("Pool refresh complete.")
|
|
||||||
|
|
||||||
async def map_over_pool(self, callable):
|
|
||||||
"""
|
|
||||||
Dangerous method to call a method on each connection in the pool.
|
|
||||||
|
|
||||||
Utilises private methods of the AsyncConnectionPool.
|
|
||||||
"""
|
|
||||||
async with self.pool._lock:
|
|
||||||
conns = list(self.pool._pool)
|
|
||||||
while conns:
|
|
||||||
conn = conns.pop()
|
|
||||||
try:
|
|
||||||
await callable(conn)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def open(self):
|
|
||||||
try:
|
|
||||||
logger.info("Opening database pool.")
|
|
||||||
await self.pool.open()
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# May be a different pool!
|
|
||||||
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
|
||||||
await self.pool.close()
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def connection(self) -> psq.AsyncConnection:
|
|
||||||
"""
|
|
||||||
Asynchronous context manager to get and manage a connection.
|
|
||||||
|
|
||||||
If the context connection is set, uses this and does not manage the lifetime.
|
|
||||||
Otherwise, requests a new connection from the pool and returns it when done.
|
|
||||||
"""
|
|
||||||
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
|
||||||
if (conn := self.conn):
|
|
||||||
yield conn
|
|
||||||
else:
|
|
||||||
async with self.pool.connection() as conn:
|
|
||||||
yield conn
|
|
||||||
|
|
||||||
async def _setup_connection(self, conn: psq.AsyncConnection):
|
|
||||||
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
|
||||||
for hook in self.conn_hooks:
|
|
||||||
try:
|
|
||||||
await hook(conn)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception encountered setting up new connection")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
|
||||||
"""
|
|
||||||
Minimal decorator to register a coroutine to run on connect or reconnect.
|
|
||||||
|
|
||||||
Note that these are only run on connect and reconnect.
|
|
||||||
If a hook is registered after connection, it will not be run.
|
|
||||||
"""
|
|
||||||
self.conn_hooks.append(coro)
|
|
||||||
return coro
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class Connectable(Protocol):
|
|
||||||
def bind(self, connector: Connector):
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from psycopg import AsyncCursor, sql
|
|
||||||
from psycopg.abc import Query, Params
|
|
||||||
from psycopg._encodings import pgconn_encoding
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncLoggingCursor(AsyncCursor):
|
|
||||||
def mogrify_query(self, query: Query):
|
|
||||||
if isinstance(query, str):
|
|
||||||
msg = query
|
|
||||||
elif isinstance(query, (sql.SQL, sql.Composed)):
|
|
||||||
msg = query.as_string(self)
|
|
||||||
elif isinstance(query, bytes):
|
|
||||||
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
|
|
||||||
else:
|
|
||||||
msg = repr(query)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
|
|
||||||
if logging.DEBUG >= logger.getEffectiveLevel():
|
|
||||||
msg = self.mogrify_query(query)
|
|
||||||
logger.debug(
|
|
||||||
"Executing query (%s) with values %s", msg, params,
|
|
||||||
extra={'action': "Query Execute"}
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return await super().execute(query, params=params, **kwargs)
|
|
||||||
except Exception:
|
|
||||||
msg = self.mogrify_query(query)
|
|
||||||
logger.exception(
|
|
||||||
"Exception during query execution. Query (%s) with parameters %s.",
|
|
||||||
msg, params,
|
|
||||||
extra={'action': "Query Execute"},
|
|
||||||
stack_info=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# TODO: Possibly log execution time
|
|
||||||
pass
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
from typing import TypeVar
|
|
||||||
import logging
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
# from .cursor import AsyncLoggingCursor
|
|
||||||
from .registry import Registry
|
|
||||||
from .connector import Connector
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
Version = namedtuple('Version', ('version', 'time', 'author'))
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=Registry)
|
|
||||||
|
|
||||||
|
|
||||||
class Database(Connector):
|
|
||||||
# cursor_factory = AsyncLoggingCursor
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.registries: dict[str, Registry] = {}
|
|
||||||
|
|
||||||
def load_registry(self, registry: T) -> T:
|
|
||||||
logger.debug(
|
|
||||||
f"Loading and binding registry '{registry.name}'.",
|
|
||||||
extra={'action': f"Reg {registry.name}"}
|
|
||||||
)
|
|
||||||
registry.bind(self)
|
|
||||||
self.registries[registry.name] = registry
|
|
||||||
return registry
|
|
||||||
|
|
||||||
async def version(self) -> Version:
|
|
||||||
"""
|
|
||||||
Return the current schema version as a Version namedtuple.
|
|
||||||
"""
|
|
||||||
async with self.connection() as conn:
|
|
||||||
async with conn.cursor() as cursor:
|
|
||||||
# Get last entry in version table, compare against desired version
|
|
||||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
|
||||||
row = await cursor.fetchone()
|
|
||||||
if row:
|
|
||||||
return Version(row['version'], row['time'], row['author'])
|
|
||||||
else:
|
|
||||||
# No versions in the database
|
|
||||||
return Version(-1, None, None)
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
from typing import TypeVar, Type, Optional, Generic, Union
|
|
||||||
# from typing_extensions import Self
|
|
||||||
from weakref import WeakValueDictionary
|
|
||||||
from collections.abc import MutableMapping
|
|
||||||
|
|
||||||
from psycopg.rows import DictRow
|
|
||||||
|
|
||||||
from .table import Table
|
|
||||||
from .columns import Column
|
|
||||||
from . import queries as q
|
|
||||||
from .connector import Connector
|
|
||||||
from .registry import Registry
|
|
||||||
|
|
||||||
|
|
||||||
RowT = TypeVar('RowT', bound='RowModel')
|
|
||||||
|
|
||||||
|
|
||||||
class MISSING:
|
|
||||||
__slots__ = ('oid',)
|
|
||||||
|
|
||||||
def __init__(self, oid):
|
|
||||||
self.oid = oid
|
|
||||||
|
|
||||||
|
|
||||||
class RowTable(Table, Generic[RowT]):
|
|
||||||
__slots__ = (
|
|
||||||
'model',
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, name, model: Type[RowT], **kwargs):
|
|
||||||
super().__init__(name, **kwargs)
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def columns(self):
|
|
||||||
return self.model._columns_
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id_col(self):
|
|
||||||
return self.model._key_
|
|
||||||
|
|
||||||
@property
|
|
||||||
def row_cache(self):
|
|
||||||
return self.model._cache_
|
|
||||||
|
|
||||||
def _many_query_adapter(self, *data):
|
|
||||||
self.model._make_rows(*data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _single_query_adapter(self, *data):
|
|
||||||
if data:
|
|
||||||
self.model._make_rows(*data)
|
|
||||||
return data[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _delete_query_adapter(self, *data):
|
|
||||||
self.model._delete_rows(*data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
# New methods to fetch and create rows
|
|
||||||
async def create_row(self, *args, **kwargs) -> RowT:
|
|
||||||
data = await super().insert(*args, **kwargs)
|
|
||||||
return self.model._make_rows(data)[0]
|
|
||||||
|
|
||||||
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
|
||||||
# TODO: Handle list of rowids here?
|
|
||||||
return q.Select(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self.model._make_rows,
|
|
||||||
connector=self.connector
|
|
||||||
).where(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
WK = TypeVar('WK')
|
|
||||||
WV = TypeVar('WV')
|
|
||||||
|
|
||||||
|
|
||||||
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
|
||||||
def __init__(self, ref_cache):
|
|
||||||
self.ref_cache = ref_cache
|
|
||||||
self.weak_cache = WeakValueDictionary()
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
value = self.weak_cache[key]
|
|
||||||
self.ref_cache[key] = value
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.weak_cache[key] = value
|
|
||||||
self.ref_cache[key] = value
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del self.weak_cache[key]
|
|
||||||
try:
|
|
||||||
del self.ref_cache[key]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
return key in self.weak_cache
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.weak_cache)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.weak_cache)
|
|
||||||
|
|
||||||
def get(self, key, default=None):
|
|
||||||
try:
|
|
||||||
return self[key]
|
|
||||||
except KeyError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
def pop(self, key, default=None):
|
|
||||||
if key in self:
|
|
||||||
value = self[key]
|
|
||||||
del self[key]
|
|
||||||
else:
|
|
||||||
value = default
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement getitem and setitem, for dynamic column access
|
|
||||||
class RowModel:
|
|
||||||
__slots__ = ('data',)
|
|
||||||
|
|
||||||
_schema_: str = 'public'
|
|
||||||
_tablename_: Optional[str] = None
|
|
||||||
_columns_: dict[str, Column] = {}
|
|
||||||
|
|
||||||
# Cache to keep track of registered Rows
|
|
||||||
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
|
||||||
|
|
||||||
_key_: tuple[str, ...] = ()
|
|
||||||
_connector: Optional[Connector] = None
|
|
||||||
_registry: Optional[Registry] = None
|
|
||||||
|
|
||||||
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
|
||||||
table: RowTable = None
|
|
||||||
|
|
||||||
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Set table, _columns_, and _key_.
|
|
||||||
"""
|
|
||||||
if table is not None:
|
|
||||||
cls._tablename_ = table
|
|
||||||
|
|
||||||
if cls._tablename_ is not None:
|
|
||||||
columns = {}
|
|
||||||
for key, value in cls.__dict__.items():
|
|
||||||
if isinstance(value, Column):
|
|
||||||
columns[key] = value
|
|
||||||
|
|
||||||
cls._columns_ = columns
|
|
||||||
if not cls._key_:
|
|
||||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
|
||||||
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
|
||||||
if cls._cache_ is None:
|
|
||||||
cls._cache_ = WeakValueDictionary()
|
|
||||||
|
|
||||||
def __new__(cls, data):
|
|
||||||
# Registry pattern.
|
|
||||||
# Ensure each rowid always refers to a single Model instance
|
|
||||||
if data is not None:
|
|
||||||
rowid = cls._id_from_data(data)
|
|
||||||
|
|
||||||
cache = cls._cache_
|
|
||||||
|
|
||||||
if (row := cache.get(rowid, None)) is not None:
|
|
||||||
obj = row
|
|
||||||
else:
|
|
||||||
obj = cache[rowid] = super().__new__(cls)
|
|
||||||
else:
|
|
||||||
obj = super().__new__(cls)
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def as_tuple(cls):
|
|
||||||
return (cls.table.identifier, ())
|
|
||||||
|
|
||||||
def __init__(self, data):
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.data[key]
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.data[key] = value
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def bind(cls, connector: Connector):
|
|
||||||
if cls.table is None:
|
|
||||||
raise ValueError("Cannot bind abstract RowModel")
|
|
||||||
cls._connector = connector
|
|
||||||
cls.table.bind(connector)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def attach_to(cls, registry: Registry):
|
|
||||||
cls._registry = registry
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _dict_(self):
|
|
||||||
return {key: self.data[key] for key in self._key_}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _rowid_(self):
|
|
||||||
return tuple(self.data[key] for key in self._key_)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "{}.{}({})".format(
|
|
||||||
self.table.schema,
|
|
||||||
self.table.name,
|
|
||||||
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _id_from_data(cls, data):
|
|
||||||
return tuple(data[key] for key in cls._key_)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _dict_from_id(cls, rowid):
|
|
||||||
return dict(zip(cls._key_, rowid))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
|
||||||
"""
|
|
||||||
Create or retrieve Row objects for each provided data row.
|
|
||||||
If the rows already exist in cache, updates the cached row.
|
|
||||||
"""
|
|
||||||
# TODO: Handle partial row data here somehow?
|
|
||||||
rows = [cls(data_row) for data_row in data_rows]
|
|
||||||
return rows
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _delete_rows(cls, *data_rows):
|
|
||||||
"""
|
|
||||||
Remove the given rows from cache, if they exist.
|
|
||||||
May be extended to handle object deletion.
|
|
||||||
"""
|
|
||||||
cache = cls._cache_
|
|
||||||
|
|
||||||
for data_row in data_rows:
|
|
||||||
rowid = cls._id_from_data(data_row)
|
|
||||||
cache.pop(rowid, None)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
|
||||||
return await cls.table.create_row(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
|
||||||
return cls.table.fetch_rows_where(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
|
|
||||||
"""
|
|
||||||
Fetch the row with the given id, retrieving from cache where possible.
|
|
||||||
"""
|
|
||||||
row = cls._cache_.get(rowid, None) if cached else None
|
|
||||||
if row is None:
|
|
||||||
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
|
||||||
row = rows[0] if rows else None
|
|
||||||
if row is None:
|
|
||||||
cls._cache_[rowid] = cls(None)
|
|
||||||
elif row.data is None:
|
|
||||||
row = None
|
|
||||||
|
|
||||||
return row
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def fetch_or_create(cls, *rowid, **kwargs):
|
|
||||||
"""
|
|
||||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
|
||||||
"""
|
|
||||||
if rowid:
|
|
||||||
row = await cls.fetch(*rowid)
|
|
||||||
else:
|
|
||||||
rows = await cls.fetch_where(**kwargs).limit(1)
|
|
||||||
row = rows[0] if rows else None
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
creation_kwargs = kwargs
|
|
||||||
if rowid:
|
|
||||||
creation_kwargs.update(cls._dict_from_id(rowid))
|
|
||||||
row = await cls.create(**creation_kwargs)
|
|
||||||
return row
|
|
||||||
|
|
||||||
async def refresh(self: RowT) -> Optional[RowT]:
|
|
||||||
"""
|
|
||||||
Refresh this Row from data.
|
|
||||||
|
|
||||||
The return value may be `None` if the row was deleted.
|
|
||||||
"""
|
|
||||||
rows = await self.table.select_where(**self._dict_)
|
|
||||||
if not rows:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
self.data = rows[0]
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def update(self: RowT, **values) -> Optional[RowT]:
|
|
||||||
"""
|
|
||||||
Update this Row with the given values.
|
|
||||||
|
|
||||||
Internally passes the provided `values` to the `update` Query.
|
|
||||||
The return value may be `None` if the row was deleted.
|
|
||||||
"""
|
|
||||||
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
|
||||||
if not data:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return data[0]
|
|
||||||
|
|
||||||
async def delete(self: RowT) -> Optional[RowT]:
|
|
||||||
"""
|
|
||||||
Delete this Row.
|
|
||||||
"""
|
|
||||||
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
|
||||||
return data[0] if data is not None else None
|
|
||||||
@@ -1,644 +0,0 @@
|
|||||||
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
|
|
||||||
from enum import Enum
|
|
||||||
from itertools import chain
|
|
||||||
from psycopg import AsyncConnection, AsyncCursor
|
|
||||||
from psycopg import sql
|
|
||||||
from psycopg.rows import DictRow
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .conditions import Condition
|
|
||||||
from .base import Expression, RawExpr
|
|
||||||
from .connector import Connector
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
TQueryT = TypeVar('TQueryT', bound='TableQuery')
|
|
||||||
SQueryT = TypeVar('SQueryT', bound='Select')
|
|
||||||
|
|
||||||
QueryResult = TypeVar('QueryResult')
|
|
||||||
|
|
||||||
|
|
||||||
class Query(Generic[QueryResult]):
|
|
||||||
"""
|
|
||||||
ABC for an executable query statement.
|
|
||||||
"""
|
|
||||||
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
|
|
||||||
|
|
||||||
_adapter: Callable[..., QueryResult]
|
|
||||||
|
|
||||||
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
|
|
||||||
self.connector: Optional[Connector] = connector
|
|
||||||
self.conn: Optional[AsyncConnection] = conn
|
|
||||||
self.cursor: Optional[AsyncCursor] = cursor
|
|
||||||
|
|
||||||
if row_adapter is not None:
|
|
||||||
self._adapter = row_adapter
|
|
||||||
else:
|
|
||||||
self._adapter = self._no_adapter
|
|
||||||
|
|
||||||
self.result: Optional[QueryResult] = None
|
|
||||||
|
|
||||||
def bind(self, connector: Connector):
|
|
||||||
self.connector = connector
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_cursor(self, cursor: AsyncCursor):
|
|
||||||
self.cursor = cursor
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_connection(self, conn: AsyncConnection):
|
|
||||||
self.conn = conn
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
|
||||||
return data
|
|
||||||
|
|
||||||
def with_adapter(self, callable: Callable[..., QueryResult]):
|
|
||||||
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
|
|
||||||
# For this to work cleanly, callable should have arg type of QR1, not any
|
|
||||||
self._adapter = callable
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_no_adapter(self):
|
|
||||||
"""
|
|
||||||
Sets the adapater to the identity.
|
|
||||||
"""
|
|
||||||
self._adapter = self._no_adapter
|
|
||||||
return self
|
|
||||||
|
|
||||||
def one(self):
|
|
||||||
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self) -> Expression:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
|
|
||||||
query, values = self.build().as_tuple()
|
|
||||||
# TODO: Move logging out to a custom cursor
|
|
||||||
# logger.debug(
|
|
||||||
# f"Executing query ({query.as_string(cursor)}) with values {values}",
|
|
||||||
# extra={'action': "Query"}
|
|
||||||
# )
|
|
||||||
await cursor.execute(sql.Composed((query,)), values)
|
|
||||||
data = await cursor.fetchall()
|
|
||||||
self.result = self._adapter(*data)
|
|
||||||
return self.result
|
|
||||||
|
|
||||||
async def execute(self, cursor=None) -> QueryResult:
|
|
||||||
"""
|
|
||||||
Execute the query, optionally with the provided cursor, and return the result rows.
|
|
||||||
If no cursor is provided, and no cursor has been set with `with_cursor`,
|
|
||||||
the execution will create a new cursor from the connection and close it automatically.
|
|
||||||
"""
|
|
||||||
# Create a cursor if possible
|
|
||||||
cursor = cursor if cursor is not None else self.cursor
|
|
||||||
if self.cursor is None:
|
|
||||||
if self.conn is None:
|
|
||||||
if self.connector is None:
|
|
||||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
|
||||||
else:
|
|
||||||
async with self.connector.connection() as conn:
|
|
||||||
async with conn.cursor() as cursor:
|
|
||||||
data = await self._execute(cursor)
|
|
||||||
else:
|
|
||||||
async with self.conn.cursor() as cursor:
|
|
||||||
data = await self._execute(cursor)
|
|
||||||
else:
|
|
||||||
data = await self._execute(cursor)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __await__(self):
|
|
||||||
return self.execute().__await__()
|
|
||||||
|
|
||||||
|
|
||||||
class TableQuery(Query[QueryResult]):
|
|
||||||
"""
|
|
||||||
ABC for an executable query statement expected to be run on a single table.
|
|
||||||
"""
|
|
||||||
__slots__ = (
|
|
||||||
'tableid',
|
|
||||||
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, tableid, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.tableid: sql.Identifier = tableid
|
|
||||||
|
|
||||||
def options(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Set some query options.
|
|
||||||
Default implementation does nothing.
|
|
||||||
Should be overridden to provide specific options.
|
|
||||||
"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class WhereMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.condition: Optional[Condition] = None
|
|
||||||
|
|
||||||
def where(self, *args: Condition, **kwargs):
|
|
||||||
"""
|
|
||||||
Add a Condition to the query.
|
|
||||||
Position arguments should be Conditions,
|
|
||||||
and keyword arguments should be of the form `column=Value`,
|
|
||||||
where Value may be a Value-type or a literal value.
|
|
||||||
All provided Conditions will be and-ed together to create a new Condition.
|
|
||||||
TODO: Maybe just pass this verbatim to a condition.
|
|
||||||
"""
|
|
||||||
if args or kwargs:
|
|
||||||
condition = Condition.construct(*args, **kwargs)
|
|
||||||
if self.condition is not None:
|
|
||||||
condition = self.condition & condition
|
|
||||||
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _where_section(self) -> Optional[Expression]:
|
|
||||||
if self.condition is not None:
|
|
||||||
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class JOINTYPE(Enum):
|
|
||||||
LEFT = sql.SQL('LEFT JOIN')
|
|
||||||
RIGHT = sql.SQL('RIGHT JOIN')
|
|
||||||
INNER = sql.SQL('INNER JOIN')
|
|
||||||
OUTER = sql.SQL('OUTER JOIN')
|
|
||||||
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
|
||||||
|
|
||||||
|
|
||||||
class JoinMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
# TODO: Remember to add join slots to TableQuery
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._joins: list[Expression] = []
|
|
||||||
|
|
||||||
def join(self,
|
|
||||||
target: Union[str, Expression],
|
|
||||||
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
|
||||||
join_type: JOINTYPE = JOINTYPE.INNER,
|
|
||||||
natural=False):
|
|
||||||
available = (on is not None) + (using is not None) + natural
|
|
||||||
if available == 0:
|
|
||||||
raise ValueError("No conditions given for Query Join")
|
|
||||||
if available > 1:
|
|
||||||
raise ValueError("Exactly one join format must be given for Query Join")
|
|
||||||
|
|
||||||
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
|
||||||
if isinstance(target, str):
|
|
||||||
sections.append((sql.Identifier(target), ()))
|
|
||||||
else:
|
|
||||||
sections.append(target.as_tuple())
|
|
||||||
|
|
||||||
if on is not None:
|
|
||||||
sections.append((sql.SQL('ON'), ()))
|
|
||||||
sections.append(on.as_tuple())
|
|
||||||
elif using is not None:
|
|
||||||
sections.append((sql.SQL('USING'), ()))
|
|
||||||
if isinstance(using, Expression):
|
|
||||||
sections.append(using.as_tuple())
|
|
||||||
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
|
|
||||||
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
|
|
||||||
sections.append((cols, ()))
|
|
||||||
else:
|
|
||||||
raise ValueError("Unrecognised 'using' type.")
|
|
||||||
elif natural:
|
|
||||||
sections.insert(0, (sql.SQL('NATURAL'), ()))
|
|
||||||
|
|
||||||
expr = RawExpr.join_tuples(*sections)
|
|
||||||
self._joins.append(expr)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def leftjoin(self, *args, **kwargs):
|
|
||||||
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _join_section(self) -> Optional[Expression]:
|
|
||||||
if self._joins:
|
|
||||||
return RawExpr.join(*self._joins)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ExtraMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._extra: Optional[Expression] = None
|
|
||||||
|
|
||||||
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
|
|
||||||
"""
|
|
||||||
Add an extra string, and optionally values, to this query.
|
|
||||||
The extra string is inserted after any condition, and before the limit.
|
|
||||||
"""
|
|
||||||
extra_expr = RawExpr(extra, values)
|
|
||||||
if self._extra is not None:
|
|
||||||
extra_expr = RawExpr.join(self._extra, extra_expr)
|
|
||||||
self._extra = extra_expr
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _extra_section(self) -> Optional[Expression]:
|
|
||||||
if self._extra is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return self._extra
|
|
||||||
|
|
||||||
|
|
||||||
class LimitMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self._limit: Optional[int] = None
|
|
||||||
|
|
||||||
def limit(self, limit: int):
|
|
||||||
"""
|
|
||||||
Add a limit to this query.
|
|
||||||
"""
|
|
||||||
self._limit = limit
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _limit_section(self) -> Optional[Expression]:
|
|
||||||
if self._limit is not None:
|
|
||||||
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class FromMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._from: Optional[Expression] = None
|
|
||||||
|
|
||||||
def from_expr(self, _from: Expression):
|
|
||||||
self._from = _from
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _from_section(self) -> Optional[Expression]:
|
|
||||||
if self._from is not None:
|
|
||||||
expr, values = self._from.as_tuple()
|
|
||||||
return RawExpr(sql.SQL("FROM {}").format(expr), values)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ORDER(Enum):
|
|
||||||
ASC = sql.SQL('ASC')
|
|
||||||
DESC = sql.SQL('DESC')
|
|
||||||
|
|
||||||
|
|
||||||
class NULLS(Enum):
|
|
||||||
FIRST = sql.SQL('NULLS FIRST')
|
|
||||||
LAST = sql.SQL('NULLS LAST')
|
|
||||||
|
|
||||||
|
|
||||||
class OrderMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self._order: list[Expression] = []
|
|
||||||
|
|
||||||
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
|
|
||||||
"""
|
|
||||||
Add a single sort expression to the query.
|
|
||||||
This method stacks.
|
|
||||||
"""
|
|
||||||
if isinstance(expr, Expression):
|
|
||||||
string, values = expr.as_tuple()
|
|
||||||
else:
|
|
||||||
string = sql.Identifier(expr)
|
|
||||||
values = ()
|
|
||||||
|
|
||||||
parts = [string]
|
|
||||||
if direction is not None:
|
|
||||||
parts.append(direction.value)
|
|
||||||
if nulls is not None:
|
|
||||||
parts.append(nulls.value)
|
|
||||||
|
|
||||||
order_string = sql.SQL(' ').join(parts)
|
|
||||||
self._order.append(RawExpr(order_string, values))
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _order_section(self) -> Optional[Expression]:
|
|
||||||
if self._order:
|
|
||||||
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
|
|
||||||
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
|
|
||||||
return expr
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class GroupMixin(TableQuery[QueryResult]):
|
|
||||||
__slots__ = ()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self._group: list[Expression] = []
|
|
||||||
|
|
||||||
def group_by(self, *exprs: Union[Expression, str]):
|
|
||||||
"""
|
|
||||||
Add a group expression(s) to the query.
|
|
||||||
This method stacks.
|
|
||||||
"""
|
|
||||||
for expr in exprs:
|
|
||||||
if isinstance(expr, Expression):
|
|
||||||
self._group.append(expr)
|
|
||||||
else:
|
|
||||||
self._group.append(RawExpr(sql.Identifier(expr)))
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _group_section(self) -> Optional[Expression]:
|
|
||||||
if self._group:
|
|
||||||
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
|
|
||||||
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
|
|
||||||
return expr
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|
||||||
"""
|
|
||||||
Query type representing a table insert query.
|
|
||||||
"""
|
|
||||||
# TODO: Support ON CONFLICT for upserts
|
|
||||||
__slots__ = ('_columns', '_values', '_conflict')
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._columns: tuple[str, ...] = ()
|
|
||||||
self._values: tuple[tuple[Any, ...], ...] = ()
|
|
||||||
self._conflict: Optional[Expression] = None
|
|
||||||
|
|
||||||
def insert(self, columns, *values):
|
|
||||||
"""
|
|
||||||
Insert the given data.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
columns: tuple[str]
|
|
||||||
Tuple of column names to insert.
|
|
||||||
|
|
||||||
values: tuple[tuple[Any, ...], ...]
|
|
||||||
Tuple of values to insert, corresponding to the columns.
|
|
||||||
"""
|
|
||||||
if not values:
|
|
||||||
raise ValueError("Cannot insert zero rows.")
|
|
||||||
if len(values[0]) != len(columns):
|
|
||||||
raise ValueError("Number of columns does not match length of values.")
|
|
||||||
|
|
||||||
self._columns = columns
|
|
||||||
self._values = values
|
|
||||||
return self
|
|
||||||
|
|
||||||
def on_conflict(self, ignore=False):
|
|
||||||
# TODO lots more we can do here
|
|
||||||
# Maybe return a Conflict object that can chain itself (not the query)
|
|
||||||
if ignore:
|
|
||||||
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
|
||||||
return self
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _conflict_section(self) -> Optional[Expression]:
|
|
||||||
if self._conflict is not None:
|
|
||||||
e, v = self._conflict.as_tuple()
|
|
||||||
expr = RawExpr(
|
|
||||||
sql.SQL("ON CONFLICT {}").format(
|
|
||||||
e
|
|
||||||
),
|
|
||||||
v
|
|
||||||
)
|
|
||||||
return expr
|
|
||||||
return None
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
|
||||||
single_value_str = sql.SQL('({})').format(
|
|
||||||
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
|
|
||||||
)
|
|
||||||
values_str = sql.SQL(',').join(single_value_str * len(self._values))
|
|
||||||
|
|
||||||
# TODO: Check efficiency of inserting multiple values like this
|
|
||||||
# Also implement a Copy query
|
|
||||||
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
|
||||||
table=self.tableid,
|
|
||||||
columns=columns,
|
|
||||||
values_str=values_str
|
|
||||||
)
|
|
||||||
|
|
||||||
sections = [
|
|
||||||
RawExpr(base, tuple(chain(*self._values))),
|
|
||||||
self._conflict_section,
|
|
||||||
self._extra_section,
|
|
||||||
RawExpr(sql.SQL('RETURNING *'))
|
|
||||||
]
|
|
||||||
|
|
||||||
sections = (section for section in sections if section is not None)
|
|
||||||
return RawExpr.join(*sections)
|
|
||||||
|
|
||||||
|
|
||||||
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
|
|
||||||
"""
|
|
||||||
Select rows from a table matching provided conditions.
|
|
||||||
"""
|
|
||||||
__slots__ = ('_columns',)
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._columns: tuple[Expression, ...] = ()
|
|
||||||
|
|
||||||
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
|
|
||||||
"""
|
|
||||||
Set the columns and expressions to select.
|
|
||||||
If none are given, selects all columns.
|
|
||||||
"""
|
|
||||||
cols: List[Expression] = []
|
|
||||||
if columns:
|
|
||||||
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
|
|
||||||
if exprs:
|
|
||||||
for name, expr in exprs.items():
|
|
||||||
if isinstance(expr, str):
|
|
||||||
cols.append(
|
|
||||||
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
|
|
||||||
)
|
|
||||||
elif isinstance(expr, sql.Composable):
|
|
||||||
cols.append(
|
|
||||||
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
|
|
||||||
)
|
|
||||||
elif isinstance(expr, Expression):
|
|
||||||
value_expr, value_values = expr.as_tuple()
|
|
||||||
cols.append(RawExpr(
|
|
||||||
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
|
|
||||||
value_values
|
|
||||||
))
|
|
||||||
if cols:
|
|
||||||
self._columns = (*self._columns, *cols)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
if not self._columns:
|
|
||||||
columns, columns_values = sql.SQL('*'), ()
|
|
||||||
else:
|
|
||||||
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
|
|
||||||
|
|
||||||
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
|
||||||
columns=columns,
|
|
||||||
table=self.tableid
|
|
||||||
)
|
|
||||||
|
|
||||||
sections = [
|
|
||||||
RawExpr(base, columns_values),
|
|
||||||
self._join_section,
|
|
||||||
self._where_section,
|
|
||||||
self._group_section,
|
|
||||||
self._extra_section,
|
|
||||||
self._order_section,
|
|
||||||
self._limit_section,
|
|
||||||
]
|
|
||||||
|
|
||||||
sections = (section for section in sections if section is not None)
|
|
||||||
return RawExpr.join(*sections)
|
|
||||||
|
|
||||||
|
|
||||||
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
|
||||||
"""
|
|
||||||
Query type representing a table delete query.
|
|
||||||
"""
|
|
||||||
# TODO: Cascade option for delete, maybe other options
|
|
||||||
# TODO: Require a where unless specifically disabled, for safety
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
base = sql.SQL("DELETE FROM {table}").format(
|
|
||||||
table=self.tableid,
|
|
||||||
)
|
|
||||||
sections = [
|
|
||||||
RawExpr(base),
|
|
||||||
self._where_section,
|
|
||||||
self._extra_section,
|
|
||||||
RawExpr(sql.SQL('RETURNING *'))
|
|
||||||
]
|
|
||||||
|
|
||||||
sections = (section for section in sections if section is not None)
|
|
||||||
return RawExpr.join(*sections)
|
|
||||||
|
|
||||||
|
|
||||||
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
|
|
||||||
__slots__ = (
|
|
||||||
'_set',
|
|
||||||
)
|
|
||||||
# TODO: Again, require a where unless specifically disabled
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._set: List[Expression] = []
|
|
||||||
|
|
||||||
def set(self, **column_values: Union[Any, Expression]):
|
|
||||||
exprs: List[Expression] = []
|
|
||||||
for name, value in column_values.items():
|
|
||||||
if isinstance(value, Expression):
|
|
||||||
value_tup = value.as_tuple()
|
|
||||||
else:
|
|
||||||
value_tup = (sql.Placeholder(), (value,))
|
|
||||||
|
|
||||||
exprs.append(
|
|
||||||
RawExpr.join_tuples(
|
|
||||||
(sql.Identifier(name), ()),
|
|
||||||
value_tup,
|
|
||||||
joiner=sql.SQL(' = ')
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self._set.extend(exprs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
if not self._set:
|
|
||||||
raise ValueError("No columns provided to update.")
|
|
||||||
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
|
||||||
|
|
||||||
base = sql.SQL("UPDATE {table} SET {set}").format(
|
|
||||||
table=self.tableid,
|
|
||||||
set=set_expr
|
|
||||||
)
|
|
||||||
sections = [
|
|
||||||
RawExpr(base, set_values),
|
|
||||||
self._from_section,
|
|
||||||
self._where_section,
|
|
||||||
self._extra_section,
|
|
||||||
self._limit_section,
|
|
||||||
RawExpr(sql.SQL('RETURNING *'))
|
|
||||||
]
|
|
||||||
|
|
||||||
sections = (section for section in sections if section is not None)
|
|
||||||
return RawExpr.join(*sections)
|
|
||||||
|
|
||||||
|
|
||||||
# async def upsert(cursor, table, constraint, **values):
|
|
||||||
# """
|
|
||||||
# Insert or on conflict update.
|
|
||||||
# """
|
|
||||||
# valuedict = values
|
|
||||||
# keys, values = zip(*values.items())
|
|
||||||
#
|
|
||||||
# key_str = _format_insertkeys(keys)
|
|
||||||
# value_str, values = _format_insertvalues(values)
|
|
||||||
# update_key_str, update_key_values = _format_updatestr(valuedict)
|
|
||||||
#
|
|
||||||
# if not isinstance(constraint, str):
|
|
||||||
# constraint = ", ".join(constraint)
|
|
||||||
#
|
|
||||||
# await cursor.execute(
|
|
||||||
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
|
||||||
# table, key_str, value_str, constraint, update_key_str
|
|
||||||
# ),
|
|
||||||
# tuple((*values, *update_key_values))
|
|
||||||
# )
|
|
||||||
# return await cursor.fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
|
|
||||||
# cursor = cursor or conn.cursor()
|
|
||||||
#
|
|
||||||
# # TODO: executemany or copy syntax now
|
|
||||||
# return execute_values(
|
|
||||||
# cursor,
|
|
||||||
# """
|
|
||||||
# UPDATE {table}
|
|
||||||
# SET {set_clause}
|
|
||||||
# FROM (VALUES {cast_row}%s)
|
|
||||||
# AS {temp_table}
|
|
||||||
# WHERE {where_clause}
|
|
||||||
# RETURNING *
|
|
||||||
# """.format(
|
|
||||||
# table=table,
|
|
||||||
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
|
||||||
# cast_row=cast_row + ',' if cast_row else '',
|
|
||||||
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
|
||||||
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
|
||||||
# ),
|
|
||||||
# values,
|
|
||||||
# fetch=True
|
|
||||||
# )
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
from typing import Protocol, runtime_checkable, Optional
|
|
||||||
|
|
||||||
from psycopg import AsyncConnection
|
|
||||||
|
|
||||||
from .connector import Connector, Connectable
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class _Attachable(Connectable, Protocol):
|
|
||||||
def attach_to(self, registry: 'Registry'):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class Registry:
|
|
||||||
_attached: list[_Attachable] = []
|
|
||||||
_name: Optional[str] = None
|
|
||||||
|
|
||||||
def __init_subclass__(cls, name=None):
|
|
||||||
attached = []
|
|
||||||
for _, member in cls.__dict__.items():
|
|
||||||
if isinstance(member, _Attachable):
|
|
||||||
attached.append(member)
|
|
||||||
cls._attached = attached
|
|
||||||
cls._name = name or cls.__name__
|
|
||||||
|
|
||||||
def __init__(self, name=None):
|
|
||||||
self._conn: Optional[Connector] = None
|
|
||||||
self.name: str = name if name is not None else self._name
|
|
||||||
if self.name is None:
|
|
||||||
raise ValueError("A Registry must have a name!")
|
|
||||||
|
|
||||||
self.init_tasks = []
|
|
||||||
|
|
||||||
for member in self._attached:
|
|
||||||
member.attach_to(self)
|
|
||||||
|
|
||||||
def bind(self, connector: Connector):
|
|
||||||
self._conn = connector
|
|
||||||
for child in self._attached:
|
|
||||||
child.bind(connector)
|
|
||||||
|
|
||||||
def attach(self, attachable):
|
|
||||||
self._attached.append(attachable)
|
|
||||||
if self._conn is not None:
|
|
||||||
attachable.bind(self._conn)
|
|
||||||
return attachable
|
|
||||||
|
|
||||||
def init_task(self, coro):
|
|
||||||
"""
|
|
||||||
Initialisation tasks are run to setup the registry state.
|
|
||||||
These tasks will be run in the event loop, after connection to the database.
|
|
||||||
These tasks should be idempotent, as they may be run on reload and reconnect.
|
|
||||||
"""
|
|
||||||
self.init_tasks.append(coro)
|
|
||||||
return coro
|
|
||||||
|
|
||||||
async def init(self):
|
|
||||||
for task in self.init_tasks:
|
|
||||||
await task(self)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class AttachableClass:
|
|
||||||
"""ABC for a default implementation of an Attachable class."""
|
|
||||||
|
|
||||||
_connector: Optional[Connector] = None
|
|
||||||
_registry: Optional[Registry] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def bind(cls, connector: Connector):
|
|
||||||
cls._connector = connector
|
|
||||||
connector.connect_hook(cls.on_connect)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def attach_to(cls, registry: Registry):
|
|
||||||
cls._registry = registry
|
|
||||||
return cls
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def on_connect(cls, connection: AsyncConnection):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Attachable:
|
|
||||||
"""ABC for a default implementation of an Attachable object."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._connector: Optional[Connector] = None
|
|
||||||
self._registry: Optional[Registry] = None
|
|
||||||
|
|
||||||
def bind(self, connector: Connector):
|
|
||||||
self._connector = connector
|
|
||||||
connector.connect_hook(self.on_connect)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def attach_to(self, registry: Registry):
|
|
||||||
self._registry = registry
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def on_connect(self, connection: AsyncConnection):
|
|
||||||
pass
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from psycopg.rows import DictRow
|
|
||||||
from psycopg import sql
|
|
||||||
|
|
||||||
from . import queries as q
|
|
||||||
from .connector import Connector
|
|
||||||
from .registry import Registry
|
|
||||||
|
|
||||||
|
|
||||||
class Table:
|
|
||||||
"""
|
|
||||||
Transparent interface to a single table structure in the database.
|
|
||||||
Contains standard methods to access the table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name, *args, schema='public', **kwargs):
|
|
||||||
self.name: str = name
|
|
||||||
self.schema: str = schema
|
|
||||||
self.connector: Connector = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def identifier(self):
|
|
||||||
if self.schema == 'public':
|
|
||||||
return sql.Identifier(self.name)
|
|
||||||
else:
|
|
||||||
return sql.Identifier(self.schema, self.name)
|
|
||||||
|
|
||||||
def bind(self, connector: Connector):
|
|
||||||
self.connector = connector
|
|
||||||
return self
|
|
||||||
|
|
||||||
def attach_to(self, registry: Registry):
|
|
||||||
self._registry = registry
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
|
|
||||||
if data:
|
|
||||||
return data[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
|
||||||
return data
|
|
||||||
|
|
||||||
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
|
||||||
return q.Select(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._many_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).where(*args, **kwargs)
|
|
||||||
|
|
||||||
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
|
||||||
return q.Select(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._single_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).where(*args, **kwargs)
|
|
||||||
|
|
||||||
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
|
||||||
return q.Update(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._many_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).where(*args, **kwargs)
|
|
||||||
|
|
||||||
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
|
||||||
return q.Delete(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._many_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).where(*args, **kwargs)
|
|
||||||
|
|
||||||
def insert(self, **column_values) -> q.Insert[DictRow]:
|
|
||||||
return q.Insert(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._single_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).insert(column_values.keys(), column_values.values())
|
|
||||||
|
|
||||||
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
|
||||||
return q.Insert(
|
|
||||||
self.identifier,
|
|
||||||
row_adapter=self._many_query_adapter,
|
|
||||||
connector=self.connector
|
|
||||||
).insert(*args, **kwargs)
|
|
||||||
|
|
||||||
# def update_many(self, *args, **kwargs):
|
|
||||||
# with self.conn:
|
|
||||||
# return update_many(self.identifier, *args, **kwargs)
|
|
||||||
|
|
||||||
# def upsert(self, *args, **kwargs):
|
|
||||||
# return upsert(self.identifier, *args, **kwargs)
|
|
||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
|
from constants import SCHEMA_VERSIONS
|
||||||
import discord
|
import discord
|
||||||
from discord.utils import MISSING
|
from discord.utils import MISSING
|
||||||
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
|
||||||
@@ -13,6 +14,7 @@ from aiohttp import ClientSession
|
|||||||
from data import Database
|
from data import Database
|
||||||
from utils.lib import tabulate
|
from utils.lib import tabulate
|
||||||
from babel.translator import LeoBabel
|
from babel.translator import LeoBabel
|
||||||
|
from botdata import BotData, VersionHistory
|
||||||
|
|
||||||
from .config import Conf
|
from .config import Conf
|
||||||
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
|
||||||
@@ -43,6 +45,7 @@ class LionBot(Bot):
|
|||||||
self.appname = appname
|
self.appname = appname
|
||||||
self.shardname = shardname
|
self.shardname = shardname
|
||||||
# self.appdata = appdata
|
# self.appdata = appdata
|
||||||
|
self.data: BotData = db.load_registry(BotData())
|
||||||
self.config = config
|
self.config = config
|
||||||
self.translator = LeoBabel()
|
self.translator = LeoBabel()
|
||||||
|
|
||||||
@@ -53,6 +56,10 @@ class LionBot(Bot):
|
|||||||
self._locks = WeakValueDictionary()
|
self._locks = WeakValueDictionary()
|
||||||
self._running_events = set()
|
self._running_events = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dbconn(self):
|
||||||
|
return self.db
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def core(self):
|
def core(self):
|
||||||
return self.get_cog('CoreCog')
|
return self.get_cog('CoreCog')
|
||||||
@@ -129,6 +136,10 @@ class LionBot(Bot):
|
|||||||
await wrapper()
|
await wrapper()
|
||||||
|
|
||||||
async def start(self, token: str, *, reconnect: bool = True):
|
async def start(self, token: str, *, reconnect: bool = True):
|
||||||
|
await self.data.init()
|
||||||
|
for component, req in SCHEMA_VERSIONS.items():
|
||||||
|
await self.version_check(component, req)
|
||||||
|
|
||||||
with logging_context(action="Login"):
|
with logging_context(action="Login"):
|
||||||
start_task = asyncio.create_task(self.login(token))
|
start_task = asyncio.create_task(self.login(token))
|
||||||
await start_task
|
await start_task
|
||||||
@@ -137,6 +148,24 @@ class LionBot(Bot):
|
|||||||
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
|
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
|
||||||
await run_task
|
await run_task
|
||||||
|
|
||||||
|
async def version_check(self, component: str, req_version: int):
|
||||||
|
# Query the database to confirm that the given component is listed with the given version.
|
||||||
|
# Typically done upon loading a component
|
||||||
|
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||||
|
|
||||||
|
version = rows[0].to_version if rows else 0
|
||||||
|
|
||||||
|
if version != req_version:
|
||||||
|
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"Component %s passed version check with version %s",
|
||||||
|
component,
|
||||||
|
version
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def dispatch(self, event_name: str, *args, **kwargs):
|
def dispatch(self, event_name: str, *args, **kwargs):
|
||||||
with logging_context(action=f"Dispatch {event_name}"):
|
with logging_context(action=f"Dispatch {event_name}"):
|
||||||
super().dispatch(event_name, *args, **kwargs)
|
super().dispatch(event_name, *args, **kwargs)
|
||||||
@@ -191,7 +220,7 @@ class LionBot(Bot):
|
|||||||
# TODO: Some of these could have more user-feedback
|
# TODO: Some of these could have more user-feedback
|
||||||
logger.debug(f"Handling command error for {ctx}: {exception}")
|
logger.debug(f"Handling command error for {ctx}: {exception}")
|
||||||
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
|
||||||
cmd_str = ctx.command.app_command.to_dict()
|
cmd_str = ctx.command.app_command.to_dict(self.tree)
|
||||||
else:
|
else:
|
||||||
cmd_str = str(ctx.command)
|
cmd_str = str(ctx.command)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ class LionTree(CommandTree):
|
|||||||
return
|
return
|
||||||
|
|
||||||
set_logging_context(action=f"Run {command.qualified_name}")
|
set_logging_context(action=f"Run {command.qualified_name}")
|
||||||
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
|
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
|
||||||
try:
|
try:
|
||||||
await command._invoke_with_namespace(interaction, namespace)
|
await command._invoke_with_namespace(interaction, namespace)
|
||||||
except AppCommandError as e:
|
except AppCommandError as e:
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
this_package = 'modules'
|
this_package = 'modules'
|
||||||
|
|
||||||
active = [
|
active = [
|
||||||
'.sysadmin',
|
|
||||||
'.voicefix',
|
|
||||||
'.streamalerts',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
import logging
|
|
||||||
from meta import LionBot
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def setup(bot: LionBot):
|
|
||||||
from .cog import AlertCog
|
|
||||||
await bot.add_cog(AlertCog(bot))
|
|
||||||
@@ -1,609 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.ext import commands as cmds
|
|
||||||
from discord import app_commands as appcmds
|
|
||||||
|
|
||||||
from twitchAPI.twitch import Twitch
|
|
||||||
from twitchAPI.helper import first
|
|
||||||
|
|
||||||
from meta import LionBot, LionCog, LionContext
|
|
||||||
from meta.errors import UserInputError
|
|
||||||
from meta.logger import log_wrap
|
|
||||||
from utils.lib import utc_now
|
|
||||||
from data.conditions import NULL
|
|
||||||
|
|
||||||
from . import logger
|
|
||||||
from .data import AlertsData
|
|
||||||
from .settings import AlertConfig, AlertSettings
|
|
||||||
from .editor import AlertEditorUI
|
|
||||||
|
|
||||||
|
|
||||||
class AlertCog(LionCog):
|
|
||||||
POLL_PERIOD = 60
|
|
||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
|
||||||
self.bot = bot
|
|
||||||
self.data = bot.db.load_registry(AlertsData())
|
|
||||||
self.twitch = None
|
|
||||||
self.alert_settings = AlertSettings()
|
|
||||||
|
|
||||||
self.poll_task = None
|
|
||||||
self.event_tasks = set()
|
|
||||||
|
|
||||||
# Cache of currently live streams, maps streamerid -> stream
|
|
||||||
self.live_streams = {}
|
|
||||||
|
|
||||||
# Cache of streamers we are watching state changes for
|
|
||||||
# Map of streamerid -> streamer
|
|
||||||
self.watching = {}
|
|
||||||
|
|
||||||
async def cog_load(self):
|
|
||||||
await self.data.init()
|
|
||||||
|
|
||||||
await self.twitch_login()
|
|
||||||
await self.load_subs()
|
|
||||||
self.poll_task = asyncio.create_task(self.poll_live())
|
|
||||||
|
|
||||||
async def twitch_login(self):
|
|
||||||
# TODO: Probably abstract this out to core or a dedicated core cog
|
|
||||||
# Also handle refresh tokens
|
|
||||||
if self.twitch is not None:
|
|
||||||
await self.twitch.close()
|
|
||||||
self.twitch = None
|
|
||||||
|
|
||||||
self.twitch = await Twitch(
|
|
||||||
self.bot.config.twitch['app_id'].strip(),
|
|
||||||
self.bot.config.twitch['app_secret'].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_subs(self):
|
|
||||||
# Load active subscriptions
|
|
||||||
active_subs = await self.data.AlertChannel.fetch_where()
|
|
||||||
to_watch = {sub.streamerid for sub in active_subs}
|
|
||||||
live_streams = await self.data.Stream.fetch_where(
|
|
||||||
self.data.Stream.end_at != NULL
|
|
||||||
)
|
|
||||||
to_watch.union(stream.streamerid for stream in live_streams)
|
|
||||||
|
|
||||||
# Load associated streamers
|
|
||||||
watching = {}
|
|
||||||
if to_watch:
|
|
||||||
streamers = await self.data.Streamer.fetch_where(
|
|
||||||
userid=list(to_watch)
|
|
||||||
)
|
|
||||||
for streamer in streamers:
|
|
||||||
watching[streamer.userid] = streamer
|
|
||||||
|
|
||||||
self.watching = watching
|
|
||||||
self.live_streams = {stream.streamerid: stream for stream in live_streams}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Watching {len(watching)} streamers for state changes. "
|
|
||||||
f"Loaded {len(live_streams)} (previously) live streams into cache."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def poll_live(self):
|
|
||||||
# Every PERIOD seconds,
|
|
||||||
# request get_streams for the streamers we are currently watching.
|
|
||||||
# Check if they are in the live_stream cache,
|
|
||||||
# and update cache and data and fire-and-forget start/stop events as required.
|
|
||||||
# TODO: Logging
|
|
||||||
# TODO: Error handling so the poll loop doesn't die from temporary errors
|
|
||||||
# And when it does die it gets logged properly.
|
|
||||||
if not self.twitch:
|
|
||||||
raise ValueError("Attempting to start alert poll-loop before twitch set.")
|
|
||||||
|
|
||||||
block_i = 0
|
|
||||||
|
|
||||||
self.polling = True
|
|
||||||
while self.polling:
|
|
||||||
await asyncio.sleep(self.POLL_PERIOD)
|
|
||||||
|
|
||||||
to_request = list(self.watching.keys())
|
|
||||||
if not to_request:
|
|
||||||
continue
|
|
||||||
# Each loop we request the 'next' slice of 100 userids
|
|
||||||
blocks = [to_request[i:i+100] for i in range(0, len(to_request), 100)]
|
|
||||||
block_i += 1
|
|
||||||
block_i %= len(blocks)
|
|
||||||
block = blocks[block_i]
|
|
||||||
|
|
||||||
streaming = {}
|
|
||||||
async for stream in self.twitch.get_streams(user_id=block, first=100):
|
|
||||||
# Note we set page size to 100
|
|
||||||
# So we should never get repeat or missed streams
|
|
||||||
# Since we can request a max of 100 userids anyway.
|
|
||||||
streaming[stream.user_id] = stream
|
|
||||||
|
|
||||||
started = set(streaming.keys()).difference(self.live_streams.keys())
|
|
||||||
ended = set(self.live_streams.keys()).difference(streaming.keys())
|
|
||||||
|
|
||||||
for streamerid in started:
|
|
||||||
stream = streaming[streamerid]
|
|
||||||
stream_data = await self.data.Stream.create(
|
|
||||||
streamerid=stream.user_id,
|
|
||||||
start_at=stream.started_at,
|
|
||||||
twitch_stream_id=stream.id,
|
|
||||||
game_name=stream.game_name,
|
|
||||||
title=stream.title,
|
|
||||||
)
|
|
||||||
self.live_streams[streamerid] = stream_data
|
|
||||||
task = asyncio.create_task(self.on_stream_start(stream_data))
|
|
||||||
self.event_tasks.add(task)
|
|
||||||
task.add_done_callback(self.event_tasks.discard)
|
|
||||||
|
|
||||||
for streamerid in ended:
|
|
||||||
stream_data = self.live_streams.pop(streamerid)
|
|
||||||
await stream_data.update(end_at=utc_now())
|
|
||||||
task = asyncio.create_task(self.on_stream_end(stream_data))
|
|
||||||
self.event_tasks.add(task)
|
|
||||||
task.add_done_callback(self.event_tasks.discard)
|
|
||||||
|
|
||||||
async def on_stream_start(self, stream_data):
|
|
||||||
# Get channel subscriptions listening for this streamer
|
|
||||||
uid = stream_data.streamerid
|
|
||||||
logger.info(f"Streamer <uid:{uid}> started streaming! {stream_data=}")
|
|
||||||
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
|
|
||||||
|
|
||||||
# Fulfill those alerts
|
|
||||||
for sub in subbed:
|
|
||||||
try:
|
|
||||||
# If the sub is paused, don't create the alert
|
|
||||||
await self.sub_alert(sub, stream_data)
|
|
||||||
except discord.HTTPException:
|
|
||||||
# TODO: Needs to be handled more gracefully at user level
|
|
||||||
# Retry logic?
|
|
||||||
logger.warning(
|
|
||||||
f"Could not complete subscription {sub=} for {stream_data=}", exc_info=True
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"Unexpected exception completing {sub=} for {stream_data=}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def subscription_error(self, subscription, stream_data, err_msg):
|
|
||||||
"""
|
|
||||||
Handle a subscription fulfill failure.
|
|
||||||
Stores the error message for user display,
|
|
||||||
and deletes the subscription after some number of errors.
|
|
||||||
# TODO
|
|
||||||
"""
|
|
||||||
logger.warning(
|
|
||||||
f"Subscription error {subscription=} {stream_data=} {err_msg=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def sub_alert(self, subscription, stream_data):
|
|
||||||
# Base alert behaviour is just to send a message
|
|
||||||
# and create an alert row
|
|
||||||
|
|
||||||
channel = self.bot.get_channel(subscription.channelid)
|
|
||||||
if channel is None or not isinstance(channel, discord.abc.Messageable):
|
|
||||||
# Subscription channel is gone!
|
|
||||||
# Or the Discord channel cache died
|
|
||||||
await self.subscription_error(
|
|
||||||
subscription, stream_data,
|
|
||||||
"Subscription channel no longer exists."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
permissions = channel.permissions_for(channel.guild.me)
|
|
||||||
if not (permissions.send_messages and permissions.embed_links):
|
|
||||||
await self.subscription_error(
|
|
||||||
subscription, stream_data,
|
|
||||||
"Insufficient permissions to post alert message."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build message
|
|
||||||
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
|
|
||||||
if not streamer:
|
|
||||||
# Streamer was deleted while handling the alert
|
|
||||||
# Just quietly ignore
|
|
||||||
# Don't error out because the stream data row won't exist anymore
|
|
||||||
logger.warning(
|
|
||||||
f"Cancelling alert for subscription {subscription.subscriptionid}"
|
|
||||||
" because the streamer no longer exists."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
alert_config = AlertConfig(subscription.subscriptionid, subscription)
|
|
||||||
paused = alert_config.get(self.alert_settings.AlertPaused.setting_id)
|
|
||||||
if paused.value:
|
|
||||||
logger.info(f"Skipping alert for subscription {subscription=} because it is paused.")
|
|
||||||
return
|
|
||||||
|
|
||||||
live_message = alert_config.get(self.alert_settings.AlertMessage.setting_id)
|
|
||||||
|
|
||||||
formatter = await live_message.generate_formatter(self.bot, stream_data, streamer)
|
|
||||||
formatted = await formatter(live_message.value)
|
|
||||||
args = live_message.value_to_args(subscription.subscriptionid, formatted)
|
|
||||||
|
|
||||||
try:
|
|
||||||
message = await channel.send(**args.send_args)
|
|
||||||
except discord.HTTPException as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Message send failure while sending streamalert {subscription.subscriptionid}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
await self.subscription_error(
|
|
||||||
subscription, stream_data,
|
|
||||||
"Failed to post live alert."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Store sent alert
|
|
||||||
alert = await self.data.StreamAlert.create(
|
|
||||||
streamid=stream_data.streamid,
|
|
||||||
subscriptionid=subscription.subscriptionid,
|
|
||||||
sent_at=utc_now(),
|
|
||||||
messageid=message.id
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Fulfilled subscription {subscription.subscriptionid} with alert {alert.alertid}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_stream_end(self, stream_data):
|
|
||||||
# Get channel subscriptions listening for this streamer
|
|
||||||
uid = stream_data.streamerid
|
|
||||||
logger.info(f"Streamer <uid:{uid}> stopped streaming! {stream_data=}")
|
|
||||||
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
|
|
||||||
|
|
||||||
# Resolve subscriptions
|
|
||||||
for sub in subbed:
|
|
||||||
try:
|
|
||||||
await self.sub_resolve(sub, stream_data)
|
|
||||||
except discord.HTTPException:
|
|
||||||
# TODO: Needs to be handled more gracefully at user level
|
|
||||||
# Retry logic?
|
|
||||||
logger.warning(
|
|
||||||
f"Could not resolve subscription {sub=} for {stream_data=}", exc_info=True
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"Unexpected exception resolving {sub=} for {stream_data=}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def sub_resolve(self, subscription, stream_data):
|
|
||||||
# Check if there is a current active alert to resolve
|
|
||||||
alerts = await self.data.StreamAlert.fetch_where(
|
|
||||||
streamid=stream_data.streamid,
|
|
||||||
subscriptionid=subscription.subscriptionid,
|
|
||||||
)
|
|
||||||
if not alerts:
|
|
||||||
logger.info(
|
|
||||||
f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} "
|
|
||||||
"but no active alerts were found."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
alert = alerts[0]
|
|
||||||
if alert.resolved_at is not None:
|
|
||||||
# Alert was already resolved
|
|
||||||
# This is okay, Twitch might have just sent the stream ending twice
|
|
||||||
logger.info(
|
|
||||||
f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} "
|
|
||||||
"but alert was already resolved."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if message is to be deleted or edited (or nothing)
|
|
||||||
alert_config = AlertConfig(subscription.subscriptionid, subscription)
|
|
||||||
del_setting = alert_config.get(self.alert_settings.AlertEndDelete.setting_id)
|
|
||||||
edit_setting = alert_config.get(self.alert_settings.AlertEndMessage.setting_id)
|
|
||||||
|
|
||||||
if (delmsg := del_setting.value) or (edit_setting.value):
|
|
||||||
# Find the message
|
|
||||||
message = None
|
|
||||||
channel = self.bot.get_channel(subscription.channelid)
|
|
||||||
if channel:
|
|
||||||
try:
|
|
||||||
message = await channel.fetch_message(alert.messageid)
|
|
||||||
except discord.HTTPException:
|
|
||||||
# Message was probably deleted already
|
|
||||||
# Or permissions were changed
|
|
||||||
# Or Discord connection broke
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Channel went after posting the alert
|
|
||||||
# Or Discord cache sucks
|
|
||||||
# Nothing we can do, just mark it handled
|
|
||||||
pass
|
|
||||||
if message:
|
|
||||||
if delmsg:
|
|
||||||
# Delete the message
|
|
||||||
try:
|
|
||||||
await message.delete()
|
|
||||||
except discord.HTTPException:
|
|
||||||
logger.warning(
|
|
||||||
f"Discord exception while del-resolve live alert {alert=}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Edit message with custom arguments
|
|
||||||
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
|
|
||||||
formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer)
|
|
||||||
formatted = await formatter(edit_setting.value)
|
|
||||||
args = edit_setting.value_to_args(subscription.subscriptionid, formatted)
|
|
||||||
try:
|
|
||||||
await message.edit(**args.edit_args)
|
|
||||||
except discord.HTTPException:
|
|
||||||
logger.warning(
|
|
||||||
f"Discord exception while edit-resolve live alert {alert=}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Explicitly don't need to do anything to the alert
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Save alert as resolved
|
|
||||||
await alert.update(resolved_at=utc_now())
|
|
||||||
|
|
||||||
async def cog_unload(self):
|
|
||||||
if self.poll_task is not None and not self.poll_task.cancelled():
|
|
||||||
self.poll_task.cancel()
|
|
||||||
|
|
||||||
if self.twitch is not None:
|
|
||||||
await self.twitch.close()
|
|
||||||
self.twitch = None
|
|
||||||
|
|
||||||
# ----- Commands -----
|
|
||||||
@cmds.hybrid_group(
|
|
||||||
name='streamalert',
|
|
||||||
description=(
|
|
||||||
"Create and configure stream live-alerts."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@cmds.guild_only()
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_group(self, ctx: LionContext):
|
|
||||||
# Placeholder group, method not used
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@streamalert_group.command(
|
|
||||||
name='create',
|
|
||||||
description=(
|
|
||||||
"Subscribe a Discord channel to notifications when a Twitch stream goes live."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
streamer="Name of the twitch channel to watch.",
|
|
||||||
channel="Which Discord channel to send live alerts in.",
|
|
||||||
message="Custom message to send when the channel goes live (may be edited later)."
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_create_cmd(self, ctx: LionContext,
|
|
||||||
streamer: str,
|
|
||||||
channel: discord.TextChannel,
|
|
||||||
message: Optional[str]):
|
|
||||||
# Type guards
|
|
||||||
assert ctx.guild is not None, "Guild-only command has no guild ctx."
|
|
||||||
assert self.twitch is not None, "Twitch command run with no twitch obj."
|
|
||||||
|
|
||||||
# Wards
|
|
||||||
if not channel.permissions_for(ctx.author).manage_channels:
|
|
||||||
await ctx.error_reply(
|
|
||||||
"Sorry, you need the `MANAGE_CHANNELS` permission "
|
|
||||||
"to add a stream alert to a channel."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Look up the specified streamer
|
|
||||||
tw_user = await first(self.twitch.get_users(logins=[streamer]))
|
|
||||||
if not tw_user:
|
|
||||||
await ctx.error_reply(
|
|
||||||
f"Sorry, could not find `{streamer}` on Twitch! "
|
|
||||||
"Make sure you use the name in their channel url."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create streamer data if it doesn't already exist
|
|
||||||
streamer_data = await self.data.Streamer.fetch_or_create(
|
|
||||||
tw_user.id,
|
|
||||||
login_name=tw_user.login,
|
|
||||||
display_name=tw_user.display_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add subscription to alerts list
|
|
||||||
sub_data = await self.data.AlertChannel.create(
|
|
||||||
streamerid=streamer_data.userid,
|
|
||||||
guildid=channel.guild.id,
|
|
||||||
channelid=channel.id,
|
|
||||||
created_by=ctx.author.id,
|
|
||||||
paused=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to watchlist
|
|
||||||
self.watching[streamer_data.userid] = streamer_data
|
|
||||||
|
|
||||||
# Open AlertEditorUI for the new subscription
|
|
||||||
# TODO
|
|
||||||
await ctx.reply("StreamAlert Created.")
|
|
||||||
|
|
||||||
async def alert_acmpl(self, interaction: discord.Interaction, partial: str):
|
|
||||||
if not interaction.guild:
|
|
||||||
raise ValueError("Cannot acmpl alert in guildless interaction.")
|
|
||||||
|
|
||||||
# Get all alerts in the server
|
|
||||||
alerts = await self.data.AlertChannel.fetch_where(guildid=interaction.guild_id)
|
|
||||||
|
|
||||||
if not alerts:
|
|
||||||
# No alerts available
|
|
||||||
options = [
|
|
||||||
appcmds.Choice(
|
|
||||||
name="No stream alerts are set up in this server!",
|
|
||||||
value=partial
|
|
||||||
)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
options = []
|
|
||||||
for alert in alerts:
|
|
||||||
streamer = await self.data.Streamer.fetch(alert.streamerid)
|
|
||||||
if streamer is None:
|
|
||||||
# Should be impossible by foreign key condition
|
|
||||||
# Might be a stale cache
|
|
||||||
continue
|
|
||||||
channel = interaction.guild.get_channel(alert.channelid)
|
|
||||||
display = f"{streamer.display_name} in #{channel.name if channel else 'unknown'}"
|
|
||||||
if partial.lower() in display.lower():
|
|
||||||
# Matching option
|
|
||||||
options.append(appcmds.Choice(name=display, value=str(alert.subscriptionid)))
|
|
||||||
if not options:
|
|
||||||
options.append(
|
|
||||||
appcmds.Choice(
|
|
||||||
name=f"No stream alerts matching {partial}"[:25],
|
|
||||||
value=partial
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return options
|
|
||||||
|
|
||||||
async def resolve_alert(self, interaction: discord.Interaction, alert_str: str):
|
|
||||||
if not interaction.guild:
|
|
||||||
raise ValueError("Resolving alert outside of a guild.")
|
|
||||||
# Expect alert_str to be the integer subscriptionid
|
|
||||||
if not alert_str.isdigit():
|
|
||||||
raise UserInputError(
|
|
||||||
f"No stream alerts in this server matching `{alert_str}`!"
|
|
||||||
)
|
|
||||||
alert = await self.data.AlertChannel.fetch(int(alert_str))
|
|
||||||
if not alert or not alert.guildid == interaction.guild_id:
|
|
||||||
raise UserInputError(
|
|
||||||
"Could not find the selected alert! Please try again."
|
|
||||||
)
|
|
||||||
return alert
|
|
||||||
|
|
||||||
@streamalert_group.command(
|
|
||||||
name='edit',
|
|
||||||
description=(
|
|
||||||
"Update settings for an existing Twitch stream alert."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
alert="Which alert do you want to edit?",
|
|
||||||
# TODO: Other settings here
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_edit_cmd(self, ctx: LionContext, alert: str):
|
|
||||||
# Type guards
|
|
||||||
assert ctx.guild is not None, "Guild-only command has no guild ctx."
|
|
||||||
assert self.twitch is not None, "Twitch command run with no twitch obj."
|
|
||||||
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
|
|
||||||
|
|
||||||
# Look up provided alert
|
|
||||||
sub_data = await self.resolve_alert(ctx.interaction, alert)
|
|
||||||
|
|
||||||
# Check user permissions for editing this alert
|
|
||||||
channel = ctx.guild.get_channel(sub_data.channelid)
|
|
||||||
permlevel = channel if channel else ctx.guild
|
|
||||||
if not permlevel.permissions_for(ctx.author).manage_channels:
|
|
||||||
await ctx.error_reply(
|
|
||||||
"Sorry, you need the `MANAGE_CHANNELS` permission "
|
|
||||||
"in this channel to edit the stream alert."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# If edit options have been given, save edits and retouch cache if needed
|
|
||||||
# If not, open AlertEditorUI
|
|
||||||
ui = AlertEditorUI(bot=self.bot, sub_data=sub_data, callerid=ctx.author.id)
|
|
||||||
await ui.run(ctx.interaction)
|
|
||||||
await ui.wait()
|
|
||||||
|
|
||||||
@streamalert_edit_cmd.autocomplete('alert')
|
|
||||||
async def streamalert_edit_cmd_alert_acmpl(self, interaction, partial):
|
|
||||||
return await self.alert_acmpl(interaction, partial)
|
|
||||||
|
|
||||||
@streamalert_group.command(
|
|
||||||
name='pause',
|
|
||||||
description=(
|
|
||||||
"Pause a streamalert."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
alert="Which alert do you want to pause?",
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_pause_cmd(self, ctx: LionContext, alert: str):
|
|
||||||
# Type guards
|
|
||||||
assert ctx.guild is not None, "Guild-only command has no guild ctx."
|
|
||||||
assert self.twitch is not None, "Twitch command run with no twitch obj."
|
|
||||||
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
|
|
||||||
|
|
||||||
# Look up provided alert
|
|
||||||
sub_data = await self.resolve_alert(ctx.interaction, alert)
|
|
||||||
|
|
||||||
# Check user permissions for editing this alert
|
|
||||||
channel = ctx.guild.get_channel(sub_data.channelid)
|
|
||||||
permlevel = channel if channel else ctx.guild
|
|
||||||
if not permlevel.permissions_for(ctx.author).manage_channels:
|
|
||||||
await ctx.error_reply(
|
|
||||||
"Sorry, you need the `MANAGE_CHANNELS` permission "
|
|
||||||
"in this channel to edit the stream alert."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await sub_data.update(paused=True)
|
|
||||||
await ctx.reply("This alert is now paused!")
|
|
||||||
|
|
||||||
@streamalert_group.command(
|
|
||||||
name='unpause',
|
|
||||||
description=(
|
|
||||||
"Resume a streamalert."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
alert="Which alert do you want to unpause?",
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_unpause_cmd(self, ctx: LionContext, alert: str):
|
|
||||||
# Type guards
|
|
||||||
assert ctx.guild is not None, "Guild-only command has no guild ctx."
|
|
||||||
assert self.twitch is not None, "Twitch command run with no twitch obj."
|
|
||||||
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
|
|
||||||
|
|
||||||
# Look up provided alert
|
|
||||||
sub_data = await self.resolve_alert(ctx.interaction, alert)
|
|
||||||
|
|
||||||
# Check user permissions for editing this alert
|
|
||||||
channel = ctx.guild.get_channel(sub_data.channelid)
|
|
||||||
permlevel = channel if channel else ctx.guild
|
|
||||||
if not permlevel.permissions_for(ctx.author).manage_channels:
|
|
||||||
await ctx.error_reply(
|
|
||||||
"Sorry, you need the `MANAGE_CHANNELS` permission "
|
|
||||||
"in this channel to edit the stream alert."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await sub_data.update(paused=False)
|
|
||||||
await ctx.reply("This alert has been unpaused!")
|
|
||||||
|
|
||||||
@streamalert_group.command(
|
|
||||||
name='remove',
|
|
||||||
description=(
|
|
||||||
"Deactivate a streamalert entirely (see /streamalert pause to temporarily pause it)."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
alert="Which alert do you want to remove?",
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def streamalert_remove_cmd(self, ctx: LionContext, alert: str):
|
|
||||||
# Type guards
|
|
||||||
assert ctx.guild is not None, "Guild-only command has no guild ctx."
|
|
||||||
assert self.twitch is not None, "Twitch command run with no twitch obj."
|
|
||||||
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
|
|
||||||
|
|
||||||
# Look up provided alert
|
|
||||||
sub_data = await self.resolve_alert(ctx.interaction, alert)
|
|
||||||
|
|
||||||
# Check user permissions for editing this alert
|
|
||||||
channel = ctx.guild.get_channel(sub_data.channelid)
|
|
||||||
permlevel = channel if channel else ctx.guild
|
|
||||||
if not permlevel.permissions_for(ctx.author).manage_channels:
|
|
||||||
await ctx.error_reply(
|
|
||||||
"Sorry, you need the `MANAGE_CHANNELS` permission "
|
|
||||||
"in this channel to edit the stream alert."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
await sub_data.delete()
|
|
||||||
await ctx.reply("This alert has been deleted.")
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
from data import Registry, RowModel
|
|
||||||
from data.columns import Integer, Bool, Timestamp, String
|
|
||||||
from data.models import WeakCache
|
|
||||||
from cachetools import TTLCache
|
|
||||||
|
|
||||||
|
|
||||||
class AlertsData(Registry):
|
|
||||||
class Streamer(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE streamers(
|
|
||||||
userid BIGINT PRIMARY KEY,
|
|
||||||
login_name TEXT NOT NULL,
|
|
||||||
display_name TEXT NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'streamers'
|
|
||||||
_cache_ = {}
|
|
||||||
|
|
||||||
userid = Integer(primary=True)
|
|
||||||
login_name = String()
|
|
||||||
display_name = String()
|
|
||||||
|
|
||||||
class AlertChannel(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE alert_channels(
|
|
||||||
subscriptionid SERIAL PRIMARY KEY,
|
|
||||||
guildid BIGINT NOT NULL,
|
|
||||||
channelid BIGINT NOT NULL,
|
|
||||||
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
|
|
||||||
created_by BIGINT NOT NULL,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
paused BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
end_delete BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
live_message TEXT,
|
|
||||||
end_message TEXT
|
|
||||||
);
|
|
||||||
CREATE INDEX alert_channels_guilds ON alert_channels (guildid);
|
|
||||||
CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'alert_channels'
|
|
||||||
_cache_ = {}
|
|
||||||
|
|
||||||
subscriptionid = Integer(primary=True)
|
|
||||||
guildid = Integer()
|
|
||||||
channelid = Integer()
|
|
||||||
streamerid = Integer()
|
|
||||||
display_name = Integer()
|
|
||||||
created_by = Integer()
|
|
||||||
created_at = Timestamp()
|
|
||||||
paused = Bool()
|
|
||||||
end_delete = Bool()
|
|
||||||
live_message = String()
|
|
||||||
end_message = String()
|
|
||||||
|
|
||||||
class Stream(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE streams(
|
|
||||||
streamid SERIAL PRIMARY KEY,
|
|
||||||
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
|
|
||||||
start_at TIMESTAMPTZ NOT NULL,
|
|
||||||
twitch_stream_id BIGINT,
|
|
||||||
game_name TEXT,
|
|
||||||
title TEXT,
|
|
||||||
end_at TIMESTAMPTZ
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'streams'
|
|
||||||
_cache_ = WeakCache(TTLCache(maxsize=100, ttl=24*60*60))
|
|
||||||
|
|
||||||
streamid = Integer(primary=True)
|
|
||||||
streamerid = Integer()
|
|
||||||
start_at = Timestamp()
|
|
||||||
twitch_stream_id = Integer()
|
|
||||||
game_name = String()
|
|
||||||
title = String()
|
|
||||||
end_at = Timestamp()
|
|
||||||
|
|
||||||
class StreamAlert(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE stream_alerts(
|
|
||||||
alertid SERIAL PRIMARY KEY,
|
|
||||||
streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE,
|
|
||||||
subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE,
|
|
||||||
sent_at TIMESTAMPTZ NOT NULL,
|
|
||||||
messageid BIGINT NOT NULL,
|
|
||||||
resolved_at TIMESTAMPTZ
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'stream_alerts'
|
|
||||||
_cache_ = WeakCache(TTLCache(maxsize=1000, ttl=24*60*60))
|
|
||||||
|
|
||||||
alertid = Integer(primary=True)
|
|
||||||
streamid = Integer()
|
|
||||||
subscriptionid = Integer()
|
|
||||||
sent_at = Timestamp()
|
|
||||||
messageid = Integer()
|
|
||||||
resolved_at = Timestamp()
|
|
||||||
@@ -1,369 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import datetime as dt
|
|
||||||
from collections import namedtuple
|
|
||||||
from functools import wraps
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.ui.button import button, Button, ButtonStyle
|
|
||||||
from discord.ui.select import select, Select, SelectOption, ChannelSelect
|
|
||||||
|
|
||||||
from meta import LionBot, conf
|
|
||||||
|
|
||||||
from utils.lib import MessageArgs, tabulate, utc_now
|
|
||||||
from utils.ui import MessageUI
|
|
||||||
from utils.ui.msgeditor import MsgEditor
|
|
||||||
|
|
||||||
from .settings import AlertSettings as Settings
|
|
||||||
from .settings import AlertConfig as Config
|
|
||||||
from .data import AlertsData
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .cog import AlertCog
|
|
||||||
|
|
||||||
|
|
||||||
FakeStream = namedtuple(
|
|
||||||
'FakeStream',
|
|
||||||
["streamid", "streamerid", "start_at", "twitch_stream_id", "game_name", "title", "end_at"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AlertEditorUI(MessageUI):
|
|
||||||
setting_classes = (
|
|
||||||
Settings.AlertPaused,
|
|
||||||
Settings.AlertEndDelete,
|
|
||||||
Settings.AlertEndMessage,
|
|
||||||
Settings.AlertMessage,
|
|
||||||
Settings.AlertChannel,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, bot: LionBot, sub_data: AlertsData.AlertChannel, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.bot = bot
|
|
||||||
self.sub_data = sub_data
|
|
||||||
self.subid = sub_data.subscriptionid
|
|
||||||
self.cog: 'AlertCog' = bot.get_cog('AlertCog')
|
|
||||||
self.config = Config(self.subid, sub_data)
|
|
||||||
|
|
||||||
# ----- UI API -----
|
|
||||||
def preview_stream_data(self):
|
|
||||||
# TODO: Probably makes sense to factor this out to the cog
|
|
||||||
# Or even generate it in the formatters themselves
|
|
||||||
data = self.sub_data
|
|
||||||
return FakeStream(
|
|
||||||
-1,
|
|
||||||
data.streamerid,
|
|
||||||
utc_now() - dt.timedelta(hours=1),
|
|
||||||
-1,
|
|
||||||
"Discord Admin",
|
|
||||||
"Testing Go Live Message",
|
|
||||||
utc_now()
|
|
||||||
)
|
|
||||||
|
|
||||||
def call_and_refresh(self, func):
|
|
||||||
"""
|
|
||||||
Generate a wrapper which runs coroutine 'func' and then refreshes the UI.
|
|
||||||
"""
|
|
||||||
# TODO: Check whether the UI has finished interaction
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapped(*args, **kwargs):
|
|
||||||
await func(*args, **kwargs)
|
|
||||||
await self.refresh()
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
# ----- UI Components -----
|
|
||||||
|
|
||||||
# Pause button
|
|
||||||
@button(label="PAUSE_PLACEHOLDER", style=ButtonStyle.blurple)
|
|
||||||
async def pause_button(self, press: discord.Interaction, pressed: Button):
|
|
||||||
await press.response.defer(thinking=True, ephemeral=True)
|
|
||||||
setting = self.config.get(Settings.AlertPaused.setting_id)
|
|
||||||
setting.value = not setting.value
|
|
||||||
await setting.write()
|
|
||||||
await self.refresh(thinking=press)
|
|
||||||
|
|
||||||
async def pause_button_refresh(self):
|
|
||||||
button = self.pause_button
|
|
||||||
if self.config.get(Settings.AlertPaused.setting_id).value:
|
|
||||||
button.label = "UnPause"
|
|
||||||
button.style = ButtonStyle.grey
|
|
||||||
else:
|
|
||||||
button.label = "Pause"
|
|
||||||
button.style = ButtonStyle.green
|
|
||||||
|
|
||||||
# Delete button
|
|
||||||
@button(label="Delete Alert", style=ButtonStyle.red)
|
|
||||||
async def delete_button(self, press: discord.Interaction, pressed: Button):
|
|
||||||
await press.response.defer(thinking=True, ephemeral=True)
|
|
||||||
await self.sub_data.delete()
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
description="Stream alert removed."
|
|
||||||
)
|
|
||||||
await press.edit_original_response(embed=embed)
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
# Close button
|
|
||||||
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
|
|
||||||
async def close_button(self, press: discord.Interaction, pressed: Button):
|
|
||||||
await press.response.defer(thinking=False)
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
# Edit Alert button
|
|
||||||
@button(label="Edit Alert", style=ButtonStyle.blurple)
|
|
||||||
async def edit_alert_button(self, press: discord.Interaction, pressed: Button):
|
|
||||||
# Spawn MsgEditor for the live alert
|
|
||||||
await press.response.defer(thinking=True, ephemeral=True)
|
|
||||||
|
|
||||||
setting = self.config.get(Settings.AlertMessage.setting_id)
|
|
||||||
|
|
||||||
stream = self.preview_stream_data()
|
|
||||||
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
|
|
||||||
|
|
||||||
editor = MsgEditor(
|
|
||||||
self.bot,
|
|
||||||
setting.value,
|
|
||||||
callback=self.call_and_refresh(setting.editor_callback),
|
|
||||||
formatter=await setting.generate_formatter(self.bot, stream, streamer),
|
|
||||||
callerid=press.user.id
|
|
||||||
)
|
|
||||||
self._slaves.append(editor)
|
|
||||||
await editor.run(press)
|
|
||||||
|
|
||||||
# Edit End message
|
|
||||||
@button(label="Edit Ending Alert", style=ButtonStyle.blurple)
|
|
||||||
async def edit_end_button(self, press: discord.Interaction, pressed: Button):
|
|
||||||
# Spawn MsgEditor for the ending alert
|
|
||||||
await press.response.defer(thinking=True, ephemeral=True)
|
|
||||||
await self.open_end_editor(press)
|
|
||||||
|
|
||||||
async def open_end_editor(self, respond_to: discord.Interaction):
|
|
||||||
setting = self.config.get(Settings.AlertEndMessage.setting_id)
|
|
||||||
# Start from current live alert data if not set
|
|
||||||
if not setting.value:
|
|
||||||
alert_setting = self.config.get(Settings.AlertMessage.setting_id)
|
|
||||||
setting.value = alert_setting.value
|
|
||||||
|
|
||||||
stream = self.preview_stream_data()
|
|
||||||
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
|
|
||||||
|
|
||||||
editor = MsgEditor(
|
|
||||||
self.bot,
|
|
||||||
setting.value,
|
|
||||||
callback=self.call_and_refresh(setting.editor_callback),
|
|
||||||
formatter=await setting.generate_formatter(self.bot, stream, streamer),
|
|
||||||
callerid=respond_to.user.id
|
|
||||||
)
|
|
||||||
self._slaves.append(editor)
|
|
||||||
await editor.run(respond_to)
|
|
||||||
return editor
|
|
||||||
|
|
||||||
# Ending Mode Menu
|
|
||||||
@select(
|
|
||||||
cls=Select,
|
|
||||||
placeholder="Select action to take when the stream ends",
|
|
||||||
options=[SelectOption(label="DUMMY")],
|
|
||||||
min_values=0, max_values=1
|
|
||||||
)
|
|
||||||
async def ending_mode_menu(self, selection: discord.Interaction, selected: Select):
|
|
||||||
if not selected.values:
|
|
||||||
await selection.response.defer()
|
|
||||||
return
|
|
||||||
|
|
||||||
await selection.response.defer(thinking=True, ephemeral=True)
|
|
||||||
value = selected.values[0]
|
|
||||||
|
|
||||||
if value == '0':
|
|
||||||
# In Do Nothing case,
|
|
||||||
# Ensure Delete is off and custom edit message is unset
|
|
||||||
setting = self.config.get(Settings.AlertEndDelete.setting_id)
|
|
||||||
if setting.value:
|
|
||||||
setting.value = False
|
|
||||||
await setting.write()
|
|
||||||
setting = self.config.get(Settings.AlertEndMessage.setting_id)
|
|
||||||
if setting.value:
|
|
||||||
setting.value = None
|
|
||||||
await setting.write()
|
|
||||||
|
|
||||||
await self.refresh(thinking=selection)
|
|
||||||
elif value == '1':
|
|
||||||
# In Delete Alert case,
|
|
||||||
# Set the delete setting to True
|
|
||||||
setting = self.config.get(Settings.AlertEndDelete.setting_id)
|
|
||||||
if not setting.value:
|
|
||||||
setting.value = True
|
|
||||||
await setting.write()
|
|
||||||
|
|
||||||
await self.refresh(thinking=selection)
|
|
||||||
elif value == '2':
|
|
||||||
# In Edit Message case,
|
|
||||||
# Set the delete setting to False,
|
|
||||||
setting = self.config.get(Settings.AlertEndDelete.setting_id)
|
|
||||||
if setting.value:
|
|
||||||
setting.value = False
|
|
||||||
await setting.write()
|
|
||||||
|
|
||||||
# And open the edit message editor
|
|
||||||
await self.open_end_editor(selection)
|
|
||||||
await self.refresh()
|
|
||||||
|
|
||||||
async def ending_mode_menu_refresh(self):
|
|
||||||
# Build menu options
|
|
||||||
options = [
|
|
||||||
SelectOption(
|
|
||||||
label="Do Nothing",
|
|
||||||
description="Don't modify the live alert message.",
|
|
||||||
value="0",
|
|
||||||
),
|
|
||||||
SelectOption(
|
|
||||||
label="Delete Alert After Stream",
|
|
||||||
description="Delete the live alert message.",
|
|
||||||
value="1",
|
|
||||||
),
|
|
||||||
SelectOption(
|
|
||||||
label="Edit Alert After Stream",
|
|
||||||
description="Edit the live alert message to a custom message. Opens editor.",
|
|
||||||
value="2",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Calculate the correct default
|
|
||||||
if self.config.get(Settings.AlertEndDelete.setting_id).value:
|
|
||||||
options[1].default = True
|
|
||||||
elif self.config.get(Settings.AlertEndMessage.setting_id).value:
|
|
||||||
options[2].default = True
|
|
||||||
|
|
||||||
self.ending_mode_menu.options = options
|
|
||||||
|
|
||||||
# Edit channel menu
|
|
||||||
@select(cls=ChannelSelect,
|
|
||||||
placeholder="Select Alert Channel",
|
|
||||||
channel_types=[discord.ChannelType.text, discord.ChannelType.voice],
|
|
||||||
min_values=0, max_values=1)
|
|
||||||
async def channel_menu(self, selection: discord.Interaction, selected):
|
|
||||||
if selected.values:
|
|
||||||
await selection.response.defer(thinking=True, ephemeral=True)
|
|
||||||
setting = self.config.get(Settings.AlertChannel.setting_id)
|
|
||||||
setting.value = selected.values[0]
|
|
||||||
await setting.write()
|
|
||||||
await self.refresh(thinking=selection)
|
|
||||||
else:
|
|
||||||
await selection.response.defer(thinking=False)
|
|
||||||
|
|
||||||
async def channel_menu_refresh(self):
|
|
||||||
# current = self.config.get(Settings.AlertChannel.setting_id).value
|
|
||||||
# TODO: Check if discord-typed menus can have defaults yet
|
|
||||||
# Impl in stable dpy, but not released to pip yet
|
|
||||||
...
|
|
||||||
|
|
||||||
# ----- UI Flow -----
|
|
||||||
async def make_message(self) -> MessageArgs:
|
|
||||||
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
|
|
||||||
if streamer is None:
|
|
||||||
raise ValueError("Streamer row does not exist in AlertEditor")
|
|
||||||
name = streamer.display_name
|
|
||||||
|
|
||||||
# Build relevant setting table
|
|
||||||
table_map = {}
|
|
||||||
table_map['Channel'] = self.config.get(Settings.AlertChannel.setting_id).formatted
|
|
||||||
table_map['Streamer'] = f"https://www.twitch.tv/{streamer.login_name}"
|
|
||||||
table_map['Paused'] = self.config.get(Settings.AlertPaused.setting_id).formatted
|
|
||||||
|
|
||||||
prop_table = '\n'.join(tabulate(*table_map.items()))
|
|
||||||
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.dark_green(),
|
|
||||||
title=f"Stream Alert for {name}",
|
|
||||||
description=prop_table,
|
|
||||||
timestamp=utc_now()
|
|
||||||
)
|
|
||||||
|
|
||||||
message_setting = self.config.get(Settings.AlertMessage.setting_id)
|
|
||||||
message_desc_lines = [
|
|
||||||
f"An alert message will be posted to {table_map['Channel']}.",
|
|
||||||
f"Press `{self.edit_alert_button.label}`"
|
|
||||||
" to preview or edit the alert.",
|
|
||||||
"The following keys will be substituted in the alert message."
|
|
||||||
]
|
|
||||||
keytable = tabulate(*message_setting._subkey_desc.items())
|
|
||||||
for line in keytable:
|
|
||||||
message_desc_lines.append(f"> {line}")
|
|
||||||
|
|
||||||
embed.add_field(
|
|
||||||
name=f"When {name} goes live",
|
|
||||||
value='\n'.join(message_desc_lines),
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the ending behaviour
|
|
||||||
del_setting = self.config.get(Settings.AlertEndDelete.setting_id)
|
|
||||||
end_msg_setting = self.config.get(Settings.AlertEndMessage.setting_id)
|
|
||||||
|
|
||||||
if del_setting.value:
|
|
||||||
# Deleting
|
|
||||||
end_msg_desc = "The live alert message will be deleted."
|
|
||||||
...
|
|
||||||
elif end_msg_setting.value:
|
|
||||||
# Editing
|
|
||||||
lines = [
|
|
||||||
"The live alert message will edited to the configured message.",
|
|
||||||
f"Press `{self.edit_end_button.label}` to preview or edit the message.",
|
|
||||||
"The following substitution keys are supported "
|
|
||||||
"*in addition* to the live alert keys."
|
|
||||||
]
|
|
||||||
keytable = tabulate(
|
|
||||||
*[(k, v) for k, v in end_msg_setting._subkey_desc.items() if k not in message_setting._subkey_desc]
|
|
||||||
)
|
|
||||||
for line in keytable:
|
|
||||||
lines.append(f"> {line}")
|
|
||||||
end_msg_desc = '\n'.join(lines)
|
|
||||||
else:
|
|
||||||
# Doing nothing
|
|
||||||
end_msg_desc = "The live alert message will not be changed."
|
|
||||||
|
|
||||||
embed.add_field(
|
|
||||||
name=f"When {name} ends their stream",
|
|
||||||
value=end_msg_desc,
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return MessageArgs(embed=embed)
|
|
||||||
|
|
||||||
async def reload(self):
|
|
||||||
await self.sub_data.refresh()
|
|
||||||
# Note self.config references the sub_data, and doesn't need reloading.
|
|
||||||
|
|
||||||
async def refresh_layout(self):
|
|
||||||
to_refresh = (
|
|
||||||
self.pause_button_refresh(),
|
|
||||||
self.channel_menu_refresh(),
|
|
||||||
self.ending_mode_menu_refresh(),
|
|
||||||
)
|
|
||||||
await asyncio.gather(*to_refresh)
|
|
||||||
|
|
||||||
show_end_edit = (
|
|
||||||
not self.config.get(Settings.AlertEndDelete.setting_id).value
|
|
||||||
and
|
|
||||||
self.config.get(Settings.AlertEndMessage.setting_id).value
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if not show_end_edit:
|
|
||||||
# Don't show edit end button
|
|
||||||
buttons = (
|
|
||||||
self.edit_alert_button,
|
|
||||||
self.pause_button, self.delete_button, self.close_button
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
buttons = (
|
|
||||||
self.edit_alert_button, self.edit_end_button,
|
|
||||||
self.pause_button, self.delete_button, self.close_button
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_layout(
|
|
||||||
buttons,
|
|
||||||
(self.ending_mode_menu,),
|
|
||||||
(self.channel_menu,),
|
|
||||||
)
|
|
||||||
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
from typing import Optional, Any
|
|
||||||
import json
|
|
||||||
|
|
||||||
from meta.LionBot import LionBot
|
|
||||||
from settings import ModelData
|
|
||||||
from settings.groups import SettingGroup, ModelConfig, SettingDotDict
|
|
||||||
from settings.setting_types import BoolSetting, ChannelSetting
|
|
||||||
from core.setting_types import MessageSetting
|
|
||||||
from babel.translator import LocalBabel
|
|
||||||
from utils.lib import recurse_map, replace_multiple, tabulate
|
|
||||||
|
|
||||||
from .data import AlertsData
|
|
||||||
|
|
||||||
|
|
||||||
babel = LocalBabel('streamalerts')
|
|
||||||
_p = babel._p
|
|
||||||
|
|
||||||
|
|
||||||
class AlertConfig(ModelConfig):
|
|
||||||
settings = SettingDotDict()
|
|
||||||
_model_settings = set()
|
|
||||||
model = AlertsData.AlertChannel
|
|
||||||
|
|
||||||
|
|
||||||
class AlertSettings(SettingGroup):
|
|
||||||
@AlertConfig.register_model_setting
|
|
||||||
class AlertMessage(ModelData, MessageSetting):
|
|
||||||
setting_id = 'alert_live_message'
|
|
||||||
_display_name = _p('', 'live_message')
|
|
||||||
|
|
||||||
_desc = _p(
|
|
||||||
'',
|
|
||||||
'Message sent to the channel when the streamer goes live.'
|
|
||||||
)
|
|
||||||
_long_desc = _p(
|
|
||||||
'',
|
|
||||||
'Message sent to the attached channel when the Twitch streamer goes live.'
|
|
||||||
)
|
|
||||||
_accepts = _p('', 'JSON formatted greeting message data')
|
|
||||||
_default = json.dumps({'content': "**{display_name}** just went live at {channel_link}"})
|
|
||||||
|
|
||||||
_model = AlertsData.AlertChannel
|
|
||||||
_column = AlertsData.AlertChannel.live_message.name
|
|
||||||
|
|
||||||
_subkey_desc = {
|
|
||||||
'{display_name}': "Twitch channel name (with capitalisation)",
|
|
||||||
'{login_name}': "Twitch channel login name (as in url)",
|
|
||||||
'{channel_link}': "Link to the live twitch channel",
|
|
||||||
'{stream_start}': "Numeric timestamp when stream went live",
|
|
||||||
}
|
|
||||||
# TODO: More stuff
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_message(self) -> str:
|
|
||||||
return "The go-live notification message has been updated!"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs):
|
|
||||||
"""
|
|
||||||
Generate a formatter function for this message
|
|
||||||
from the provided stream and streamer data.
|
|
||||||
|
|
||||||
The formatter function accepts and returns a message data dict.
|
|
||||||
"""
|
|
||||||
async def formatter(data_dict: Optional[dict[str, Any]]):
|
|
||||||
if not data_dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mapping = {
|
|
||||||
'{display_name}': streamer.display_name,
|
|
||||||
'{login_name}': streamer.login_name,
|
|
||||||
'{channel_link}': f"https://www.twitch.tv/{streamer.login_name}",
|
|
||||||
'{stream_start}': int(stream.start_at.timestamp()),
|
|
||||||
}
|
|
||||||
|
|
||||||
recurse_map(
|
|
||||||
lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value,
|
|
||||||
data_dict,
|
|
||||||
)
|
|
||||||
return data_dict
|
|
||||||
return formatter
|
|
||||||
|
|
||||||
async def editor_callback(self, editor_data):
|
|
||||||
self.value = editor_data
|
|
||||||
await self.write()
|
|
||||||
|
|
||||||
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
|
|
||||||
lines = super()._desc_table(show_value=show_value)
|
|
||||||
keytable = tabulate(*self._subkey_desc.items(), colon='')
|
|
||||||
expline = (
|
|
||||||
"The following placeholders will be substituted with their values."
|
|
||||||
)
|
|
||||||
keyfield = (
|
|
||||||
"Placeholders",
|
|
||||||
expline + '\n' + '\n'.join(f"> {line}" for line in keytable)
|
|
||||||
)
|
|
||||||
lines.append(keyfield)
|
|
||||||
return lines
|
|
||||||
|
|
||||||
@AlertConfig.register_model_setting
|
|
||||||
class AlertEndMessage(ModelData, MessageSetting):
|
|
||||||
"""
|
|
||||||
Custom ending message to edit the live alert to.
|
|
||||||
If not set, doesn't edit the alert.
|
|
||||||
"""
|
|
||||||
setting_id = 'alert_end_message'
|
|
||||||
_display_name = _p('', 'end_message')
|
|
||||||
|
|
||||||
_desc = _p(
|
|
||||||
'',
|
|
||||||
'Optional message to edit the live alert with when the stream ends.'
|
|
||||||
)
|
|
||||||
_long_desc = _p(
|
|
||||||
'',
|
|
||||||
"If set, and `end_delete` is not on, "
|
|
||||||
"the live alert will be edited with this custom message "
|
|
||||||
"when the stream ends."
|
|
||||||
)
|
|
||||||
_accepts = _p('', 'JSON formatted greeting message data')
|
|
||||||
_default = None
|
|
||||||
|
|
||||||
_model = AlertsData.AlertChannel
|
|
||||||
_column = AlertsData.AlertChannel.end_message.name
|
|
||||||
|
|
||||||
_subkey_desc = {
|
|
||||||
'{display_name}': "Twitch channel name (with capitalisation)",
|
|
||||||
'{login_name}': "Twitch channel login name (as in url)",
|
|
||||||
'{channel_link}': "Link to the live twitch channel",
|
|
||||||
'{stream_start}': "Numeric timestamp when stream went live",
|
|
||||||
'{stream_end}': "Numeric timestamp when stream ended",
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_message(self) -> str:
|
|
||||||
if self.value:
|
|
||||||
return "The stream ending message has been updated."
|
|
||||||
else:
|
|
||||||
return "The stream ending message has been unset."
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs):
|
|
||||||
"""
|
|
||||||
Generate a formatter function for this message
|
|
||||||
from the provided stream and streamer data.
|
|
||||||
|
|
||||||
The formatter function accepts and returns a message data dict.
|
|
||||||
"""
|
|
||||||
# TODO: Fake stream data maker (namedtuple?) for previewing
|
|
||||||
async def formatter(data_dict: Optional[dict[str, Any]]):
|
|
||||||
if not data_dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
mapping = {
|
|
||||||
'{display_name}': streamer.display_name,
|
|
||||||
'{login_name}': streamer.login_name,
|
|
||||||
'{channel_link}': f"https://www.twitch.tv/{streamer.login_name}",
|
|
||||||
'{stream_start}': int(stream.start_at.timestamp()),
|
|
||||||
'{stream_end}': int(stream.end_at.timestamp()),
|
|
||||||
}
|
|
||||||
|
|
||||||
recurse_map(
|
|
||||||
lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value,
|
|
||||||
data_dict,
|
|
||||||
)
|
|
||||||
return data_dict
|
|
||||||
return formatter
|
|
||||||
|
|
||||||
async def editor_callback(self, editor_data):
|
|
||||||
self.value = editor_data
|
|
||||||
await self.write()
|
|
||||||
|
|
||||||
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
|
|
||||||
lines = super()._desc_table(show_value=show_value)
|
|
||||||
keytable = tabulate(*self._subkey_desc.items(), colon='')
|
|
||||||
expline = (
|
|
||||||
"The following placeholders will be substituted with their values."
|
|
||||||
)
|
|
||||||
keyfield = (
|
|
||||||
"Placeholders",
|
|
||||||
expline + '\n' + '\n'.join(f"> {line}" for line in keytable)
|
|
||||||
)
|
|
||||||
lines.append(keyfield)
|
|
||||||
return lines
|
|
||||||
...
|
|
||||||
|
|
||||||
@AlertConfig.register_model_setting
|
|
||||||
class AlertEndDelete(ModelData, BoolSetting):
|
|
||||||
"""
|
|
||||||
Whether to delete the live alert after the stream ends.
|
|
||||||
"""
|
|
||||||
setting_id = 'alert_end_delete'
|
|
||||||
_display_name = _p('', 'end_delete')
|
|
||||||
_desc = _p(
|
|
||||||
'',
|
|
||||||
'Whether to delete the live alert after the stream ends.'
|
|
||||||
)
|
|
||||||
_long_desc = _p(
|
|
||||||
'',
|
|
||||||
"If enabled, the live alert message will be deleted when the stream ends. "
|
|
||||||
"This overrides the `end_message` setting."
|
|
||||||
)
|
|
||||||
_default = False
|
|
||||||
|
|
||||||
_model = AlertsData.AlertChannel
|
|
||||||
_column = AlertsData.AlertChannel.end_delete.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_message(self) -> str:
|
|
||||||
if self.value:
|
|
||||||
return "The live alert will be deleted at the end of the stream."
|
|
||||||
else:
|
|
||||||
return "The live alert will not be deleted when the stream ends."
|
|
||||||
|
|
||||||
@AlertConfig.register_model_setting
|
|
||||||
class AlertPaused(ModelData, BoolSetting):
|
|
||||||
"""
|
|
||||||
Whether this live alert is currently paused.
|
|
||||||
"""
|
|
||||||
setting_id = 'alert_paused'
|
|
||||||
_display_name = _p('', 'paused')
|
|
||||||
_desc = _p(
|
|
||||||
'',
|
|
||||||
"Whether the alert is currently paused."
|
|
||||||
)
|
|
||||||
_long_desc = _p(
|
|
||||||
'',
|
|
||||||
"Paused alerts will not trigger live notifications, "
|
|
||||||
"although the streams will still be tracked internally."
|
|
||||||
)
|
|
||||||
_default = False
|
|
||||||
|
|
||||||
_model = AlertsData.AlertChannel
|
|
||||||
_column = AlertsData.AlertChannel.paused.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_message(self):
|
|
||||||
if self.value:
|
|
||||||
return "This alert is now paused"
|
|
||||||
else:
|
|
||||||
return "This alert has been unpaused"
|
|
||||||
|
|
||||||
@AlertConfig.register_model_setting
|
|
||||||
class AlertChannel(ModelData, ChannelSetting):
|
|
||||||
"""
|
|
||||||
The channel associated to this alert.
|
|
||||||
"""
|
|
||||||
setting_id = 'alert_channel'
|
|
||||||
_display_name = _p('', 'channel')
|
|
||||||
_desc = _p(
|
|
||||||
'',
|
|
||||||
"The Discord channel this live alert will be sent in."
|
|
||||||
)
|
|
||||||
_long_desc = _desc
|
|
||||||
|
|
||||||
# Note that this cannot actually be None,
|
|
||||||
# as there is no UI pathway to unset the setting.
|
|
||||||
_default = None
|
|
||||||
|
|
||||||
_model = AlertsData.AlertChannel
|
|
||||||
_column = AlertsData.AlertChannel.channelid.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def update_message(self):
|
|
||||||
return f"This alert will now be posted to {self.value.channel.mention}"
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
|
|
||||||
async def setup(bot):
|
|
||||||
from .exec_cog import Exec
|
|
||||||
|
|
||||||
await bot.add_cog(Exec(bot))
|
|
||||||
@@ -1,336 +0,0 @@
|
|||||||
import io
|
|
||||||
import ast
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
import asyncio
|
|
||||||
import traceback
|
|
||||||
import builtins
|
|
||||||
import inspect
|
|
||||||
import logging
|
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
from typing import Callable, Any, Optional
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.ext import commands
|
|
||||||
from discord.ext.commands.errors import CheckFailure
|
|
||||||
from discord.ui import TextInput, View
|
|
||||||
from discord.ui.button import button
|
|
||||||
import discord.app_commands as appcmd
|
|
||||||
|
|
||||||
from meta.logger import logging_context, log_wrap
|
|
||||||
from meta import conf
|
|
||||||
from meta.context import context, ctx_bot
|
|
||||||
from meta.LionContext import LionContext
|
|
||||||
from meta.LionCog import LionCog
|
|
||||||
from meta.LionBot import LionBot
|
|
||||||
|
|
||||||
from utils.ui import FastModal, input
|
|
||||||
|
|
||||||
from wards import sys_admin
|
|
||||||
|
|
||||||
|
|
||||||
def _(arg): return arg
|
|
||||||
|
|
||||||
def _p(ctx, arg): return arg
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class ExecModal(FastModal, title="Execute"):
|
|
||||||
code: TextInput = TextInput(
|
|
||||||
label="Code to execute",
|
|
||||||
style=discord.TextStyle.long,
|
|
||||||
required=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExecStyle(Enum):
|
|
||||||
EXEC = 'exec'
|
|
||||||
EVAL = 'eval'
|
|
||||||
|
|
||||||
|
|
||||||
class ExecUI(View):
|
|
||||||
def __init__(self, ctx, code=None, style=ExecStyle.EXEC, ephemeral=True) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.ctx: LionContext = ctx
|
|
||||||
self.interaction: Optional[discord.Interaction] = ctx.interaction
|
|
||||||
self.code: Optional[str] = code
|
|
||||||
self.style: ExecStyle = style
|
|
||||||
self.ephemeral: bool = ephemeral
|
|
||||||
|
|
||||||
self._modal: Optional[ExecModal] = None
|
|
||||||
self._msg: Optional[discord.Message] = None
|
|
||||||
|
|
||||||
async def interaction_check(self, interaction: discord.Interaction):
|
|
||||||
"""Only allow the original author to use this View"""
|
|
||||||
if interaction.user.id != self.ctx.author.id:
|
|
||||||
await interaction.response.send_message(
|
|
||||||
("You cannot use this interface!"),
|
|
||||||
ephemeral=True
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
if self.code is None:
|
|
||||||
if (interaction := self.interaction) is not None:
|
|
||||||
self.interaction = None
|
|
||||||
await interaction.response.send_modal(self.get_modal())
|
|
||||||
await self.wait()
|
|
||||||
else:
|
|
||||||
# Complain
|
|
||||||
# TODO: error_reply
|
|
||||||
await self.ctx.reply("Pls give code.")
|
|
||||||
else:
|
|
||||||
await self.interaction.response.defer(thinking=True, ephemeral=self.ephemeral)
|
|
||||||
await self.compile()
|
|
||||||
await self.wait()
|
|
||||||
|
|
||||||
@button(label="Recompile")
|
|
||||||
async def recompile_button(self, interaction, butt):
|
|
||||||
# Interaction response with modal
|
|
||||||
await interaction.response.send_modal(self.get_modal())
|
|
||||||
|
|
||||||
@button(label="Show Source")
|
|
||||||
async def source_button(self, interaction, butt):
|
|
||||||
if len(self.code) > 1900:
|
|
||||||
# Send as file
|
|
||||||
with StringIO(self.code) as fp:
|
|
||||||
fp.seek(0)
|
|
||||||
file = discord.File(fp, filename="source.py")
|
|
||||||
await interaction.response.send_message(file=file, ephemeral=True)
|
|
||||||
else:
|
|
||||||
# Send as message
|
|
||||||
await interaction.response.send_message(
|
|
||||||
content=f"```py\n{self.code}```",
|
|
||||||
ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_modal(self) -> ExecModal:
|
|
||||||
modal = ExecModal()
|
|
||||||
|
|
||||||
@modal.submit_callback()
|
|
||||||
async def exec_submit(interaction: discord.Interaction):
|
|
||||||
if self.interaction is None:
|
|
||||||
self.interaction = interaction
|
|
||||||
await interaction.response.defer(thinking=True)
|
|
||||||
else:
|
|
||||||
await interaction.response.defer()
|
|
||||||
|
|
||||||
# Set code
|
|
||||||
self.code = modal.code.value
|
|
||||||
|
|
||||||
# Call compile
|
|
||||||
await self.compile()
|
|
||||||
|
|
||||||
return modal
|
|
||||||
|
|
||||||
def get_modal(self):
|
|
||||||
self._modal = self.create_modal()
|
|
||||||
self._modal.code.default = self.code
|
|
||||||
return self._modal
|
|
||||||
|
|
||||||
async def compile(self):
|
|
||||||
# Call _async
|
|
||||||
result = await _async(self.code, style=self.style.value)
|
|
||||||
|
|
||||||
# Display output
|
|
||||||
await self.show_output(result)
|
|
||||||
|
|
||||||
async def show_output(self, output):
|
|
||||||
# Format output
|
|
||||||
# If output message exists and not ephemeral, edit
|
|
||||||
# Otherwise, send message, add buttons
|
|
||||||
if len(output) > 1900:
|
|
||||||
# Send as file
|
|
||||||
with StringIO(output) as fp:
|
|
||||||
fp.seek(0)
|
|
||||||
args = {
|
|
||||||
'content': None,
|
|
||||||
'attachments': [discord.File(fp, filename="output.md")]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
args = {
|
|
||||||
'content': f"```md\n{output}```",
|
|
||||||
'attachments': []
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._msg is None:
|
|
||||||
if self.interaction is not None:
|
|
||||||
msg = await self.interaction.edit_original_response(**args, view=self)
|
|
||||||
else:
|
|
||||||
# Send new message
|
|
||||||
if args['content'] is None:
|
|
||||||
args['file'] = args.pop('attachments')[0]
|
|
||||||
msg = await self.ctx.reply(**args, ephemeral=self.ephemeral, view=self)
|
|
||||||
|
|
||||||
if not self.ephemeral:
|
|
||||||
self._msg = msg
|
|
||||||
else:
|
|
||||||
if self.interaction is not None:
|
|
||||||
await self.interaction.edit_original_response(**args, view=self)
|
|
||||||
else:
|
|
||||||
# Edit message
|
|
||||||
await self._msg.edit(**args)
|
|
||||||
|
|
||||||
|
|
||||||
def mk_print(fp: io.StringIO) -> Callable[..., None]:
|
|
||||||
def _print(*args, file: Any = fp, **kwargs):
|
|
||||||
return print(*args, file=file, **kwargs)
|
|
||||||
return _print
|
|
||||||
|
|
||||||
|
|
||||||
def mk_status_printer(bot, printer):
|
|
||||||
async def _status(details=False):
|
|
||||||
if details:
|
|
||||||
status = await bot.system_monitor.get_overview()
|
|
||||||
else:
|
|
||||||
status = await bot.system_monitor.get_summary()
|
|
||||||
printer(status)
|
|
||||||
return status
|
|
||||||
return _status
|
|
||||||
|
|
||||||
|
|
||||||
@log_wrap(action="Code Exec")
|
|
||||||
async def _async(to_eval: str, style='exec'):
|
|
||||||
newline = '\n' * ('\n' in to_eval)
|
|
||||||
logger.info(
|
|
||||||
f"Exec code with {style}: {newline}{to_eval}"
|
|
||||||
)
|
|
||||||
|
|
||||||
output = io.StringIO()
|
|
||||||
_print = mk_print(output)
|
|
||||||
|
|
||||||
scope: dict[str, Any] = dict(sys.modules)
|
|
||||||
scope['__builtins__'] = builtins
|
|
||||||
scope.update(builtins.__dict__)
|
|
||||||
scope['ctx'] = ctx = context.get()
|
|
||||||
scope['bot'] = ctx_bot.get()
|
|
||||||
scope['print'] = _print # type: ignore
|
|
||||||
scope['print_status'] = mk_status_printer(scope['bot'], _print)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if ctx and ctx.message:
|
|
||||||
source_str = f"<msg: {ctx.message.id}>"
|
|
||||||
elif ctx and ctx.interaction:
|
|
||||||
source_str = f"<iid: {ctx.interaction.id}>"
|
|
||||||
else:
|
|
||||||
source_str = "Unknown async"
|
|
||||||
|
|
||||||
code = compile(
|
|
||||||
to_eval,
|
|
||||||
source_str,
|
|
||||||
style,
|
|
||||||
ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
|
|
||||||
)
|
|
||||||
func = types.FunctionType(code, scope)
|
|
||||||
|
|
||||||
ret = func()
|
|
||||||
if inspect.iscoroutine(ret):
|
|
||||||
ret = await ret
|
|
||||||
if ret is not None:
|
|
||||||
_print(repr(ret))
|
|
||||||
except Exception:
|
|
||||||
_, exc, tb = sys.exc_info()
|
|
||||||
_print("".join(traceback.format_tb(tb)))
|
|
||||||
_print(f"{type(exc).__name__}: {exc}")
|
|
||||||
|
|
||||||
result = output.getvalue().strip()
|
|
||||||
newline = '\n' * ('\n' in result)
|
|
||||||
logger.info(
|
|
||||||
f"Exec complete, output: {newline}{result}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class Exec(LionCog):
|
|
||||||
guild_ids = conf.bot.getintlist('admin_guilds')
|
|
||||||
|
|
||||||
def __init__(self, bot: LionBot):
|
|
||||||
self.bot = bot
|
|
||||||
|
|
||||||
async def cog_check(self, ctx: LionContext) -> bool: # type: ignore
|
|
||||||
passed = await sys_admin(ctx.bot, ctx.author.id)
|
|
||||||
if passed:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
raise CheckFailure(
|
|
||||||
"You must be a bot owner to do this!"
|
|
||||||
)
|
|
||||||
|
|
||||||
@commands.hybrid_command(
|
|
||||||
name=_('async'),
|
|
||||||
description=_("Execute arbitrary code with Exec")
|
|
||||||
)
|
|
||||||
@appcmd.describe(
|
|
||||||
string="Code to execute.",
|
|
||||||
)
|
|
||||||
async def async_cmd(self, ctx: LionContext,
|
|
||||||
string: Optional[str] = None,
|
|
||||||
):
|
|
||||||
await ExecUI(ctx, string, ExecStyle.EXEC, ephemeral=False).run()
|
|
||||||
|
|
||||||
@commands.hybrid_command(
|
|
||||||
name=_('reload'),
|
|
||||||
description=_("Reload a given LionBot extension. Launches an ExecUI.")
|
|
||||||
)
|
|
||||||
@appcmd.describe(
|
|
||||||
extension=_("Name of the extension to reload. See autocomplete for options."),
|
|
||||||
force=_("Whether to force an extension reload even if it doesn't exist.")
|
|
||||||
)
|
|
||||||
@appcmd.guilds(*guild_ids)
|
|
||||||
async def reload_cmd(self, ctx: LionContext, extension: str, force: Optional[bool] = False):
|
|
||||||
"""
|
|
||||||
This is essentially just a friendly wrapper to reload an extension.
|
|
||||||
It is equivalent to running "await bot.reload_extension(extension)" in eval,
|
|
||||||
with a slightly nicer interface through the autocomplete and error handling.
|
|
||||||
"""
|
|
||||||
exists = (extension in self.bot.extensions)
|
|
||||||
if not (force or exists):
|
|
||||||
embed = discord.Embed(description=f"Unknown extension {extension}", colour=discord.Colour.red())
|
|
||||||
await ctx.reply(embed=embed)
|
|
||||||
else:
|
|
||||||
# Uses an ExecUI to simplify error handling and re-execution
|
|
||||||
if exists:
|
|
||||||
string = f"await bot.reload_extension('{extension}')"
|
|
||||||
else:
|
|
||||||
string = f"await bot.load_extension('{extension}')"
|
|
||||||
await ExecUI(ctx, string, ExecStyle.EVAL).run()
|
|
||||||
|
|
||||||
@reload_cmd.autocomplete('extension')
|
|
||||||
async def reload_extension_acmpl(self, interaction: discord.Interaction, partial: str):
|
|
||||||
keys = set(self.bot.extensions.keys())
|
|
||||||
results = [
|
|
||||||
appcmd.Choice(name=key, value=key)
|
|
||||||
for key in keys
|
|
||||||
if partial.lower() in key.lower()
|
|
||||||
]
|
|
||||||
if not results:
|
|
||||||
results = [
|
|
||||||
appcmd.Choice(name=f"No extensions found matching {partial}", value="None")
|
|
||||||
]
|
|
||||||
return results[:25]
|
|
||||||
|
|
||||||
@commands.hybrid_command(
|
|
||||||
name=_('shutdown'),
|
|
||||||
description=_("Shutdown (or restart) the client.")
|
|
||||||
)
|
|
||||||
@appcmd.guilds(*guild_ids)
|
|
||||||
async def shutdown_cmd(self, ctx: LionContext):
|
|
||||||
"""
|
|
||||||
Shutdown the client.
|
|
||||||
Maybe do something friendly here?
|
|
||||||
"""
|
|
||||||
logger.info("Shutting down on admin request.")
|
|
||||||
await ctx.reply(
|
|
||||||
embed=discord.Embed(
|
|
||||||
description=f"Understood {ctx.author.mention}, cleaning up and shutting down!",
|
|
||||||
colour=discord.Colour.orange()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await self.bot.close()
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def setup(bot):
|
|
||||||
from .cog import VoiceFixCog
|
|
||||||
await bot.add_cog(VoiceFixCog(bot))
|
|
||||||
@@ -1,449 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from typing import Optional
|
|
||||||
import asyncio
|
|
||||||
from cachetools import FIFOCache
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.abc import GuildChannel
|
|
||||||
from discord.ext import commands as cmds
|
|
||||||
from discord import app_commands as appcmds
|
|
||||||
|
|
||||||
from meta import LionBot, LionCog, LionContext
|
|
||||||
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
|
|
||||||
from utils.ui import Confirm
|
|
||||||
|
|
||||||
from . import logger
|
|
||||||
from .data import LinkData
|
|
||||||
|
|
||||||
|
|
||||||
async def prepare_attachments(attachments: list[discord.Attachment]):
|
|
||||||
results = []
|
|
||||||
for attach in attachments:
|
|
||||||
try:
|
|
||||||
as_file = await attach.to_file(spoiler=attach.is_spoiler())
|
|
||||||
results.append(as_file)
|
|
||||||
except discord.HTTPException:
|
|
||||||
pass
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def prepare_embeds(message: discord.Message):
|
|
||||||
embeds = [embed for embed in message.embeds if embed.type == 'rich']
|
|
||||||
if message.reference:
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.dark_gray(),
|
|
||||||
description=f"Reply to {message.reference.jump_url}"
|
|
||||||
)
|
|
||||||
embeds.append(embed)
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceFixCog(LionCog):
|
|
||||||
def __init__(self, bot: LionBot):
|
|
||||||
self.bot = bot
|
|
||||||
self.data = bot.db.load_registry(LinkData())
|
|
||||||
|
|
||||||
# Map of linkids to list of channelids
|
|
||||||
self.link_channels = {}
|
|
||||||
|
|
||||||
# Map of channelids to linkids
|
|
||||||
self.channel_links = {}
|
|
||||||
|
|
||||||
# Map of channelids to initialised discord.Webhook
|
|
||||||
self.hooks = {}
|
|
||||||
|
|
||||||
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
|
|
||||||
self.message_cache = FIFOCache(maxsize=200)
|
|
||||||
# webhook msgid -> orig msgid
|
|
||||||
self.wmessages = FIFOCache(maxsize=600)
|
|
||||||
|
|
||||||
self.lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def cog_load(self):
|
|
||||||
await self.data.init()
|
|
||||||
|
|
||||||
await self.reload_links()
|
|
||||||
|
|
||||||
async def reload_links(self):
|
|
||||||
records = await self.data.channel_links.select_where()
|
|
||||||
channel_links = defaultdict(set)
|
|
||||||
link_channels = defaultdict(set)
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
linkid = record['linkid']
|
|
||||||
channelid = record['channelid']
|
|
||||||
|
|
||||||
channel_links[channelid].add(linkid)
|
|
||||||
link_channels[linkid].add(channelid)
|
|
||||||
|
|
||||||
channelids = list(channel_links.keys())
|
|
||||||
if channelids:
|
|
||||||
await self.data.LinkHook.fetch_where(channelid=channelids)
|
|
||||||
for channelid in channelids:
|
|
||||||
# Will hit cache, so don't need any more data queries
|
|
||||||
await self.fetch_webhook_for(channelid)
|
|
||||||
|
|
||||||
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
|
|
||||||
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
|
|
||||||
)
|
|
||||||
|
|
||||||
@LionCog.listener('on_message')
|
|
||||||
async def on_message(self, message: discord.Message):
|
|
||||||
# Don't need this because everything except explicit messages are webhooks now
|
|
||||||
# if self.bot.user and (message.author.id == self.bot.user.id):
|
|
||||||
# return
|
|
||||||
if message.webhook_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with self.lock:
|
|
||||||
sent = []
|
|
||||||
linkids = self.channel_links.get(message.channel.id, ())
|
|
||||||
if linkids:
|
|
||||||
for linkid in linkids:
|
|
||||||
for channelid in self.link_channels[linkid]:
|
|
||||||
if channelid != message.channel.id:
|
|
||||||
if message.attachments:
|
|
||||||
files = await prepare_attachments(message.attachments)
|
|
||||||
else:
|
|
||||||
files = []
|
|
||||||
|
|
||||||
hook = self.hooks[channelid]
|
|
||||||
avatar = message.author.avatar or message.author.default_avatar
|
|
||||||
msg = await hook.send(
|
|
||||||
content=message.content,
|
|
||||||
wait=True,
|
|
||||||
username=message.author.display_name,
|
|
||||||
avatar_url=avatar.url,
|
|
||||||
embeds=await prepare_embeds(message),
|
|
||||||
files=files,
|
|
||||||
allowed_mentions=discord.AllowedMentions.none()
|
|
||||||
)
|
|
||||||
sent.append((channelid, msg))
|
|
||||||
self.wmessages[msg.id] = message.id
|
|
||||||
if sent:
|
|
||||||
# For easier lookup
|
|
||||||
self.wmessages[message.id] = message.id
|
|
||||||
sent.append((message.channel.id, message))
|
|
||||||
|
|
||||||
self.message_cache[message.id] = sent
|
|
||||||
logger.info(f"Forwarded message {message.id}")
|
|
||||||
|
|
||||||
|
|
||||||
@LionCog.listener('on_message_edit')
|
|
||||||
async def on_message_edit(self, before, after):
|
|
||||||
async with self.lock:
|
|
||||||
cached_sent = self.message_cache.pop(before.id, ())
|
|
||||||
new_sent = []
|
|
||||||
for cid, msg in cached_sent:
|
|
||||||
try:
|
|
||||||
if msg.id != before.id:
|
|
||||||
msg = await msg.edit(
|
|
||||||
content=after.content,
|
|
||||||
embeds=await prepare_embeds(after),
|
|
||||||
)
|
|
||||||
new_sent.append((cid, msg))
|
|
||||||
except discord.NotFound:
|
|
||||||
pass
|
|
||||||
if new_sent:
|
|
||||||
self.message_cache[after.id] = new_sent
|
|
||||||
|
|
||||||
@LionCog.listener('on_message_delete')
|
|
||||||
async def on_message_delete(self, message):
|
|
||||||
async with self.lock:
|
|
||||||
origid = self.wmessages.get(message.id, None)
|
|
||||||
if origid:
|
|
||||||
cached_sent = self.message_cache.pop(origid, ())
|
|
||||||
for _, msg in cached_sent:
|
|
||||||
try:
|
|
||||||
if msg.id != message.id:
|
|
||||||
await msg.delete()
|
|
||||||
except discord.NotFound:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@LionCog.listener('on_reaction_add')
|
|
||||||
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
|
|
||||||
async with self.lock:
|
|
||||||
message = reaction.message
|
|
||||||
emoji = reaction.emoji
|
|
||||||
origid = self.wmessages.get(message.id, None)
|
|
||||||
if origid and reaction.count == 1:
|
|
||||||
cached_sent = self.message_cache.get(origid, ())
|
|
||||||
for _, msg in cached_sent:
|
|
||||||
# TODO: Would be better to have a Message and check the reactions
|
|
||||||
try:
|
|
||||||
if msg.id != message.id:
|
|
||||||
await msg.add_reaction(emoji)
|
|
||||||
except discord.HTTPException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
|
|
||||||
hook = self.hooks.get(channelid, None)
|
|
||||||
if hook is None:
|
|
||||||
row = await self.data.LinkHook.fetch(channelid)
|
|
||||||
if row is None:
|
|
||||||
channel = self.bot.get_channel(channelid)
|
|
||||||
if channel is None:
|
|
||||||
raise ValueError("Cannot find channel to create hook.")
|
|
||||||
hook = await channel.create_webhook(name="LabRat Channel Link")
|
|
||||||
await self.data.LinkHook.create(
|
|
||||||
channelid=channelid,
|
|
||||||
webhookid=hook.id,
|
|
||||||
token=hook.token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
|
|
||||||
self.hooks[channelid] = hook
|
|
||||||
return hook
|
|
||||||
|
|
||||||
@cmds.hybrid_group(
|
|
||||||
name='linker',
|
|
||||||
description="Base command group for the channel linker"
|
|
||||||
)
|
|
||||||
@appcmds.default_permissions(manage_channels=True)
|
|
||||||
async def linker_group(self, ctx: LionContext):
|
|
||||||
...
|
|
||||||
|
|
||||||
@linker_group.command(
|
|
||||||
name='link',
|
|
||||||
description="Create a new link, or add a channel to an existing link."
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
name="Name of the new or existing channel link.",
|
|
||||||
channel1="First channel to add to the link.",
|
|
||||||
channel2="Second channel to add to the link.",
|
|
||||||
channel3="Third channel to add to the link.",
|
|
||||||
channel4="Fourth channel to add to the link.",
|
|
||||||
channel5="Fifth channel to add to the link.",
|
|
||||||
channelid="Optionally add a channel by id (for e.g. cross-server links).",
|
|
||||||
)
|
|
||||||
async def linker_link(self, ctx: LionContext,
|
|
||||||
name: str,
|
|
||||||
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
||||||
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
||||||
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
||||||
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
||||||
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
|
|
||||||
channelid: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if not ctx.interaction:
|
|
||||||
return
|
|
||||||
await ctx.interaction.response.defer(thinking=True)
|
|
||||||
|
|
||||||
# Check if link 'name' already exists, create if not
|
|
||||||
existing = await self.data.Link.fetch_where()
|
|
||||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
|
||||||
if link_row is None:
|
|
||||||
# Create
|
|
||||||
link_row = await self.data.Link.create(name=name)
|
|
||||||
link_channels = set()
|
|
||||||
created = True
|
|
||||||
else:
|
|
||||||
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
|
|
||||||
link_channels = {record['channelid'] for record in records}
|
|
||||||
created = False
|
|
||||||
|
|
||||||
# Create webhooks and webhook rows on channels if required
|
|
||||||
maybe_channels = [
|
|
||||||
channel1, channel2, channel3, channel4, channel5,
|
|
||||||
]
|
|
||||||
if channelid and channelid.isdigit():
|
|
||||||
channel = self.bot.get_channel(int(channelid))
|
|
||||||
maybe_channels.append(channel)
|
|
||||||
|
|
||||||
channels = [channel for channel in maybe_channels if channel]
|
|
||||||
for channel in channels:
|
|
||||||
await self.fetch_webhook_for(channel.id)
|
|
||||||
|
|
||||||
# Insert or update the links
|
|
||||||
for channel in channels:
|
|
||||||
if channel.id not in link_channels:
|
|
||||||
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
|
|
||||||
|
|
||||||
await self.reload_links()
|
|
||||||
|
|
||||||
if created:
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title="Link Created",
|
|
||||||
description=(
|
|
||||||
"Created the link **{name}** and linked channels:\n{channels}"
|
|
||||||
).format(name=name, channels=', '.join(channel.mention for channel in channels))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
channelids = self.link_channels[link_row.linkid]
|
|
||||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title="Channels Linked",
|
|
||||||
description=(
|
|
||||||
"Updated the link **{name}** to link the following channels:\n{channelstr}"
|
|
||||||
).format(name=link_row.name, channelstr=channelstr)
|
|
||||||
)
|
|
||||||
await ctx.reply(embed=embed)
|
|
||||||
|
|
||||||
@linker_group.command(
|
|
||||||
name='unlink',
|
|
||||||
description="Destroy a link, or remove a channel from a link."
|
|
||||||
)
|
|
||||||
@appcmds.describe(
|
|
||||||
name="Name of the link to destroy",
|
|
||||||
channel="Channel to remove from the link.",
|
|
||||||
)
|
|
||||||
async def linker_unlink(self, ctx: LionContext,
|
|
||||||
name: str, channel: Optional[GuildChannel] = None):
|
|
||||||
if not ctx.interaction:
|
|
||||||
return
|
|
||||||
# Get the link, error if it doesn't exist
|
|
||||||
existing = await self.data.Link.fetch_where()
|
|
||||||
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
|
|
||||||
if link_row is None:
|
|
||||||
raise UserInputError(
|
|
||||||
f"Link **{name}** doesn't exist!"
|
|
||||||
)
|
|
||||||
|
|
||||||
link_channelids = self.link_channels.get(link_row.linkid, ())
|
|
||||||
|
|
||||||
if channel is not None:
|
|
||||||
# If channel was given, remove channel from link and ack
|
|
||||||
if channel.id not in link_channelids:
|
|
||||||
raise UserInputError(
|
|
||||||
f"{channel.mention} is not linked in **{link_row.name}**!"
|
|
||||||
)
|
|
||||||
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title="Channel Unlinked",
|
|
||||||
description=f"{channel.mention} has been removed from **{link_row.name}**."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Otherwise, confirm link destroy, delete link row, and ack
|
|
||||||
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
|
|
||||||
confirm = Confirm(
|
|
||||||
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
|
|
||||||
ctx.author.id,
|
|
||||||
)
|
|
||||||
confirm.embed.colour = discord.Colour.red()
|
|
||||||
try:
|
|
||||||
result = await confirm.ask(ctx.interaction)
|
|
||||||
except ResponseTimedOut:
|
|
||||||
result = False
|
|
||||||
if not result:
|
|
||||||
raise SafeCancellation
|
|
||||||
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title="Link removed",
|
|
||||||
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
|
|
||||||
)
|
|
||||||
await link_row.delete()
|
|
||||||
|
|
||||||
await self.reload_links()
|
|
||||||
await ctx.reply(embed=embed)
|
|
||||||
|
|
||||||
@linker_link.autocomplete('name')
|
|
||||||
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
|
|
||||||
"""
|
|
||||||
Autocomplete an existing link.
|
|
||||||
"""
|
|
||||||
existing = await self.data.Link.fetch_where()
|
|
||||||
names = [row.name for row in existing]
|
|
||||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
|
||||||
if not matching:
|
|
||||||
choice = appcmds.Choice(
|
|
||||||
name=f"Create a new link '{partial}'",
|
|
||||||
value=partial
|
|
||||||
)
|
|
||||||
choices = [choice]
|
|
||||||
else:
|
|
||||||
choices = [
|
|
||||||
appcmds.Choice(
|
|
||||||
name=f"Link {name}",
|
|
||||||
value=name
|
|
||||||
)
|
|
||||||
for name in matching
|
|
||||||
]
|
|
||||||
return choices
|
|
||||||
|
|
||||||
@linker_unlink.autocomplete('name')
|
|
||||||
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
|
|
||||||
"""
|
|
||||||
Autocomplete an existing link.
|
|
||||||
"""
|
|
||||||
existing = await self.data.Link.fetch_where()
|
|
||||||
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
|
|
||||||
if not matching:
|
|
||||||
choice = appcmds.Choice(
|
|
||||||
name=f"No existing links matching '{partial}'",
|
|
||||||
value=partial
|
|
||||||
)
|
|
||||||
choices = [choice]
|
|
||||||
else:
|
|
||||||
choices = [
|
|
||||||
appcmds.Choice(
|
|
||||||
name=f"Link {name}",
|
|
||||||
value=name
|
|
||||||
)
|
|
||||||
for name in matching
|
|
||||||
]
|
|
||||||
return choices
|
|
||||||
|
|
||||||
@linker_group.command(
|
|
||||||
name='links',
|
|
||||||
description="Display the existing channel links."
|
|
||||||
)
|
|
||||||
async def linker_links(self, ctx: LionContext):
|
|
||||||
if not ctx.interaction:
|
|
||||||
return
|
|
||||||
await ctx.interaction.response.defer(thinking=True)
|
|
||||||
|
|
||||||
links = await self.data.Link.fetch_where()
|
|
||||||
|
|
||||||
if not links:
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.light_grey(),
|
|
||||||
title="No channel links have been set up!",
|
|
||||||
description="Create a new link and add channels with {linker}".format(
|
|
||||||
linker=self.bot.core.mention_cmd('linker link')
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=f"Channel Links in {ctx.guild.name}",
|
|
||||||
)
|
|
||||||
for link in links:
|
|
||||||
channelids = self.link_channels.get(link.linkid, ())
|
|
||||||
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
|
|
||||||
embed.add_field(
|
|
||||||
name=f"Link **{link.name}**",
|
|
||||||
value=channelstr,
|
|
||||||
inline=False
|
|
||||||
)
|
|
||||||
# TODO: May want paging if over 25 links....
|
|
||||||
await ctx.reply(embed=embed)
|
|
||||||
|
|
||||||
@linker_group.command(
|
|
||||||
name="webhook",
|
|
||||||
description='Manually configure the webhook for a given channel.'
|
|
||||||
)
|
|
||||||
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
|
|
||||||
if not ctx.interaction:
|
|
||||||
return
|
|
||||||
|
|
||||||
hook = discord.Webhook.from_url(webhook, client=self.bot)
|
|
||||||
existing = await self.data.LinkHook.fetch(channel.id)
|
|
||||||
if existing:
|
|
||||||
await existing.update(webhookid=hook.id, token=hook.token)
|
|
||||||
else:
|
|
||||||
await self.data.LinkHook.create(
|
|
||||||
channelid=channel.id,
|
|
||||||
webhookid=hook.id,
|
|
||||||
token=hook.token,
|
|
||||||
)
|
|
||||||
self.hooks[channel.id] = hook
|
|
||||||
await ctx.reply(f"Webhook for {channel.mention} updated!")
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
from data import Registry, RowModel, Table
|
|
||||||
from data.columns import Integer, Bool, Timestamp, String
|
|
||||||
|
|
||||||
|
|
||||||
class LinkData(Registry):
|
|
||||||
class Link(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE links(
|
|
||||||
linkid SERIAL PRIMARY KEY,
|
|
||||||
name TEXT
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'links'
|
|
||||||
_cache_ = {}
|
|
||||||
|
|
||||||
linkid = Integer(primary=True)
|
|
||||||
name = String()
|
|
||||||
|
|
||||||
|
|
||||||
channel_links = Table('channel_links')
|
|
||||||
|
|
||||||
class LinkHook(RowModel):
|
|
||||||
"""
|
|
||||||
Schema
|
|
||||||
------
|
|
||||||
CREATE TABLE channel_webhooks(
|
|
||||||
channelid BIGINT PRIMARY KEY,
|
|
||||||
webhookid BIGINT NOT NULL,
|
|
||||||
token TEXT NOT NULL
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
_tablename_ = 'channel_webhooks'
|
|
||||||
_cache_ = {}
|
|
||||||
|
|
||||||
channelid = Integer(primary=True)
|
|
||||||
webhookid = Integer()
|
|
||||||
token = String()
|
|
||||||
@@ -54,7 +54,10 @@ class MsgEditor(MessageUI):
|
|||||||
By default, uses the provided `formatter` callback (if provided).
|
By default, uses the provided `formatter` callback (if provided).
|
||||||
"""
|
"""
|
||||||
if self._formatter is not None:
|
if self._formatter is not None:
|
||||||
await self._formatter(data)
|
return await self._formatter(data)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def copy_data(self):
|
def copy_data(self):
|
||||||
return copy.deepcopy(self.history[-1])
|
return copy.deepcopy(self.history[-1])
|
||||||
@@ -78,7 +81,8 @@ class MsgEditor(MessageUI):
|
|||||||
|
|
||||||
if 'embed' in new_data:
|
if 'embed' in new_data:
|
||||||
try:
|
try:
|
||||||
discord.Embed.from_dict(new_data['embed'])
|
formatted_data = copy.deepcopy(new_data)
|
||||||
|
discord.Embed.from_dict(await self.format_data(formatted_data['embed']))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise UserInputError(
|
raise UserInputError(
|
||||||
t(_p(
|
t(_p(
|
||||||
@@ -445,8 +449,17 @@ class MsgEditor(MessageUI):
|
|||||||
embed_data.pop('footer', None)
|
embed_data.pop('footer', None)
|
||||||
|
|
||||||
if (ts := timestamp_field.value):
|
if (ts := timestamp_field.value):
|
||||||
|
if ts.isdigit():
|
||||||
|
# Treat as UTC timestamp
|
||||||
|
timestamp = dt.datetime.fromtimestamp(int(ts), dt.timezone.utc)
|
||||||
|
ts = timestamp.isoformat()
|
||||||
|
to_validate = ts
|
||||||
|
elif self._formatter:
|
||||||
|
to_validate = await self._formatter(ts)
|
||||||
|
else:
|
||||||
|
to_validate = ts
|
||||||
try:
|
try:
|
||||||
dt.datetime.fromisoformat(ts)
|
dt.datetime.fromisoformat(to_validate)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise UserInputError(
|
raise UserInputError(
|
||||||
t(_p(
|
t(_p(
|
||||||
|
|||||||
Reference in New Issue
Block a user