Merge branch 'plugin-refactor'
This commit is contained in:
9
.gitmodules
vendored
Normal file
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[submodule "src/modules/profiles"]
|
||||
path = src/modules/profiles
|
||||
url = https://git.thewisewolf.dev/HoloTech/profiles-plugin.git
|
||||
[submodule "src/data"]
|
||||
path = src/data
|
||||
url = https://git.thewisewolf.dev/HoloTech/psqlmapper.git
|
||||
[submodule "src/modules/koans"]
|
||||
path = src/modules/koans
|
||||
url = https://git.thewisewolf.dev/Foxfire/sideeyes-koans-plugin.git
|
||||
@@ -1,10 +1,12 @@
|
||||
-- Metadata {{{
|
||||
CREATE TABLE VersionHistory(
|
||||
version INTEGER NOT NULL,
|
||||
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
author TEXT
|
||||
CREATE TABLE version_history(
|
||||
component TEXT NOT NULL,
|
||||
from_version INTEGER NOT NULL,
|
||||
to_version INTEGER NOT NULL,
|
||||
author TEXT NOT NULL,
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
|
||||
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('ROOT', 0, 1, 'Initial Creation');
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||
@@ -14,9 +16,6 @@ BEGIN
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ language 'plpgsql';
|
||||
-- }}}
|
||||
|
||||
-- App metadata {{{
|
||||
|
||||
CREATE TABLE app_config(
|
||||
appname TEXT PRIMARY KEY,
|
||||
@@ -26,6 +25,7 @@ CREATE TABLE app_config(
|
||||
-- }}}
|
||||
|
||||
-- Twitch Auth {{{
|
||||
INSERT INTO version_history (component, from_version, to_version, author) VALUES ('TWITCH_AUTH', 0, 1, 'Initial Creation');
|
||||
|
||||
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
|
||||
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
|
||||
@@ -49,7 +49,6 @@ CREATE TABLE user_auth_scopes(
|
||||
CREATE TABLE bot_channels(
|
||||
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
autojoin BOOLEAN DEFAULT true,
|
||||
listen_redeems BOOLEAN,
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
@@ -59,58 +58,6 @@ CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Twitch user data {{{
|
||||
|
||||
---- Users are internally represented by 'profiles' with a unique profileid
|
||||
---- This is associated to the user's twitch userid.
|
||||
---- Any user-specific configuration data or preferences can be added here
|
||||
CREATE TABLE user_profiles(
|
||||
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
twitchid TEXT NOT NULL,
|
||||
name TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TRIGGER user_profiles_timestamp BEFORE UPDATE ON user_profiles
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
|
||||
CREATE UNIQUE INDEX user_profile_twitchid ON user_profiles (twitchid);
|
||||
|
||||
-- }}}
|
||||
|
||||
-- Twitch channel data {{{
|
||||
|
||||
---- Similar to user profiles, we associate twitch channels with 'communities'
|
||||
---- This slight abstraction gives us more flexibility and control over the community and user data.
|
||||
CREATE TABLE communities(
|
||||
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
twitchid TEXT NOT NULL,
|
||||
name TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TRIGGER communities_timestamp BEFORE UPDATE ON communities
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
-- }}}
|
||||
|
||||
|
||||
-- Koan data {{{
|
||||
|
||||
---- !koans lists koans. !koan gives a random koan. !koans add name ... !koans del name ...
|
||||
|
||||
CREATE TABLE koans(
|
||||
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TRIGGER koans_timestamp BEFORE UPDATE ON koans
|
||||
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||
|
||||
-- }}}
|
||||
|
||||
-- vim: set fdm=marker:
|
||||
|
||||
21
example-config/bot.conf
Normal file
21
example-config/bot.conf
Normal file
@@ -0,0 +1,21 @@
|
||||
[BOT]
|
||||
prefix = ?
|
||||
bot_id =
|
||||
|
||||
ALSO_READ = config/secrets.conf
|
||||
|
||||
[TWTICH]
|
||||
host =
|
||||
port =
|
||||
domain =
|
||||
redirect_path =
|
||||
oauth_path =
|
||||
evenstub_path =
|
||||
|
||||
webhooks =
|
||||
|
||||
[LOGGING]
|
||||
general_log =
|
||||
warning_log =
|
||||
error_log =
|
||||
critical_log =
|
||||
10
example-config/secrets.conf
Normal file
10
example-config/secrets.conf
Normal file
@@ -0,0 +1,10 @@
|
||||
[BOT]
|
||||
client_id =
|
||||
client_secret =
|
||||
|
||||
[TWITCH]
|
||||
eventsub_secret =
|
||||
|
||||
[DATA]
|
||||
args =
|
||||
appid =
|
||||
@@ -2,3 +2,4 @@ websockets
|
||||
twitchio
|
||||
psycopg[pool]
|
||||
cachetools
|
||||
discord.py
|
||||
|
||||
51
src/bot.py
51
src/bot.py
@@ -1,44 +1,57 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
from twitchio.web import AiohttpAdapter
|
||||
|
||||
from meta import CrocBot, conf, setup_main_logger, args
|
||||
from meta import Bot, conf, setup_main_logger, args, sockets
|
||||
from data import Database
|
||||
from constants import DATA_VERSION
|
||||
|
||||
from modules import setup
|
||||
from modules import twitch_setup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyAiohttpAdapter(AiohttpAdapter):
|
||||
"""
|
||||
Overrides the computed AiohttpAdapter redirect_url
|
||||
to always use provided domain.
|
||||
"""
|
||||
def _find_redirect(self, request):
|
||||
return self.redirect_url
|
||||
|
||||
|
||||
async def main():
|
||||
db = Database(conf.data['args'])
|
||||
|
||||
async with db.open():
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||
logger.critical(error)
|
||||
raise RuntimeError(error)
|
||||
|
||||
adapter = AiohttpAdapter(
|
||||
host=conf.bot.get('wshost', None),
|
||||
port=conf.bot.getint('wsport', None),
|
||||
domain=conf.bot.get('wsdomain', None),
|
||||
adapter_keys = (
|
||||
'host', 'domain', 'port',
|
||||
'redirect_path', 'oauth_path', 'eventsub_path',
|
||||
'eventsub_secret',
|
||||
)
|
||||
adapter_args = {}
|
||||
for key in adapter_keys:
|
||||
value = conf.twitch.get(key, '').strip()
|
||||
if value:
|
||||
if key == 'port':
|
||||
value = int(value)
|
||||
adapter_args[key] = value
|
||||
adapter = ProxyAiohttpAdapter(**adapter_args)
|
||||
|
||||
bot = CrocBot(
|
||||
bot = Bot(
|
||||
config=conf,
|
||||
dbconn=db,
|
||||
adapter=adapter,
|
||||
setup=setup,
|
||||
setup=twitch_setup,
|
||||
using_webhooks=conf.twitch.getboolean('webhooks', False)
|
||||
)
|
||||
|
||||
try:
|
||||
await bot.start()
|
||||
finally:
|
||||
await bot.close()
|
||||
async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')):
|
||||
try:
|
||||
await bot.start()
|
||||
finally:
|
||||
await bot.close()
|
||||
|
||||
|
||||
def _main():
|
||||
|
||||
82
src/botdata.py
Normal file
82
src/botdata.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import String, Timestamp, Integer, Bool
|
||||
|
||||
|
||||
class UserAuth(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_auth'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
token = String()
|
||||
refresh_token = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class BotChannel(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE bot_channels(
|
||||
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
autojoin BOOLEAN DEFAULT true,
|
||||
listen_redeems BOOLEAN,
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'bot_channels'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
autojoin = Bool()
|
||||
listen_redeems = Bool()
|
||||
joined_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class VersionHistory(RowModel):
|
||||
"""
|
||||
CREATE TABLE version_history(
|
||||
component TEXT NOT NULL,
|
||||
from_version INTEGER NOT NULL,
|
||||
to_version INTEGER NOT NULL,
|
||||
author TEXT NOT NULL,
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'version_history'
|
||||
_cache_ = {}
|
||||
|
||||
component = String()
|
||||
from_version = Integer()
|
||||
to_version = Integer()
|
||||
author = String()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class BotData(Registry):
|
||||
version_history = VersionHistory.table
|
||||
|
||||
user_auth = UserAuth.table
|
||||
bot_channels = BotChannel.table
|
||||
|
||||
"""
|
||||
CREATE TABLE user_auth_scopes(
|
||||
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
scope TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
user_auth_scopes = Table('user_auth_scopes')
|
||||
|
||||
@@ -1,2 +1,22 @@
|
||||
from twitchio import Scopes
|
||||
|
||||
|
||||
CONFIG_FILE = 'config/bot.conf'
|
||||
DATA_VERSION = 1
|
||||
|
||||
SCHEMA_VERSIONS = {
|
||||
'ROOT': 1,
|
||||
'TWITCH_AUTH': 1
|
||||
}
|
||||
|
||||
# Requested scopes for the bots own twitch user
|
||||
BOTUSER_SCOPES = Scopes((
|
||||
Scopes.user_read_chat,
|
||||
Scopes.user_write_chat,
|
||||
Scopes.user_bot,
|
||||
Scopes.channel_bot,
|
||||
))
|
||||
|
||||
# Default requested scopes for joining a channel
|
||||
CHANNEL_SCOPES = Scopes((
|
||||
Scopes.channel_bot,
|
||||
))
|
||||
|
||||
1
src/data
Submodule
1
src/data
Submodule
Submodule src/data added at c495b4d097
@@ -1,9 +0,0 @@
|
||||
from .conditions import Condition, condition, NULL
|
||||
from .database import Database
|
||||
from .models import RowModel, RowTable, WeakCache
|
||||
from .table import Table
|
||||
from .base import Expression, RawExpr
|
||||
from .columns import ColumnExpr, Column, Integer, String
|
||||
from .registry import Registry, AttachableClass, Attachable
|
||||
from .adapted import RegisterEnum
|
||||
from .queries import ORDER, NULLS, JOINTYPE
|
||||
@@ -1,40 +0,0 @@
|
||||
# from enum import Enum
|
||||
from typing import Optional
|
||||
from psycopg.types.enum import register_enum, EnumInfo
|
||||
from psycopg import AsyncConnection
|
||||
from .registry import Attachable, Registry
|
||||
|
||||
|
||||
class RegisterEnum(Attachable):
|
||||
def __init__(self, enum, name: Optional[str] = None, mapper=None):
|
||||
super().__init__()
|
||||
self.enum = enum
|
||||
self.name = name or enum.__name__
|
||||
self.mapping = mapper(enum) if mapper is not None else self._mapper()
|
||||
|
||||
def _mapper(self):
|
||||
return {m: m.value[0] for m in self.enum}
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
registry.init_task(self.on_init)
|
||||
return self
|
||||
|
||||
async def on_init(self, registry: Registry):
|
||||
connector = registry._conn
|
||||
if connector is None:
|
||||
raise ValueError("Cannot initialise without connector!")
|
||||
connector.connect_hook(self.connection_hook)
|
||||
# await connector.refresh_pool()
|
||||
# The below may be somewhat dangerous
|
||||
# But adaption should never write to the database
|
||||
await connector.map_over_pool(self.connection_hook)
|
||||
# if conn := connector.conn:
|
||||
# # Ensure the adaption is run in the current context as well
|
||||
# await self.connection_hook(conn)
|
||||
|
||||
async def connection_hook(self, conn: AsyncConnection):
|
||||
info = await EnumInfo.fetch(conn, self.name)
|
||||
if info is None:
|
||||
raise ValueError(f"Enum {self.name} not found in database.")
|
||||
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
||||
@@ -1,45 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Expression(Protocol):
|
||||
__slots__ = ()
|
||||
|
||||
@abstractmethod
|
||||
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RawExpr(Expression):
|
||||
__slots__ = ('expr', 'values')
|
||||
|
||||
expr: sql.Composable
|
||||
values: tuple[Any, ...]
|
||||
|
||||
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
|
||||
self.expr = expr
|
||||
self.values = values
|
||||
|
||||
def as_tuple(self):
|
||||
return (self.expr, self.values)
|
||||
|
||||
@classmethod
|
||||
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
|
||||
"""
|
||||
Join a sequence of Expressions into a single RawExpr.
|
||||
"""
|
||||
tups = (
|
||||
expression.as_tuple()
|
||||
for expression in expressions
|
||||
)
|
||||
return cls.join_tuples(*tups, joiner=joiner)
|
||||
|
||||
@classmethod
|
||||
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
|
||||
exprs, values = zip(*tuples)
|
||||
expr = joiner.join(exprs)
|
||||
value = tuple(chain(*values))
|
||||
return cls(expr, value)
|
||||
@@ -1,155 +0,0 @@
|
||||
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
|
||||
from psycopg import sql
|
||||
from datetime import datetime
|
||||
|
||||
from .base import RawExpr, Expression
|
||||
from .conditions import Condition, Joiner
|
||||
from .table import Table
|
||||
|
||||
|
||||
class ColumnExpr(RawExpr):
|
||||
__slots__ = ()
|
||||
|
||||
def __lt__(self, obj) -> Condition:
|
||||
expr, values = self.as_tuple()
|
||||
|
||||
if isinstance(obj, Expression):
|
||||
# column < Expression
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
cond_exprs = (expr, Joiner.LT, obj_expr)
|
||||
cond_values = (*values, *obj_values)
|
||||
else:
|
||||
# column < Literal
|
||||
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
|
||||
cond_values = (*values, obj)
|
||||
|
||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __le__(self, obj) -> Condition:
|
||||
expr, values = self.as_tuple()
|
||||
|
||||
if isinstance(obj, Expression):
|
||||
# column <= Expression
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
cond_exprs = (expr, Joiner.LE, obj_expr)
|
||||
cond_values = (*values, *obj_values)
|
||||
else:
|
||||
# column <= Literal
|
||||
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
|
||||
cond_values = (*values, obj)
|
||||
|
||||
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __eq__(self, obj) -> Condition: # type: ignore[override]
|
||||
return Condition._expression_equality(self, obj)
|
||||
|
||||
def __ne__(self, obj) -> Condition: # type: ignore[override]
|
||||
return ~(self.__eq__(obj))
|
||||
|
||||
def __gt__(self, obj) -> Condition:
|
||||
return ~(self.__le__(obj))
|
||||
|
||||
def __ge__(self, obj) -> Condition:
|
||||
return ~(self.__lt__(obj))
|
||||
|
||||
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} + {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def __sub__(self, obj) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} - {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def __mul__(self, obj) -> 'ColumnExpr':
|
||||
if isinstance(obj, Expression):
|
||||
obj_expr, obj_values = obj.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * {})").format(self.expr, obj_expr),
|
||||
(*self.values, *obj_values)
|
||||
)
|
||||
else:
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
|
||||
(*self.values, obj)
|
||||
)
|
||||
|
||||
def CAST(self, target_type: sql.Composable):
|
||||
return ColumnExpr(
|
||||
sql.SQL("({}::{})").format(self.expr, target_type),
|
||||
self.values
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import RowModel
|
||||
|
||||
|
||||
class Column(ColumnExpr, Generic[T]):
|
||||
def __init__(self, name: Optional[str] = None,
|
||||
primary: bool = False, references: Optional['Column'] = None,
|
||||
type: Optional[Type[T]] = None):
|
||||
self.primary = primary
|
||||
self.references = references
|
||||
self.name: str = name # type: ignore
|
||||
self.owner: Optional['RowModel'] = None
|
||||
self._type = type
|
||||
|
||||
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
||||
self.values = ()
|
||||
|
||||
def __set_name__(self, owner, name):
|
||||
# Only allow setting the owner once
|
||||
self.name = self.name or name
|
||||
self.owner = owner
|
||||
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
|
||||
|
||||
@overload
|
||||
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
|
||||
...
|
||||
|
||||
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
|
||||
# Get value from row data or session
|
||||
if obj is None:
|
||||
return self
|
||||
else:
|
||||
return obj.data[self.name]
|
||||
|
||||
|
||||
class Integer(Column[int]):
|
||||
pass
|
||||
|
||||
|
||||
class String(Column[str]):
|
||||
pass
|
||||
|
||||
|
||||
class Bool(Column[bool]):
|
||||
pass
|
||||
|
||||
|
||||
class Timestamp(Column[datetime]):
|
||||
pass
|
||||
@@ -1,214 +0,0 @@
|
||||
# from meta import sharding
|
||||
from typing import Any, Union
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
from .base import Expression, RawExpr
|
||||
|
||||
|
||||
"""
|
||||
A Condition is a "logical" database expression, intended for use in Where statements.
|
||||
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
|
||||
"""
|
||||
|
||||
NULL = None
|
||||
|
||||
|
||||
class Joiner(Enum):
|
||||
EQUALS = ('=', '!=')
|
||||
IS = ('IS', 'IS NOT')
|
||||
LIKE = ('LIKE', 'NOT LIKE')
|
||||
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
|
||||
IN = ('IN', 'NOT IN')
|
||||
LT = ('<', '>=')
|
||||
LE = ('<=', '>')
|
||||
NONE = ('', '')
|
||||
|
||||
|
||||
class Condition(Expression):
|
||||
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
|
||||
|
||||
def __init__(self,
|
||||
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
|
||||
values: tuple[Any, ...] = (), negated=False
|
||||
):
|
||||
self.expr1 = expr1
|
||||
self.joiner = joiner
|
||||
self.negated = negated
|
||||
self.expr2 = expr2
|
||||
self.values = values
|
||||
|
||||
def as_tuple(self):
|
||||
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
|
||||
if self.negated and self.joiner is Joiner.NONE:
|
||||
expr = sql.SQL("NOT ({})").format(expr)
|
||||
return (expr, self.values)
|
||||
|
||||
@classmethod
|
||||
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
|
||||
"""
|
||||
Construct a Condition from a sequence of Conditions,
|
||||
together with some explicit column conditions.
|
||||
"""
|
||||
# TODO: Consider adding a _table identifier here so we can identify implicit columns
|
||||
# Or just require subquery type conditions to always come from modelled tables.
|
||||
implicit_conditions = (
|
||||
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
|
||||
)
|
||||
return cls._and(*conditions, *implicit_conditions)
|
||||
|
||||
@classmethod
|
||||
def _and(cls, *conditions: 'Condition'):
|
||||
if not len(conditions):
|
||||
raise ValueError("Cannot combine 0 Conditions")
|
||||
if len(conditions) == 1:
|
||||
return conditions[0]
|
||||
|
||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||
cond_values = tuple(chain(*values))
|
||||
|
||||
return Condition(cond_expr, values=cond_values)
|
||||
|
||||
@classmethod
|
||||
def _or(cls, *conditions: 'Condition'):
|
||||
if not len(conditions):
|
||||
raise ValueError("Cannot combine 0 Conditions")
|
||||
if len(conditions) == 1:
|
||||
return conditions[0]
|
||||
|
||||
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||
cond_values = tuple(chain(*values))
|
||||
|
||||
return Condition(cond_expr, values=cond_values)
|
||||
|
||||
@classmethod
|
||||
def _not(cls, condition: 'Condition'):
|
||||
condition.negated = not condition.negated
|
||||
return condition
|
||||
|
||||
@classmethod
|
||||
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
|
||||
# TODO: Check if this supports sbqueries
|
||||
col_expr, col_values = column.as_tuple()
|
||||
|
||||
# TODO: Also support sql.SQL? For joins?
|
||||
if isinstance(value, Expression):
|
||||
# column = Expression
|
||||
value_expr, value_values = value.as_tuple()
|
||||
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
|
||||
cond_values = (*col_values, *value_values)
|
||||
elif isinstance(value, (tuple, list)):
|
||||
# column in (...)
|
||||
# TODO: Support expressions in value tuple?
|
||||
if not value:
|
||||
raise ValueError("Cannot create Condition from empty iterable!")
|
||||
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
|
||||
cond_exprs = (col_expr, Joiner.IN, value_expr)
|
||||
cond_values = (*col_values, *value)
|
||||
elif value is None:
|
||||
# column IS NULL
|
||||
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
|
||||
cond_values = col_values
|
||||
else:
|
||||
# column = Literal
|
||||
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
|
||||
cond_values = (*col_values, value)
|
||||
|
||||
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||
|
||||
def __invert__(self) -> 'Condition':
|
||||
self.negated = not self.negated
|
||||
return self
|
||||
|
||||
def __and__(self, condition: 'Condition') -> 'Condition':
|
||||
return self._and(self, condition)
|
||||
|
||||
def __or__(self, condition: 'Condition') -> 'Condition':
|
||||
return self._or(self, condition)
|
||||
|
||||
|
||||
# Helper method to simply condition construction
|
||||
def condition(*args, **kwargs) -> Condition:
|
||||
return Condition.construct(*args, **kwargs)
|
||||
|
||||
|
||||
# class NOT(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# if item:
|
||||
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
# values.extend(item)
|
||||
# else:
|
||||
# raise ValueError("Cannot check an empty iterable!")
|
||||
# else:
|
||||
# conditions.append("{}!={}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class GEQ(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# raise ValueError("Cannot apply GEQ condition to a list!")
|
||||
# else:
|
||||
# conditions.append("{} >= {}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class LEQ(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# item = self.value
|
||||
# if isinstance(item, (list, tuple)):
|
||||
# raise ValueError("Cannot apply LEQ condition to a list!")
|
||||
# else:
|
||||
# conditions.append("{} <= {}".format(key, _replace_char))
|
||||
# values.append(item)
|
||||
#
|
||||
#
|
||||
# class Constant(Condition):
|
||||
# __slots__ = ('value',)
|
||||
#
|
||||
# def __init__(self, value):
|
||||
# self.value = value
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# conditions.append("{} {}".format(key, self.value))
|
||||
#
|
||||
#
|
||||
# class SHARDID(Condition):
|
||||
# __slots__ = ('shardid', 'shard_count')
|
||||
#
|
||||
# def __init__(self, shardid, shard_count):
|
||||
# self.shardid = shardid
|
||||
# self.shard_count = shard_count
|
||||
#
|
||||
# def apply(self, key, values, conditions):
|
||||
# if self.shard_count > 1:
|
||||
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
|
||||
# values.append(self.shardid)
|
||||
#
|
||||
#
|
||||
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
|
||||
#
|
||||
#
|
||||
# NULL = Constant('IS NULL')
|
||||
# NOTNULL = Constant('IS NOT NULL')
|
||||
@@ -1,135 +0,0 @@
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||
import logging
|
||||
|
||||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
import psycopg as psq
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.pq import TransactionStatus
|
||||
|
||||
from .cursor import AsyncLoggingCursor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
row_factory = psq.rows.dict_row
|
||||
|
||||
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||
|
||||
|
||||
class Connector:
|
||||
cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, conn_args):
|
||||
self._conn_args = conn_args
|
||||
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||
|
||||
self.pool = self.make_pool()
|
||||
|
||||
self.conn_hooks = []
|
||||
|
||||
@property
|
||||
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||
"""
|
||||
Convenience property for the current context connection.
|
||||
"""
|
||||
return ctx_connection.get()
|
||||
|
||||
@conn.setter
|
||||
def conn(self, conn: psq.AsyncConnection):
|
||||
"""
|
||||
Set the contextual connection in the current context.
|
||||
Always do this in an isolated context!
|
||||
"""
|
||||
ctx_connection.set(conn)
|
||||
|
||||
def make_pool(self) -> AsyncConnectionPool:
|
||||
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||
return AsyncConnectionPool(
|
||||
self._conn_args,
|
||||
open=False,
|
||||
min_size=1,
|
||||
max_size=4,
|
||||
configure=self._setup_connection,
|
||||
kwargs=self._conn_kwargs
|
||||
)
|
||||
|
||||
async def refresh_pool(self):
|
||||
"""
|
||||
Refresh the pool.
|
||||
|
||||
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||
Better ways should be sought (a way to
|
||||
"""
|
||||
logger.info("Pool refresh requested, closing and reopening.")
|
||||
old_pool = self.pool
|
||||
self.pool = self.make_pool()
|
||||
await self.pool.open()
|
||||
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||
await old_pool.close()
|
||||
logger.info("Pool refresh complete.")
|
||||
|
||||
async def map_over_pool(self, callable):
|
||||
"""
|
||||
Dangerous method to call a method on each connection in the pool.
|
||||
|
||||
Utilises private methods of the AsyncConnectionPool.
|
||||
"""
|
||||
async with self.pool._lock:
|
||||
conns = list(self.pool._pool)
|
||||
while conns:
|
||||
conn = conns.pop()
|
||||
try:
|
||||
await callable(conn)
|
||||
except Exception:
|
||||
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self):
|
||||
try:
|
||||
logger.info("Opening database pool.")
|
||||
await self.pool.open()
|
||||
yield
|
||||
finally:
|
||||
# May be a different pool!
|
||||
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||
await self.pool.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> psq.AsyncConnection:
|
||||
"""
|
||||
Asynchronous context manager to get and manage a connection.
|
||||
|
||||
If the context connection is set, uses this and does not manage the lifetime.
|
||||
Otherwise, requests a new connection from the pool and returns it when done.
|
||||
"""
|
||||
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||
if (conn := self.conn):
|
||||
yield conn
|
||||
else:
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||
for hook in self.conn_hooks:
|
||||
try:
|
||||
await hook(conn)
|
||||
except Exception:
|
||||
logger.exception("Exception encountered setting up new connection")
|
||||
return conn
|
||||
|
||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||
"""
|
||||
Minimal decorator to register a coroutine to run on connect or reconnect.
|
||||
|
||||
Note that these are only run on connect and reconnect.
|
||||
If a hook is registered after connection, it will not be run.
|
||||
"""
|
||||
self.conn_hooks.append(coro)
|
||||
return coro
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Connectable(Protocol):
|
||||
def bind(self, connector: Connector):
|
||||
raise NotImplementedError
|
||||
@@ -1,42 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from psycopg import AsyncCursor, sql
|
||||
from psycopg.abc import Query, Params
|
||||
from psycopg._encodings import conn_encoding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncLoggingCursor(AsyncCursor):
|
||||
def mogrify_query(self, query: Query):
|
||||
if isinstance(query, str):
|
||||
msg = query
|
||||
elif isinstance(query, (sql.SQL, sql.Composed)):
|
||||
msg = query.as_string(self)
|
||||
elif isinstance(query, bytes):
|
||||
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
|
||||
else:
|
||||
msg = repr(query)
|
||||
return msg
|
||||
|
||||
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
|
||||
if logging.DEBUG >= logger.getEffectiveLevel():
|
||||
msg = self.mogrify_query(query)
|
||||
logger.debug(
|
||||
"Executing query (%s) with values %s", msg, params,
|
||||
extra={'action': "Query Execute"}
|
||||
)
|
||||
try:
|
||||
return await super().execute(query, params=params, **kwargs)
|
||||
except Exception:
|
||||
msg = self.mogrify_query(query)
|
||||
logger.exception(
|
||||
"Exception during query execution. Query (%s) with parameters %s.",
|
||||
msg, params,
|
||||
extra={'action': "Query Execute"},
|
||||
stack_info=True
|
||||
)
|
||||
else:
|
||||
# TODO: Possibly log execution time
|
||||
pass
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TypeVar
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
# from .cursor import AsyncLoggingCursor
|
||||
from .registry import Registry
|
||||
from .connector import Connector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Version = namedtuple('Version', ('version', 'time', 'author'))
|
||||
|
||||
T = TypeVar('T', bound=Registry)
|
||||
|
||||
|
||||
class Database(Connector):
|
||||
# cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.registries: dict[str, Registry] = {}
|
||||
|
||||
def load_registry(self, registry: T) -> T:
|
||||
logger.debug(
|
||||
f"Loading and binding registry '{registry.name}'.",
|
||||
extra={'action': f"Reg {registry.name}"}
|
||||
)
|
||||
registry.bind(self)
|
||||
self.registries[registry.name] = registry
|
||||
return registry
|
||||
|
||||
async def version(self) -> Version:
|
||||
"""
|
||||
Return the current schema version as a Version namedtuple.
|
||||
"""
|
||||
async with self.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Get last entry in version table, compare against desired version
|
||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Version(row['version'], row['time'], row['author'])
|
||||
else:
|
||||
# No versions in the database
|
||||
return Version(-1, None, None)
|
||||
@@ -1,323 +0,0 @@
|
||||
from typing import TypeVar, Type, Optional, Generic, Union
|
||||
# from typing_extensions import Self
|
||||
from weakref import WeakValueDictionary
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from psycopg.rows import DictRow
|
||||
|
||||
from .table import Table
|
||||
from .columns import Column
|
||||
from . import queries as q
|
||||
from .connector import Connector
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
RowT = TypeVar('RowT', bound='RowModel')
|
||||
|
||||
|
||||
class MISSING:
|
||||
__slots__ = ('oid',)
|
||||
|
||||
def __init__(self, oid):
|
||||
self.oid = oid
|
||||
|
||||
|
||||
class RowTable(Table, Generic[RowT]):
|
||||
__slots__ = (
|
||||
'model',
|
||||
)
|
||||
|
||||
def __init__(self, name, model: Type[RowT], **kwargs):
|
||||
super().__init__(name, **kwargs)
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return self.model._columns_
|
||||
|
||||
@property
|
||||
def id_col(self):
|
||||
return self.model._key_
|
||||
|
||||
@property
|
||||
def row_cache(self):
|
||||
return self.model._cache_
|
||||
|
||||
def _many_query_adapter(self, *data):
|
||||
self.model._make_rows(*data)
|
||||
return data
|
||||
|
||||
def _single_query_adapter(self, *data):
|
||||
if data:
|
||||
self.model._make_rows(*data)
|
||||
return data[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _delete_query_adapter(self, *data):
|
||||
self.model._delete_rows(*data)
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
async def create_row(self, *args, **kwargs) -> RowT:
|
||||
data = await super().insert(*args, **kwargs)
|
||||
return self.model._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
||||
# TODO: Handle list of rowids here?
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self.model._make_rows,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
|
||||
WK = TypeVar('WK')
|
||||
WV = TypeVar('WV')
|
||||
|
||||
|
||||
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
||||
def __init__(self, ref_cache):
|
||||
self.ref_cache = ref_cache
|
||||
self.weak_cache = WeakValueDictionary()
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self.weak_cache[key]
|
||||
self.ref_cache[key] = value
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.weak_cache[key] = value
|
||||
self.ref_cache[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.weak_cache[key]
|
||||
try:
|
||||
del self.ref_cache[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.weak_cache
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.weak_cache)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.weak_cache)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def pop(self, key, default=None):
|
||||
if key in self:
|
||||
value = self[key]
|
||||
del self[key]
|
||||
else:
|
||||
value = default
|
||||
return value
|
||||
|
||||
|
||||
# TODO: Implement getitem and setitem, for dynamic column access
|
||||
class RowModel:
|
||||
__slots__ = ('data',)
|
||||
|
||||
_schema_: str = 'public'
|
||||
_tablename_: Optional[str] = None
|
||||
_columns_: dict[str, Column] = {}
|
||||
|
||||
# Cache to keep track of registered Rows
|
||||
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
||||
|
||||
_key_: tuple[str, ...] = ()
|
||||
_connector: Optional[Connector] = None
|
||||
_registry: Optional[Registry] = None
|
||||
|
||||
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
||||
table: RowTable = None
|
||||
|
||||
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
||||
"""
|
||||
Set table, _columns_, and _key_.
|
||||
"""
|
||||
if table is not None:
|
||||
cls._tablename_ = table
|
||||
|
||||
if cls._tablename_ is not None:
|
||||
columns = {}
|
||||
for key, value in cls.__dict__.items():
|
||||
if isinstance(value, Column):
|
||||
columns[key] = value
|
||||
|
||||
cls._columns_ = columns
|
||||
if not cls._key_:
|
||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
||||
if cls._cache_ is None:
|
||||
cls._cache_ = WeakValueDictionary()
|
||||
|
||||
def __new__(cls, data):
|
||||
# Registry pattern.
|
||||
# Ensure each rowid always refers to a single Model instance
|
||||
if data is not None:
|
||||
rowid = cls._id_from_data(data)
|
||||
|
||||
cache = cls._cache_
|
||||
|
||||
if (row := cache.get(rowid, None)) is not None:
|
||||
obj = row
|
||||
else:
|
||||
obj = cache[rowid] = super().__new__(cls)
|
||||
else:
|
||||
obj = super().__new__(cls)
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def as_tuple(cls):
|
||||
return (cls.table.identifier, ())
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
@classmethod
|
||||
def bind(cls, connector: Connector):
|
||||
if cls.table is None:
|
||||
raise ValueError("Cannot bind abstract RowModel")
|
||||
cls._connector = connector
|
||||
cls.table.bind(connector)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def attach_to(cls, registry: Registry):
|
||||
cls._registry = registry
|
||||
return cls
|
||||
|
||||
@property
|
||||
def _dict_(self):
|
||||
return {key: self.data[key] for key in self._key_}
|
||||
|
||||
@property
|
||||
def _rowid_(self):
|
||||
return tuple(self.data[key] for key in self._key_)
|
||||
|
||||
def __repr__(self):
|
||||
return "{}.{}({})".format(
|
||||
self.table.schema,
|
||||
self.table.name,
|
||||
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _id_from_data(cls, data):
|
||||
return tuple(data[key] for key in cls._key_)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_id(cls, rowid):
|
||||
return dict(zip(cls._key_, rowid))
|
||||
|
||||
@classmethod
|
||||
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
# TODO: Handle partial row data here somehow?
|
||||
rows = [cls(data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
def _delete_rows(cls, *data_rows):
|
||||
"""
|
||||
Remove the given rows from cache, if they exist.
|
||||
May be extended to handle object deletion.
|
||||
"""
|
||||
cache = cls._cache_
|
||||
|
||||
for data_row in data_rows:
|
||||
rowid = cls._id_from_data(data_row)
|
||||
cache.pop(rowid, None)
|
||||
|
||||
@classmethod
|
||||
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
||||
return await cls.table.create_row(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
||||
return cls.table.fetch_rows_where(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = cls._cache_.get(rowid, None) if cached else None
|
||||
if row is None:
|
||||
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
||||
row = rows[0] if rows else None
|
||||
if row is None:
|
||||
cls._cache_[rowid] = cls(None)
|
||||
elif row.data is None:
|
||||
row = None
|
||||
|
||||
return row
|
||||
|
||||
@classmethod
|
||||
async def fetch_or_create(cls, *rowid, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid:
|
||||
row = await cls.fetch(*rowid)
|
||||
else:
|
||||
rows = await cls.fetch_where(**kwargs).limit(1)
|
||||
row = rows[0] if rows else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid:
|
||||
creation_kwargs.update(cls._dict_from_id(rowid))
|
||||
row = await cls.create(**creation_kwargs)
|
||||
return row
|
||||
|
||||
async def refresh(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Refresh this Row from data.
|
||||
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
rows = await self.table.select_where(**self._dict_)
|
||||
if not rows:
|
||||
return None
|
||||
else:
|
||||
self.data = rows[0]
|
||||
return self
|
||||
|
||||
async def update(self: RowT, **values) -> Optional[RowT]:
|
||||
"""
|
||||
Update this Row with the given values.
|
||||
|
||||
Internally passes the provided `values` to the `update` Query.
|
||||
The return value may be `None` if the row was deleted.
|
||||
"""
|
||||
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
||||
if not data:
|
||||
return None
|
||||
else:
|
||||
return data[0]
|
||||
|
||||
async def delete(self: RowT) -> Optional[RowT]:
|
||||
"""
|
||||
Delete this Row.
|
||||
"""
|
||||
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
||||
return data[0] if data is not None else None
|
||||
@@ -1,644 +0,0 @@
|
||||
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from psycopg import AsyncConnection, AsyncCursor
|
||||
from psycopg import sql
|
||||
from psycopg.rows import DictRow
|
||||
|
||||
import logging
|
||||
|
||||
from .conditions import Condition
|
||||
from .base import Expression, RawExpr
|
||||
from .connector import Connector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TQueryT = TypeVar('TQueryT', bound='TableQuery')
|
||||
SQueryT = TypeVar('SQueryT', bound='Select')
|
||||
|
||||
QueryResult = TypeVar('QueryResult')
|
||||
|
||||
|
||||
class Query(Generic[QueryResult]):
|
||||
"""
|
||||
ABC for an executable query statement.
|
||||
"""
|
||||
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
|
||||
|
||||
_adapter: Callable[..., QueryResult]
|
||||
|
||||
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
|
||||
self.connector: Optional[Connector] = connector
|
||||
self.conn: Optional[AsyncConnection] = conn
|
||||
self.cursor: Optional[AsyncCursor] = cursor
|
||||
|
||||
if row_adapter is not None:
|
||||
self._adapter = row_adapter
|
||||
else:
|
||||
self._adapter = self._no_adapter
|
||||
|
||||
self.result: Optional[QueryResult] = None
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self.connector = connector
|
||||
return self
|
||||
|
||||
def with_cursor(self, cursor: AsyncCursor):
|
||||
self.cursor = cursor
|
||||
return self
|
||||
|
||||
def with_connection(self, conn: AsyncConnection):
|
||||
self.conn = conn
|
||||
return self
|
||||
|
||||
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def with_adapter(self, callable: Callable[..., QueryResult]):
|
||||
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
|
||||
# For this to work cleanly, callable should have arg type of QR1, not any
|
||||
self._adapter = callable
|
||||
return self
|
||||
|
||||
def with_no_adapter(self):
|
||||
"""
|
||||
Sets the adapater to the identity.
|
||||
"""
|
||||
self._adapter = self._no_adapter
|
||||
return self
|
||||
|
||||
def one(self):
|
||||
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
||||
return self
|
||||
|
||||
def build(self) -> Expression:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
|
||||
query, values = self.build().as_tuple()
|
||||
# TODO: Move logging out to a custom cursor
|
||||
# logger.debug(
|
||||
# f"Executing query ({query.as_string(cursor)}) with values {values}",
|
||||
# extra={'action': "Query"}
|
||||
# )
|
||||
await cursor.execute(sql.Composed((query,)), values)
|
||||
data = await cursor.fetchall()
|
||||
self.result = self._adapter(*data)
|
||||
return self.result
|
||||
|
||||
async def execute(self, cursor=None) -> QueryResult:
|
||||
"""
|
||||
Execute the query, optionally with the provided cursor, and return the result rows.
|
||||
If no cursor is provided, and no cursor has been set with `with_cursor`,
|
||||
the execution will create a new cursor from the connection and close it automatically.
|
||||
"""
|
||||
# Create a cursor if possible
|
||||
cursor = cursor if cursor is not None else self.cursor
|
||||
if self.cursor is None:
|
||||
if self.conn is None:
|
||||
if self.connector is None:
|
||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||
else:
|
||||
async with self.connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
async with self.conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
data = await self._execute(cursor)
|
||||
return data
|
||||
|
||||
def __await__(self):
|
||||
return self.execute().__await__()
|
||||
|
||||
|
||||
class TableQuery(Query[QueryResult]):
|
||||
"""
|
||||
ABC for an executable query statement expected to be run on a single table.
|
||||
"""
|
||||
__slots__ = (
|
||||
'tableid',
|
||||
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
|
||||
)
|
||||
|
||||
def __init__(self, tableid, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tableid: sql.Identifier = tableid
|
||||
|
||||
def options(self, **kwargs):
|
||||
"""
|
||||
Set some query options.
|
||||
Default implementation does nothing.
|
||||
Should be overridden to provide specific options.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class WhereMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.condition: Optional[Condition] = None
|
||||
|
||||
def where(self, *args: Condition, **kwargs):
|
||||
"""
|
||||
Add a Condition to the query.
|
||||
Position arguments should be Conditions,
|
||||
and keyword arguments should be of the form `column=Value`,
|
||||
where Value may be a Value-type or a literal value.
|
||||
All provided Conditions will be and-ed together to create a new Condition.
|
||||
TODO: Maybe just pass this verbatim to a condition.
|
||||
"""
|
||||
if args or kwargs:
|
||||
condition = Condition.construct(*args, **kwargs)
|
||||
if self.condition is not None:
|
||||
condition = self.condition & condition
|
||||
|
||||
self.condition = condition
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def _where_section(self) -> Optional[Expression]:
|
||||
if self.condition is not None:
|
||||
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class JOINTYPE(Enum):
|
||||
LEFT = sql.SQL('LEFT JOIN')
|
||||
RIGHT = sql.SQL('RIGHT JOIN')
|
||||
INNER = sql.SQL('INNER JOIN')
|
||||
OUTER = sql.SQL('OUTER JOIN')
|
||||
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
||||
|
||||
|
||||
class JoinMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
# TODO: Remember to add join slots to TableQuery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._joins: list[Expression] = []
|
||||
|
||||
def join(self,
|
||||
target: Union[str, Expression],
|
||||
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
||||
join_type: JOINTYPE = JOINTYPE.INNER,
|
||||
natural=False):
|
||||
available = (on is not None) + (using is not None) + natural
|
||||
if available == 0:
|
||||
raise ValueError("No conditions given for Query Join")
|
||||
if available > 1:
|
||||
raise ValueError("Exactly one join format must be given for Query Join")
|
||||
|
||||
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
||||
if isinstance(target, str):
|
||||
sections.append((sql.Identifier(target), ()))
|
||||
else:
|
||||
sections.append(target.as_tuple())
|
||||
|
||||
if on is not None:
|
||||
sections.append((sql.SQL('ON'), ()))
|
||||
sections.append(on.as_tuple())
|
||||
elif using is not None:
|
||||
sections.append((sql.SQL('USING'), ()))
|
||||
if isinstance(using, Expression):
|
||||
sections.append(using.as_tuple())
|
||||
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
|
||||
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
|
||||
sections.append((cols, ()))
|
||||
else:
|
||||
raise ValueError("Unrecognised 'using' type.")
|
||||
elif natural:
|
||||
sections.insert(0, (sql.SQL('NATURAL'), ()))
|
||||
|
||||
expr = RawExpr.join_tuples(*sections)
|
||||
self._joins.append(expr)
|
||||
return self
|
||||
|
||||
def leftjoin(self, *args, **kwargs):
|
||||
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
||||
|
||||
@property
|
||||
def _join_section(self) -> Optional[Expression]:
|
||||
if self._joins:
|
||||
return RawExpr.join(*self._joins)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ExtraMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._extra: Optional[Expression] = None
|
||||
|
||||
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
|
||||
"""
|
||||
Add an extra string, and optionally values, to this query.
|
||||
The extra string is inserted after any condition, and before the limit.
|
||||
"""
|
||||
extra_expr = RawExpr(extra, values)
|
||||
if self._extra is not None:
|
||||
extra_expr = RawExpr.join(self._extra, extra_expr)
|
||||
self._extra = extra_expr
|
||||
return self
|
||||
|
||||
@property
|
||||
def _extra_section(self) -> Optional[Expression]:
|
||||
if self._extra is None:
|
||||
return None
|
||||
else:
|
||||
return self._extra
|
||||
|
||||
|
||||
class LimitMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._limit: Optional[int] = None
|
||||
|
||||
def limit(self, limit: int):
|
||||
"""
|
||||
Add a limit to this query.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
@property
|
||||
def _limit_section(self) -> Optional[Expression]:
|
||||
if self._limit is not None:
|
||||
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class FromMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._from: Optional[Expression] = None
|
||||
|
||||
def from_expr(self, _from: Expression):
|
||||
self._from = _from
|
||||
return self
|
||||
|
||||
@property
|
||||
def _from_section(self) -> Optional[Expression]:
|
||||
if self._from is not None:
|
||||
expr, values = self._from.as_tuple()
|
||||
return RawExpr(sql.SQL("FROM {}").format(expr), values)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ORDER(Enum):
|
||||
ASC = sql.SQL('ASC')
|
||||
DESC = sql.SQL('DESC')
|
||||
|
||||
|
||||
class NULLS(Enum):
|
||||
FIRST = sql.SQL('NULLS FIRST')
|
||||
LAST = sql.SQL('NULLS LAST')
|
||||
|
||||
|
||||
class OrderMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._order: list[Expression] = []
|
||||
|
||||
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
|
||||
"""
|
||||
Add a single sort expression to the query.
|
||||
This method stacks.
|
||||
"""
|
||||
if isinstance(expr, Expression):
|
||||
string, values = expr.as_tuple()
|
||||
else:
|
||||
string = sql.Identifier(expr)
|
||||
values = ()
|
||||
|
||||
parts = [string]
|
||||
if direction is not None:
|
||||
parts.append(direction.value)
|
||||
if nulls is not None:
|
||||
parts.append(nulls.value)
|
||||
|
||||
order_string = sql.SQL(' ').join(parts)
|
||||
self._order.append(RawExpr(order_string, values))
|
||||
return self
|
||||
|
||||
@property
|
||||
def _order_section(self) -> Optional[Expression]:
|
||||
if self._order:
|
||||
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
|
||||
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
|
||||
return expr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class GroupMixin(TableQuery[QueryResult]):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._group: list[Expression] = []
|
||||
|
||||
def group_by(self, *exprs: Union[Expression, str]):
|
||||
"""
|
||||
Add a group expression(s) to the query.
|
||||
This method stacks.
|
||||
"""
|
||||
for expr in exprs:
|
||||
if isinstance(expr, Expression):
|
||||
self._group.append(expr)
|
||||
else:
|
||||
self._group.append(RawExpr(sql.Identifier(expr)))
|
||||
return self
|
||||
|
||||
@property
|
||||
def _group_section(self) -> Optional[Expression]:
|
||||
if self._group:
|
||||
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
|
||||
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
|
||||
return expr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Insert(ExtraMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Query type representing a table insert query.
|
||||
"""
|
||||
# TODO: Support ON CONFLICT for upserts
|
||||
__slots__ = ('_columns', '_values', '_conflict')
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._columns: tuple[str, ...] = ()
|
||||
self._values: tuple[tuple[Any, ...], ...] = ()
|
||||
self._conflict: Optional[Expression] = None
|
||||
|
||||
def insert(self, columns, *values):
|
||||
"""
|
||||
Insert the given data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: tuple[str]
|
||||
Tuple of column names to insert.
|
||||
|
||||
values: tuple[tuple[Any, ...], ...]
|
||||
Tuple of values to insert, corresponding to the columns.
|
||||
"""
|
||||
if not values:
|
||||
raise ValueError("Cannot insert zero rows.")
|
||||
if len(values[0]) != len(columns):
|
||||
raise ValueError("Number of columns does not match length of values.")
|
||||
|
||||
self._columns = columns
|
||||
self._values = values
|
||||
return self
|
||||
|
||||
def on_conflict(self, ignore=False):
|
||||
# TODO lots more we can do here
|
||||
# Maybe return a Conflict object that can chain itself (not the query)
|
||||
if ignore:
|
||||
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
||||
return self
|
||||
|
||||
@property
|
||||
def _conflict_section(self) -> Optional[Expression]:
|
||||
if self._conflict is not None:
|
||||
e, v = self._conflict.as_tuple()
|
||||
expr = RawExpr(
|
||||
sql.SQL("ON CONFLICT {}").format(
|
||||
e
|
||||
),
|
||||
v
|
||||
)
|
||||
return expr
|
||||
return None
|
||||
|
||||
def build(self):
|
||||
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
||||
single_value_str = sql.SQL('({})').format(
|
||||
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
|
||||
)
|
||||
values_str = sql.SQL(',').join(single_value_str * len(self._values))
|
||||
|
||||
# TODO: Check efficiency of inserting multiple values like this
|
||||
# Also implement a Copy query
|
||||
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
||||
table=self.tableid,
|
||||
columns=columns,
|
||||
values_str=values_str
|
||||
)
|
||||
|
||||
sections = [
|
||||
RawExpr(base, tuple(chain(*self._values))),
|
||||
self._conflict_section,
|
||||
self._extra_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Select rows from a table matching provided conditions.
|
||||
"""
|
||||
__slots__ = ('_columns',)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._columns: tuple[Expression, ...] = ()
|
||||
|
||||
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
|
||||
"""
|
||||
Set the columns and expressions to select.
|
||||
If none are given, selects all columns.
|
||||
"""
|
||||
cols: List[Expression] = []
|
||||
if columns:
|
||||
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
|
||||
if exprs:
|
||||
for name, expr in exprs.items():
|
||||
if isinstance(expr, str):
|
||||
cols.append(
|
||||
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
|
||||
)
|
||||
elif isinstance(expr, sql.Composable):
|
||||
cols.append(
|
||||
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
|
||||
)
|
||||
elif isinstance(expr, Expression):
|
||||
value_expr, value_values = expr.as_tuple()
|
||||
cols.append(RawExpr(
|
||||
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
|
||||
value_values
|
||||
))
|
||||
if cols:
|
||||
self._columns = (*self._columns, *cols)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
if not self._columns:
|
||||
columns, columns_values = sql.SQL('*'), ()
|
||||
else:
|
||||
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
|
||||
|
||||
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
||||
columns=columns,
|
||||
table=self.tableid
|
||||
)
|
||||
|
||||
sections = [
|
||||
RawExpr(base, columns_values),
|
||||
self._join_section,
|
||||
self._where_section,
|
||||
self._group_section,
|
||||
self._extra_section,
|
||||
self._order_section,
|
||||
self._limit_section,
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
||||
"""
|
||||
Query type representing a table delete query.
|
||||
"""
|
||||
# TODO: Cascade option for delete, maybe other options
|
||||
# TODO: Require a where unless specifically disabled, for safety
|
||||
|
||||
def build(self):
|
||||
base = sql.SQL("DELETE FROM {table}").format(
|
||||
table=self.tableid,
|
||||
)
|
||||
sections = [
|
||||
RawExpr(base),
|
||||
self._where_section,
|
||||
self._extra_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
|
||||
__slots__ = (
|
||||
'_set',
|
||||
)
|
||||
# TODO: Again, require a where unless specifically disabled
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._set: List[Expression] = []
|
||||
|
||||
def set(self, **column_values: Union[Any, Expression]):
|
||||
exprs: List[Expression] = []
|
||||
for name, value in column_values.items():
|
||||
if isinstance(value, Expression):
|
||||
value_tup = value.as_tuple()
|
||||
else:
|
||||
value_tup = (sql.Placeholder(), (value,))
|
||||
|
||||
exprs.append(
|
||||
RawExpr.join_tuples(
|
||||
(sql.Identifier(name), ()),
|
||||
value_tup,
|
||||
joiner=sql.SQL(' = ')
|
||||
)
|
||||
)
|
||||
self._set.extend(exprs)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
if not self._set:
|
||||
raise ValueError("No columns provided to update.")
|
||||
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
||||
|
||||
base = sql.SQL("UPDATE {table} SET {set}").format(
|
||||
table=self.tableid,
|
||||
set=set_expr
|
||||
)
|
||||
sections = [
|
||||
RawExpr(base, set_values),
|
||||
self._from_section,
|
||||
self._where_section,
|
||||
self._extra_section,
|
||||
self._limit_section,
|
||||
RawExpr(sql.SQL('RETURNING *'))
|
||||
]
|
||||
|
||||
sections = (section for section in sections if section is not None)
|
||||
return RawExpr.join(*sections)
|
||||
|
||||
|
||||
# async def upsert(cursor, table, constraint, **values):
|
||||
# """
|
||||
# Insert or on conflict update.
|
||||
# """
|
||||
# valuedict = values
|
||||
# keys, values = zip(*values.items())
|
||||
#
|
||||
# key_str = _format_insertkeys(keys)
|
||||
# value_str, values = _format_insertvalues(values)
|
||||
# update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||
#
|
||||
# if not isinstance(constraint, str):
|
||||
# constraint = ", ".join(constraint)
|
||||
#
|
||||
# await cursor.execute(
|
||||
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||
# table, key_str, value_str, constraint, update_key_str
|
||||
# ),
|
||||
# tuple((*values, *update_key_values))
|
||||
# )
|
||||
# return await cursor.fetchone()
|
||||
|
||||
|
||||
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
|
||||
# cursor = cursor or conn.cursor()
|
||||
#
|
||||
# # TODO: executemany or copy syntax now
|
||||
# return execute_values(
|
||||
# cursor,
|
||||
# """
|
||||
# UPDATE {table}
|
||||
# SET {set_clause}
|
||||
# FROM (VALUES {cast_row}%s)
|
||||
# AS {temp_table}
|
||||
# WHERE {where_clause}
|
||||
# RETURNING *
|
||||
# """.format(
|
||||
# table=table,
|
||||
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
||||
# cast_row=cast_row + ',' if cast_row else '',
|
||||
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
||||
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
||||
# ),
|
||||
# values,
|
||||
# fetch=True
|
||||
# )
|
||||
@@ -1,102 +0,0 @@
|
||||
from typing import Protocol, runtime_checkable, Optional
|
||||
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .connector import Connector, Connectable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _Attachable(Connectable, Protocol):
|
||||
def attach_to(self, registry: 'Registry'):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Registry:
|
||||
_attached: list[_Attachable] = []
|
||||
_name: Optional[str] = None
|
||||
|
||||
def __init_subclass__(cls, name=None):
|
||||
attached = []
|
||||
for _, member in cls.__dict__.items():
|
||||
if isinstance(member, _Attachable):
|
||||
attached.append(member)
|
||||
cls._attached = attached
|
||||
cls._name = name or cls.__name__
|
||||
|
||||
def __init__(self, name=None):
|
||||
self._conn: Optional[Connector] = None
|
||||
self.name: str = name if name is not None else self._name
|
||||
if self.name is None:
|
||||
raise ValueError("A Registry must have a name!")
|
||||
|
||||
self.init_tasks = []
|
||||
|
||||
for member in self._attached:
|
||||
member.attach_to(self)
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self._conn = connector
|
||||
for child in self._attached:
|
||||
child.bind(connector)
|
||||
|
||||
def attach(self, attachable):
|
||||
self._attached.append(attachable)
|
||||
if self._conn is not None:
|
||||
attachable.bind(self._conn)
|
||||
return attachable
|
||||
|
||||
def init_task(self, coro):
|
||||
"""
|
||||
Initialisation tasks are run to setup the registry state.
|
||||
These tasks will be run in the event loop, after connection to the database.
|
||||
These tasks should be idempotent, as they may be run on reload and reconnect.
|
||||
"""
|
||||
self.init_tasks.append(coro)
|
||||
return coro
|
||||
|
||||
async def init(self):
|
||||
for task in self.init_tasks:
|
||||
await task(self)
|
||||
return self
|
||||
|
||||
|
||||
class AttachableClass:
|
||||
"""ABC for a default implementation of an Attachable class."""
|
||||
|
||||
_connector: Optional[Connector] = None
|
||||
_registry: Optional[Registry] = None
|
||||
|
||||
@classmethod
|
||||
def bind(cls, connector: Connector):
|
||||
cls._connector = connector
|
||||
connector.connect_hook(cls.on_connect)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def attach_to(cls, registry: Registry):
|
||||
cls._registry = registry
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
async def on_connect(cls, connection: AsyncConnection):
|
||||
pass
|
||||
|
||||
|
||||
class Attachable:
|
||||
"""ABC for a default implementation of an Attachable object."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._connector: Optional[Connector] = None
|
||||
self._registry: Optional[Registry] = None
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self._connector = connector
|
||||
connector.connect_hook(self.on_connect)
|
||||
return self
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
return self
|
||||
|
||||
async def on_connect(self, connection: AsyncConnection):
|
||||
pass
|
||||
@@ -1,95 +0,0 @@
|
||||
from typing import Optional
|
||||
from psycopg.rows import DictRow
|
||||
from psycopg import sql
|
||||
|
||||
from . import queries as q
|
||||
from .connector import Connector
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
class Table:
|
||||
"""
|
||||
Transparent interface to a single table structure in the database.
|
||||
Contains standard methods to access the table.
|
||||
"""
|
||||
|
||||
def __init__(self, name, *args, schema='public', **kwargs):
|
||||
self.name: str = name
|
||||
self.schema: str = schema
|
||||
self.connector: Connector = None
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self.schema == 'public':
|
||||
return sql.Identifier(self.name)
|
||||
else:
|
||||
return sql.Identifier(self.schema, self.name)
|
||||
|
||||
def bind(self, connector: Connector):
|
||||
self.connector = connector
|
||||
return self
|
||||
|
||||
def attach_to(self, registry: Registry):
|
||||
self._registry = registry
|
||||
return self
|
||||
|
||||
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
|
||||
if data:
|
||||
return data[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||
return data
|
||||
|
||||
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
||||
return q.Select(
|
||||
self.identifier,
|
||||
row_adapter=self._single_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
||||
return q.Update(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
||||
return q.Delete(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).where(*args, **kwargs)
|
||||
|
||||
def insert(self, **column_values) -> q.Insert[DictRow]:
|
||||
return q.Insert(
|
||||
self.identifier,
|
||||
row_adapter=self._single_query_adapter,
|
||||
connector=self.connector
|
||||
).insert(column_values.keys(), column_values.values())
|
||||
|
||||
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
||||
return q.Insert(
|
||||
self.identifier,
|
||||
row_adapter=self._many_query_adapter,
|
||||
connector=self.connector
|
||||
).insert(*args, **kwargs)
|
||||
|
||||
# def update_many(self, *args, **kwargs):
|
||||
# with self.conn:
|
||||
# return update_many(self.identifier, *args, **kwargs)
|
||||
|
||||
# def upsert(self, *args, **kwargs):
|
||||
# return upsert(self.identifier, *args, **kwargs)
|
||||
@@ -1,133 +0,0 @@
|
||||
from data import Registry, RowModel, Table
|
||||
from data.columns import String, Timestamp, Integer, Bool
|
||||
|
||||
|
||||
class UserAuth(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE user_auth(
|
||||
userid TEXT PRIMARY KEY,
|
||||
token TEXT NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_auth'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
token = String()
|
||||
refresh_token = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class BotChannel(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE bot_channels(
|
||||
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
autojoin BOOLEAN DEFAULT true,
|
||||
listen_redeems BOOLEAN,
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'bot_channels'
|
||||
_cache_ = {}
|
||||
|
||||
userid = String(primary=True)
|
||||
autojoin = Bool()
|
||||
listen_redeems = Bool()
|
||||
joined_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class UserProfile(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE user_profiles(
|
||||
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
twitchid TEXT NOT NULL,
|
||||
name TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'user_profiles'
|
||||
_cache_ = {}
|
||||
|
||||
profileid = Integer(primary=True)
|
||||
twitchid = String()
|
||||
name = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class Communities(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE communities(
|
||||
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
twitchid TEXT NOT NULL,
|
||||
name TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'communities'
|
||||
_cache_ = {}
|
||||
|
||||
communityid = Integer(primary=True)
|
||||
twitchid = Integer()
|
||||
name = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
class Koan(RowModel):
|
||||
"""
|
||||
Schema
|
||||
======
|
||||
CREATE TABLE koans(
|
||||
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
_tablename_ = 'koans'
|
||||
_cache_ = {}
|
||||
|
||||
koanid = Integer(primary=True)
|
||||
communityid = Integer()
|
||||
name = String()
|
||||
message = String()
|
||||
created_at = Timestamp()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
|
||||
|
||||
class BotData(Registry):
|
||||
user_auth = UserAuth.table
|
||||
|
||||
"""
|
||||
CREATE TABLE user_auth_scopes(
|
||||
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||
scope TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
user_auth_scopes = Table('user_auth_scopes')
|
||||
|
||||
bot_channels = BotChannel.table
|
||||
user_profiles = UserProfile.table
|
||||
communities = Communities.table
|
||||
|
||||
koans = Koan.table
|
||||
@@ -1,4 +1,5 @@
|
||||
from .args import args
|
||||
from .crocbot import CrocBot
|
||||
from .bot import Bot
|
||||
from .context import Context
|
||||
from .config import Conf, conf
|
||||
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
||||
|
||||
from twitchio.authentication import UserTokenPayload
|
||||
from twitchio.ext import commands
|
||||
from twitchio import Scopes, eventsub
|
||||
|
||||
from data import Database
|
||||
from datamodels import BotData, UserAuth, BotChannel
|
||||
from data import Database, ORDER
|
||||
from botdata import BotData, UserAuth, BotChannel, VersionHistory
|
||||
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
||||
|
||||
from .config import Conf
|
||||
from .context import Context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from modules.profiles.profiles.twitch.component import ProfilesComponent
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrocBot(commands.Bot):
|
||||
class Bot(commands.Bot):
|
||||
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
|
||||
kwargs.setdefault('client_id', config.bot['client_id'])
|
||||
kwargs.setdefault('client_secret', config.bot['client_secret'])
|
||||
@@ -23,6 +28,9 @@ class CrocBot(commands.Bot):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Whether we should do eventsub via webhooks or websockets
|
||||
self.using_webhooks = kwargs.get('using_webhooks', False)
|
||||
|
||||
self.config = config
|
||||
self.dbconn = dbconn
|
||||
self.data: BotData = dbconn.load_registry(BotData())
|
||||
@@ -30,12 +38,56 @@ class CrocBot(commands.Bot):
|
||||
|
||||
self.joined: dict[str, BotChannel] = {}
|
||||
|
||||
# Make the type checker happy about fetching components by name
|
||||
# TODO: Move to stubs
|
||||
|
||||
@property
|
||||
def profiles(self):
|
||||
return self.get_component('ProfilesComponent')
|
||||
|
||||
@overload
|
||||
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||
...
|
||||
|
||||
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||
return super().get_component(name)
|
||||
|
||||
def get_context(self, payload, *, cls: Any = None) -> Context:
|
||||
cls = cls or Context
|
||||
return cls(payload, bot=self)
|
||||
|
||||
async def event_ready(self):
|
||||
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||
logger.info("Logged in as %s", self.bot_id)
|
||||
|
||||
async def version_check(self, component: str, req_version: int):
|
||||
# Query the database to confirm that the given component is listed with the given version.
|
||||
# Typically done upon loading a component
|
||||
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
||||
|
||||
version = rows[0].to_version if rows else 0
|
||||
|
||||
if version != req_version:
|
||||
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
||||
else:
|
||||
logger.debug(
|
||||
"Component %s passed version check with version %s",
|
||||
component,
|
||||
version
|
||||
)
|
||||
return True
|
||||
|
||||
async def setup_hook(self):
|
||||
await self.data.init()
|
||||
for component, req in SCHEMA_VERSIONS.items():
|
||||
await self.version_check(component, req)
|
||||
|
||||
if self._setup_hook is not None:
|
||||
await self._setup_hook(self)
|
||||
|
||||
# Get all current bot channels
|
||||
channels = await BotChannel.fetch_where(autojoin=True)
|
||||
@@ -44,29 +96,15 @@ class CrocBot(commands.Bot):
|
||||
await self.join_channels(*channels)
|
||||
|
||||
# Build bot account's own url
|
||||
scopes = Scopes((
|
||||
Scopes.user_read_chat,
|
||||
Scopes.user_write_chat,
|
||||
Scopes.user_bot,
|
||||
Scopes.channel_read_redemptions,
|
||||
Scopes.channel_manage_redemptions,
|
||||
Scopes.channel_bot,
|
||||
))
|
||||
scopes = BOTUSER_SCOPES
|
||||
url = self.get_auth_url(scopes)
|
||||
logger.info("Bot account authorisation url: %s", url)
|
||||
|
||||
# Build everyone else's url
|
||||
scopes = Scopes((
|
||||
Scopes.channel_bot,
|
||||
Scopes.channel_read_redemptions,
|
||||
Scopes.channel_manage_redemptions,
|
||||
))
|
||||
scopes = CHANNEL_SCOPES
|
||||
url = self.get_auth_url(scopes)
|
||||
logger.info("User account authorisation url: %s", url)
|
||||
|
||||
if self._setup_hook is not None:
|
||||
await self._setup_hook(self)
|
||||
|
||||
logger.info("Finished setup")
|
||||
|
||||
def get_auth_url(self, scopes: Optional[Scopes] = None):
|
||||
@@ -81,7 +119,6 @@ class CrocBot(commands.Bot):
|
||||
Register webhook subscriptions to the given channel(s).
|
||||
"""
|
||||
# TODO: If channels are already joined, unsubscribe
|
||||
# TODO: Determine (or switch?) whether to use webhook or websocket
|
||||
for channel in channels:
|
||||
sub = None
|
||||
try:
|
||||
@@ -89,19 +126,16 @@ class CrocBot(commands.Bot):
|
||||
broadcaster_user_id=channel.userid,
|
||||
user_id=self.bot_id,
|
||||
)
|
||||
resp = await self.subscribe_websocket(sub)
|
||||
if self.using_webhooks:
|
||||
resp = await self.subscribe_webhook(sub)
|
||||
else:
|
||||
resp = await self.subscribe_websocket(sub)
|
||||
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||
if channel.listen_redeems:
|
||||
sub = eventsub.ChannelPointsRedeemAddSubscription(
|
||||
broadcaster_user_id=channel.userid
|
||||
)
|
||||
resp = await self.subscribe_websocket(sub, as_bot=False, token_for=channel.userid)
|
||||
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||
self.joined[channel.userid] = channel
|
||||
self.safe_dispatch('channel_joined', payload=channel)
|
||||
except Exception:
|
||||
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
|
||||
|
||||
|
||||
async def event_oauth_authorized(self, payload: UserTokenPayload):
|
||||
logger.debug("Oauth flow authorization with payload %s", repr(payload))
|
||||
# Save the token and scopes and update internal authorisations
|
||||
@@ -141,21 +175,22 @@ class CrocBot(commands.Bot):
|
||||
|
||||
# Save the token and scopes to data
|
||||
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
||||
async with self.dbconn.connection() as conn:
|
||||
self.dbconn.conn = conn
|
||||
async with conn.transaction():
|
||||
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||
if row.token != token or row.refresh_token != refresh:
|
||||
await row.update(token=token, refresh_token=refresh)
|
||||
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||
await self.data.user_auth_scopes.insert_many(
|
||||
('userid', 'scope'),
|
||||
*((userid, scope) for scope in new_scopes)
|
||||
)
|
||||
# TODO
|
||||
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||
if row.token != token or row.refresh_token != refresh:
|
||||
await row.update(token=token, refresh_token=refresh)
|
||||
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||
await self.data.user_auth_scopes.insert_many(
|
||||
('userid', 'scope'),
|
||||
*((userid, scope) for scope in new_scopes)
|
||||
)
|
||||
|
||||
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
||||
return resp
|
||||
|
||||
async def load_tokens(self, path: str | None = None):
|
||||
for row in await UserAuth.fetch_where():
|
||||
await self.add_token(row.token, row.refresh_token)
|
||||
try:
|
||||
await self.add_token(row.token, row.refresh_token)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add token for {row}")
|
||||
@@ -102,4 +102,4 @@ class Conf:
|
||||
self.config.write(conffile)
|
||||
|
||||
|
||||
conf = Conf(args.config, 'CROCBOT')
|
||||
conf = Conf(args.config, 'BOT')
|
||||
|
||||
4
src/meta/context.py
Normal file
4
src/meta/context.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from twitchio.ext import commands as cmds
|
||||
|
||||
class Context(cmds.Context):
|
||||
...
|
||||
@@ -14,6 +14,7 @@ import aiohttp
|
||||
|
||||
from .config import conf
|
||||
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||
from utils.lib import utc_now
|
||||
|
||||
|
||||
log_logger = logging.getLogger(__name__)
|
||||
@@ -24,8 +25,7 @@ log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT
|
||||
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
|
||||
|
||||
# TODO merge into TwithIO context
|
||||
context: ContextVar[str] = ContextVar('context', default=None)
|
||||
context: ContextVar[Optional[str]] = ContextVar('context', default=None)
|
||||
|
||||
def set_logging_context(
|
||||
context: Optional[str] = None,
|
||||
@@ -299,6 +299,7 @@ class WebHookHandler(logging.StreamHandler):
|
||||
asyncio.create_task(self.post(record))
|
||||
|
||||
def setup(self):
|
||||
from discord import Webhook
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||
|
||||
@@ -365,6 +366,8 @@ class WebHookHandler(logging.StreamHandler):
|
||||
await self._send(batched)
|
||||
|
||||
async def _send(self, message, as_file=False):
|
||||
import discord
|
||||
from discord import File
|
||||
try:
|
||||
self.bucket.request()
|
||||
except BucketOverFull:
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
async def setup(bot):
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from meta import Bot
|
||||
|
||||
|
||||
async def twitch_setup(bot: 'Bot'):
|
||||
from . import profiles
|
||||
await profiles.twitch_setup(bot)
|
||||
from . import koans
|
||||
await koans.setup(bot)
|
||||
await koans.twitch_setup(bot)
|
||||
|
||||
1
src/modules/koans
Submodule
1
src/modules/koans
Submodule
Submodule src/modules/koans added at 076987a4d5
@@ -1,7 +0,0 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def setup(bot):
|
||||
from .component import KoanComponent
|
||||
await bot.add_component(KoanComponent(bot))
|
||||
@@ -1,123 +0,0 @@
|
||||
from typing import Optional
|
||||
import random
|
||||
import twitchio
|
||||
from twitchio.ext import commands as cmds
|
||||
|
||||
from datamodels import Koan, Communities
|
||||
|
||||
from . import logger
|
||||
|
||||
|
||||
class KoanComponent(cmds.Component):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
|
||||
@cmds.Component.listener()
|
||||
async def event_message(self, payload: twitchio.ChatMessage) -> None:
|
||||
print(f"[{payload.broadcaster.name}] - {payload.chatter.name}: {payload.text}")
|
||||
|
||||
@cmds.group(invoke_fallback=True)
|
||||
async def koans(self, ctx: cmds.Context) -> None:
|
||||
"""
|
||||
List the (names of the) koans in this channel.
|
||||
|
||||
!koans
|
||||
"""
|
||||
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
|
||||
cid = community.communityid
|
||||
|
||||
koans = await Koan.fetch_where(communityid=cid)
|
||||
if koans:
|
||||
names = ', '.join(koan.name for koan in koans)
|
||||
await ctx.reply(
|
||||
f"Koans: {names}"
|
||||
)
|
||||
else:
|
||||
await ctx.reply("No koans have been made in this channel!")
|
||||
|
||||
@koans.command(name='add', aliases=['new', 'create'])
|
||||
async def koans_add(self, ctx: cmds.Context, name: str, *, text: str):
|
||||
"""
|
||||
Add or overwrite a koan to this channel.
|
||||
|
||||
!koans add wind This is a wind koan
|
||||
"""
|
||||
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
|
||||
cid = community.communityid
|
||||
|
||||
name = name.lower()
|
||||
|
||||
assert isinstance(ctx.author, twitchio.Chatter)
|
||||
if (ctx.author.moderator or ctx.author.broadcaster):
|
||||
# Delete the koan with this name if it exists
|
||||
existing = await Koan.table.delete_where(
|
||||
communityid=cid,
|
||||
name=name,
|
||||
)
|
||||
|
||||
# Insert the new koan
|
||||
await Koan.create(
|
||||
communityid=cid,
|
||||
name=name,
|
||||
message=text
|
||||
)
|
||||
|
||||
# Ack
|
||||
if existing:
|
||||
await ctx.reply(f"Updated the koan '{name}'")
|
||||
else:
|
||||
await ctx.reply(f"Created the new koan '{name}'")
|
||||
|
||||
@koans.command(name='del', aliases=['delete', 'rm', 'remove'])
|
||||
async def koans_del(self, ctx: cmds.Context, name: str):
|
||||
"""
|
||||
Remove a koan from this channel by name.
|
||||
|
||||
!koans del wind
|
||||
"""
|
||||
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
|
||||
cid = community.communityid
|
||||
|
||||
name = name.lower()
|
||||
|
||||
assert isinstance(ctx.author, twitchio.Chatter)
|
||||
if (ctx.author.moderator or ctx.author.broadcaster):
|
||||
# Delete the koan with this name if it exists
|
||||
existing = await Koan.table.delete_where(
|
||||
communityid=cid,
|
||||
name=name,
|
||||
)
|
||||
if existing:
|
||||
await ctx.reply(f"Deleted the koan '{name}'")
|
||||
else:
|
||||
await ctx.reply(f"The koan '{name}' does not exist to delete!")
|
||||
|
||||
@cmds.command(name='koan')
|
||||
async def koan(self, ctx: cmds.Context, name: Optional[str] = None):
|
||||
"""
|
||||
Show a koan from this channel. Optionally by name.
|
||||
|
||||
!koan
|
||||
!koan wind
|
||||
"""
|
||||
community = await Communities.fetch_or_create(twitchid=ctx.channel.id, name=ctx.channel.name)
|
||||
cid = community.communityid
|
||||
|
||||
if name is not None:
|
||||
name = name.lower()
|
||||
koans = await Koan.fetch_where(
|
||||
communityid=cid,
|
||||
name=name
|
||||
)
|
||||
if koans:
|
||||
koan = koans[0]
|
||||
await ctx.reply(koan.message)
|
||||
else:
|
||||
await ctx.reply(f"The requested koan '{name}' does not exist! Use '{ctx.prefix}koans' to see all the koans.")
|
||||
else:
|
||||
koans = await Koan.fetch_where(communityid=cid)
|
||||
if koans:
|
||||
koan = random.choice(koans)
|
||||
await ctx.reply(koan.message)
|
||||
else:
|
||||
await ctx.reply("This channel doesn't have any koans!")
|
||||
1
src/modules/profiles
Submodule
1
src/modules/profiles
Submodule
Submodule src/modules/profiles added at 0363dc2bcd
Reference in New Issue
Block a user