Compare commits

..

15 Commits

37 changed files with 144 additions and 3833 deletions

2
.gitignore vendored
View File

@@ -147,5 +147,3 @@ dmypy.json
# Cython debug symbols
cython_debug/
config/**

9
.gitmodules vendored Normal file
View File

@@ -0,0 +1,9 @@
[submodule "src/modules/voicefix"]
path = src/modules/voicefix
url = git@github.com:Intery/StudyLion-voicefix.git
[submodule "src/modules/streamalerts"]
path = src/modules/streamalerts
url = git@github.com:Intery/StudyLion-streamalerts.git
[submodule "src/data"]
path = src/data
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git

1
config/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.conf

View File

@@ -1,28 +0,0 @@
[BOT]
prefix = !!
admins =
admin_guilds =
shard_count = 1
ALSO_READ = config/emojis.conf, config/secrets.conf
[LOGGING]
log_file = bot.log
general_log =
error_log = %(general_log)
critical_log = %(general_log)
warning_log = %(general_log)
warning_prefix =
error_prefix =
critical_prefix =
[LOGGING_LEVELS]
root = DEBUG
discord = INFO
discord.http = INFO
discord.gateway = INFO

View File

@@ -1,10 +1,14 @@
BEGIN;
-- Metadata {{{
CREATE TABLE VersionHistory(
version INTEGER NOT NULL,
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
author TEXT
CREATE TABLE version_history(
component TEXT NOT NULL,
from_version INTEGER NOT NULL,
to_version INTEGER NOT NULL,
author TEXT NOT NULL,
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
CREATE OR REPLACE FUNCTION update_timestamp_column()
@@ -31,76 +35,8 @@ CREATE TABLE bot_config(
);
-- }}}
-- Channel Linker {{{
CREATE TABLE links(
linkid SERIAL PRIMARY KEY,
name TEXT
);
CREATE TABLE channel_webhooks(
channelid BIGINT PRIMARY KEY,
webhookid BIGINT NOT NULL,
token TEXT NOT NULL
);
CREATE TABLE channel_links(
linkid INTEGER NOT NULL REFERENCES links (linkid) ON DELETE CASCADE,
channelid BIGINT NOT NULL REFERENCES channel_webhooks (channelid) ON DELETE CASCADE,
PRIMARY KEY (linkid, channelid)
);
-- TODO: Profile data
-- }}}
-- Stream Alerts {{{
-- DROP TABLE IF EXISTS stream_alerts;
-- DROP TABLE IF EXISTS streams;
-- DROP TABLE IF EXISTS alert_channels;
-- DROP TABLE IF EXISTS streamers;
CREATE TABLE streamers(
userid BIGINT PRIMARY KEY,
login_name TEXT NOT NULL,
display_name TEXT NOT NULL
);
CREATE TABLE alert_channels(
subscriptionid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
channelid BIGINT NOT NULL,
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
created_by BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
paused BOOLEAN NOT NULL DEFAULT FALSE,
end_delete BOOLEAN NOT NULL DEFAULT FALSE,
live_message TEXT,
end_message TEXT
);
CREATE INDEX alert_channels_guilds ON alert_channels (guildid);
CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid);
CREATE TABLE streams(
streamid SERIAL PRIMARY KEY,
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
start_at TIMESTAMPTZ NOT NULL,
twitch_stream_id BIGINT,
game_name TEXT,
title TEXT,
end_at TIMESTAMPTZ
);
CREATE TABLE stream_alerts(
alertid SERIAL PRIMARY KEY,
streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE,
subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE,
sent_at TIMESTAMPTZ NOT NULL,
messageid BIGINT NOT NULL,
resolved_at TIMESTAMPTZ
);
-- }}}
COMMIT;
-- vim: set fdm=marker:

View File

@@ -0,0 +1,9 @@
[EMOJIS]
tick = :✅:
clock = :⏱️:
warning = :⚠️:
config = :⚙️:
stats = :📊:
utility = :⏱️:
cancel = :❌:

View File

@@ -0,0 +1,27 @@
[BOT]
prefix = t!
admins = 413668234269818890
admin_guilds = 1265249490063851571
shard_count = 1
ALSO_READ = config/emojis.conf, config/secrets.conf
[LOGGING]
log_file = bot.log
general_log = https://discord.com/api/webhooks/1409394313552593009/5SB_zbzyPa_ccshoe3ePGjCnbT9s6mPfCfpY8P7bL_Zn6vNkeF4CFFbAFEykHZlZl7e8
error_log = %(general_log)s
critical_log = %(general_log)s
warning_log = %(general_log)s
warning_prefix = **WARNING**
error_prefix = **ERROR**
critical_prefix = ***CRITICAL***
[LOGGING_LEVELS]
root = DEBUG
discord = INFO
discord.http = INFO
discord.gateway = INFO

View File

@@ -1,4 +1,4 @@
[STUDYLION]
[BOT]
token =
[DATA]

View File

@@ -1,8 +1,7 @@
aiohttp==3.7.4.post0
cachetools==4.2.2
configparser==5.0.2
aiohttp
cachetools
configparser
discord.py [voice]
iso8601==0.1.16
iso8601
psycopg[pool]
pytz==2021.1
twitchAPI
pytz

View File

@@ -13,8 +13,6 @@ from meta.monitor import ComponentMonitor, StatusLevel, ComponentStatus
from data import Database
from constants import DATA_VERSION
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
logging.getLogger(name).setLevel(conf.logging_levels[name])
@@ -57,15 +55,10 @@ async def main():
intents.presences = False
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
async with aiohttp.ClientSession() as session:
async with LionBot(
command_prefix='!leo!',
command_prefix=conf.bot.get('prefix', '!!'),
intents=intents,
appname=appname,
shardname=shardname,
@@ -81,7 +74,7 @@ async def main():
shard_count=sharding.shard_count,
help_command=None,
proxy=conf.bot.get('proxy', None),
chunk_guilds_at_startup=False,
chunk_guilds_at_startup=True,
) as lionbot:
ctx_bot.set(lionbot)
lionbot.system_monitor.add_component(

26
src/botdata.py Normal file
View File

@@ -0,0 +1,26 @@
from data import Registry, RowModel, Table
from data.columns import String, Timestamp, Integer, Bool
class VersionHistory(RowModel):
"""
CREATE TABLE version_history(
component TEXT NOT NULL,
from_version INTEGER NOT NULL,
to_version INTEGER NOT NULL,
author TEXT NOT NULL,
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
);
"""
_tablename_ = 'version_history'
_cache_ = {}
component = String()
from_version = Integer()
to_version = Integer()
author = String()
_timestamp = Timestamp()
class BotData(Registry):
version_history = VersionHistory.table

View File

@@ -1,6 +1,7 @@
CONFIG_FILE = "config/bot.conf"
DATA_VERSION = 1
MAX_COINS = 2147483647 - 1
HINT_ICON = "https://projects.iamcal.com/emoji-data/img-apple-64/1f4a1.png"
SCHEMA_VERSIONS = {
'ROOT': 1,
}

1
src/data Submodule

Submodule src/data added at cfdfe0eb50

View File

@@ -1,9 +0,0 @@
from .conditions import Condition, condition, NULL
from .database import Database
from .models import RowModel, RowTable, WeakCache
from .table import Table
from .base import Expression, RawExpr
from .columns import ColumnExpr, Column, Integer, String
from .registry import Registry, AttachableClass, Attachable
from .adapted import RegisterEnum
from .queries import ORDER, NULLS, JOINTYPE

View File

@@ -1,40 +0,0 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
class RegisterEnum(Attachable):
def __init__(self, enum, name: Optional[str] = None, mapper=None):
super().__init__()
self.enum = enum
self.name = name or enum.__name__
self.mapping = mapper(enum) if mapper is not None else self._mapper()
def _mapper(self):
return {m: m.value[0] for m in self.enum}
def attach_to(self, registry: Registry):
self._registry = registry
registry.init_task(self.on_init)
return self
async def on_init(self, registry: Registry):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,45 +0,0 @@
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from itertools import chain
from psycopg import sql
@runtime_checkable
class Expression(Protocol):
__slots__ = ()
@abstractmethod
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
raise NotImplementedError
class RawExpr(Expression):
__slots__ = ('expr', 'values')
expr: sql.Composable
values: tuple[Any, ...]
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
self.expr = expr
self.values = values
def as_tuple(self):
return (self.expr, self.values)
@classmethod
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
"""
Join a sequence of Expressions into a single RawExpr.
"""
tups = (
expression.as_tuple()
for expression in expressions
)
return cls.join_tuples(*tups, joiner=joiner)
@classmethod
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
exprs, values = zip(*tuples)
expr = joiner.join(exprs)
value = tuple(chain(*values))
return cls(expr, value)

View File

@@ -1,155 +0,0 @@
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
from psycopg import sql
from datetime import datetime
from .base import RawExpr, Expression
from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr):
__slots__ = ()
def __lt__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column < Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LT, obj_expr)
cond_values = (*values, *obj_values)
else:
# column < Literal
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __le__(self, obj) -> Condition:
expr, values = self.as_tuple()
if isinstance(obj, Expression):
# column <= Expression
obj_expr, obj_values = obj.as_tuple()
cond_exprs = (expr, Joiner.LE, obj_expr)
cond_values = (*values, *obj_values)
else:
# column <= Literal
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
cond_values = (*values, obj)
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __eq__(self, obj) -> Condition: # type: ignore[override]
return Condition._expression_equality(self, obj)
def __ne__(self, obj) -> Condition: # type: ignore[override]
return ~(self.__eq__(obj))
def __gt__(self, obj) -> Condition:
return ~(self.__le__(obj))
def __ge__(self, obj) -> Condition:
return ~(self.__lt__(obj))
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __sub__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def __mul__(self, obj) -> 'ColumnExpr':
if isinstance(obj, Expression):
obj_expr, obj_values = obj.as_tuple()
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, obj_expr),
(*self.values, *obj_values)
)
else:
return ColumnExpr(
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
(*self.values, obj)
)
def CAST(self, target_type: sql.Composable):
return ColumnExpr(
sql.SQL("({}::{})").format(self.expr, target_type),
self.values
)
T = TypeVar('T')
if TYPE_CHECKING:
from .models import RowModel
class Column(ColumnExpr, Generic[T]):
def __init__(self, name: Optional[str] = None,
primary: bool = False, references: Optional['Column'] = None,
type: Optional[Type[T]] = None):
self.primary = primary
self.references = references
self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None
self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('')
self.values = ()
def __set_name__(self, owner, name):
# Only allow setting the owner once
self.name = self.name or name
self.owner = owner
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
@overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
...
@overload
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
...
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
# Get value from row data or session
if obj is None:
return self
else:
return obj.data[self.name]
class Integer(Column[int]):
pass
class String(Column[str]):
pass
class Bool(Column[bool]):
pass
class Timestamp(Column[datetime]):
pass

View File

@@ -1,214 +0,0 @@
# from meta import sharding
from typing import Any, Union
from enum import Enum
from itertools import chain
from psycopg import sql
from .base import Expression, RawExpr
"""
A Condition is a "logical" database expression, intended for use in Where statements.
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
"""
NULL = None
class Joiner(Enum):
EQUALS = ('=', '!=')
IS = ('IS', 'IS NOT')
LIKE = ('LIKE', 'NOT LIKE')
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
IN = ('IN', 'NOT IN')
LT = ('<', '>=')
LE = ('<=', '>')
NONE = ('', '')
class Condition(Expression):
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
def __init__(self,
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
values: tuple[Any, ...] = (), negated=False
):
self.expr1 = expr1
self.joiner = joiner
self.negated = negated
self.expr2 = expr2
self.values = values
def as_tuple(self):
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
if self.negated and self.joiner is Joiner.NONE:
expr = sql.SQL("NOT ({})").format(expr)
return (expr, self.values)
@classmethod
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
"""
Construct a Condition from a sequence of Conditions,
together with some explicit column conditions.
"""
# TODO: Consider adding a _table identifier here so we can identify implicit columns
# Or just require subquery type conditions to always come from modelled tables.
implicit_conditions = (
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
)
return cls._and(*conditions, *implicit_conditions)
@classmethod
def _and(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _or(cls, *conditions: 'Condition'):
if not len(conditions):
raise ValueError("Cannot combine 0 Conditions")
if len(conditions) == 1:
return conditions[0]
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
cond_values = tuple(chain(*values))
return Condition(cond_expr, values=cond_values)
@classmethod
def _not(cls, condition: 'Condition'):
condition.negated = not condition.negated
return condition
@classmethod
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
# TODO: Check if this supports sbqueries
col_expr, col_values = column.as_tuple()
# TODO: Also support sql.SQL? For joins?
if isinstance(value, Expression):
# column = Expression
value_expr, value_values = value.as_tuple()
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
cond_values = (*col_values, *value_values)
elif isinstance(value, (tuple, list)):
# column in (...)
# TODO: Support expressions in value tuple?
if not value:
raise ValueError("Cannot create Condition from empty iterable!")
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
cond_exprs = (col_expr, Joiner.IN, value_expr)
cond_values = (*col_values, *value)
elif value is None:
# column IS NULL
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
cond_values = col_values
else:
# column = Literal
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
cond_values = (*col_values, value)
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
def __invert__(self) -> 'Condition':
self.negated = not self.negated
return self
def __and__(self, condition: 'Condition') -> 'Condition':
return self._and(self, condition)
def __or__(self, condition: 'Condition') -> 'Condition':
return self._or(self, condition)
# Helper method to simply condition construction
def condition(*args, **kwargs) -> Condition:
return Condition.construct(*args, **kwargs)
# class NOT(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# if item:
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
# values.extend(item)
# else:
# raise ValueError("Cannot check an empty iterable!")
# else:
# conditions.append("{}!={}".format(key, _replace_char))
# values.append(item)
#
#
# class GEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply GEQ condition to a list!")
# else:
# conditions.append("{} >= {}".format(key, _replace_char))
# values.append(item)
#
#
# class LEQ(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# item = self.value
# if isinstance(item, (list, tuple)):
# raise ValueError("Cannot apply LEQ condition to a list!")
# else:
# conditions.append("{} <= {}".format(key, _replace_char))
# values.append(item)
#
#
# class Constant(Condition):
# __slots__ = ('value',)
#
# def __init__(self, value):
# self.value = value
#
# def apply(self, key, values, conditions):
# conditions.append("{} {}".format(key, self.value))
#
#
# class SHARDID(Condition):
# __slots__ = ('shardid', 'shard_count')
#
# def __init__(self, shardid, shard_count):
# self.shardid = shardid
# self.shard_count = shard_count
#
# def apply(self, key, values, conditions):
# if self.shard_count > 1:
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
# values.append(self.shardid)
#
#
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
#
#
# NULL = Constant('IS NULL')
# NOTNULL = Constant('IS NOT NULL')

View File

@@ -1,135 +0,0 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Convenience property for the current context connection.
"""
return ctx_connection.get()
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""
Minimal decorator to register a coroutine to run on connect or reconnect.
Note that these are only run on connect and reconnect.
If a hook is registered after connection, it will not be run.
"""
self.conn_hooks.append(coro)
return coro
@runtime_checkable
class Connectable(Protocol):
def bind(self, connector: Connector):
raise NotImplementedError

View File

@@ -1,42 +0,0 @@
import logging
from typing import Optional
from psycopg import AsyncCursor, sql
from psycopg.abc import Query, Params
from psycopg._encodings import pgconn_encoding
logger = logging.getLogger(__name__)
class AsyncLoggingCursor(AsyncCursor):
def mogrify_query(self, query: Query):
if isinstance(query, str):
msg = query
elif isinstance(query, (sql.SQL, sql.Composed)):
msg = query.as_string(self)
elif isinstance(query, bytes):
msg = query.decode(pgconn_encoding(self._conn.pgconn), 'replace')
else:
msg = repr(query)
return msg
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
if logging.DEBUG >= logger.getEffectiveLevel():
msg = self.mogrify_query(query)
logger.debug(
"Executing query (%s) with values %s", msg, params,
extra={'action': "Query Execute"}
)
try:
return await super().execute(query, params=params, **kwargs)
except Exception:
msg = self.mogrify_query(query)
logger.exception(
"Exception during query execution. Query (%s) with parameters %s.",
msg, params,
extra={'action': "Query Execute"},
stack_info=True
)
else:
# TODO: Possibly log execution time
pass

View File

@@ -1,47 +0,0 @@
from typing import TypeVar
import logging
from collections import namedtuple
# from .cursor import AsyncLoggingCursor
from .registry import Registry
from .connector import Connector
logger = logging.getLogger(__name__)
Version = namedtuple('Version', ('version', 'time', 'author'))
T = TypeVar('T', bound=Registry)
class Database(Connector):
# cursor_factory = AsyncLoggingCursor
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.registries: dict[str, Registry] = {}
def load_registry(self, registry: T) -> T:
logger.debug(
f"Loading and binding registry '{registry.name}'.",
extra={'action': f"Reg {registry.name}"}
)
registry.bind(self)
self.registries[registry.name] = registry
return registry
async def version(self) -> Version:
"""
Return the current schema version as a Version namedtuple.
"""
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

View File

@@ -1,323 +0,0 @@
from typing import TypeVar, Type, Optional, Generic, Union
# from typing_extensions import Self
from weakref import WeakValueDictionary
from collections.abc import MutableMapping
from psycopg.rows import DictRow
from .table import Table
from .columns import Column
from . import queries as q
from .connector import Connector
from .registry import Registry
RowT = TypeVar('RowT', bound='RowModel')
class MISSING:
__slots__ = ('oid',)
def __init__(self, oid):
self.oid = oid
class RowTable(Table, Generic[RowT]):
__slots__ = (
'model',
)
def __init__(self, name, model: Type[RowT], **kwargs):
super().__init__(name, **kwargs)
self.model = model
@property
def columns(self):
return self.model._columns_
@property
def id_col(self):
return self.model._key_
@property
def row_cache(self):
return self.model._cache_
def _many_query_adapter(self, *data):
self.model._make_rows(*data)
return data
def _single_query_adapter(self, *data):
if data:
self.model._make_rows(*data)
return data[0]
else:
return None
def _delete_query_adapter(self, *data):
self.model._delete_rows(*data)
return data
# New methods to fetch and create rows
async def create_row(self, *args, **kwargs) -> RowT:
data = await super().insert(*args, **kwargs)
return self.model._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here?
return q.Select(
self.identifier,
row_adapter=self.model._make_rows,
connector=self.connector
).where(*args, **kwargs)
WK = TypeVar('WK')
WV = TypeVar('WV')
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
def __init__(self, ref_cache):
self.ref_cache = ref_cache
self.weak_cache = WeakValueDictionary()
def __getitem__(self, key):
value = self.weak_cache[key]
self.ref_cache[key] = value
return value
def __setitem__(self, key, value):
self.weak_cache[key] = value
self.ref_cache[key] = value
def __delitem__(self, key):
del self.weak_cache[key]
try:
del self.ref_cache[key]
except KeyError:
pass
def __contains__(self, key):
return key in self.weak_cache
def __iter__(self):
return iter(self.weak_cache)
def __len__(self):
return len(self.weak_cache)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
if key in self:
value = self[key]
del self[key]
else:
value = default
return value
# TODO: Implement getitem and setitem, for dynamic column access
class RowModel:
__slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None
_columns_: dict[str, Column] = {}
# Cache to keep track of registered Rows
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
_key_: tuple[str, ...] = ()
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
table: RowTable = None
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
"""
Set table, _columns_, and _key_.
"""
if table is not None:
cls._tablename_ = table
if cls._tablename_ is not None:
columns = {}
for key, value in cls.__dict__.items():
if isinstance(value, Column):
columns[key] = value
cls._columns_ = columns
if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None:
cls._cache_ = WeakValueDictionary()
def __new__(cls, data):
# Registry pattern.
# Ensure each rowid always refers to a single Model instance
if data is not None:
rowid = cls._id_from_data(data)
cache = cls._cache_
if (row := cache.get(rowid, None)) is not None:
obj = row
else:
obj = cache[rowid] = super().__new__(cls)
else:
obj = super().__new__(cls)
return obj
@classmethod
def as_tuple(cls):
return (cls.table.identifier, ())
def __init__(self, data):
self.data = data
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
@classmethod
def bind(cls, connector: Connector):
if cls.table is None:
raise ValueError("Cannot bind abstract RowModel")
cls._connector = connector
cls.table.bind(connector)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@property
def _dict_(self):
return {key: self.data[key] for key in self._key_}
@property
def _rowid_(self):
return tuple(self.data[key] for key in self._key_)
def __repr__(self):
return "{}.{}({})".format(
self.table.schema,
self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
)
@classmethod
def _id_from_data(cls, data):
return tuple(data[key] for key in cls._key_)
@classmethod
def _dict_from_id(cls, rowid):
return dict(zip(cls._key_, rowid))
@classmethod
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
# TODO: Handle partial row data here somehow?
rows = [cls(data_row) for data_row in data_rows]
return rows
@classmethod
def _delete_rows(cls, *data_rows):
"""
Remove the given rows from cache, if they exist.
May be extended to handle object deletion.
"""
cache = cls._cache_
for data_row in data_rows:
rowid = cls._id_from_data(data_row)
cache.pop(rowid, None)
@classmethod
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
return await cls.table.create_row(*args, **kwargs)
@classmethod
def fetch_where(cls: Type[RowT], *args, **kwargs):
return cls.table.fetch_rows_where(*args, **kwargs)
@classmethod
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = cls._cache_.get(rowid, None) if cached else None
if row is None:
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
row = rows[0] if rows else None
if row is None:
cls._cache_[rowid] = cls(None)
elif row.data is None:
row = None
return row
@classmethod
async def fetch_or_create(cls, *rowid, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid:
row = await cls.fetch(*rowid)
else:
rows = await cls.fetch_where(**kwargs).limit(1)
row = rows[0] if rows else None
if row is None:
creation_kwargs = kwargs
if rowid:
creation_kwargs.update(cls._dict_from_id(rowid))
row = await cls.create(**creation_kwargs)
return row
async def refresh(self: RowT) -> Optional[RowT]:
"""
Refresh this Row from data.
The return value may be `None` if the row was deleted.
"""
rows = await self.table.select_where(**self._dict_)
if not rows:
return None
else:
self.data = rows[0]
return self
async def update(self: RowT, **values) -> Optional[RowT]:
"""
Update this Row with the given values.
Internally passes the provided `values` to the `update` Query.
The return value may be `None` if the row was deleted.
"""
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
if not data:
return None
else:
return data[0]
async def delete(self: RowT) -> Optional[RowT]:
"""
Delete this Row.
"""
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
return data[0] if data is not None else None

View File

@@ -1,644 +0,0 @@
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
from enum import Enum
from itertools import chain
from psycopg import AsyncConnection, AsyncCursor
from psycopg import sql
from psycopg.rows import DictRow
import logging
from .conditions import Condition
from .base import Expression, RawExpr
from .connector import Connector
logger = logging.getLogger(__name__)
TQueryT = TypeVar('TQueryT', bound='TableQuery')
SQueryT = TypeVar('SQueryT', bound='Select')
QueryResult = TypeVar('QueryResult')
class Query(Generic[QueryResult]):
"""
ABC for an executable query statement.
"""
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
_adapter: Callable[..., QueryResult]
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
self.connector: Optional[Connector] = connector
self.conn: Optional[AsyncConnection] = conn
self.cursor: Optional[AsyncCursor] = cursor
if row_adapter is not None:
self._adapter = row_adapter
else:
self._adapter = self._no_adapter
self.result: Optional[QueryResult] = None
def bind(self, connector: Connector):
self.connector = connector
return self
def with_cursor(self, cursor: AsyncCursor):
self.cursor = cursor
return self
def with_connection(self, conn: AsyncConnection):
self.conn = conn
return self
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def with_adapter(self, callable: Callable[..., QueryResult]):
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
# For this to work cleanly, callable should have arg type of QR1, not any
self._adapter = callable
return self
def with_no_adapter(self):
"""
Sets the adapater to the identity.
"""
self._adapter = self._no_adapter
return self
def one(self):
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
return self
def build(self) -> Expression:
raise NotImplementedError
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
query, values = self.build().as_tuple()
# TODO: Move logging out to a custom cursor
# logger.debug(
# f"Executing query ({query.as_string(cursor)}) with values {values}",
# extra={'action': "Query"}
# )
await cursor.execute(sql.Composed((query,)), values)
data = await cursor.fetchall()
self.result = self._adapter(*data)
return self.result
async def execute(self, cursor=None) -> QueryResult:
"""
Execute the query, optionally with the provided cursor, and return the result rows.
If no cursor is provided, and no cursor has been set with `with_cursor`,
the execution will create a new cursor from the connection and close it automatically.
"""
# Create a cursor if possible
cursor = cursor if cursor is not None else self.cursor
if self.cursor is None:
if self.conn is None:
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data
def __await__(self):
return self.execute().__await__()
class TableQuery(Query[QueryResult]):
"""
ABC for an executable query statement expected to be run on a single table.
"""
__slots__ = (
'tableid',
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
)
def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tableid: sql.Identifier = tableid
def options(self, **kwargs):
"""
Set some query options.
Default implementation does nothing.
Should be overridden to provide specific options.
"""
return self
class WhereMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.condition: Optional[Condition] = None
def where(self, *args: Condition, **kwargs):
"""
Add a Condition to the query.
Position arguments should be Conditions,
and keyword arguments should be of the form `column=Value`,
where Value may be a Value-type or a literal value.
All provided Conditions will be and-ed together to create a new Condition.
TODO: Maybe just pass this verbatim to a condition.
"""
if args or kwargs:
condition = Condition.construct(*args, **kwargs)
if self.condition is not None:
condition = self.condition & condition
self.condition = condition
return self
@property
def _where_section(self) -> Optional[Expression]:
if self.condition is not None:
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
else:
return None
class JOINTYPE(Enum):
LEFT = sql.SQL('LEFT JOIN')
RIGHT = sql.SQL('RIGHT JOIN')
INNER = sql.SQL('INNER JOIN')
OUTER = sql.SQL('OUTER JOIN')
FULLOUTER = sql.SQL('FULL OUTER JOIN')
class JoinMixin(TableQuery[QueryResult]):
__slots__ = ()
# TODO: Remember to add join slots to TableQuery
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._joins: list[Expression] = []
def join(self,
target: Union[str, Expression],
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
join_type: JOINTYPE = JOINTYPE.INNER,
natural=False):
available = (on is not None) + (using is not None) + natural
if available == 0:
raise ValueError("No conditions given for Query Join")
if available > 1:
raise ValueError("Exactly one join format must be given for Query Join")
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
if isinstance(target, str):
sections.append((sql.Identifier(target), ()))
else:
sections.append(target.as_tuple())
if on is not None:
sections.append((sql.SQL('ON'), ()))
sections.append(on.as_tuple())
elif using is not None:
sections.append((sql.SQL('USING'), ()))
if isinstance(using, Expression):
sections.append(using.as_tuple())
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
sections.append((cols, ()))
else:
raise ValueError("Unrecognised 'using' type.")
elif natural:
sections.insert(0, (sql.SQL('NATURAL'), ()))
expr = RawExpr.join_tuples(*sections)
self._joins.append(expr)
return self
def leftjoin(self, *args, **kwargs):
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
@property
def _join_section(self) -> Optional[Expression]:
if self._joins:
return RawExpr.join(*self._joins)
else:
return None
class ExtraMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._extra: Optional[Expression] = None
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
"""
Add an extra string, and optionally values, to this query.
The extra string is inserted after any condition, and before the limit.
"""
extra_expr = RawExpr(extra, values)
if self._extra is not None:
extra_expr = RawExpr.join(self._extra, extra_expr)
self._extra = extra_expr
return self
@property
def _extra_section(self) -> Optional[Expression]:
if self._extra is None:
return None
else:
return self._extra
class LimitMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._limit: Optional[int] = None
def limit(self, limit: int):
"""
Add a limit to this query.
"""
self._limit = limit
return self
@property
def _limit_section(self) -> Optional[Expression]:
if self._limit is not None:
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
else:
return None
class FromMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._from: Optional[Expression] = None
def from_expr(self, _from: Expression):
self._from = _from
return self
@property
def _from_section(self) -> Optional[Expression]:
if self._from is not None:
expr, values = self._from.as_tuple()
return RawExpr(sql.SQL("FROM {}").format(expr), values)
else:
return None
class ORDER(Enum):
ASC = sql.SQL('ASC')
DESC = sql.SQL('DESC')
class NULLS(Enum):
FIRST = sql.SQL('NULLS FIRST')
LAST = sql.SQL('NULLS LAST')
class OrderMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._order: list[Expression] = []
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
"""
Add a single sort expression to the query.
This method stacks.
"""
if isinstance(expr, Expression):
string, values = expr.as_tuple()
else:
string = sql.Identifier(expr)
values = ()
parts = [string]
if direction is not None:
parts.append(direction.value)
if nulls is not None:
parts.append(nulls.value)
order_string = sql.SQL(' ').join(parts)
self._order.append(RawExpr(order_string, values))
return self
@property
def _order_section(self) -> Optional[Expression]:
if self._order:
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
return expr
else:
return None
class GroupMixin(TableQuery[QueryResult]):
__slots__ = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._group: list[Expression] = []
def group_by(self, *exprs: Union[Expression, str]):
"""
Add a group expression(s) to the query.
This method stacks.
"""
for expr in exprs:
if isinstance(expr, Expression):
self._group.append(expr)
else:
self._group.append(RawExpr(sql.Identifier(expr)))
return self
@property
def _group_section(self) -> Optional[Expression]:
if self._group:
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
return expr
else:
return None
class Insert(ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table insert query.
"""
# TODO: Support ON CONFLICT for upserts
__slots__ = ('_columns', '_values', '_conflict')
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[str, ...] = ()
self._values: tuple[tuple[Any, ...], ...] = ()
self._conflict: Optional[Expression] = None
def insert(self, columns, *values):
"""
Insert the given data.
Parameters
----------
columns: tuple[str]
Tuple of column names to insert.
values: tuple[tuple[Any, ...], ...]
Tuple of values to insert, corresponding to the columns.
"""
if not values:
raise ValueError("Cannot insert zero rows.")
if len(values[0]) != len(columns):
raise ValueError("Number of columns does not match length of values.")
self._columns = columns
self._values = values
return self
def on_conflict(self, ignore=False):
# TODO lots more we can do here
# Maybe return a Conflict object that can chain itself (not the query)
if ignore:
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
return self
@property
def _conflict_section(self) -> Optional[Expression]:
if self._conflict is not None:
e, v = self._conflict.as_tuple()
expr = RawExpr(
sql.SQL("ON CONFLICT {}").format(
e
),
v
)
return expr
return None
def build(self):
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
single_value_str = sql.SQL('({})').format(
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
)
values_str = sql.SQL(',').join(single_value_str * len(self._values))
# TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=self.tableid,
columns=columns,
values_str=values_str
)
sections = [
RawExpr(base, tuple(chain(*self._values))),
self._conflict_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
"""
Select rows from a table matching provided conditions.
"""
__slots__ = ('_columns',)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._columns: tuple[Expression, ...] = ()
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
"""
Set the columns and expressions to select.
If none are given, selects all columns.
"""
cols: List[Expression] = []
if columns:
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
if exprs:
for name, expr in exprs.items():
if isinstance(expr, str):
cols.append(
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, sql.Composable):
cols.append(
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
)
elif isinstance(expr, Expression):
value_expr, value_values = expr.as_tuple()
cols.append(RawExpr(
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
value_values
))
if cols:
self._columns = (*self._columns, *cols)
return self
def build(self):
if not self._columns:
columns, columns_values = sql.SQL('*'), ()
else:
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns,
table=self.tableid
)
sections = [
RawExpr(base, columns_values),
self._join_section,
self._where_section,
self._group_section,
self._extra_section,
self._order_section,
self._limit_section,
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
"""
Query type representing a table delete query.
"""
# TODO: Cascade option for delete, maybe other options
# TODO: Require a where unless specifically disabled, for safety
def build(self):
base = sql.SQL("DELETE FROM {table}").format(
table=self.tableid,
)
sections = [
RawExpr(base),
self._where_section,
self._extra_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
__slots__ = (
'_set',
)
# TODO: Again, require a where unless specifically disabled
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set: List[Expression] = []
def set(self, **column_values: Union[Any, Expression]):
exprs: List[Expression] = []
for name, value in column_values.items():
if isinstance(value, Expression):
value_tup = value.as_tuple()
else:
value_tup = (sql.Placeholder(), (value,))
exprs.append(
RawExpr.join_tuples(
(sql.Identifier(name), ()),
value_tup,
joiner=sql.SQL(' = ')
)
)
self._set.extend(exprs)
return self
def build(self):
if not self._set:
raise ValueError("No columns provided to update.")
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format(
table=self.tableid,
set=set_expr
)
sections = [
RawExpr(base, set_values),
self._from_section,
self._where_section,
self._extra_section,
self._limit_section,
RawExpr(sql.SQL('RETURNING *'))
]
sections = (section for section in sections if section is not None)
return RawExpr.join(*sections)
# async def upsert(cursor, table, constraint, **values):
# """
# Insert or on conflict update.
# """
# valuedict = values
# keys, values = zip(*values.items())
#
# key_str = _format_insertkeys(keys)
# value_str, values = _format_insertvalues(values)
# update_key_str, update_key_values = _format_updatestr(valuedict)
#
# if not isinstance(constraint, str):
# constraint = ", ".join(constraint)
#
# await cursor.execute(
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
# table, key_str, value_str, constraint, update_key_str
# ),
# tuple((*values, *update_key_values))
# )
# return await cursor.fetchone()
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
# cursor = cursor or conn.cursor()
#
# # TODO: executemany or copy syntax now
# return execute_values(
# cursor,
# """
# UPDATE {table}
# SET {set_clause}
# FROM (VALUES {cast_row}%s)
# AS {temp_table}
# WHERE {where_clause}
# RETURNING *
# """.format(
# table=table,
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
# cast_row=cast_row + ',' if cast_row else '',
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
# ),
# values,
# fetch=True
# )

View File

@@ -1,102 +0,0 @@
from typing import Protocol, runtime_checkable, Optional
from psycopg import AsyncConnection
from .connector import Connector, Connectable
@runtime_checkable
class _Attachable(Connectable, Protocol):
def attach_to(self, registry: 'Registry'):
raise NotImplementedError
class Registry:
_attached: list[_Attachable] = []
_name: Optional[str] = None
def __init_subclass__(cls, name=None):
attached = []
for _, member in cls.__dict__.items():
if isinstance(member, _Attachable):
attached.append(member)
cls._attached = attached
cls._name = name or cls.__name__
def __init__(self, name=None):
self._conn: Optional[Connector] = None
self.name: str = name if name is not None else self._name
if self.name is None:
raise ValueError("A Registry must have a name!")
self.init_tasks = []
for member in self._attached:
member.attach_to(self)
def bind(self, connector: Connector):
self._conn = connector
for child in self._attached:
child.bind(connector)
def attach(self, attachable):
self._attached.append(attachable)
if self._conn is not None:
attachable.bind(self._conn)
return attachable
def init_task(self, coro):
"""
Initialisation tasks are run to setup the registry state.
These tasks will be run in the event loop, after connection to the database.
These tasks should be idempotent, as they may be run on reload and reconnect.
"""
self.init_tasks.append(coro)
return coro
async def init(self):
for task in self.init_tasks:
await task(self)
return self
class AttachableClass:
"""ABC for a default implementation of an Attachable class."""
_connector: Optional[Connector] = None
_registry: Optional[Registry] = None
@classmethod
def bind(cls, connector: Connector):
cls._connector = connector
connector.connect_hook(cls.on_connect)
return cls
@classmethod
def attach_to(cls, registry: Registry):
cls._registry = registry
return cls
@classmethod
async def on_connect(cls, connection: AsyncConnection):
pass
class Attachable:
"""ABC for a default implementation of an Attachable object."""
def __init__(self, *args, **kwargs):
self._connector: Optional[Connector] = None
self._registry: Optional[Registry] = None
def bind(self, connector: Connector):
self._connector = connector
connector.connect_hook(self.on_connect)
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
async def on_connect(self, connection: AsyncConnection):
pass

View File

@@ -1,95 +0,0 @@
from typing import Optional
from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q
from .connector import Connector
from .registry import Registry
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
"""
def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name
self.schema: str = schema
self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector):
self.connector = connector
return self
def attach_to(self, registry: Registry):
self._registry = registry
return self
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
if data:
return data[0]
else:
return None
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
return data
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert(
self.identifier,
row_adapter=self._single_query_adapter,
connector=self.connector
).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert(
self.identifier,
row_adapter=self._many_query_adapter,
connector=self.connector
).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs):
# with self.conn:
# return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs):
# return upsert(self.identifier, *args, **kwargs)

View File

@@ -3,6 +3,7 @@ import logging
import asyncio
from weakref import WeakValueDictionary
from constants import SCHEMA_VERSIONS
import discord
from discord.utils import MISSING
from discord.ext.commands import Bot, Cog, HybridCommand, HybridCommandError
@@ -10,9 +11,10 @@ from discord.ext.commands.errors import CommandInvokeError, CheckFailure
from discord.app_commands.errors import CommandInvokeError as appCommandInvokeError, TransformerError
from aiohttp import ClientSession
from data import Database
from data import Database, ORDER
from utils.lib import tabulate
from babel.translator import LeoBabel
from botdata import BotData, VersionHistory
from .config import Conf
from .logger import logging_context, log_context, log_action_stack, log_wrap, set_logging_context
@@ -43,6 +45,7 @@ class LionBot(Bot):
self.appname = appname
self.shardname = shardname
# self.appdata = appdata
self.data: BotData = db.load_registry(BotData())
self.config = config
self.translator = LeoBabel()
@@ -53,6 +56,10 @@ class LionBot(Bot):
self._locks = WeakValueDictionary()
self._running_events = set()
@property
def dbconn(self):
return self.db
@property
def core(self):
return self.get_cog('CoreCog')
@@ -129,6 +136,10 @@ class LionBot(Bot):
await wrapper()
async def start(self, token: str, *, reconnect: bool = True):
await self.data.init()
for component, req in SCHEMA_VERSIONS.items():
await self.version_check(component, req)
with logging_context(action="Login"):
start_task = asyncio.create_task(self.login(token))
await start_task
@@ -137,6 +148,24 @@ class LionBot(Bot):
run_task = asyncio.create_task(self.connect(reconnect=reconnect))
await run_task
async def version_check(self, component: str, req_version: int):
# Query the database to confirm that the given component is listed with the given version.
# Typically done upon loading a component
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
version = rows[0].to_version if rows else 0
if version != req_version:
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
else:
logger.debug(
"Component %s passed version check with version %s",
component,
version
)
return True
def dispatch(self, event_name: str, *args, **kwargs):
with logging_context(action=f"Dispatch {event_name}"):
super().dispatch(event_name, *args, **kwargs)
@@ -191,7 +220,7 @@ class LionBot(Bot):
# TODO: Some of these could have more user-feedback
logger.debug(f"Handling command error for {ctx}: {exception}")
if isinstance(ctx.command, HybridCommand) and ctx.command.app_command:
cmd_str = ctx.command.app_command.to_dict()
cmd_str = ctx.command.app_command.to_dict(self.tree)
else:
cmd_str = str(ctx.command)
try:

View File

@@ -131,7 +131,7 @@ class LionTree(CommandTree):
return
set_logging_context(action=f"Run {command.qualified_name}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict()}")
logger.debug(f"Running command '{command.qualified_name}': {command.to_dict(self)}")
try:
await command._invoke_with_namespace(interaction, namespace)
except AppCommandError as e:

View File

@@ -2,8 +2,6 @@ this_package = 'modules'
active = [
'.sysadmin',
'.voicefix',
'.streamalerts',
]

View File

@@ -1,8 +0,0 @@
import logging
from meta import LionBot
logger = logging.getLogger(__name__)
async def setup(bot: LionBot):
from .cog import AlertCog
await bot.add_cog(AlertCog(bot))

View File

@@ -1,609 +0,0 @@
import asyncio
from typing import Optional
import discord
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from twitchAPI.twitch import Twitch
from twitchAPI.helper import first
from meta import LionBot, LionCog, LionContext
from meta.errors import UserInputError
from meta.logger import log_wrap
from utils.lib import utc_now
from data.conditions import NULL
from . import logger
from .data import AlertsData
from .settings import AlertConfig, AlertSettings
from .editor import AlertEditorUI
class AlertCog(LionCog):
POLL_PERIOD = 60
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(AlertsData())
self.twitch = None
self.alert_settings = AlertSettings()
self.poll_task = None
self.event_tasks = set()
# Cache of currently live streams, maps streamerid -> stream
self.live_streams = {}
# Cache of streamers we are watching state changes for
# Map of streamerid -> streamer
self.watching = {}
async def cog_load(self):
await self.data.init()
await self.twitch_login()
await self.load_subs()
self.poll_task = asyncio.create_task(self.poll_live())
async def twitch_login(self):
# TODO: Probably abstract this out to core or a dedicated core cog
# Also handle refresh tokens
if self.twitch is not None:
await self.twitch.close()
self.twitch = None
self.twitch = await Twitch(
self.bot.config.twitch['app_id'].strip(),
self.bot.config.twitch['app_secret'].strip()
)
async def load_subs(self):
# Load active subscriptions
active_subs = await self.data.AlertChannel.fetch_where()
to_watch = {sub.streamerid for sub in active_subs}
live_streams = await self.data.Stream.fetch_where(
self.data.Stream.end_at != NULL
)
to_watch.union(stream.streamerid for stream in live_streams)
# Load associated streamers
watching = {}
if to_watch:
streamers = await self.data.Streamer.fetch_where(
userid=list(to_watch)
)
for streamer in streamers:
watching[streamer.userid] = streamer
self.watching = watching
self.live_streams = {stream.streamerid: stream for stream in live_streams}
logger.info(
f"Watching {len(watching)} streamers for state changes. "
f"Loaded {len(live_streams)} (previously) live streams into cache."
)
async def poll_live(self):
# Every PERIOD seconds,
# request get_streams for the streamers we are currently watching.
# Check if they are in the live_stream cache,
# and update cache and data and fire-and-forget start/stop events as required.
# TODO: Logging
# TODO: Error handling so the poll loop doesn't die from temporary errors
# And when it does die it gets logged properly.
if not self.twitch:
raise ValueError("Attempting to start alert poll-loop before twitch set.")
block_i = 0
self.polling = True
while self.polling:
await asyncio.sleep(self.POLL_PERIOD)
to_request = list(self.watching.keys())
if not to_request:
continue
# Each loop we request the 'next' slice of 100 userids
blocks = [to_request[i:i+100] for i in range(0, len(to_request), 100)]
block_i += 1
block_i %= len(blocks)
block = blocks[block_i]
streaming = {}
async for stream in self.twitch.get_streams(user_id=block, first=100):
# Note we set page size to 100
# So we should never get repeat or missed streams
# Since we can request a max of 100 userids anyway.
streaming[stream.user_id] = stream
started = set(streaming.keys()).difference(self.live_streams.keys())
ended = set(self.live_streams.keys()).difference(streaming.keys())
for streamerid in started:
stream = streaming[streamerid]
stream_data = await self.data.Stream.create(
streamerid=stream.user_id,
start_at=stream.started_at,
twitch_stream_id=stream.id,
game_name=stream.game_name,
title=stream.title,
)
self.live_streams[streamerid] = stream_data
task = asyncio.create_task(self.on_stream_start(stream_data))
self.event_tasks.add(task)
task.add_done_callback(self.event_tasks.discard)
for streamerid in ended:
stream_data = self.live_streams.pop(streamerid)
await stream_data.update(end_at=utc_now())
task = asyncio.create_task(self.on_stream_end(stream_data))
self.event_tasks.add(task)
task.add_done_callback(self.event_tasks.discard)
async def on_stream_start(self, stream_data):
# Get channel subscriptions listening for this streamer
uid = stream_data.streamerid
logger.info(f"Streamer <uid:{uid}> started streaming! {stream_data=}")
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
# Fulfill those alerts
for sub in subbed:
try:
# If the sub is paused, don't create the alert
await self.sub_alert(sub, stream_data)
except discord.HTTPException:
# TODO: Needs to be handled more gracefully at user level
# Retry logic?
logger.warning(
f"Could not complete subscription {sub=} for {stream_data=}", exc_info=True
)
except Exception:
logger.exception(
f"Unexpected exception completing {sub=} for {stream_data=}"
)
raise
async def subscription_error(self, subscription, stream_data, err_msg):
"""
Handle a subscription fulfill failure.
Stores the error message for user display,
and deletes the subscription after some number of errors.
# TODO
"""
logger.warning(
f"Subscription error {subscription=} {stream_data=} {err_msg=}"
)
async def sub_alert(self, subscription, stream_data):
# Base alert behaviour is just to send a message
# and create an alert row
channel = self.bot.get_channel(subscription.channelid)
if channel is None or not isinstance(channel, discord.abc.Messageable):
# Subscription channel is gone!
# Or the Discord channel cache died
await self.subscription_error(
subscription, stream_data,
"Subscription channel no longer exists."
)
return
permissions = channel.permissions_for(channel.guild.me)
if not (permissions.send_messages and permissions.embed_links):
await self.subscription_error(
subscription, stream_data,
"Insufficient permissions to post alert message."
)
return
# Build message
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
if not streamer:
# Streamer was deleted while handling the alert
# Just quietly ignore
# Don't error out because the stream data row won't exist anymore
logger.warning(
f"Cancelling alert for subscription {subscription.subscriptionid}"
" because the streamer no longer exists."
)
return
alert_config = AlertConfig(subscription.subscriptionid, subscription)
paused = alert_config.get(self.alert_settings.AlertPaused.setting_id)
if paused.value:
logger.info(f"Skipping alert for subscription {subscription=} because it is paused.")
return
live_message = alert_config.get(self.alert_settings.AlertMessage.setting_id)
formatter = await live_message.generate_formatter(self.bot, stream_data, streamer)
formatted = await formatter(live_message.value)
args = live_message.value_to_args(subscription.subscriptionid, formatted)
try:
message = await channel.send(**args.send_args)
except discord.HTTPException as e:
logger.warning(
f"Message send failure while sending streamalert {subscription.subscriptionid}",
exc_info=True
)
await self.subscription_error(
subscription, stream_data,
"Failed to post live alert."
)
return
# Store sent alert
alert = await self.data.StreamAlert.create(
streamid=stream_data.streamid,
subscriptionid=subscription.subscriptionid,
sent_at=utc_now(),
messageid=message.id
)
logger.debug(
f"Fulfilled subscription {subscription.subscriptionid} with alert {alert.alertid}"
)
async def on_stream_end(self, stream_data):
# Get channel subscriptions listening for this streamer
uid = stream_data.streamerid
logger.info(f"Streamer <uid:{uid}> stopped streaming! {stream_data=}")
subbed = await self.data.AlertChannel.fetch_where(streamerid=uid)
# Resolve subscriptions
for sub in subbed:
try:
await self.sub_resolve(sub, stream_data)
except discord.HTTPException:
# TODO: Needs to be handled more gracefully at user level
# Retry logic?
logger.warning(
f"Could not resolve subscription {sub=} for {stream_data=}", exc_info=True
)
except Exception:
logger.exception(
f"Unexpected exception resolving {sub=} for {stream_data=}"
)
raise
async def sub_resolve(self, subscription, stream_data):
# Check if there is a current active alert to resolve
alerts = await self.data.StreamAlert.fetch_where(
streamid=stream_data.streamid,
subscriptionid=subscription.subscriptionid,
)
if not alerts:
logger.info(
f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} "
"but no active alerts were found."
)
return
alert = alerts[0]
if alert.resolved_at is not None:
# Alert was already resolved
# This is okay, Twitch might have just sent the stream ending twice
logger.info(
f"Resolution requested for subscription {subscription.subscriptionid} with stream {stream_data.streamid} "
"but alert was already resolved."
)
return
# Check if message is to be deleted or edited (or nothing)
alert_config = AlertConfig(subscription.subscriptionid, subscription)
del_setting = alert_config.get(self.alert_settings.AlertEndDelete.setting_id)
edit_setting = alert_config.get(self.alert_settings.AlertEndMessage.setting_id)
if (delmsg := del_setting.value) or (edit_setting.value):
# Find the message
message = None
channel = self.bot.get_channel(subscription.channelid)
if channel:
try:
message = await channel.fetch_message(alert.messageid)
except discord.HTTPException:
# Message was probably deleted already
# Or permissions were changed
# Or Discord connection broke
pass
else:
# Channel went after posting the alert
# Or Discord cache sucks
# Nothing we can do, just mark it handled
pass
if message:
if delmsg:
# Delete the message
try:
await message.delete()
except discord.HTTPException:
logger.warning(
f"Discord exception while del-resolve live alert {alert=}",
exc_info=True
)
else:
# Edit message with custom arguments
streamer = await self.data.Streamer.fetch(stream_data.streamerid)
formatter = await edit_setting.generate_formatter(self.bot, stream_data, streamer)
formatted = await formatter(edit_setting.value)
args = edit_setting.value_to_args(subscription.subscriptionid, formatted)
try:
await message.edit(**args.edit_args)
except discord.HTTPException:
logger.warning(
f"Discord exception while edit-resolve live alert {alert=}",
exc_info=True
)
else:
# Explicitly don't need to do anything to the alert
pass
# Save alert as resolved
await alert.update(resolved_at=utc_now())
async def cog_unload(self):
if self.poll_task is not None and not self.poll_task.cancelled():
self.poll_task.cancel()
if self.twitch is not None:
await self.twitch.close()
self.twitch = None
# ----- Commands -----
@cmds.hybrid_group(
name='streamalert',
description=(
"Create and configure stream live-alerts."
)
)
@cmds.guild_only()
@appcmds.default_permissions(manage_channels=True)
async def streamalert_group(self, ctx: LionContext):
# Placeholder group, method not used
raise NotImplementedError
@streamalert_group.command(
name='create',
description=(
"Subscribe a Discord channel to notifications when a Twitch stream goes live."
)
)
@appcmds.describe(
streamer="Name of the twitch channel to watch.",
channel="Which Discord channel to send live alerts in.",
message="Custom message to send when the channel goes live (may be edited later)."
)
@appcmds.default_permissions(manage_channels=True)
async def streamalert_create_cmd(self, ctx: LionContext,
streamer: str,
channel: discord.TextChannel,
message: Optional[str]):
# Type guards
assert ctx.guild is not None, "Guild-only command has no guild ctx."
assert self.twitch is not None, "Twitch command run with no twitch obj."
# Wards
if not channel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(
"Sorry, you need the `MANAGE_CHANNELS` permission "
"to add a stream alert to a channel."
)
return
# Look up the specified streamer
tw_user = await first(self.twitch.get_users(logins=[streamer]))
if not tw_user:
await ctx.error_reply(
f"Sorry, could not find `{streamer}` on Twitch! "
"Make sure you use the name in their channel url."
)
return
# Create streamer data if it doesn't already exist
streamer_data = await self.data.Streamer.fetch_or_create(
tw_user.id,
login_name=tw_user.login,
display_name=tw_user.display_name,
)
# Add subscription to alerts list
sub_data = await self.data.AlertChannel.create(
streamerid=streamer_data.userid,
guildid=channel.guild.id,
channelid=channel.id,
created_by=ctx.author.id,
paused=False
)
# Add to watchlist
self.watching[streamer_data.userid] = streamer_data
# Open AlertEditorUI for the new subscription
# TODO
await ctx.reply("StreamAlert Created.")
async def alert_acmpl(self, interaction: discord.Interaction, partial: str):
if not interaction.guild:
raise ValueError("Cannot acmpl alert in guildless interaction.")
# Get all alerts in the server
alerts = await self.data.AlertChannel.fetch_where(guildid=interaction.guild_id)
if not alerts:
# No alerts available
options = [
appcmds.Choice(
name="No stream alerts are set up in this server!",
value=partial
)
]
else:
options = []
for alert in alerts:
streamer = await self.data.Streamer.fetch(alert.streamerid)
if streamer is None:
# Should be impossible by foreign key condition
# Might be a stale cache
continue
channel = interaction.guild.get_channel(alert.channelid)
display = f"{streamer.display_name} in #{channel.name if channel else 'unknown'}"
if partial.lower() in display.lower():
# Matching option
options.append(appcmds.Choice(name=display, value=str(alert.subscriptionid)))
if not options:
options.append(
appcmds.Choice(
name=f"No stream alerts matching {partial}"[:25],
value=partial
)
)
return options
async def resolve_alert(self, interaction: discord.Interaction, alert_str: str):
if not interaction.guild:
raise ValueError("Resolving alert outside of a guild.")
# Expect alert_str to be the integer subscriptionid
if not alert_str.isdigit():
raise UserInputError(
f"No stream alerts in this server matching `{alert_str}`!"
)
alert = await self.data.AlertChannel.fetch(int(alert_str))
if not alert or not alert.guildid == interaction.guild_id:
raise UserInputError(
"Could not find the selected alert! Please try again."
)
return alert
@streamalert_group.command(
name='edit',
description=(
"Update settings for an existing Twitch stream alert."
)
)
@appcmds.describe(
alert="Which alert do you want to edit?",
# TODO: Other settings here
)
@appcmds.default_permissions(manage_channels=True)
async def streamalert_edit_cmd(self, ctx: LionContext, alert: str):
# Type guards
assert ctx.guild is not None, "Guild-only command has no guild ctx."
assert self.twitch is not None, "Twitch command run with no twitch obj."
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
# Look up provided alert
sub_data = await self.resolve_alert(ctx.interaction, alert)
# Check user permissions for editing this alert
channel = ctx.guild.get_channel(sub_data.channelid)
permlevel = channel if channel else ctx.guild
if not permlevel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(
"Sorry, you need the `MANAGE_CHANNELS` permission "
"in this channel to edit the stream alert."
)
return
# If edit options have been given, save edits and retouch cache if needed
# If not, open AlertEditorUI
ui = AlertEditorUI(bot=self.bot, sub_data=sub_data, callerid=ctx.author.id)
await ui.run(ctx.interaction)
await ui.wait()
@streamalert_edit_cmd.autocomplete('alert')
async def streamalert_edit_cmd_alert_acmpl(self, interaction, partial):
return await self.alert_acmpl(interaction, partial)
@streamalert_group.command(
name='pause',
description=(
"Pause a streamalert."
)
)
@appcmds.describe(
alert="Which alert do you want to pause?",
)
@appcmds.default_permissions(manage_channels=True)
async def streamalert_pause_cmd(self, ctx: LionContext, alert: str):
# Type guards
assert ctx.guild is not None, "Guild-only command has no guild ctx."
assert self.twitch is not None, "Twitch command run with no twitch obj."
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
# Look up provided alert
sub_data = await self.resolve_alert(ctx.interaction, alert)
# Check user permissions for editing this alert
channel = ctx.guild.get_channel(sub_data.channelid)
permlevel = channel if channel else ctx.guild
if not permlevel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(
"Sorry, you need the `MANAGE_CHANNELS` permission "
"in this channel to edit the stream alert."
)
return
await sub_data.update(paused=True)
await ctx.reply("This alert is now paused!")
@streamalert_group.command(
name='unpause',
description=(
"Resume a streamalert."
)
)
@appcmds.describe(
alert="Which alert do you want to unpause?",
)
@appcmds.default_permissions(manage_channels=True)
async def streamalert_unpause_cmd(self, ctx: LionContext, alert: str):
# Type guards
assert ctx.guild is not None, "Guild-only command has no guild ctx."
assert self.twitch is not None, "Twitch command run with no twitch obj."
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
# Look up provided alert
sub_data = await self.resolve_alert(ctx.interaction, alert)
# Check user permissions for editing this alert
channel = ctx.guild.get_channel(sub_data.channelid)
permlevel = channel if channel else ctx.guild
if not permlevel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(
"Sorry, you need the `MANAGE_CHANNELS` permission "
"in this channel to edit the stream alert."
)
return
await sub_data.update(paused=False)
await ctx.reply("This alert has been unpaused!")
@streamalert_group.command(
name='remove',
description=(
"Deactivate a streamalert entirely (see /streamalert pause to temporarily pause it)."
)
)
@appcmds.describe(
alert="Which alert do you want to remove?",
)
@appcmds.default_permissions(manage_channels=True)
async def streamalert_remove_cmd(self, ctx: LionContext, alert: str):
# Type guards
assert ctx.guild is not None, "Guild-only command has no guild ctx."
assert self.twitch is not None, "Twitch command run with no twitch obj."
assert ctx.interaction is not None, "Twitch command needs interaction ctx."
# Look up provided alert
sub_data = await self.resolve_alert(ctx.interaction, alert)
# Check user permissions for editing this alert
channel = ctx.guild.get_channel(sub_data.channelid)
permlevel = channel if channel else ctx.guild
if not permlevel.permissions_for(ctx.author).manage_channels:
await ctx.error_reply(
"Sorry, you need the `MANAGE_CHANNELS` permission "
"in this channel to edit the stream alert."
)
return
await sub_data.delete()
await ctx.reply("This alert has been deleted.")

View File

@@ -1,105 +0,0 @@
from data import Registry, RowModel
from data.columns import Integer, Bool, Timestamp, String
from data.models import WeakCache
from cachetools import TTLCache
class AlertsData(Registry):
class Streamer(RowModel):
"""
Schema
------
CREATE TABLE streamers(
userid BIGINT PRIMARY KEY,
login_name TEXT NOT NULL,
display_name TEXT NOT NULL
);
"""
_tablename_ = 'streamers'
_cache_ = {}
userid = Integer(primary=True)
login_name = String()
display_name = String()
class AlertChannel(RowModel):
"""
Schema
------
CREATE TABLE alert_channels(
subscriptionid SERIAL PRIMARY KEY,
guildid BIGINT NOT NULL,
channelid BIGINT NOT NULL,
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
created_by BIGINT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
paused BOOLEAN NOT NULL DEFAULT FALSE,
end_delete BOOLEAN NOT NULL DEFAULT FALSE,
live_message TEXT,
end_message TEXT
);
CREATE INDEX alert_channels_guilds ON alert_channels (guildid);
CREATE UNIQUE INDEX alert_channels_channelid_streamerid ON alert_channels (channelid, streamerid);
"""
_tablename_ = 'alert_channels'
_cache_ = {}
subscriptionid = Integer(primary=True)
guildid = Integer()
channelid = Integer()
streamerid = Integer()
display_name = Integer()
created_by = Integer()
created_at = Timestamp()
paused = Bool()
end_delete = Bool()
live_message = String()
end_message = String()
class Stream(RowModel):
"""
Schema
------
CREATE TABLE streams(
streamid SERIAL PRIMARY KEY,
streamerid BIGINT NOT NULL REFERENCES streamers (userid) ON DELETE CASCADE,
start_at TIMESTAMPTZ NOT NULL,
twitch_stream_id BIGINT,
game_name TEXT,
title TEXT,
end_at TIMESTAMPTZ
);
"""
_tablename_ = 'streams'
_cache_ = WeakCache(TTLCache(maxsize=100, ttl=24*60*60))
streamid = Integer(primary=True)
streamerid = Integer()
start_at = Timestamp()
twitch_stream_id = Integer()
game_name = String()
title = String()
end_at = Timestamp()
class StreamAlert(RowModel):
"""
Schema
------
CREATE TABLE stream_alerts(
alertid SERIAL PRIMARY KEY,
streamid INTEGER NOT NULL REFERENCES streams (streamid) ON DELETE CASCADE,
subscriptionid INTEGER NOT NULL REFERENCES alert_channels (subscriptionid) ON DELETE CASCADE,
sent_at TIMESTAMPTZ NOT NULL,
messageid BIGINT NOT NULL,
resolved_at TIMESTAMPTZ
);
"""
_tablename_ = 'stream_alerts'
_cache_ = WeakCache(TTLCache(maxsize=1000, ttl=24*60*60))
alertid = Integer(primary=True)
streamid = Integer()
subscriptionid = Integer()
sent_at = Timestamp()
messageid = Integer()
resolved_at = Timestamp()

View File

@@ -1,369 +0,0 @@
import asyncio
import datetime as dt
from collections import namedtuple
from functools import wraps
from typing import TYPE_CHECKING
import discord
from discord.ui.button import button, Button, ButtonStyle
from discord.ui.select import select, Select, SelectOption, ChannelSelect
from meta import LionBot, conf
from utils.lib import MessageArgs, tabulate, utc_now
from utils.ui import MessageUI
from utils.ui.msgeditor import MsgEditor
from .settings import AlertSettings as Settings
from .settings import AlertConfig as Config
from .data import AlertsData
if TYPE_CHECKING:
from .cog import AlertCog
FakeStream = namedtuple(
'FakeStream',
["streamid", "streamerid", "start_at", "twitch_stream_id", "game_name", "title", "end_at"]
)
class AlertEditorUI(MessageUI):
setting_classes = (
Settings.AlertPaused,
Settings.AlertEndDelete,
Settings.AlertEndMessage,
Settings.AlertMessage,
Settings.AlertChannel,
)
def __init__(self, bot: LionBot, sub_data: AlertsData.AlertChannel, **kwargs):
super().__init__(**kwargs)
self.bot = bot
self.sub_data = sub_data
self.subid = sub_data.subscriptionid
self.cog: 'AlertCog' = bot.get_cog('AlertCog')
self.config = Config(self.subid, sub_data)
# ----- UI API -----
def preview_stream_data(self):
# TODO: Probably makes sense to factor this out to the cog
# Or even generate it in the formatters themselves
data = self.sub_data
return FakeStream(
-1,
data.streamerid,
utc_now() - dt.timedelta(hours=1),
-1,
"Discord Admin",
"Testing Go Live Message",
utc_now()
)
def call_and_refresh(self, func):
"""
Generate a wrapper which runs coroutine 'func' and then refreshes the UI.
"""
# TODO: Check whether the UI has finished interaction
@wraps(func)
async def wrapped(*args, **kwargs):
await func(*args, **kwargs)
await self.refresh()
return wrapped
# ----- UI Components -----
# Pause button
@button(label="PAUSE_PLACEHOLDER", style=ButtonStyle.blurple)
async def pause_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True, ephemeral=True)
setting = self.config.get(Settings.AlertPaused.setting_id)
setting.value = not setting.value
await setting.write()
await self.refresh(thinking=press)
async def pause_button_refresh(self):
button = self.pause_button
if self.config.get(Settings.AlertPaused.setting_id).value:
button.label = "UnPause"
button.style = ButtonStyle.grey
else:
button.label = "Pause"
button.style = ButtonStyle.green
# Delete button
@button(label="Delete Alert", style=ButtonStyle.red)
async def delete_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=True, ephemeral=True)
await self.sub_data.delete()
embed = discord.Embed(
colour=discord.Colour.brand_green(),
description="Stream alert removed."
)
await press.edit_original_response(embed=embed)
await self.close()
# Close button
@button(emoji=conf.emojis.cancel, style=ButtonStyle.red)
async def close_button(self, press: discord.Interaction, pressed: Button):
await press.response.defer(thinking=False)
await self.close()
# Edit Alert button
@button(label="Edit Alert", style=ButtonStyle.blurple)
async def edit_alert_button(self, press: discord.Interaction, pressed: Button):
# Spawn MsgEditor for the live alert
await press.response.defer(thinking=True, ephemeral=True)
setting = self.config.get(Settings.AlertMessage.setting_id)
stream = self.preview_stream_data()
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
editor = MsgEditor(
self.bot,
setting.value,
callback=self.call_and_refresh(setting.editor_callback),
formatter=await setting.generate_formatter(self.bot, stream, streamer),
callerid=press.user.id
)
self._slaves.append(editor)
await editor.run(press)
# Edit End message
@button(label="Edit Ending Alert", style=ButtonStyle.blurple)
async def edit_end_button(self, press: discord.Interaction, pressed: Button):
# Spawn MsgEditor for the ending alert
await press.response.defer(thinking=True, ephemeral=True)
await self.open_end_editor(press)
async def open_end_editor(self, respond_to: discord.Interaction):
setting = self.config.get(Settings.AlertEndMessage.setting_id)
# Start from current live alert data if not set
if not setting.value:
alert_setting = self.config.get(Settings.AlertMessage.setting_id)
setting.value = alert_setting.value
stream = self.preview_stream_data()
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
editor = MsgEditor(
self.bot,
setting.value,
callback=self.call_and_refresh(setting.editor_callback),
formatter=await setting.generate_formatter(self.bot, stream, streamer),
callerid=respond_to.user.id
)
self._slaves.append(editor)
await editor.run(respond_to)
return editor
# Ending Mode Menu
@select(
cls=Select,
placeholder="Select action to take when the stream ends",
options=[SelectOption(label="DUMMY")],
min_values=0, max_values=1
)
async def ending_mode_menu(self, selection: discord.Interaction, selected: Select):
if not selected.values:
await selection.response.defer()
return
await selection.response.defer(thinking=True, ephemeral=True)
value = selected.values[0]
if value == '0':
# In Do Nothing case,
# Ensure Delete is off and custom edit message is unset
setting = self.config.get(Settings.AlertEndDelete.setting_id)
if setting.value:
setting.value = False
await setting.write()
setting = self.config.get(Settings.AlertEndMessage.setting_id)
if setting.value:
setting.value = None
await setting.write()
await self.refresh(thinking=selection)
elif value == '1':
# In Delete Alert case,
# Set the delete setting to True
setting = self.config.get(Settings.AlertEndDelete.setting_id)
if not setting.value:
setting.value = True
await setting.write()
await self.refresh(thinking=selection)
elif value == '2':
# In Edit Message case,
# Set the delete setting to False,
setting = self.config.get(Settings.AlertEndDelete.setting_id)
if setting.value:
setting.value = False
await setting.write()
# And open the edit message editor
await self.open_end_editor(selection)
await self.refresh()
async def ending_mode_menu_refresh(self):
# Build menu options
options = [
SelectOption(
label="Do Nothing",
description="Don't modify the live alert message.",
value="0",
),
SelectOption(
label="Delete Alert After Stream",
description="Delete the live alert message.",
value="1",
),
SelectOption(
label="Edit Alert After Stream",
description="Edit the live alert message to a custom message. Opens editor.",
value="2",
),
]
# Calculate the correct default
if self.config.get(Settings.AlertEndDelete.setting_id).value:
options[1].default = True
elif self.config.get(Settings.AlertEndMessage.setting_id).value:
options[2].default = True
self.ending_mode_menu.options = options
# Edit channel menu
@select(cls=ChannelSelect,
placeholder="Select Alert Channel",
channel_types=[discord.ChannelType.text, discord.ChannelType.voice],
min_values=0, max_values=1)
async def channel_menu(self, selection: discord.Interaction, selected):
if selected.values:
await selection.response.defer(thinking=True, ephemeral=True)
setting = self.config.get(Settings.AlertChannel.setting_id)
setting.value = selected.values[0]
await setting.write()
await self.refresh(thinking=selection)
else:
await selection.response.defer(thinking=False)
async def channel_menu_refresh(self):
# current = self.config.get(Settings.AlertChannel.setting_id).value
# TODO: Check if discord-typed menus can have defaults yet
# Impl in stable dpy, but not released to pip yet
...
# ----- UI Flow -----
async def make_message(self) -> MessageArgs:
streamer = await self.cog.data.Streamer.fetch(self.sub_data.streamerid)
if streamer is None:
raise ValueError("Streamer row does not exist in AlertEditor")
name = streamer.display_name
# Build relevant setting table
table_map = {}
table_map['Channel'] = self.config.get(Settings.AlertChannel.setting_id).formatted
table_map['Streamer'] = f"https://www.twitch.tv/{streamer.login_name}"
table_map['Paused'] = self.config.get(Settings.AlertPaused.setting_id).formatted
prop_table = '\n'.join(tabulate(*table_map.items()))
embed = discord.Embed(
colour=discord.Colour.dark_green(),
title=f"Stream Alert for {name}",
description=prop_table,
timestamp=utc_now()
)
message_setting = self.config.get(Settings.AlertMessage.setting_id)
message_desc_lines = [
f"An alert message will be posted to {table_map['Channel']}.",
f"Press `{self.edit_alert_button.label}`"
" to preview or edit the alert.",
"The following keys will be substituted in the alert message."
]
keytable = tabulate(*message_setting._subkey_desc.items())
for line in keytable:
message_desc_lines.append(f"> {line}")
embed.add_field(
name=f"When {name} goes live",
value='\n'.join(message_desc_lines),
inline=False
)
# Determine the ending behaviour
del_setting = self.config.get(Settings.AlertEndDelete.setting_id)
end_msg_setting = self.config.get(Settings.AlertEndMessage.setting_id)
if del_setting.value:
# Deleting
end_msg_desc = "The live alert message will be deleted."
...
elif end_msg_setting.value:
# Editing
lines = [
"The live alert message will edited to the configured message.",
f"Press `{self.edit_end_button.label}` to preview or edit the message.",
"The following substitution keys are supported "
"*in addition* to the live alert keys."
]
keytable = tabulate(
*[(k, v) for k, v in end_msg_setting._subkey_desc.items() if k not in message_setting._subkey_desc]
)
for line in keytable:
lines.append(f"> {line}")
end_msg_desc = '\n'.join(lines)
else:
# Doing nothing
end_msg_desc = "The live alert message will not be changed."
embed.add_field(
name=f"When {name} ends their stream",
value=end_msg_desc,
inline=False
)
return MessageArgs(embed=embed)
async def reload(self):
await self.sub_data.refresh()
# Note self.config references the sub_data, and doesn't need reloading.
async def refresh_layout(self):
to_refresh = (
self.pause_button_refresh(),
self.channel_menu_refresh(),
self.ending_mode_menu_refresh(),
)
await asyncio.gather(*to_refresh)
show_end_edit = (
not self.config.get(Settings.AlertEndDelete.setting_id).value
and
self.config.get(Settings.AlertEndMessage.setting_id).value
)
if not show_end_edit:
# Don't show edit end button
buttons = (
self.edit_alert_button,
self.pause_button, self.delete_button, self.close_button
)
else:
buttons = (
self.edit_alert_button, self.edit_end_button,
self.pause_button, self.delete_button, self.close_button
)
self.set_layout(
buttons,
(self.ending_mode_menu,),
(self.channel_menu,),
)

View File

@@ -1,264 +0,0 @@
from typing import Optional, Any
import json
from meta.LionBot import LionBot
from settings import ModelData
from settings.groups import SettingGroup, ModelConfig, SettingDotDict
from settings.setting_types import BoolSetting, ChannelSetting
from core.setting_types import MessageSetting
from babel.translator import LocalBabel
from utils.lib import recurse_map, replace_multiple, tabulate
from .data import AlertsData
babel = LocalBabel('streamalerts')
_p = babel._p
class AlertConfig(ModelConfig):
settings = SettingDotDict()
_model_settings = set()
model = AlertsData.AlertChannel
class AlertSettings(SettingGroup):
@AlertConfig.register_model_setting
class AlertMessage(ModelData, MessageSetting):
setting_id = 'alert_live_message'
_display_name = _p('', 'live_message')
_desc = _p(
'',
'Message sent to the channel when the streamer goes live.'
)
_long_desc = _p(
'',
'Message sent to the attached channel when the Twitch streamer goes live.'
)
_accepts = _p('', 'JSON formatted greeting message data')
_default = json.dumps({'content': "**{display_name}** just went live at {channel_link}"})
_model = AlertsData.AlertChannel
_column = AlertsData.AlertChannel.live_message.name
_subkey_desc = {
'{display_name}': "Twitch channel name (with capitalisation)",
'{login_name}': "Twitch channel login name (as in url)",
'{channel_link}': "Link to the live twitch channel",
'{stream_start}': "Numeric timestamp when stream went live",
}
# TODO: More stuff
@property
def update_message(self) -> str:
return "The go-live notification message has been updated!"
@classmethod
async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs):
"""
Generate a formatter function for this message
from the provided stream and streamer data.
The formatter function accepts and returns a message data dict.
"""
async def formatter(data_dict: Optional[dict[str, Any]]):
if not data_dict:
return None
mapping = {
'{display_name}': streamer.display_name,
'{login_name}': streamer.login_name,
'{channel_link}': f"https://www.twitch.tv/{streamer.login_name}",
'{stream_start}': int(stream.start_at.timestamp()),
}
recurse_map(
lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value,
data_dict,
)
return data_dict
return formatter
async def editor_callback(self, editor_data):
self.value = editor_data
await self.write()
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
lines = super()._desc_table(show_value=show_value)
keytable = tabulate(*self._subkey_desc.items(), colon='')
expline = (
"The following placeholders will be substituted with their values."
)
keyfield = (
"Placeholders",
expline + '\n' + '\n'.join(f"> {line}" for line in keytable)
)
lines.append(keyfield)
return lines
@AlertConfig.register_model_setting
class AlertEndMessage(ModelData, MessageSetting):
"""
Custom ending message to edit the live alert to.
If not set, doesn't edit the alert.
"""
setting_id = 'alert_end_message'
_display_name = _p('', 'end_message')
_desc = _p(
'',
'Optional message to edit the live alert with when the stream ends.'
)
_long_desc = _p(
'',
"If set, and `end_delete` is not on, "
"the live alert will be edited with this custom message "
"when the stream ends."
)
_accepts = _p('', 'JSON formatted greeting message data')
_default = None
_model = AlertsData.AlertChannel
_column = AlertsData.AlertChannel.end_message.name
_subkey_desc = {
'{display_name}': "Twitch channel name (with capitalisation)",
'{login_name}': "Twitch channel login name (as in url)",
'{channel_link}': "Link to the live twitch channel",
'{stream_start}': "Numeric timestamp when stream went live",
'{stream_end}': "Numeric timestamp when stream ended",
}
@property
def update_message(self) -> str:
if self.value:
return "The stream ending message has been updated."
else:
return "The stream ending message has been unset."
@classmethod
async def generate_formatter(cls, bot: LionBot, stream: AlertsData.Stream, streamer: AlertsData.Streamer, **kwargs):
"""
Generate a formatter function for this message
from the provided stream and streamer data.
The formatter function accepts and returns a message data dict.
"""
# TODO: Fake stream data maker (namedtuple?) for previewing
async def formatter(data_dict: Optional[dict[str, Any]]):
if not data_dict:
return None
mapping = {
'{display_name}': streamer.display_name,
'{login_name}': streamer.login_name,
'{channel_link}': f"https://www.twitch.tv/{streamer.login_name}",
'{stream_start}': int(stream.start_at.timestamp()),
'{stream_end}': int(stream.end_at.timestamp()),
}
recurse_map(
lambda loc, value: replace_multiple(value, mapping) if isinstance(value, str) else value,
data_dict,
)
return data_dict
return formatter
async def editor_callback(self, editor_data):
self.value = editor_data
await self.write()
def _desc_table(self, show_value: Optional[str] = None) -> list[tuple[str, str]]:
lines = super()._desc_table(show_value=show_value)
keytable = tabulate(*self._subkey_desc.items(), colon='')
expline = (
"The following placeholders will be substituted with their values."
)
keyfield = (
"Placeholders",
expline + '\n' + '\n'.join(f"> {line}" for line in keytable)
)
lines.append(keyfield)
return lines
...
@AlertConfig.register_model_setting
class AlertEndDelete(ModelData, BoolSetting):
"""
Whether to delete the live alert after the stream ends.
"""
setting_id = 'alert_end_delete'
_display_name = _p('', 'end_delete')
_desc = _p(
'',
'Whether to delete the live alert after the stream ends.'
)
_long_desc = _p(
'',
"If enabled, the live alert message will be deleted when the stream ends. "
"This overrides the `end_message` setting."
)
_default = False
_model = AlertsData.AlertChannel
_column = AlertsData.AlertChannel.end_delete.name
@property
def update_message(self) -> str:
if self.value:
return "The live alert will be deleted at the end of the stream."
else:
return "The live alert will not be deleted when the stream ends."
@AlertConfig.register_model_setting
class AlertPaused(ModelData, BoolSetting):
"""
Whether this live alert is currently paused.
"""
setting_id = 'alert_paused'
_display_name = _p('', 'paused')
_desc = _p(
'',
"Whether the alert is currently paused."
)
_long_desc = _p(
'',
"Paused alerts will not trigger live notifications, "
"although the streams will still be tracked internally."
)
_default = False
_model = AlertsData.AlertChannel
_column = AlertsData.AlertChannel.paused.name
@property
def update_message(self):
if self.value:
return "This alert is now paused"
else:
return "This alert has been unpaused"
@AlertConfig.register_model_setting
class AlertChannel(ModelData, ChannelSetting):
"""
The channel associated to this alert.
"""
setting_id = 'alert_channel'
_display_name = _p('', 'channel')
_desc = _p(
'',
"The Discord channel this live alert will be sent in."
)
_long_desc = _desc
# Note that this cannot actually be None,
# as there is no UI pathway to unset the setting.
_default = None
_model = AlertsData.AlertChannel
_column = AlertsData.AlertChannel.channelid.name
@property
def update_message(self):
return f"This alert will now be posted to {self.value.channel.mention}"

View File

@@ -1,7 +0,0 @@
import logging
logger = logging.getLogger(__name__)
async def setup(bot):
from .cog import VoiceFixCog
await bot.add_cog(VoiceFixCog(bot))

View File

@@ -1,449 +0,0 @@
from collections import defaultdict
from typing import Optional
import asyncio
from cachetools import FIFOCache
import discord
from discord.abc import GuildChannel
from discord.ext import commands as cmds
from discord import app_commands as appcmds
from meta import LionBot, LionCog, LionContext
from meta.errors import ResponseTimedOut, SafeCancellation, UserInputError
from utils.ui import Confirm
from . import logger
from .data import LinkData
async def prepare_attachments(attachments: list[discord.Attachment]):
results = []
for attach in attachments:
try:
as_file = await attach.to_file(spoiler=attach.is_spoiler())
results.append(as_file)
except discord.HTTPException:
pass
return results
async def prepare_embeds(message: discord.Message):
embeds = [embed for embed in message.embeds if embed.type == 'rich']
if message.reference:
embed = discord.Embed(
colour=discord.Colour.dark_gray(),
description=f"Reply to {message.reference.jump_url}"
)
embeds.append(embed)
return embeds
class VoiceFixCog(LionCog):
def __init__(self, bot: LionBot):
self.bot = bot
self.data = bot.db.load_registry(LinkData())
# Map of linkids to list of channelids
self.link_channels = {}
# Map of channelids to linkids
self.channel_links = {}
# Map of channelids to initialised discord.Webhook
self.hooks = {}
# Map of messageid to list of (channelid, webhookmsg) pairs, for updates
self.message_cache = FIFOCache(maxsize=200)
# webhook msgid -> orig msgid
self.wmessages = FIFOCache(maxsize=600)
self.lock = asyncio.Lock()
async def cog_load(self):
await self.data.init()
await self.reload_links()
async def reload_links(self):
records = await self.data.channel_links.select_where()
channel_links = defaultdict(set)
link_channels = defaultdict(set)
for record in records:
linkid = record['linkid']
channelid = record['channelid']
channel_links[channelid].add(linkid)
link_channels[linkid].add(channelid)
channelids = list(channel_links.keys())
if channelids:
await self.data.LinkHook.fetch_where(channelid=channelids)
for channelid in channelids:
# Will hit cache, so don't need any more data queries
await self.fetch_webhook_for(channelid)
self.channel_links = {cid: tuple(linkids) for cid, linkids in channel_links.items()}
self.link_channels = {lid: tuple(cids) for lid, cids in link_channels.items()}
logger.info(
f"Loaded '{len(link_channels)}' channel links with '{len(self.channel_links)}' linked channels."
)
@LionCog.listener('on_message')
async def on_message(self, message: discord.Message):
# Don't need this because everything except explicit messages are webhooks now
# if self.bot.user and (message.author.id == self.bot.user.id):
# return
if message.webhook_id:
return
async with self.lock:
sent = []
linkids = self.channel_links.get(message.channel.id, ())
if linkids:
for linkid in linkids:
for channelid in self.link_channels[linkid]:
if channelid != message.channel.id:
if message.attachments:
files = await prepare_attachments(message.attachments)
else:
files = []
hook = self.hooks[channelid]
avatar = message.author.avatar or message.author.default_avatar
msg = await hook.send(
content=message.content,
wait=True,
username=message.author.display_name,
avatar_url=avatar.url,
embeds=await prepare_embeds(message),
files=files,
allowed_mentions=discord.AllowedMentions.none()
)
sent.append((channelid, msg))
self.wmessages[msg.id] = message.id
if sent:
# For easier lookup
self.wmessages[message.id] = message.id
sent.append((message.channel.id, message))
self.message_cache[message.id] = sent
logger.info(f"Forwarded message {message.id}")
@LionCog.listener('on_message_edit')
async def on_message_edit(self, before, after):
async with self.lock:
cached_sent = self.message_cache.pop(before.id, ())
new_sent = []
for cid, msg in cached_sent:
try:
if msg.id != before.id:
msg = await msg.edit(
content=after.content,
embeds=await prepare_embeds(after),
)
new_sent.append((cid, msg))
except discord.NotFound:
pass
if new_sent:
self.message_cache[after.id] = new_sent
@LionCog.listener('on_message_delete')
async def on_message_delete(self, message):
async with self.lock:
origid = self.wmessages.get(message.id, None)
if origid:
cached_sent = self.message_cache.pop(origid, ())
for _, msg in cached_sent:
try:
if msg.id != message.id:
await msg.delete()
except discord.NotFound:
pass
@LionCog.listener('on_reaction_add')
async def on_reaction_add(self, reaction: discord.Reaction, user: discord.User):
async with self.lock:
message = reaction.message
emoji = reaction.emoji
origid = self.wmessages.get(message.id, None)
if origid and reaction.count == 1:
cached_sent = self.message_cache.get(origid, ())
for _, msg in cached_sent:
# TODO: Would be better to have a Message and check the reactions
try:
if msg.id != message.id:
await msg.add_reaction(emoji)
except discord.HTTPException:
pass
async def fetch_webhook_for(self, channelid) -> discord.Webhook:
hook = self.hooks.get(channelid, None)
if hook is None:
row = await self.data.LinkHook.fetch(channelid)
if row is None:
channel = self.bot.get_channel(channelid)
if channel is None:
raise ValueError("Cannot find channel to create hook.")
hook = await channel.create_webhook(name="LabRat Channel Link")
await self.data.LinkHook.create(
channelid=channelid,
webhookid=hook.id,
token=hook.token,
)
else:
hook = discord.Webhook.partial(row.webhookid, row.token, client=self.bot)
self.hooks[channelid] = hook
return hook
@cmds.hybrid_group(
name='linker',
description="Base command group for the channel linker"
)
@appcmds.default_permissions(manage_channels=True)
async def linker_group(self, ctx: LionContext):
...
@linker_group.command(
name='link',
description="Create a new link, or add a channel to an existing link."
)
@appcmds.describe(
name="Name of the new or existing channel link.",
channel1="First channel to add to the link.",
channel2="Second channel to add to the link.",
channel3="Third channel to add to the link.",
channel4="Fourth channel to add to the link.",
channel5="Fifth channel to add to the link.",
channelid="Optionally add a channel by id (for e.g. cross-server links).",
)
async def linker_link(self, ctx: LionContext,
name: str,
channel1: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel2: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel3: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel4: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channel5: Optional[discord.TextChannel | discord.VoiceChannel] = None,
channelid: Optional[str] = None,
):
if not ctx.interaction:
return
await ctx.interaction.response.defer(thinking=True)
# Check if link 'name' already exists, create if not
existing = await self.data.Link.fetch_where()
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
if link_row is None:
# Create
link_row = await self.data.Link.create(name=name)
link_channels = set()
created = True
else:
records = await self.data.channel_links.select_where(linkid=link_row.linkid)
link_channels = {record['channelid'] for record in records}
created = False
# Create webhooks and webhook rows on channels if required
maybe_channels = [
channel1, channel2, channel3, channel4, channel5,
]
if channelid and channelid.isdigit():
channel = self.bot.get_channel(int(channelid))
maybe_channels.append(channel)
channels = [channel for channel in maybe_channels if channel]
for channel in channels:
await self.fetch_webhook_for(channel.id)
# Insert or update the links
for channel in channels:
if channel.id not in link_channels:
await self.data.channel_links.insert(linkid=link_row.linkid, channelid=channel.id)
await self.reload_links()
if created:
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Link Created",
description=(
"Created the link **{name}** and linked channels:\n{channels}"
).format(name=name, channels=', '.join(channel.mention for channel in channels))
)
else:
channelids = self.link_channels[link_row.linkid]
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Channels Linked",
description=(
"Updated the link **{name}** to link the following channels:\n{channelstr}"
).format(name=link_row.name, channelstr=channelstr)
)
await ctx.reply(embed=embed)
@linker_group.command(
name='unlink',
description="Destroy a link, or remove a channel from a link."
)
@appcmds.describe(
name="Name of the link to destroy",
channel="Channel to remove from the link.",
)
async def linker_unlink(self, ctx: LionContext,
name: str, channel: Optional[GuildChannel] = None):
if not ctx.interaction:
return
# Get the link, error if it doesn't exist
existing = await self.data.Link.fetch_where()
link_row = next((row for row in existing if row.name.lower() == name.lower()), None)
if link_row is None:
raise UserInputError(
f"Link **{name}** doesn't exist!"
)
link_channelids = self.link_channels.get(link_row.linkid, ())
if channel is not None:
# If channel was given, remove channel from link and ack
if channel.id not in link_channelids:
raise UserInputError(
f"{channel.mention} is not linked in **{link_row.name}**!"
)
await self.data.channel_links.delete_where(channelid=channel.id, linkid=link_row.linkid)
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Channel Unlinked",
description=f"{channel.mention} has been removed from **{link_row.name}**."
)
else:
# Otherwise, confirm link destroy, delete link row, and ack
channels = ', '.join(f"<#{cid}>" for cid in link_channelids)
confirm = Confirm(
f"Are you sure you want to remove the link **{link_row.name}**?\nLinked channels: {channels}",
ctx.author.id,
)
confirm.embed.colour = discord.Colour.red()
try:
result = await confirm.ask(ctx.interaction)
except ResponseTimedOut:
result = False
if not result:
raise SafeCancellation
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title="Link removed",
description=f"Link **{link_row.name}** removed, the following channels were unlinked:\n{channels}"
)
await link_row.delete()
await self.reload_links()
await ctx.reply(embed=embed)
@linker_link.autocomplete('name')
async def _acmpl_link_name(self, interaction: discord.Interaction, partial: str):
"""
Autocomplete an existing link.
"""
existing = await self.data.Link.fetch_where()
names = [row.name for row in existing]
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
if not matching:
choice = appcmds.Choice(
name=f"Create a new link '{partial}'",
value=partial
)
choices = [choice]
else:
choices = [
appcmds.Choice(
name=f"Link {name}",
value=name
)
for name in matching
]
return choices
@linker_unlink.autocomplete('name')
async def _acmpl_unlink_name(self, interaction: discord.Interaction, partial: str):
"""
Autocomplete an existing link.
"""
existing = await self.data.Link.fetch_where()
matching = [row.name for row in existing if partial.lower() in row.name.lower()]
if not matching:
choice = appcmds.Choice(
name=f"No existing links matching '{partial}'",
value=partial
)
choices = [choice]
else:
choices = [
appcmds.Choice(
name=f"Link {name}",
value=name
)
for name in matching
]
return choices
@linker_group.command(
name='links',
description="Display the existing channel links."
)
async def linker_links(self, ctx: LionContext):
if not ctx.interaction:
return
await ctx.interaction.response.defer(thinking=True)
links = await self.data.Link.fetch_where()
if not links:
embed = discord.Embed(
colour=discord.Colour.light_grey(),
title="No channel links have been set up!",
description="Create a new link and add channels with {linker}".format(
linker=self.bot.core.mention_cmd('linker link')
)
)
else:
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=f"Channel Links in {ctx.guild.name}",
)
for link in links:
channelids = self.link_channels.get(link.linkid, ())
channelstr = ', '.join(f"<#{cid}>" for cid in channelids)
embed.add_field(
name=f"Link **{link.name}**",
value=channelstr,
inline=False
)
# TODO: May want paging if over 25 links....
await ctx.reply(embed=embed)
@linker_group.command(
name="webhook",
description='Manually configure the webhook for a given channel.'
)
async def linker_webhook(self, ctx: LionContext, channel: discord.abc.GuildChannel, webhook: str):
if not ctx.interaction:
return
hook = discord.Webhook.from_url(webhook, client=self.bot)
existing = await self.data.LinkHook.fetch(channel.id)
if existing:
await existing.update(webhookid=hook.id, token=hook.token)
else:
await self.data.LinkHook.create(
channelid=channel.id,
webhookid=hook.id,
token=hook.token,
)
self.hooks[channel.id] = hook
await ctx.reply(f"Webhook for {channel.mention} updated!")

View File

@@ -1,39 +0,0 @@
from data import Registry, RowModel, Table
from data.columns import Integer, Bool, Timestamp, String
class LinkData(Registry):
class Link(RowModel):
"""
Schema
------
CREATE TABLE links(
linkid SERIAL PRIMARY KEY,
name TEXT
);
"""
_tablename_ = 'links'
_cache_ = {}
linkid = Integer(primary=True)
name = String()
channel_links = Table('channel_links')
class LinkHook(RowModel):
"""
Schema
------
CREATE TABLE channel_webhooks(
channelid BIGINT PRIMARY KEY,
webhookid BIGINT NOT NULL,
token TEXT NOT NULL
);
"""
_tablename_ = 'channel_webhooks'
_cache_ = {}
channelid = Integer(primary=True)
webhookid = Integer()
token = String()

View File

@@ -54,7 +54,10 @@ class MsgEditor(MessageUI):
By default, uses the provided `formatter` callback (if provided).
"""
if self._formatter is not None:
await self._formatter(data)
return await self._formatter(data)
else:
return data
def copy_data(self):
return copy.deepcopy(self.history[-1])
@@ -78,7 +81,8 @@ class MsgEditor(MessageUI):
if 'embed' in new_data:
try:
discord.Embed.from_dict(new_data['embed'])
formatted_data = copy.deepcopy(new_data)
discord.Embed.from_dict(await self.format_data(formatted_data['embed']))
except Exception as e:
raise UserInputError(
t(_p(
@@ -445,8 +449,17 @@ class MsgEditor(MessageUI):
embed_data.pop('footer', None)
if (ts := timestamp_field.value):
if ts.isdigit():
# Treat as UTC timestamp
timestamp = dt.datetime.fromtimestamp(int(ts), dt.timezone.utc)
ts = timestamp.isoformat()
to_validate = ts
elif self._formatter:
to_validate = await self._formatter(ts)
else:
to_validate = ts
try:
dt.datetime.fromisoformat(ts)
dt.datetime.fromisoformat(to_validate)
except ValueError:
raise UserInputError(
t(_p(