Initial framework
This commit is contained in:
151
.gitignore
vendored
Normal file
151
.gitignore
vendored
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
src/modules/test/*
|
||||||
|
|
||||||
|
pending-rewrite/
|
||||||
|
logs/*
|
||||||
|
notes/*
|
||||||
|
tmp/*
|
||||||
|
output/*
|
||||||
|
locales/domains
|
||||||
|
|
||||||
|
.idea/*
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
config/**
|
||||||
16
config/example-bot.conf
Normal file
16
config/example-bot.conf
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
[CROCBOT]
|
||||||
|
prefix = ?
|
||||||
|
owner_id =
|
||||||
|
bot_id =
|
||||||
|
|
||||||
|
ALSO_READ = config/secrets.conf
|
||||||
|
|
||||||
|
wshost = localhost
|
||||||
|
wsport = 4343
|
||||||
|
wsdomain = localhost:4343
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
general_log =
|
||||||
|
warning_log =
|
||||||
|
error_log =
|
||||||
|
critical_log =
|
||||||
7
config/example-secrets.conf
Normal file
7
config/example-secrets.conf
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[CROCBOT]
|
||||||
|
client_id =
|
||||||
|
client_secret =
|
||||||
|
|
||||||
|
[DATA]
|
||||||
|
args =
|
||||||
|
appid =
|
||||||
0
data/.gitignore
vendored
Normal file
0
data/.gitignore
vendored
Normal file
116
data/schema.sql
Normal file
116
data/schema.sql
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
-- Metadata {{{
|
||||||
|
CREATE TABLE VersionHistory(
|
||||||
|
version INTEGER NOT NULL,
|
||||||
|
time TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||||
|
author TEXT
|
||||||
|
);
|
||||||
|
INSERT INTO VersionHistory (version, author) VALUES (1, 'Initial Creation');
|
||||||
|
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION update_timestamp_column()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW._timestamp = now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ language 'plpgsql';
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- App metadata {{{
|
||||||
|
|
||||||
|
CREATE TABLE app_config(
|
||||||
|
appname TEXT PRIMARY KEY,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- Twitch Auth {{{
|
||||||
|
|
||||||
|
-- Authorisation tokens allowing us to take actions on behalf of certain users or channels.
|
||||||
|
-- For example, channels we have joined will need to be authorised with a 'channel:bot' scope.
|
||||||
|
CREATE TABLE user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER user_auth_timestamp BEFORE UPDATE ON user_auth
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
CREATE TABLE user_auth_scopes(
|
||||||
|
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
scope TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Which joins will be joined at startup,
|
||||||
|
-- and any configurable choices needed when joining the channel
|
||||||
|
CREATE TABLE bot_channels(
|
||||||
|
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
autojoin BOOLEAN DEFAULT true,
|
||||||
|
listen_redeems BOOLEAN,
|
||||||
|
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER bot_channels_timestamp BEFORE UPDATE ON bot_channels
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- Twitch user data {{{
|
||||||
|
|
||||||
|
---- Users are internally represented by 'profiles' with a unique profileid
|
||||||
|
---- This is associated to the user's twitch userid.
|
||||||
|
---- Any user-specific configuration data or preferences can be added here
|
||||||
|
CREATE TABLE user_profiles(
|
||||||
|
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
twitchid TEXT NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER user_profiles_timestamp BEFORE UPDATE ON user_profiles
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX user_profile_twitchid ON user_profiles (twitchid);
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- Twitch channel data {{{
|
||||||
|
|
||||||
|
---- Similar to user profiles, we associate twitch channels with 'communities'
|
||||||
|
---- This slight abstraction gives us more flexibility and control over the community and user data.
|
||||||
|
CREATE TABLE communities(
|
||||||
|
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
twitchid TEXT NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER communities_timestamp BEFORE UPDATE ON communities
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
|
||||||
|
-- Koan data {{{
|
||||||
|
|
||||||
|
---- !koans lists koans. !koan gives a random koan. !koans add name ... !koans del name ...
|
||||||
|
|
||||||
|
CREATE TABLE koans(
|
||||||
|
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE TRIGGER koans_timestamp BEFORE UPDATE ON koans
|
||||||
|
FOR EACH ROW EXECUTE FUNCTION update_timestamp_column();
|
||||||
|
|
||||||
|
-- }}}
|
||||||
|
|
||||||
|
-- vim: set fdm=marker:
|
||||||
4
requirements.txt
Normal file
4
requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
websockets
|
||||||
|
twitchio
|
||||||
|
psycopg[pool]
|
||||||
|
cachetools
|
||||||
12
scripts/start_bot.py
Normal file
12
scripts/start_bot.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# !/bin/python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd()))
|
||||||
|
sys.path.insert(0, os.path.join(os.getcwd(), "src"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from bot import _main
|
||||||
|
_main()
|
||||||
46
src/bot.py
Normal file
46
src/bot.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twitchio.web import AiohttpAdapter
|
||||||
|
|
||||||
|
from meta import CrocBot, conf, setup_main_logger, args
|
||||||
|
from data import Database
|
||||||
|
from constants import DATA_VERSION
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
db = Database(conf.data['args'])
|
||||||
|
|
||||||
|
async with db.open():
|
||||||
|
version = await db.version()
|
||||||
|
if version.version != DATA_VERSION:
|
||||||
|
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||||
|
logger.critical(error)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
|
||||||
|
adapter = AiohttpAdapter(
|
||||||
|
host=conf.bot.get('wshost', None),
|
||||||
|
port=conf.bot.getint('wsport', None),
|
||||||
|
domain=conf.bot.get('wsdomain', None),
|
||||||
|
)
|
||||||
|
|
||||||
|
bot = CrocBot(
|
||||||
|
config=conf,
|
||||||
|
dbconn=db,
|
||||||
|
adapter=adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# await bot.load_module('modules')
|
||||||
|
|
||||||
|
try:
|
||||||
|
await bot.start()
|
||||||
|
finally:
|
||||||
|
await bot.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _main():
|
||||||
|
setup_main_logger()
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
2
src/constants.py
Normal file
2
src/constants.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
CONFIG_FILE = 'config/bot.conf'
|
||||||
|
DATA_VERSION = 1
|
||||||
9
src/data/__init__.py
Normal file
9
src/data/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from .conditions import Condition, condition, NULL
|
||||||
|
from .database import Database
|
||||||
|
from .models import RowModel, RowTable, WeakCache
|
||||||
|
from .table import Table
|
||||||
|
from .base import Expression, RawExpr
|
||||||
|
from .columns import ColumnExpr, Column, Integer, String
|
||||||
|
from .registry import Registry, AttachableClass, Attachable
|
||||||
|
from .adapted import RegisterEnum
|
||||||
|
from .queries import ORDER, NULLS, JOINTYPE
|
||||||
40
src/data/adapted.py
Normal file
40
src/data/adapted.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
from psycopg.types.enum import register_enum, EnumInfo
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from .registry import Attachable, Registry
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterEnum(Attachable):
|
||||||
|
def __init__(self, enum, name: Optional[str] = None, mapper=None):
|
||||||
|
super().__init__()
|
||||||
|
self.enum = enum
|
||||||
|
self.name = name or enum.__name__
|
||||||
|
self.mapping = mapper(enum) if mapper is not None else self._mapper()
|
||||||
|
|
||||||
|
def _mapper(self):
|
||||||
|
return {m: m.value[0] for m in self.enum}
|
||||||
|
|
||||||
|
def attach_to(self, registry: Registry):
|
||||||
|
self._registry = registry
|
||||||
|
registry.init_task(self.on_init)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def on_init(self, registry: Registry):
|
||||||
|
connector = registry._conn
|
||||||
|
if connector is None:
|
||||||
|
raise ValueError("Cannot initialise without connector!")
|
||||||
|
connector.connect_hook(self.connection_hook)
|
||||||
|
# await connector.refresh_pool()
|
||||||
|
# The below may be somewhat dangerous
|
||||||
|
# But adaption should never write to the database
|
||||||
|
await connector.map_over_pool(self.connection_hook)
|
||||||
|
# if conn := connector.conn:
|
||||||
|
# # Ensure the adaption is run in the current context as well
|
||||||
|
# await self.connection_hook(conn)
|
||||||
|
|
||||||
|
async def connection_hook(self, conn: AsyncConnection):
|
||||||
|
info = await EnumInfo.fetch(conn, self.name)
|
||||||
|
if info is None:
|
||||||
|
raise ValueError(f"Enum {self.name} not found in database.")
|
||||||
|
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
||||||
45
src/data/base.py
Normal file
45
src/data/base.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
from itertools import chain
|
||||||
|
from psycopg import sql
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Expression(Protocol):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_tuple(self) -> tuple[sql.Composable, tuple[Any, ...]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RawExpr(Expression):
|
||||||
|
__slots__ = ('expr', 'values')
|
||||||
|
|
||||||
|
expr: sql.Composable
|
||||||
|
values: tuple[Any, ...]
|
||||||
|
|
||||||
|
def __init__(self, expr: sql.Composable, values: tuple[Any, ...] = ()):
|
||||||
|
self.expr = expr
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_tuple(self):
|
||||||
|
return (self.expr, self.values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join(cls, *expressions: Expression, joiner: sql.SQL = sql.SQL(' ')):
|
||||||
|
"""
|
||||||
|
Join a sequence of Expressions into a single RawExpr.
|
||||||
|
"""
|
||||||
|
tups = (
|
||||||
|
expression.as_tuple()
|
||||||
|
for expression in expressions
|
||||||
|
)
|
||||||
|
return cls.join_tuples(*tups, joiner=joiner)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def join_tuples(cls, *tuples: tuple[sql.Composable, tuple[Any, ...]], joiner: sql.SQL = sql.SQL(' ')):
|
||||||
|
exprs, values = zip(*tuples)
|
||||||
|
expr = joiner.join(exprs)
|
||||||
|
value = tuple(chain(*values))
|
||||||
|
return cls(expr, value)
|
||||||
155
src/data/columns.py
Normal file
155
src/data/columns.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
from typing import Any, Union, TypeVar, Generic, Type, overload, Optional, TYPE_CHECKING
|
||||||
|
from psycopg import sql
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from .base import RawExpr, Expression
|
||||||
|
from .conditions import Condition, Joiner
|
||||||
|
from .table import Table
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnExpr(RawExpr):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __lt__(self, obj) -> Condition:
|
||||||
|
expr, values = self.as_tuple()
|
||||||
|
|
||||||
|
if isinstance(obj, Expression):
|
||||||
|
# column < Expression
|
||||||
|
obj_expr, obj_values = obj.as_tuple()
|
||||||
|
cond_exprs = (expr, Joiner.LT, obj_expr)
|
||||||
|
cond_values = (*values, *obj_values)
|
||||||
|
else:
|
||||||
|
# column < Literal
|
||||||
|
cond_exprs = (expr, Joiner.LT, sql.Placeholder())
|
||||||
|
cond_values = (*values, obj)
|
||||||
|
|
||||||
|
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||||
|
|
||||||
|
def __le__(self, obj) -> Condition:
|
||||||
|
expr, values = self.as_tuple()
|
||||||
|
|
||||||
|
if isinstance(obj, Expression):
|
||||||
|
# column <= Expression
|
||||||
|
obj_expr, obj_values = obj.as_tuple()
|
||||||
|
cond_exprs = (expr, Joiner.LE, obj_expr)
|
||||||
|
cond_values = (*values, *obj_values)
|
||||||
|
else:
|
||||||
|
# column <= Literal
|
||||||
|
cond_exprs = (expr, Joiner.LE, sql.Placeholder())
|
||||||
|
cond_values = (*values, obj)
|
||||||
|
|
||||||
|
return Condition(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||||
|
|
||||||
|
def __eq__(self, obj) -> Condition: # type: ignore[override]
|
||||||
|
return Condition._expression_equality(self, obj)
|
||||||
|
|
||||||
|
def __ne__(self, obj) -> Condition: # type: ignore[override]
|
||||||
|
return ~(self.__eq__(obj))
|
||||||
|
|
||||||
|
def __gt__(self, obj) -> Condition:
|
||||||
|
return ~(self.__le__(obj))
|
||||||
|
|
||||||
|
def __ge__(self, obj) -> Condition:
|
||||||
|
return ~(self.__lt__(obj))
|
||||||
|
|
||||||
|
def __add__(self, obj: Union[Any, Expression]) -> 'ColumnExpr':
|
||||||
|
if isinstance(obj, Expression):
|
||||||
|
obj_expr, obj_values = obj.as_tuple()
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} + {})").format(self.expr, obj_expr),
|
||||||
|
(*self.values, *obj_values)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} + {})").format(self.expr, sql.Placeholder()),
|
||||||
|
(*self.values, obj)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __sub__(self, obj) -> 'ColumnExpr':
|
||||||
|
if isinstance(obj, Expression):
|
||||||
|
obj_expr, obj_values = obj.as_tuple()
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} - {})").format(self.expr, obj_expr),
|
||||||
|
(*self.values, *obj_values)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} - {})").format(self.expr, sql.Placeholder()),
|
||||||
|
(*self.values, obj)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __mul__(self, obj) -> 'ColumnExpr':
|
||||||
|
if isinstance(obj, Expression):
|
||||||
|
obj_expr, obj_values = obj.as_tuple()
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} * {})").format(self.expr, obj_expr),
|
||||||
|
(*self.values, *obj_values)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({} * {})").format(self.expr, sql.Placeholder()),
|
||||||
|
(*self.values, obj)
|
||||||
|
)
|
||||||
|
|
||||||
|
def CAST(self, target_type: sql.Composable):
|
||||||
|
return ColumnExpr(
|
||||||
|
sql.SQL("({}::{})").format(self.expr, target_type),
|
||||||
|
self.values
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import RowModel
|
||||||
|
|
||||||
|
|
||||||
|
class Column(ColumnExpr, Generic[T]):
|
||||||
|
def __init__(self, name: Optional[str] = None,
|
||||||
|
primary: bool = False, references: Optional['Column'] = None,
|
||||||
|
type: Optional[Type[T]] = None):
|
||||||
|
self.primary = primary
|
||||||
|
self.references = references
|
||||||
|
self.name: str = name # type: ignore
|
||||||
|
self.owner: Optional['RowModel'] = None
|
||||||
|
self._type = type
|
||||||
|
|
||||||
|
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
||||||
|
self.values = ()
|
||||||
|
|
||||||
|
def __set_name__(self, owner, name):
|
||||||
|
# Only allow setting the owner once
|
||||||
|
self.name = self.name or name
|
||||||
|
self.owner = owner
|
||||||
|
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self: 'Column[T]', obj: 'RowModel', objtype: Type['RowModel']) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __get__(self: 'Column[T]', obj: "RowModel | None", objtype: "Type[RowModel] | None" = None) -> "T | Column[T]":
|
||||||
|
# Get value from row data or session
|
||||||
|
if obj is None:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
return obj.data[self.name]
|
||||||
|
|
||||||
|
|
||||||
|
class Integer(Column[int]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class String(Column[str]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bool(Column[bool]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Timestamp(Column[datetime]):
|
||||||
|
pass
|
||||||
214
src/data/conditions.py
Normal file
214
src/data/conditions.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# from meta import sharding
|
||||||
|
from typing import Any, Union
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import chain
|
||||||
|
from psycopg import sql
|
||||||
|
|
||||||
|
from .base import Expression, RawExpr
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
A Condition is a "logical" database expression, intended for use in Where statements.
|
||||||
|
Conditions support bitwise logical operators ~, &, |, each producing another Condition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NULL = None
|
||||||
|
|
||||||
|
|
||||||
|
class Joiner(Enum):
|
||||||
|
EQUALS = ('=', '!=')
|
||||||
|
IS = ('IS', 'IS NOT')
|
||||||
|
LIKE = ('LIKE', 'NOT LIKE')
|
||||||
|
BETWEEN = ('BETWEEN', 'NOT BETWEEN')
|
||||||
|
IN = ('IN', 'NOT IN')
|
||||||
|
LT = ('<', '>=')
|
||||||
|
LE = ('<=', '>')
|
||||||
|
NONE = ('', '')
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(Expression):
|
||||||
|
__slots__ = ('expr1', 'joiner', 'negated', 'expr2', 'values')
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
expr1: sql.Composable, joiner: Joiner = Joiner.NONE, expr2: sql.Composable = sql.SQL(''),
|
||||||
|
values: tuple[Any, ...] = (), negated=False
|
||||||
|
):
|
||||||
|
self.expr1 = expr1
|
||||||
|
self.joiner = joiner
|
||||||
|
self.negated = negated
|
||||||
|
self.expr2 = expr2
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_tuple(self):
|
||||||
|
expr = sql.SQL(' ').join((self.expr1, sql.SQL(self.joiner.value[self.negated]), self.expr2))
|
||||||
|
if self.negated and self.joiner is Joiner.NONE:
|
||||||
|
expr = sql.SQL("NOT ({})").format(expr)
|
||||||
|
return (expr, self.values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def construct(cls, *conditions: 'Condition', **kwargs: Union[Any, Expression]):
|
||||||
|
"""
|
||||||
|
Construct a Condition from a sequence of Conditions,
|
||||||
|
together with some explicit column conditions.
|
||||||
|
"""
|
||||||
|
# TODO: Consider adding a _table identifier here so we can identify implicit columns
|
||||||
|
# Or just require subquery type conditions to always come from modelled tables.
|
||||||
|
implicit_conditions = (
|
||||||
|
cls._expression_equality(RawExpr(sql.Identifier(column)), value) for column, value in kwargs.items()
|
||||||
|
)
|
||||||
|
return cls._and(*conditions, *implicit_conditions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _and(cls, *conditions: 'Condition'):
|
||||||
|
if not len(conditions):
|
||||||
|
raise ValueError("Cannot combine 0 Conditions")
|
||||||
|
if len(conditions) == 1:
|
||||||
|
return conditions[0]
|
||||||
|
|
||||||
|
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||||
|
cond_expr = sql.SQL(' AND ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||||
|
cond_values = tuple(chain(*values))
|
||||||
|
|
||||||
|
return Condition(cond_expr, values=cond_values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _or(cls, *conditions: 'Condition'):
|
||||||
|
if not len(conditions):
|
||||||
|
raise ValueError("Cannot combine 0 Conditions")
|
||||||
|
if len(conditions) == 1:
|
||||||
|
return conditions[0]
|
||||||
|
|
||||||
|
exprs, values = zip(*(condition.as_tuple() for condition in conditions))
|
||||||
|
cond_expr = sql.SQL(' OR ').join((sql.SQL('({})').format(expr) for expr in exprs))
|
||||||
|
cond_values = tuple(chain(*values))
|
||||||
|
|
||||||
|
return Condition(cond_expr, values=cond_values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _not(cls, condition: 'Condition'):
|
||||||
|
condition.negated = not condition.negated
|
||||||
|
return condition
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _expression_equality(cls, column: Expression, value: Union[Any, Expression]) -> 'Condition':
|
||||||
|
# TODO: Check if this supports sbqueries
|
||||||
|
col_expr, col_values = column.as_tuple()
|
||||||
|
|
||||||
|
# TODO: Also support sql.SQL? For joins?
|
||||||
|
if isinstance(value, Expression):
|
||||||
|
# column = Expression
|
||||||
|
value_expr, value_values = value.as_tuple()
|
||||||
|
cond_exprs = (col_expr, Joiner.EQUALS, value_expr)
|
||||||
|
cond_values = (*col_values, *value_values)
|
||||||
|
elif isinstance(value, (tuple, list)):
|
||||||
|
# column in (...)
|
||||||
|
# TODO: Support expressions in value tuple?
|
||||||
|
if not value:
|
||||||
|
raise ValueError("Cannot create Condition from empty iterable!")
|
||||||
|
value_expr = sql.SQL('({})').format(sql.SQL(',').join(sql.Placeholder() * len(value)))
|
||||||
|
cond_exprs = (col_expr, Joiner.IN, value_expr)
|
||||||
|
cond_values = (*col_values, *value)
|
||||||
|
elif value is None:
|
||||||
|
# column IS NULL
|
||||||
|
cond_exprs = (col_expr, Joiner.IS, sql.NULL)
|
||||||
|
cond_values = col_values
|
||||||
|
else:
|
||||||
|
# column = Literal
|
||||||
|
cond_exprs = (col_expr, Joiner.EQUALS, sql.Placeholder())
|
||||||
|
cond_values = (*col_values, value)
|
||||||
|
|
||||||
|
return cls(cond_exprs[0], cond_exprs[1], cond_exprs[2], cond_values)
|
||||||
|
|
||||||
|
def __invert__(self) -> 'Condition':
|
||||||
|
self.negated = not self.negated
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __and__(self, condition: 'Condition') -> 'Condition':
|
||||||
|
return self._and(self, condition)
|
||||||
|
|
||||||
|
def __or__(self, condition: 'Condition') -> 'Condition':
|
||||||
|
return self._or(self, condition)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper method to simply condition construction
|
||||||
|
def condition(*args, **kwargs) -> Condition:
|
||||||
|
return Condition.construct(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# class NOT(Condition):
|
||||||
|
# __slots__ = ('value',)
|
||||||
|
#
|
||||||
|
# def __init__(self, value):
|
||||||
|
# self.value = value
|
||||||
|
#
|
||||||
|
# def apply(self, key, values, conditions):
|
||||||
|
# item = self.value
|
||||||
|
# if isinstance(item, (list, tuple)):
|
||||||
|
# if item:
|
||||||
|
# conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||||
|
# values.extend(item)
|
||||||
|
# else:
|
||||||
|
# raise ValueError("Cannot check an empty iterable!")
|
||||||
|
# else:
|
||||||
|
# conditions.append("{}!={}".format(key, _replace_char))
|
||||||
|
# values.append(item)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class GEQ(Condition):
|
||||||
|
# __slots__ = ('value',)
|
||||||
|
#
|
||||||
|
# def __init__(self, value):
|
||||||
|
# self.value = value
|
||||||
|
#
|
||||||
|
# def apply(self, key, values, conditions):
|
||||||
|
# item = self.value
|
||||||
|
# if isinstance(item, (list, tuple)):
|
||||||
|
# raise ValueError("Cannot apply GEQ condition to a list!")
|
||||||
|
# else:
|
||||||
|
# conditions.append("{} >= {}".format(key, _replace_char))
|
||||||
|
# values.append(item)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class LEQ(Condition):
|
||||||
|
# __slots__ = ('value',)
|
||||||
|
#
|
||||||
|
# def __init__(self, value):
|
||||||
|
# self.value = value
|
||||||
|
#
|
||||||
|
# def apply(self, key, values, conditions):
|
||||||
|
# item = self.value
|
||||||
|
# if isinstance(item, (list, tuple)):
|
||||||
|
# raise ValueError("Cannot apply LEQ condition to a list!")
|
||||||
|
# else:
|
||||||
|
# conditions.append("{} <= {}".format(key, _replace_char))
|
||||||
|
# values.append(item)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class Constant(Condition):
|
||||||
|
# __slots__ = ('value',)
|
||||||
|
#
|
||||||
|
# def __init__(self, value):
|
||||||
|
# self.value = value
|
||||||
|
#
|
||||||
|
# def apply(self, key, values, conditions):
|
||||||
|
# conditions.append("{} {}".format(key, self.value))
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class SHARDID(Condition):
|
||||||
|
# __slots__ = ('shardid', 'shard_count')
|
||||||
|
#
|
||||||
|
# def __init__(self, shardid, shard_count):
|
||||||
|
# self.shardid = shardid
|
||||||
|
# self.shard_count = shard_count
|
||||||
|
#
|
||||||
|
# def apply(self, key, values, conditions):
|
||||||
|
# if self.shard_count > 1:
|
||||||
|
# conditions.append("({} >> 22) %% {} = {}".format(key, self.shard_count, _replace_char))
|
||||||
|
# values.append(self.shardid)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# # THIS_SHARD = SHARDID(sharding.shard_number, sharding.shard_count)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# NULL = Constant('IS NULL')
|
||||||
|
# NOTNULL = Constant('IS NOT NULL')
|
||||||
135
src/data/connector.py
Normal file
135
src/data/connector.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import psycopg as psq
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from psycopg.pq import TransactionStatus
|
||||||
|
|
||||||
|
from .cursor import AsyncLoggingCursor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
row_factory = psq.rows.dict_row
|
||||||
|
|
||||||
|
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class Connector:
|
||||||
|
cursor_factory = AsyncLoggingCursor
|
||||||
|
|
||||||
|
def __init__(self, conn_args):
|
||||||
|
self._conn_args = conn_args
|
||||||
|
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||||
|
|
||||||
|
self.pool = self.make_pool()
|
||||||
|
|
||||||
|
self.conn_hooks = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||||
|
"""
|
||||||
|
Convenience property for the current context connection.
|
||||||
|
"""
|
||||||
|
return ctx_connection.get()
|
||||||
|
|
||||||
|
@conn.setter
|
||||||
|
def conn(self, conn: psq.AsyncConnection):
|
||||||
|
"""
|
||||||
|
Set the contextual connection in the current context.
|
||||||
|
Always do this in an isolated context!
|
||||||
|
"""
|
||||||
|
ctx_connection.set(conn)
|
||||||
|
|
||||||
|
def make_pool(self) -> AsyncConnectionPool:
|
||||||
|
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||||
|
return AsyncConnectionPool(
|
||||||
|
self._conn_args,
|
||||||
|
open=False,
|
||||||
|
min_size=1,
|
||||||
|
max_size=4,
|
||||||
|
configure=self._setup_connection,
|
||||||
|
kwargs=self._conn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def refresh_pool(self):
|
||||||
|
"""
|
||||||
|
Refresh the pool.
|
||||||
|
|
||||||
|
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||||
|
Better ways should be sought (a way to
|
||||||
|
"""
|
||||||
|
logger.info("Pool refresh requested, closing and reopening.")
|
||||||
|
old_pool = self.pool
|
||||||
|
self.pool = self.make_pool()
|
||||||
|
await self.pool.open()
|
||||||
|
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||||
|
await old_pool.close()
|
||||||
|
logger.info("Pool refresh complete.")
|
||||||
|
|
||||||
|
async def map_over_pool(self, callable):
|
||||||
|
"""
|
||||||
|
Dangerous method to call a method on each connection in the pool.
|
||||||
|
|
||||||
|
Utilises private methods of the AsyncConnectionPool.
|
||||||
|
"""
|
||||||
|
async with self.pool._lock:
|
||||||
|
conns = list(self.pool._pool)
|
||||||
|
while conns:
|
||||||
|
conn = conns.pop()
|
||||||
|
try:
|
||||||
|
await callable(conn)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def open(self):
|
||||||
|
try:
|
||||||
|
logger.info("Opening database pool.")
|
||||||
|
await self.pool.open()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# May be a different pool!
|
||||||
|
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||||
|
await self.pool.close()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(self) -> psq.AsyncConnection:
|
||||||
|
"""
|
||||||
|
Asynchronous context manager to get and manage a connection.
|
||||||
|
|
||||||
|
If the context connection is set, uses this and does not manage the lifetime.
|
||||||
|
Otherwise, requests a new connection from the pool and returns it when done.
|
||||||
|
"""
|
||||||
|
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||||
|
if (conn := self.conn):
|
||||||
|
yield conn
|
||||||
|
else:
|
||||||
|
async with self.pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||||
|
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||||
|
for hook in self.conn_hooks:
|
||||||
|
try:
|
||||||
|
await hook(conn)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Exception encountered setting up new connection")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||||
|
"""
|
||||||
|
Minimal decorator to register a coroutine to run on connect or reconnect.
|
||||||
|
|
||||||
|
Note that these are only run on connect and reconnect.
|
||||||
|
If a hook is registered after connection, it will not be run.
|
||||||
|
"""
|
||||||
|
self.conn_hooks.append(coro)
|
||||||
|
return coro
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Connectable(Protocol):
|
||||||
|
def bind(self, connector: Connector):
|
||||||
|
raise NotImplementedError
|
||||||
42
src/data/cursor.py
Normal file
42
src/data/cursor.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from psycopg import AsyncCursor, sql
|
||||||
|
from psycopg.abc import Query, Params
|
||||||
|
from psycopg._encodings import conn_encoding
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLoggingCursor(AsyncCursor):
|
||||||
|
def mogrify_query(self, query: Query):
|
||||||
|
if isinstance(query, str):
|
||||||
|
msg = query
|
||||||
|
elif isinstance(query, (sql.SQL, sql.Composed)):
|
||||||
|
msg = query.as_string(self)
|
||||||
|
elif isinstance(query, bytes):
|
||||||
|
msg = query.decode(conn_encoding(self._conn.pgconn), 'replace')
|
||||||
|
else:
|
||||||
|
msg = repr(query)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def execute(self, query: Query, params: Optional[Params] = None, **kwargs):
|
||||||
|
if logging.DEBUG >= logger.getEffectiveLevel():
|
||||||
|
msg = self.mogrify_query(query)
|
||||||
|
logger.debug(
|
||||||
|
"Executing query (%s) with values %s", msg, params,
|
||||||
|
extra={'action': "Query Execute"}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return await super().execute(query, params=params, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
msg = self.mogrify_query(query)
|
||||||
|
logger.exception(
|
||||||
|
"Exception during query execution. Query (%s) with parameters %s.",
|
||||||
|
msg, params,
|
||||||
|
extra={'action': "Query Execute"},
|
||||||
|
stack_info=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Possibly log execution time
|
||||||
|
pass
|
||||||
47
src/data/database.py
Normal file
47
src/data/database.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
import logging
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
# from .cursor import AsyncLoggingCursor
|
||||||
|
from .registry import Registry
|
||||||
|
from .connector import Connector
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Version = namedtuple('Version', ('version', 'time', 'author'))
|
||||||
|
|
||||||
|
T = TypeVar('T', bound=Registry)
|
||||||
|
|
||||||
|
|
||||||
|
class Database(Connector):
|
||||||
|
# cursor_factory = AsyncLoggingCursor
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.registries: dict[str, Registry] = {}
|
||||||
|
|
||||||
|
def load_registry(self, registry: T) -> T:
|
||||||
|
logger.debug(
|
||||||
|
f"Loading and binding registry '{registry.name}'.",
|
||||||
|
extra={'action': f"Reg {registry.name}"}
|
||||||
|
)
|
||||||
|
registry.bind(self)
|
||||||
|
self.registries[registry.name] = registry
|
||||||
|
return registry
|
||||||
|
|
||||||
|
async def version(self) -> Version:
|
||||||
|
"""
|
||||||
|
Return the current schema version as a Version namedtuple.
|
||||||
|
"""
|
||||||
|
async with self.connection() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
# Get last entry in version table, compare against desired version
|
||||||
|
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return Version(row['version'], row['time'], row['author'])
|
||||||
|
else:
|
||||||
|
# No versions in the database
|
||||||
|
return Version(-1, None, None)
|
||||||
323
src/data/models.py
Normal file
323
src/data/models.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
from typing import TypeVar, Type, Optional, Generic, Union
|
||||||
|
# from typing_extensions import Self
|
||||||
|
from weakref import WeakValueDictionary
|
||||||
|
from collections.abc import MutableMapping
|
||||||
|
|
||||||
|
from psycopg.rows import DictRow
|
||||||
|
|
||||||
|
from .table import Table
|
||||||
|
from .columns import Column
|
||||||
|
from . import queries as q
|
||||||
|
from .connector import Connector
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
RowT = TypeVar('RowT', bound='RowModel')
|
||||||
|
|
||||||
|
|
||||||
|
class MISSING:
|
||||||
|
__slots__ = ('oid',)
|
||||||
|
|
||||||
|
def __init__(self, oid):
|
||||||
|
self.oid = oid
|
||||||
|
|
||||||
|
|
||||||
|
class RowTable(Table, Generic[RowT]):
|
||||||
|
__slots__ = (
|
||||||
|
'model',
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, name, model: Type[RowT], **kwargs):
|
||||||
|
super().__init__(name, **kwargs)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def columns(self):
|
||||||
|
return self.model._columns_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id_col(self):
|
||||||
|
return self.model._key_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def row_cache(self):
|
||||||
|
return self.model._cache_
|
||||||
|
|
||||||
|
def _many_query_adapter(self, *data):
|
||||||
|
self.model._make_rows(*data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _single_query_adapter(self, *data):
|
||||||
|
if data:
|
||||||
|
self.model._make_rows(*data)
|
||||||
|
return data[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _delete_query_adapter(self, *data):
|
||||||
|
self.model._delete_rows(*data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
# New methods to fetch and create rows
|
||||||
|
async def create_row(self, *args, **kwargs) -> RowT:
|
||||||
|
data = await super().insert(*args, **kwargs)
|
||||||
|
return self.model._make_rows(data)[0]
|
||||||
|
|
||||||
|
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
||||||
|
# TODO: Handle list of rowids here?
|
||||||
|
return q.Select(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self.model._make_rows,
|
||||||
|
connector=self.connector
|
||||||
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
WK = TypeVar('WK')
|
||||||
|
WV = TypeVar('WV')
|
||||||
|
|
||||||
|
|
||||||
|
class WeakCache(Generic[WK, WV], MutableMapping[WK, WV]):
|
||||||
|
def __init__(self, ref_cache):
|
||||||
|
self.ref_cache = ref_cache
|
||||||
|
self.weak_cache = WeakValueDictionary()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
value = self.weak_cache[key]
|
||||||
|
self.ref_cache[key] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.weak_cache[key] = value
|
||||||
|
self.ref_cache[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.weak_cache[key]
|
||||||
|
try:
|
||||||
|
del self.ref_cache[key]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.weak_cache
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.weak_cache)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.weak_cache)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def pop(self, key, default=None):
|
||||||
|
if key in self:
|
||||||
|
value = self[key]
|
||||||
|
del self[key]
|
||||||
|
else:
|
||||||
|
value = default
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Implement getitem and setitem, for dynamic column access
|
||||||
|
class RowModel:
|
||||||
|
__slots__ = ('data',)
|
||||||
|
|
||||||
|
_schema_: str = 'public'
|
||||||
|
_tablename_: Optional[str] = None
|
||||||
|
_columns_: dict[str, Column] = {}
|
||||||
|
|
||||||
|
# Cache to keep track of registered Rows
|
||||||
|
_cache_: Union[dict, WeakValueDictionary, WeakCache] = None # type: ignore
|
||||||
|
|
||||||
|
_key_: tuple[str, ...] = ()
|
||||||
|
_connector: Optional[Connector] = None
|
||||||
|
_registry: Optional[Registry] = None
|
||||||
|
|
||||||
|
# TODO: Proper typing for a classvariable which gets dynamically assigned in subclass
|
||||||
|
table: RowTable = None
|
||||||
|
|
||||||
|
def __init_subclass__(cls: Type[RowT], table: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Set table, _columns_, and _key_.
|
||||||
|
"""
|
||||||
|
if table is not None:
|
||||||
|
cls._tablename_ = table
|
||||||
|
|
||||||
|
if cls._tablename_ is not None:
|
||||||
|
columns = {}
|
||||||
|
for key, value in cls.__dict__.items():
|
||||||
|
if isinstance(value, Column):
|
||||||
|
columns[key] = value
|
||||||
|
|
||||||
|
cls._columns_ = columns
|
||||||
|
if not cls._key_:
|
||||||
|
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||||
|
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
||||||
|
if cls._cache_ is None:
|
||||||
|
cls._cache_ = WeakValueDictionary()
|
||||||
|
|
||||||
|
def __new__(cls, data):
|
||||||
|
# Registry pattern.
|
||||||
|
# Ensure each rowid always refers to a single Model instance
|
||||||
|
if data is not None:
|
||||||
|
rowid = cls._id_from_data(data)
|
||||||
|
|
||||||
|
cache = cls._cache_
|
||||||
|
|
||||||
|
if (row := cache.get(rowid, None)) is not None:
|
||||||
|
obj = row
|
||||||
|
else:
|
||||||
|
obj = cache[rowid] = super().__new__(cls)
|
||||||
|
else:
|
||||||
|
obj = super().__new__(cls)
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_tuple(cls):
|
||||||
|
return (cls.table.identifier, ())
|
||||||
|
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.data[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def bind(cls, connector: Connector):
|
||||||
|
if cls.table is None:
|
||||||
|
raise ValueError("Cannot bind abstract RowModel")
|
||||||
|
cls._connector = connector
|
||||||
|
cls.table.bind(connector)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def attach_to(cls, registry: Registry):
|
||||||
|
cls._registry = registry
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dict_(self):
|
||||||
|
return {key: self.data[key] for key in self._key_}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _rowid_(self):
|
||||||
|
return tuple(self.data[key] for key in self._key_)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}.{}({})".format(
|
||||||
|
self.table.schema,
|
||||||
|
self.table.name,
|
||||||
|
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _id_from_data(cls, data):
|
||||||
|
return tuple(data[key] for key in cls._key_)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dict_from_id(cls, rowid):
|
||||||
|
return dict(zip(cls._key_, rowid))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _make_rows(cls: Type[RowT], *data_rows: DictRow) -> list[RowT]:
|
||||||
|
"""
|
||||||
|
Create or retrieve Row objects for each provided data row.
|
||||||
|
If the rows already exist in cache, updates the cached row.
|
||||||
|
"""
|
||||||
|
# TODO: Handle partial row data here somehow?
|
||||||
|
rows = [cls(data_row) for data_row in data_rows]
|
||||||
|
return rows
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _delete_rows(cls, *data_rows):
|
||||||
|
"""
|
||||||
|
Remove the given rows from cache, if they exist.
|
||||||
|
May be extended to handle object deletion.
|
||||||
|
"""
|
||||||
|
cache = cls._cache_
|
||||||
|
|
||||||
|
for data_row in data_rows:
|
||||||
|
rowid = cls._id_from_data(data_row)
|
||||||
|
cache.pop(rowid, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(cls: Type[RowT], *args, **kwargs) -> RowT:
|
||||||
|
return await cls.table.create_row(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fetch_where(cls: Type[RowT], *args, **kwargs):
|
||||||
|
return cls.table.fetch_rows_where(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch(cls: Type[RowT], *rowid, cached=True) -> Optional[RowT]:
|
||||||
|
"""
|
||||||
|
Fetch the row with the given id, retrieving from cache where possible.
|
||||||
|
"""
|
||||||
|
row = cls._cache_.get(rowid, None) if cached else None
|
||||||
|
if row is None:
|
||||||
|
rows = await cls.fetch_where(**cls._dict_from_id(rowid))
|
||||||
|
row = rows[0] if rows else None
|
||||||
|
if row is None:
|
||||||
|
cls._cache_[rowid] = cls(None)
|
||||||
|
elif row.data is None:
|
||||||
|
row = None
|
||||||
|
|
||||||
|
return row
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def fetch_or_create(cls, *rowid, **kwargs):
|
||||||
|
"""
|
||||||
|
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if rowid:
|
||||||
|
row = await cls.fetch(*rowid)
|
||||||
|
else:
|
||||||
|
rows = await cls.fetch_where(**kwargs).limit(1)
|
||||||
|
row = rows[0] if rows else None
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
creation_kwargs = kwargs
|
||||||
|
if rowid:
|
||||||
|
creation_kwargs.update(cls._dict_from_id(rowid))
|
||||||
|
row = await cls.create(**creation_kwargs)
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def refresh(self: RowT) -> Optional[RowT]:
|
||||||
|
"""
|
||||||
|
Refresh this Row from data.
|
||||||
|
|
||||||
|
The return value may be `None` if the row was deleted.
|
||||||
|
"""
|
||||||
|
rows = await self.table.select_where(**self._dict_)
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self.data = rows[0]
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def update(self: RowT, **values) -> Optional[RowT]:
|
||||||
|
"""
|
||||||
|
Update this Row with the given values.
|
||||||
|
|
||||||
|
Internally passes the provided `values` to the `update` Query.
|
||||||
|
The return value may be `None` if the row was deleted.
|
||||||
|
"""
|
||||||
|
data = await self.table.update_where(**self._dict_).set(**values).with_adapter(self._make_rows)
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return data[0]
|
||||||
|
|
||||||
|
async def delete(self: RowT) -> Optional[RowT]:
|
||||||
|
"""
|
||||||
|
Delete this Row.
|
||||||
|
"""
|
||||||
|
data = await self.table.delete_where(**self._dict_).with_adapter(self._delete_rows)
|
||||||
|
return data[0] if data is not None else None
|
||||||
644
src/data/queries.py
Normal file
644
src/data/queries.py
Normal file
@@ -0,0 +1,644 @@
|
|||||||
|
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import chain
|
||||||
|
from psycopg import AsyncConnection, AsyncCursor
|
||||||
|
from psycopg import sql
|
||||||
|
from psycopg.rows import DictRow
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .conditions import Condition
|
||||||
|
from .base import Expression, RawExpr
|
||||||
|
from .connector import Connector
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TQueryT = TypeVar('TQueryT', bound='TableQuery')
|
||||||
|
SQueryT = TypeVar('SQueryT', bound='Select')
|
||||||
|
|
||||||
|
QueryResult = TypeVar('QueryResult')
|
||||||
|
|
||||||
|
|
||||||
|
class Query(Generic[QueryResult]):
|
||||||
|
"""
|
||||||
|
ABC for an executable query statement.
|
||||||
|
"""
|
||||||
|
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
|
||||||
|
|
||||||
|
_adapter: Callable[..., QueryResult]
|
||||||
|
|
||||||
|
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
|
||||||
|
self.connector: Optional[Connector] = connector
|
||||||
|
self.conn: Optional[AsyncConnection] = conn
|
||||||
|
self.cursor: Optional[AsyncCursor] = cursor
|
||||||
|
|
||||||
|
if row_adapter is not None:
|
||||||
|
self._adapter = row_adapter
|
||||||
|
else:
|
||||||
|
self._adapter = self._no_adapter
|
||||||
|
|
||||||
|
self.result: Optional[QueryResult] = None
|
||||||
|
|
||||||
|
def bind(self, connector: Connector):
|
||||||
|
self.connector = connector
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_cursor(self, cursor: AsyncCursor):
|
||||||
|
self.cursor = cursor
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_connection(self, conn: AsyncConnection):
|
||||||
|
self.conn = conn
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def with_adapter(self, callable: Callable[..., QueryResult]):
|
||||||
|
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
|
||||||
|
# For this to work cleanly, callable should have arg type of QR1, not any
|
||||||
|
self._adapter = callable
|
||||||
|
return self
|
||||||
|
|
||||||
|
def with_no_adapter(self):
|
||||||
|
"""
|
||||||
|
Sets the adapater to the identity.
|
||||||
|
"""
|
||||||
|
self._adapter = self._no_adapter
|
||||||
|
return self
|
||||||
|
|
||||||
|
def one(self):
|
||||||
|
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> Expression:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
|
||||||
|
query, values = self.build().as_tuple()
|
||||||
|
# TODO: Move logging out to a custom cursor
|
||||||
|
# logger.debug(
|
||||||
|
# f"Executing query ({query.as_string(cursor)}) with values {values}",
|
||||||
|
# extra={'action': "Query"}
|
||||||
|
# )
|
||||||
|
await cursor.execute(sql.Composed((query,)), values)
|
||||||
|
data = await cursor.fetchall()
|
||||||
|
self.result = self._adapter(*data)
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
async def execute(self, cursor=None) -> QueryResult:
|
||||||
|
"""
|
||||||
|
Execute the query, optionally with the provided cursor, and return the result rows.
|
||||||
|
If no cursor is provided, and no cursor has been set with `with_cursor`,
|
||||||
|
the execution will create a new cursor from the connection and close it automatically.
|
||||||
|
"""
|
||||||
|
# Create a cursor if possible
|
||||||
|
cursor = cursor if cursor is not None else self.cursor
|
||||||
|
if self.cursor is None:
|
||||||
|
if self.conn is None:
|
||||||
|
if self.connector is None:
|
||||||
|
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||||
|
else:
|
||||||
|
async with self.connector.connection() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
data = await self._execute(cursor)
|
||||||
|
else:
|
||||||
|
async with self.conn.cursor() as cursor:
|
||||||
|
data = await self._execute(cursor)
|
||||||
|
else:
|
||||||
|
data = await self._execute(cursor)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
return self.execute().__await__()
|
||||||
|
|
||||||
|
|
||||||
|
class TableQuery(Query[QueryResult]):
|
||||||
|
"""
|
||||||
|
ABC for an executable query statement expected to be run on a single table.
|
||||||
|
"""
|
||||||
|
__slots__ = (
|
||||||
|
'tableid',
|
||||||
|
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, tableid, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.tableid: sql.Identifier = tableid
|
||||||
|
|
||||||
|
def options(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Set some query options.
|
||||||
|
Default implementation does nothing.
|
||||||
|
Should be overridden to provide specific options.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class WhereMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.condition: Optional[Condition] = None
|
||||||
|
|
||||||
|
def where(self, *args: Condition, **kwargs):
|
||||||
|
"""
|
||||||
|
Add a Condition to the query.
|
||||||
|
Position arguments should be Conditions,
|
||||||
|
and keyword arguments should be of the form `column=Value`,
|
||||||
|
where Value may be a Value-type or a literal value.
|
||||||
|
All provided Conditions will be and-ed together to create a new Condition.
|
||||||
|
TODO: Maybe just pass this verbatim to a condition.
|
||||||
|
"""
|
||||||
|
if args or kwargs:
|
||||||
|
condition = Condition.construct(*args, **kwargs)
|
||||||
|
if self.condition is not None:
|
||||||
|
condition = self.condition & condition
|
||||||
|
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _where_section(self) -> Optional[Expression]:
|
||||||
|
if self.condition is not None:
|
||||||
|
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class JOINTYPE(Enum):
|
||||||
|
LEFT = sql.SQL('LEFT JOIN')
|
||||||
|
RIGHT = sql.SQL('RIGHT JOIN')
|
||||||
|
INNER = sql.SQL('INNER JOIN')
|
||||||
|
OUTER = sql.SQL('OUTER JOIN')
|
||||||
|
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
||||||
|
|
||||||
|
|
||||||
|
class JoinMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
# TODO: Remember to add join slots to TableQuery
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._joins: list[Expression] = []
|
||||||
|
|
||||||
|
def join(self,
|
||||||
|
target: Union[str, Expression],
|
||||||
|
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
||||||
|
join_type: JOINTYPE = JOINTYPE.INNER,
|
||||||
|
natural=False):
|
||||||
|
available = (on is not None) + (using is not None) + natural
|
||||||
|
if available == 0:
|
||||||
|
raise ValueError("No conditions given for Query Join")
|
||||||
|
if available > 1:
|
||||||
|
raise ValueError("Exactly one join format must be given for Query Join")
|
||||||
|
|
||||||
|
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
||||||
|
if isinstance(target, str):
|
||||||
|
sections.append((sql.Identifier(target), ()))
|
||||||
|
else:
|
||||||
|
sections.append(target.as_tuple())
|
||||||
|
|
||||||
|
if on is not None:
|
||||||
|
sections.append((sql.SQL('ON'), ()))
|
||||||
|
sections.append(on.as_tuple())
|
||||||
|
elif using is not None:
|
||||||
|
sections.append((sql.SQL('USING'), ()))
|
||||||
|
if isinstance(using, Expression):
|
||||||
|
sections.append(using.as_tuple())
|
||||||
|
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
|
||||||
|
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
|
||||||
|
sections.append((cols, ()))
|
||||||
|
else:
|
||||||
|
raise ValueError("Unrecognised 'using' type.")
|
||||||
|
elif natural:
|
||||||
|
sections.insert(0, (sql.SQL('NATURAL'), ()))
|
||||||
|
|
||||||
|
expr = RawExpr.join_tuples(*sections)
|
||||||
|
self._joins.append(expr)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def leftjoin(self, *args, **kwargs):
|
||||||
|
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _join_section(self) -> Optional[Expression]:
|
||||||
|
if self._joins:
|
||||||
|
return RawExpr.join(*self._joins)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._extra: Optional[Expression] = None
|
||||||
|
|
||||||
|
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
|
||||||
|
"""
|
||||||
|
Add an extra string, and optionally values, to this query.
|
||||||
|
The extra string is inserted after any condition, and before the limit.
|
||||||
|
"""
|
||||||
|
extra_expr = RawExpr(extra, values)
|
||||||
|
if self._extra is not None:
|
||||||
|
extra_expr = RawExpr.join(self._extra, extra_expr)
|
||||||
|
self._extra = extra_expr
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _extra_section(self) -> Optional[Expression]:
|
||||||
|
if self._extra is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self._extra
|
||||||
|
|
||||||
|
|
||||||
|
class LimitMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._limit: Optional[int] = None
|
||||||
|
|
||||||
|
def limit(self, limit: int):
|
||||||
|
"""
|
||||||
|
Add a limit to this query.
|
||||||
|
"""
|
||||||
|
self._limit = limit
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _limit_section(self) -> Optional[Expression]:
|
||||||
|
if self._limit is not None:
|
||||||
|
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FromMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._from: Optional[Expression] = None
|
||||||
|
|
||||||
|
def from_expr(self, _from: Expression):
|
||||||
|
self._from = _from
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _from_section(self) -> Optional[Expression]:
|
||||||
|
if self._from is not None:
|
||||||
|
expr, values = self._from.as_tuple()
|
||||||
|
return RawExpr(sql.SQL("FROM {}").format(expr), values)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ORDER(Enum):
|
||||||
|
ASC = sql.SQL('ASC')
|
||||||
|
DESC = sql.SQL('DESC')
|
||||||
|
|
||||||
|
|
||||||
|
class NULLS(Enum):
|
||||||
|
FIRST = sql.SQL('NULLS FIRST')
|
||||||
|
LAST = sql.SQL('NULLS LAST')
|
||||||
|
|
||||||
|
|
||||||
|
class OrderMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._order: list[Expression] = []
|
||||||
|
|
||||||
|
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
|
||||||
|
"""
|
||||||
|
Add a single sort expression to the query.
|
||||||
|
This method stacks.
|
||||||
|
"""
|
||||||
|
if isinstance(expr, Expression):
|
||||||
|
string, values = expr.as_tuple()
|
||||||
|
else:
|
||||||
|
string = sql.Identifier(expr)
|
||||||
|
values = ()
|
||||||
|
|
||||||
|
parts = [string]
|
||||||
|
if direction is not None:
|
||||||
|
parts.append(direction.value)
|
||||||
|
if nulls is not None:
|
||||||
|
parts.append(nulls.value)
|
||||||
|
|
||||||
|
order_string = sql.SQL(' ').join(parts)
|
||||||
|
self._order.append(RawExpr(order_string, values))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _order_section(self) -> Optional[Expression]:
|
||||||
|
if self._order:
|
||||||
|
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
|
||||||
|
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
|
||||||
|
return expr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GroupMixin(TableQuery[QueryResult]):
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._group: list[Expression] = []
|
||||||
|
|
||||||
|
def group_by(self, *exprs: Union[Expression, str]):
|
||||||
|
"""
|
||||||
|
Add a group expression(s) to the query.
|
||||||
|
This method stacks.
|
||||||
|
"""
|
||||||
|
for expr in exprs:
|
||||||
|
if isinstance(expr, Expression):
|
||||||
|
self._group.append(expr)
|
||||||
|
else:
|
||||||
|
self._group.append(RawExpr(sql.Identifier(expr)))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _group_section(self) -> Optional[Expression]:
|
||||||
|
if self._group:
|
||||||
|
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
|
||||||
|
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
|
||||||
|
return expr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Insert(ExtraMixin, TableQuery[QueryResult]):
|
||||||
|
"""
|
||||||
|
Query type representing a table insert query.
|
||||||
|
"""
|
||||||
|
# TODO: Support ON CONFLICT for upserts
|
||||||
|
__slots__ = ('_columns', '_values', '_conflict')
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._columns: tuple[str, ...] = ()
|
||||||
|
self._values: tuple[tuple[Any, ...], ...] = ()
|
||||||
|
self._conflict: Optional[Expression] = None
|
||||||
|
|
||||||
|
def insert(self, columns, *values):
|
||||||
|
"""
|
||||||
|
Insert the given data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
columns: tuple[str]
|
||||||
|
Tuple of column names to insert.
|
||||||
|
|
||||||
|
values: tuple[tuple[Any, ...], ...]
|
||||||
|
Tuple of values to insert, corresponding to the columns.
|
||||||
|
"""
|
||||||
|
if not values:
|
||||||
|
raise ValueError("Cannot insert zero rows.")
|
||||||
|
if len(values[0]) != len(columns):
|
||||||
|
raise ValueError("Number of columns does not match length of values.")
|
||||||
|
|
||||||
|
self._columns = columns
|
||||||
|
self._values = values
|
||||||
|
return self
|
||||||
|
|
||||||
|
def on_conflict(self, ignore=False):
|
||||||
|
# TODO lots more we can do here
|
||||||
|
# Maybe return a Conflict object that can chain itself (not the query)
|
||||||
|
if ignore:
|
||||||
|
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _conflict_section(self) -> Optional[Expression]:
|
||||||
|
if self._conflict is not None:
|
||||||
|
e, v = self._conflict.as_tuple()
|
||||||
|
expr = RawExpr(
|
||||||
|
sql.SQL("ON CONFLICT {}").format(
|
||||||
|
e
|
||||||
|
),
|
||||||
|
v
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
||||||
|
single_value_str = sql.SQL('({})').format(
|
||||||
|
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
|
||||||
|
)
|
||||||
|
values_str = sql.SQL(',').join(single_value_str * len(self._values))
|
||||||
|
|
||||||
|
# TODO: Check efficiency of inserting multiple values like this
|
||||||
|
# Also implement a Copy query
|
||||||
|
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
||||||
|
table=self.tableid,
|
||||||
|
columns=columns,
|
||||||
|
values_str=values_str
|
||||||
|
)
|
||||||
|
|
||||||
|
sections = [
|
||||||
|
RawExpr(base, tuple(chain(*self._values))),
|
||||||
|
self._conflict_section,
|
||||||
|
self._extra_section,
|
||||||
|
RawExpr(sql.SQL('RETURNING *'))
|
||||||
|
]
|
||||||
|
|
||||||
|
sections = (section for section in sections if section is not None)
|
||||||
|
return RawExpr.join(*sections)
|
||||||
|
|
||||||
|
|
||||||
|
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
|
||||||
|
"""
|
||||||
|
Select rows from a table matching provided conditions.
|
||||||
|
"""
|
||||||
|
__slots__ = ('_columns',)
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._columns: tuple[Expression, ...] = ()
|
||||||
|
|
||||||
|
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
|
||||||
|
"""
|
||||||
|
Set the columns and expressions to select.
|
||||||
|
If none are given, selects all columns.
|
||||||
|
"""
|
||||||
|
cols: List[Expression] = []
|
||||||
|
if columns:
|
||||||
|
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
|
||||||
|
if exprs:
|
||||||
|
for name, expr in exprs.items():
|
||||||
|
if isinstance(expr, str):
|
||||||
|
cols.append(
|
||||||
|
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
|
||||||
|
)
|
||||||
|
elif isinstance(expr, sql.Composable):
|
||||||
|
cols.append(
|
||||||
|
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
|
||||||
|
)
|
||||||
|
elif isinstance(expr, Expression):
|
||||||
|
value_expr, value_values = expr.as_tuple()
|
||||||
|
cols.append(RawExpr(
|
||||||
|
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
|
||||||
|
value_values
|
||||||
|
))
|
||||||
|
if cols:
|
||||||
|
self._columns = (*self._columns, *cols)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
if not self._columns:
|
||||||
|
columns, columns_values = sql.SQL('*'), ()
|
||||||
|
else:
|
||||||
|
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
|
||||||
|
|
||||||
|
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
||||||
|
columns=columns,
|
||||||
|
table=self.tableid
|
||||||
|
)
|
||||||
|
|
||||||
|
sections = [
|
||||||
|
RawExpr(base, columns_values),
|
||||||
|
self._join_section,
|
||||||
|
self._where_section,
|
||||||
|
self._group_section,
|
||||||
|
self._extra_section,
|
||||||
|
self._order_section,
|
||||||
|
self._limit_section,
|
||||||
|
]
|
||||||
|
|
||||||
|
sections = (section for section in sections if section is not None)
|
||||||
|
return RawExpr.join(*sections)
|
||||||
|
|
||||||
|
|
||||||
|
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
||||||
|
"""
|
||||||
|
Query type representing a table delete query.
|
||||||
|
"""
|
||||||
|
# TODO: Cascade option for delete, maybe other options
|
||||||
|
# TODO: Require a where unless specifically disabled, for safety
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
base = sql.SQL("DELETE FROM {table}").format(
|
||||||
|
table=self.tableid,
|
||||||
|
)
|
||||||
|
sections = [
|
||||||
|
RawExpr(base),
|
||||||
|
self._where_section,
|
||||||
|
self._extra_section,
|
||||||
|
RawExpr(sql.SQL('RETURNING *'))
|
||||||
|
]
|
||||||
|
|
||||||
|
sections = (section for section in sections if section is not None)
|
||||||
|
return RawExpr.join(*sections)
|
||||||
|
|
||||||
|
|
||||||
|
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
|
||||||
|
__slots__ = (
|
||||||
|
'_set',
|
||||||
|
)
|
||||||
|
# TODO: Again, require a where unless specifically disabled
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._set: List[Expression] = []
|
||||||
|
|
||||||
|
def set(self, **column_values: Union[Any, Expression]):
|
||||||
|
exprs: List[Expression] = []
|
||||||
|
for name, value in column_values.items():
|
||||||
|
if isinstance(value, Expression):
|
||||||
|
value_tup = value.as_tuple()
|
||||||
|
else:
|
||||||
|
value_tup = (sql.Placeholder(), (value,))
|
||||||
|
|
||||||
|
exprs.append(
|
||||||
|
RawExpr.join_tuples(
|
||||||
|
(sql.Identifier(name), ()),
|
||||||
|
value_tup,
|
||||||
|
joiner=sql.SQL(' = ')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._set.extend(exprs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
if not self._set:
|
||||||
|
raise ValueError("No columns provided to update.")
|
||||||
|
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
||||||
|
|
||||||
|
base = sql.SQL("UPDATE {table} SET {set}").format(
|
||||||
|
table=self.tableid,
|
||||||
|
set=set_expr
|
||||||
|
)
|
||||||
|
sections = [
|
||||||
|
RawExpr(base, set_values),
|
||||||
|
self._from_section,
|
||||||
|
self._where_section,
|
||||||
|
self._extra_section,
|
||||||
|
self._limit_section,
|
||||||
|
RawExpr(sql.SQL('RETURNING *'))
|
||||||
|
]
|
||||||
|
|
||||||
|
sections = (section for section in sections if section is not None)
|
||||||
|
return RawExpr.join(*sections)
|
||||||
|
|
||||||
|
|
||||||
|
# async def upsert(cursor, table, constraint, **values):
|
||||||
|
# """
|
||||||
|
# Insert or on conflict update.
|
||||||
|
# """
|
||||||
|
# valuedict = values
|
||||||
|
# keys, values = zip(*values.items())
|
||||||
|
#
|
||||||
|
# key_str = _format_insertkeys(keys)
|
||||||
|
# value_str, values = _format_insertvalues(values)
|
||||||
|
# update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||||
|
#
|
||||||
|
# if not isinstance(constraint, str):
|
||||||
|
# constraint = ", ".join(constraint)
|
||||||
|
#
|
||||||
|
# await cursor.execute(
|
||||||
|
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||||
|
# table, key_str, value_str, constraint, update_key_str
|
||||||
|
# ),
|
||||||
|
# tuple((*values, *update_key_values))
|
||||||
|
# )
|
||||||
|
# return await cursor.fetchone()
|
||||||
|
|
||||||
|
|
||||||
|
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
|
||||||
|
# cursor = cursor or conn.cursor()
|
||||||
|
#
|
||||||
|
# # TODO: executemany or copy syntax now
|
||||||
|
# return execute_values(
|
||||||
|
# cursor,
|
||||||
|
# """
|
||||||
|
# UPDATE {table}
|
||||||
|
# SET {set_clause}
|
||||||
|
# FROM (VALUES {cast_row}%s)
|
||||||
|
# AS {temp_table}
|
||||||
|
# WHERE {where_clause}
|
||||||
|
# RETURNING *
|
||||||
|
# """.format(
|
||||||
|
# table=table,
|
||||||
|
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
||||||
|
# cast_row=cast_row + ',' if cast_row else '',
|
||||||
|
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
||||||
|
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
||||||
|
# ),
|
||||||
|
# values,
|
||||||
|
# fetch=True
|
||||||
|
# )
|
||||||
102
src/data/registry.py
Normal file
102
src/data/registry.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
from typing import Protocol, runtime_checkable, Optional
|
||||||
|
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
|
||||||
|
from .connector import Connector, Connectable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class _Attachable(Connectable, Protocol):
|
||||||
|
def attach_to(self, registry: 'Registry'):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
_attached: list[_Attachable] = []
|
||||||
|
_name: Optional[str] = None
|
||||||
|
|
||||||
|
def __init_subclass__(cls, name=None):
|
||||||
|
attached = []
|
||||||
|
for _, member in cls.__dict__.items():
|
||||||
|
if isinstance(member, _Attachable):
|
||||||
|
attached.append(member)
|
||||||
|
cls._attached = attached
|
||||||
|
cls._name = name or cls.__name__
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
self._conn: Optional[Connector] = None
|
||||||
|
self.name: str = name if name is not None else self._name
|
||||||
|
if self.name is None:
|
||||||
|
raise ValueError("A Registry must have a name!")
|
||||||
|
|
||||||
|
self.init_tasks = []
|
||||||
|
|
||||||
|
for member in self._attached:
|
||||||
|
member.attach_to(self)
|
||||||
|
|
||||||
|
def bind(self, connector: Connector):
|
||||||
|
self._conn = connector
|
||||||
|
for child in self._attached:
|
||||||
|
child.bind(connector)
|
||||||
|
|
||||||
|
def attach(self, attachable):
|
||||||
|
self._attached.append(attachable)
|
||||||
|
if self._conn is not None:
|
||||||
|
attachable.bind(self._conn)
|
||||||
|
return attachable
|
||||||
|
|
||||||
|
def init_task(self, coro):
|
||||||
|
"""
|
||||||
|
Initialisation tasks are run to setup the registry state.
|
||||||
|
These tasks will be run in the event loop, after connection to the database.
|
||||||
|
These tasks should be idempotent, as they may be run on reload and reconnect.
|
||||||
|
"""
|
||||||
|
self.init_tasks.append(coro)
|
||||||
|
return coro
|
||||||
|
|
||||||
|
async def init(self):
|
||||||
|
for task in self.init_tasks:
|
||||||
|
await task(self)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class AttachableClass:
|
||||||
|
"""ABC for a default implementation of an Attachable class."""
|
||||||
|
|
||||||
|
_connector: Optional[Connector] = None
|
||||||
|
_registry: Optional[Registry] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def bind(cls, connector: Connector):
|
||||||
|
cls._connector = connector
|
||||||
|
connector.connect_hook(cls.on_connect)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def attach_to(cls, registry: Registry):
|
||||||
|
cls._registry = registry
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def on_connect(cls, connection: AsyncConnection):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Attachable:
|
||||||
|
"""ABC for a default implementation of an Attachable object."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._connector: Optional[Connector] = None
|
||||||
|
self._registry: Optional[Registry] = None
|
||||||
|
|
||||||
|
def bind(self, connector: Connector):
|
||||||
|
self._connector = connector
|
||||||
|
connector.connect_hook(self.on_connect)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def attach_to(self, registry: Registry):
|
||||||
|
self._registry = registry
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def on_connect(self, connection: AsyncConnection):
|
||||||
|
pass
|
||||||
95
src/data/table.py
Normal file
95
src/data/table.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from psycopg.rows import DictRow
|
||||||
|
from psycopg import sql
|
||||||
|
|
||||||
|
from . import queries as q
|
||||||
|
from .connector import Connector
|
||||||
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
class Table:
|
||||||
|
"""
|
||||||
|
Transparent interface to a single table structure in the database.
|
||||||
|
Contains standard methods to access the table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, *args, schema='public', **kwargs):
|
||||||
|
self.name: str = name
|
||||||
|
self.schema: str = schema
|
||||||
|
self.connector: Connector = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self):
|
||||||
|
if self.schema == 'public':
|
||||||
|
return sql.Identifier(self.name)
|
||||||
|
else:
|
||||||
|
return sql.Identifier(self.schema, self.name)
|
||||||
|
|
||||||
|
def bind(self, connector: Connector):
|
||||||
|
self.connector = connector
|
||||||
|
return self
|
||||||
|
|
||||||
|
def attach_to(self, registry: Registry):
|
||||||
|
self._registry = registry
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _many_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _single_query_adapter(self, *data: DictRow) -> Optional[DictRow]:
|
||||||
|
if data:
|
||||||
|
return data[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _delete_query_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
||||||
|
return q.Select(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._many_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
|
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
||||||
|
return q.Select(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._single_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
|
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
||||||
|
return q.Update(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._many_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
|
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
||||||
|
return q.Delete(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._many_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
|
def insert(self, **column_values) -> q.Insert[DictRow]:
|
||||||
|
return q.Insert(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._single_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).insert(column_values.keys(), column_values.values())
|
||||||
|
|
||||||
|
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
||||||
|
return q.Insert(
|
||||||
|
self.identifier,
|
||||||
|
row_adapter=self._many_query_adapter,
|
||||||
|
connector=self.connector
|
||||||
|
).insert(*args, **kwargs)
|
||||||
|
|
||||||
|
# def update_many(self, *args, **kwargs):
|
||||||
|
# with self.conn:
|
||||||
|
# return update_many(self.identifier, *args, **kwargs)
|
||||||
|
|
||||||
|
# def upsert(self, *args, **kwargs):
|
||||||
|
# return upsert(self.identifier, *args, **kwargs)
|
||||||
133
src/datamodels.py
Normal file
133
src/datamodels.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from data import Registry, RowModel, Table
|
||||||
|
from data.columns import String, Timestamp, Integer, Bool
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuth(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE user_auth(
|
||||||
|
userid TEXT PRIMARY KEY,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
refresh_token TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'user_auth'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = String(primary=True)
|
||||||
|
token = String()
|
||||||
|
refresh_token = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class BotChannel(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE bot_channels(
|
||||||
|
userid TEXT PRIMARY KEY REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
autojoin BOOLEAN DEFAULT true,
|
||||||
|
listen_redeems BOOLEAN,
|
||||||
|
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'bot_channels'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
userid = String(primary=True)
|
||||||
|
autojoin = Bool()
|
||||||
|
listen_redeems = Bool()
|
||||||
|
joined_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE user_profiles(
|
||||||
|
profileid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
twitchid TEXT NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'user_profiles'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
profileid = Integer(primary=True)
|
||||||
|
twitchid = String()
|
||||||
|
name = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class Communities(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE communities(
|
||||||
|
communityid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
twitchid TEXT NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'communities'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
communityid = Integer(primary=True)
|
||||||
|
twitchid = Integer()
|
||||||
|
name = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
class Koan(RowModel):
|
||||||
|
"""
|
||||||
|
Schema
|
||||||
|
======
|
||||||
|
CREATE TABLE koans(
|
||||||
|
koanid INTEGER GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
|
||||||
|
communityid INTEGER NOT NULL REFERENCES communities ON UPDATE CASCADE ON DELETE CASCADE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
|
_timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
_tablename_ = 'koans'
|
||||||
|
_cache_ = {}
|
||||||
|
|
||||||
|
koanid = Integer(primary=True)
|
||||||
|
communityid = Integer()
|
||||||
|
name = String()
|
||||||
|
message = String()
|
||||||
|
created_at = Timestamp()
|
||||||
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BotData(Registry):
|
||||||
|
user_auth = UserAuth.table
|
||||||
|
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_auth_scopes(
|
||||||
|
userid TEXT NOT NULL REFERENCES user_auth(userid) ON DELETE CASCADE,
|
||||||
|
scope TEXT NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
user_auth_scopes = Table('user_auth_scopes')
|
||||||
|
|
||||||
|
bot_channels = BotChannel.table
|
||||||
|
user_profiles = UserProfile.table
|
||||||
|
communities = Communities.table
|
||||||
|
|
||||||
|
koans = Koan.table
|
||||||
4
src/meta/__init__.py
Normal file
4
src/meta/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .args import args
|
||||||
|
from .crocbot import CrocBot
|
||||||
|
from .config import Conf, conf
|
||||||
|
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
||||||
28
src/meta/args.py
Normal file
28
src/meta/args.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from constants import CONFIG_FILE
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Parsed commandline arguments
|
||||||
|
# ------------------------------
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--conf',
|
||||||
|
dest='config',
|
||||||
|
default=CONFIG_FILE,
|
||||||
|
help="Path to configuration file."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--host',
|
||||||
|
dest='host',
|
||||||
|
default='127.0.0.1',
|
||||||
|
help="IP address to run the websocket server on."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--port',
|
||||||
|
dest='port',
|
||||||
|
default='5001',
|
||||||
|
help="Port to run the websocket server on."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
105
src/meta/config.py
Normal file
105
src/meta/config.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import configparser as cfgp
|
||||||
|
|
||||||
|
from .args import args
|
||||||
|
|
||||||
|
|
||||||
|
class MapDotProxy:
|
||||||
|
"""
|
||||||
|
Allows dot access to an underlying Mappable object.
|
||||||
|
"""
|
||||||
|
__slots__ = ("_map", "_converter")
|
||||||
|
|
||||||
|
def __init__(self, mappable, converter=None):
|
||||||
|
self._map = mappable
|
||||||
|
self._converter = converter
|
||||||
|
|
||||||
|
def __getattribute__(self, key):
|
||||||
|
_map = object.__getattribute__(self, '_map')
|
||||||
|
if key == '_map':
|
||||||
|
return _map
|
||||||
|
if key in _map:
|
||||||
|
_converter = object.__getattribute__(self, '_converter')
|
||||||
|
if _converter:
|
||||||
|
return _converter(_map[key])
|
||||||
|
else:
|
||||||
|
return _map[key]
|
||||||
|
else:
|
||||||
|
return object.__getattribute__(_map, key)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._map.__getitem__(key)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParser(cfgp.ConfigParser):
|
||||||
|
"""
|
||||||
|
Extension of base ConfigParser allowing optional
|
||||||
|
section option retrieval without defaults.
|
||||||
|
"""
|
||||||
|
def options(self, section, no_defaults=False, **kwargs):
|
||||||
|
if no_defaults:
|
||||||
|
try:
|
||||||
|
return list(self._sections[section].keys())
|
||||||
|
except KeyError:
|
||||||
|
raise cfgp.NoSectionError(section)
|
||||||
|
else:
|
||||||
|
return super().options(section, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Conf:
|
||||||
|
def __init__(self, configfile, section_name="DEFAULT"):
|
||||||
|
self.configfile = configfile
|
||||||
|
|
||||||
|
self.config = ConfigParser(
|
||||||
|
converters={
|
||||||
|
"intlist": self._getintlist,
|
||||||
|
"list": self._getlist,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(configfile) as conff:
|
||||||
|
# Opening with read_file mainly to ensure the file exists
|
||||||
|
self.config.read_file(conff)
|
||||||
|
|
||||||
|
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||||
|
|
||||||
|
self.default = self.config["DEFAULT"]
|
||||||
|
self.section = MapDotProxy(self.config[self.section_name])
|
||||||
|
self.bot = self.section
|
||||||
|
|
||||||
|
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||||
|
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||||
|
read = set()
|
||||||
|
while more_to_read:
|
||||||
|
to_read = more_to_read.pop(0)
|
||||||
|
read.add(to_read)
|
||||||
|
self.config.read(to_read)
|
||||||
|
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||||
|
if path not in read and path not in more_to_read]
|
||||||
|
more_to_read.extend(new_paths)
|
||||||
|
|
||||||
|
global conf
|
||||||
|
conf = self
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.section[key].strip()
|
||||||
|
|
||||||
|
def __getattr__(self, section):
|
||||||
|
name = section.upper()
|
||||||
|
return self.config[name]
|
||||||
|
|
||||||
|
def get(self, name, fallback=None):
|
||||||
|
result = self.section.get(name, fallback)
|
||||||
|
return result.strip() if result else result
|
||||||
|
|
||||||
|
def _getintlist(self, value):
|
||||||
|
return [int(item.strip()) for item in value.split(',')]
|
||||||
|
|
||||||
|
def _getlist(self, value):
|
||||||
|
return [item.strip() for item in value.split(',')]
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
with open(self.configfile, 'w') as conffile:
|
||||||
|
self.config.write(conffile)
|
||||||
|
|
||||||
|
|
||||||
|
conf = Conf(args.config, 'CROCBOT')
|
||||||
157
src/meta/crocbot.py
Normal file
157
src/meta/crocbot.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from twitchio.authentication import UserTokenPayload
|
||||||
|
from twitchio.ext import commands
|
||||||
|
from twitchio import Scopes, eventsub
|
||||||
|
|
||||||
|
from data import Database
|
||||||
|
from datamodels import BotData, UserAuth, BotChannel
|
||||||
|
|
||||||
|
from .config import Conf
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CrocBot(commands.Bot):
|
||||||
|
def __init__(self, *args, config: Conf, dbconn: Database, **kwargs):
|
||||||
|
kwargs.setdefault('client_id', config.bot['client_id'])
|
||||||
|
kwargs.setdefault('client_secret', config.bot['client_secret'])
|
||||||
|
kwargs.setdefault('bot_id', config.bot['bot_id'])
|
||||||
|
kwargs.setdefault('prefix', config.bot['prefix'])
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.dbconn = dbconn
|
||||||
|
self.data: BotData = dbconn.load_registry(BotData())
|
||||||
|
|
||||||
|
self.joined: dict[str, BotChannel] = {}
|
||||||
|
|
||||||
|
async def event_ready(self):
|
||||||
|
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
|
logger.info("Logged in as %s", self.bot_id)
|
||||||
|
|
||||||
|
async def setup_hook(self):
|
||||||
|
await self.data.init()
|
||||||
|
|
||||||
|
# Get all current bot channels
|
||||||
|
channels = await BotChannel.fetch_where(autojoin=True)
|
||||||
|
|
||||||
|
# Join the channels
|
||||||
|
await self.join_channels(*channels)
|
||||||
|
|
||||||
|
# Build bot account's own url
|
||||||
|
scopes = Scopes((
|
||||||
|
Scopes.user_read_chat,
|
||||||
|
Scopes.user_write_chat,
|
||||||
|
Scopes.user_bot,
|
||||||
|
Scopes.channel_read_redemptions,
|
||||||
|
Scopes.channel_manage_redemptions,
|
||||||
|
Scopes.channel_bot,
|
||||||
|
))
|
||||||
|
url = self.get_auth_url(scopes)
|
||||||
|
logger.info("Bot account authorisation url: %s", url)
|
||||||
|
|
||||||
|
# Build everyone else's url
|
||||||
|
scopes = Scopes((
|
||||||
|
Scopes.channel_bot,
|
||||||
|
Scopes.channel_read_redemptions,
|
||||||
|
Scopes.channel_manage_redemptions,
|
||||||
|
))
|
||||||
|
url = self.get_auth_url(scopes)
|
||||||
|
logger.info("User account authorisation url: %s", url)
|
||||||
|
|
||||||
|
logger.info("Finished setup")
|
||||||
|
|
||||||
|
def get_auth_url(self, scopes: Optional[Scopes] = None):
|
||||||
|
if scopes is None:
|
||||||
|
scopes = Scopes((Scopes.channel_bot,))
|
||||||
|
|
||||||
|
url = self._adapter.get_authorization_url(scopes=scopes)
|
||||||
|
return url
|
||||||
|
|
||||||
|
async def join_channels(self, *channels: BotChannel):
|
||||||
|
"""
|
||||||
|
Register webhook subscriptions to the given channel(s).
|
||||||
|
"""
|
||||||
|
# TODO: If channels are already joined, unsubscribe
|
||||||
|
# TODO: Determine (or switch?) whether to use webhook or websocket
|
||||||
|
for channel in channels:
|
||||||
|
sub = None
|
||||||
|
try:
|
||||||
|
sub = eventsub.ChatMessageSubscription(
|
||||||
|
broadcaster_user_id=channel.userid,
|
||||||
|
user_id=self.bot_id,
|
||||||
|
)
|
||||||
|
resp = await self.subscribe_websocket(sub)
|
||||||
|
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||||
|
if channel.listen_redeems:
|
||||||
|
sub = eventsub.ChannelPointsRedeemAddSubscription(
|
||||||
|
broadcaster_user_id=channel.userid
|
||||||
|
)
|
||||||
|
resp = await self.subscribe_websocket(sub, as_bot=False, token_for=channel.userid)
|
||||||
|
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
||||||
|
self.joined[channel.userid] = channel
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
|
||||||
|
|
||||||
|
|
||||||
|
async def event_oauth_authorized(self, payload: UserTokenPayload):
|
||||||
|
logger.debug("Oauth flow authorization with payload %s", repr(payload))
|
||||||
|
# Save the token and scopes and update internal authorisations
|
||||||
|
resp = await self.add_token(payload.access_token, payload.refresh_token)
|
||||||
|
if resp.user_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Oauth flow recieved with no user_id. Payload was: %s",
|
||||||
|
repr(payload)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the scopes authorised included channel:bot, ensure a BotChannel exists
|
||||||
|
# And join it if needed
|
||||||
|
if Scopes.channel_bot.value in resp.scopes:
|
||||||
|
bot_channel = await BotChannel.fetch_or_create(
|
||||||
|
resp.user_id,
|
||||||
|
autojoin=True,
|
||||||
|
)
|
||||||
|
if bot_channel.autojoin:
|
||||||
|
await self.join_channels(bot_channel)
|
||||||
|
|
||||||
|
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
|
||||||
|
|
||||||
|
async def add_token(self, token: str, refresh: str):
|
||||||
|
# Update the tokens in internal cache
|
||||||
|
# This also validates the token
|
||||||
|
# And hopefully gets the userid and scopes
|
||||||
|
resp = await super().add_token(token, refresh)
|
||||||
|
if resp.user_id is None:
|
||||||
|
logger.warning(
|
||||||
|
"Added a token with no user_id. Response was: %s",
|
||||||
|
repr(resp)
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
userid = resp.user_id
|
||||||
|
new_scopes = resp.scopes
|
||||||
|
|
||||||
|
# Save the token and scopes to data
|
||||||
|
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
||||||
|
async with self.dbconn.connection() as conn:
|
||||||
|
self.dbconn.conn = conn
|
||||||
|
async with conn.transaction():
|
||||||
|
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||||
|
if row.token != token or row.refresh_token != refresh:
|
||||||
|
await row.update(token=token, refresh_token=refresh)
|
||||||
|
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||||
|
await self.data.user_auth_scopes.insert_many(
|
||||||
|
('userid', 'scope'),
|
||||||
|
*((userid, scope) for scope in new_scopes)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def load_tokens(self, path: str | None = None):
|
||||||
|
for row in await UserAuth.fetch_where():
|
||||||
|
await self.add_token(row.token, row.refresh_token)
|
||||||
466
src/meta/logger.py
Normal file
466
src/meta/logger.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
from logging.handlers import QueueListener, QueueHandler
|
||||||
|
import queue
|
||||||
|
import multiprocessing
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import StringIO
|
||||||
|
from functools import wraps
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .config import conf
|
||||||
|
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||||
|
|
||||||
|
|
||||||
|
log_logger = logging.getLogger(__name__)
|
||||||
|
log_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||||
|
log_action_stack: ContextVar[tuple[str, ...]] = ContextVar('logging_action_stack', default=())
|
||||||
|
log_app: ContextVar[str] = ContextVar('logging_shard', default="CROCBOT")
|
||||||
|
|
||||||
|
# TODO merge into TwithIO context
|
||||||
|
context: ContextVar[str] = ContextVar('context', default=None)
|
||||||
|
|
||||||
|
def set_logging_context(
|
||||||
|
context: Optional[str] = None,
|
||||||
|
action: Optional[str] = None,
|
||||||
|
stack: Optional[tuple[str, ...]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Statically set the logging context variables to the given values.
|
||||||
|
|
||||||
|
If `action` is given, pushes it onto the `log_action_stack`.
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def logging_context(context=None, action=None, stack=None):
|
||||||
|
"""
|
||||||
|
Context manager for executing a block of code in a given logging context.
|
||||||
|
|
||||||
|
This context manager should only be used around synchronous code.
|
||||||
|
This is because async code *may* get cancelled or externally garbage collected,
|
||||||
|
in which case the finally block will be executed in the wrong context.
|
||||||
|
See https://github.com/python/cpython/issues/93740
|
||||||
|
This can be refactored nicely if this gets merged:
|
||||||
|
https://github.com/python/cpython/pull/99634
|
||||||
|
|
||||||
|
(It will not necessarily break on async code,
|
||||||
|
if the async code can be guaranteed to clean up in its own context.)
|
||||||
|
"""
|
||||||
|
if context is not None:
|
||||||
|
oldcontext = log_context.get()
|
||||||
|
log_context.set(context)
|
||||||
|
if action is not None or stack is not None:
|
||||||
|
astack = log_action_stack.get()
|
||||||
|
newstack = stack if stack is not None else astack
|
||||||
|
if action is not None:
|
||||||
|
newstack = (*newstack, action)
|
||||||
|
log_action_stack.set(newstack)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if context is not None:
|
||||||
|
log_context.set(oldcontext)
|
||||||
|
if stack is not None or action is not None:
|
||||||
|
log_action_stack.set(astack)
|
||||||
|
|
||||||
|
|
||||||
|
def with_log_ctx(isolate=True, **kwargs):
|
||||||
|
"""
|
||||||
|
Execute a coroutine inside a given logging context.
|
||||||
|
|
||||||
|
If `isolate` is true, ensures that context does not leak
|
||||||
|
outside the coroutine.
|
||||||
|
|
||||||
|
If `isolate` is false, just statically set the context,
|
||||||
|
which will leak unless the coroutine is
|
||||||
|
called in an externally copied context.
|
||||||
|
"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
if isolate:
|
||||||
|
with logging_context(**kwargs):
|
||||||
|
# Task creation will synchronously copy the context
|
||||||
|
# This is gc safe
|
||||||
|
name = kwargs.get('action', f"log-wrapped-{func.__name__}")
|
||||||
|
task = asyncio.create_task(func(*w_args, **w_kwargs), name=name)
|
||||||
|
return await task
|
||||||
|
else:
|
||||||
|
# This will leak context changes
|
||||||
|
set_logging_context(**kwargs)
|
||||||
|
return await func(*w_args, **w_kwargs)
|
||||||
|
return wrapped
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
|
log_wrap = with_log_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def persist_task(task_collection: set):
|
||||||
|
"""
|
||||||
|
Coroutine decorator that ensures the coroutine is scheduled as a task
|
||||||
|
and added to the given task_collection for strong reference
|
||||||
|
when it is called.
|
||||||
|
|
||||||
|
This is just a hack to handle discord.py events potentially
|
||||||
|
being unexpectedly garbage collected.
|
||||||
|
|
||||||
|
Since this also implicitly schedules the coroutine as a task when it is called,
|
||||||
|
the coroutine will also be run inside an isolated context.
|
||||||
|
"""
|
||||||
|
def decorator(coro):
|
||||||
|
@wraps(coro)
|
||||||
|
async def wrapped(*w_args, **w_kwargs):
|
||||||
|
name = f"persisted-{coro.__name__}"
|
||||||
|
task = asyncio.create_task(coro(*w_args, **w_kwargs), name=name)
|
||||||
|
task_collection.add(task)
|
||||||
|
task.add_done_callback(lambda f: task_collection.discard(f))
|
||||||
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
RESET_SEQ = "\033[0m"
|
||||||
|
COLOR_SEQ = "\033[3%dm"
|
||||||
|
BOLD_SEQ = "\033[1m"
|
||||||
|
"]]]"
|
||||||
|
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||||
|
|
||||||
|
|
||||||
|
def colour_escape(fmt: str) -> str:
|
||||||
|
cmap = {
|
||||||
|
'%(black)': COLOR_SEQ % BLACK,
|
||||||
|
'%(red)': COLOR_SEQ % RED,
|
||||||
|
'%(green)': COLOR_SEQ % GREEN,
|
||||||
|
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||||
|
'%(blue)': COLOR_SEQ % BLUE,
|
||||||
|
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||||
|
'%(cyan)': COLOR_SEQ % CYAN,
|
||||||
|
'%(white)': COLOR_SEQ % WHITE,
|
||||||
|
'%(reset)': RESET_SEQ,
|
||||||
|
'%(bold)': BOLD_SEQ,
|
||||||
|
}
|
||||||
|
for key, value in cmap.items():
|
||||||
|
fmt = fmt.replace(key, value)
|
||||||
|
return fmt
|
||||||
|
|
||||||
|
|
||||||
|
log_format = ('%(green)%(asctime)-19s%(reset)|%(red)%(levelname)-8s%(reset)|' +
|
||||||
|
'%(cyan)%(app)-15s%(reset)|' +
|
||||||
|
'%(cyan)%(context)-24s%(reset)|' +
|
||||||
|
'%(cyan)%(actionstr)-22s%(reset)|' +
|
||||||
|
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||||
|
' %(white)%(message)s%(ctxstr)s%(reset)')
|
||||||
|
log_format = colour_escape(log_format)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup the logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_fmt = logging.Formatter(
|
||||||
|
fmt=log_format,
|
||||||
|
# datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanFilter(logging.Filter):
|
||||||
|
def __init__(self, exclusive_maximum, name=""):
|
||||||
|
super(LessThanFilter, self).__init__(name)
|
||||||
|
self.max_level = exclusive_maximum
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.levelno < self.max_level else 0
|
||||||
|
|
||||||
|
class ExactLevelFilter(logging.Filter):
|
||||||
|
def __init__(self, target_level, name=""):
|
||||||
|
super().__init__(name)
|
||||||
|
self.target_level = target_level
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
return (record.levelno == self.target_level)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadFilter(logging.Filter):
|
||||||
|
def __init__(self, thread_name):
|
||||||
|
super().__init__("")
|
||||||
|
self.thread = thread_name
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.threadName == self.thread else 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInjection(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# These guards are to allow override through _extra
|
||||||
|
# And to ensure the injection is idempotent
|
||||||
|
if not hasattr(record, 'context'):
|
||||||
|
record.context = log_context.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'actionstr'):
|
||||||
|
action_stack = log_action_stack.get()
|
||||||
|
if hasattr(record, 'action'):
|
||||||
|
action_stack = (*action_stack, record.action)
|
||||||
|
if action_stack:
|
||||||
|
record.actionstr = ' ➔ '.join(action_stack)
|
||||||
|
else:
|
||||||
|
record.actionstr = "Unknown Action"
|
||||||
|
|
||||||
|
if not hasattr(record, 'app'):
|
||||||
|
record.app = log_app.get()
|
||||||
|
|
||||||
|
if not hasattr(record, 'ctx'):
|
||||||
|
if ctx := context.get():
|
||||||
|
record.ctx = repr(ctx)
|
||||||
|
else:
|
||||||
|
record.ctx = None
|
||||||
|
|
||||||
|
if getattr(record, 'with_ctx', False) and record.ctx:
|
||||||
|
record.ctxstr = '\n' + record.ctx
|
||||||
|
else:
|
||||||
|
record.ctxstr = ""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||||
|
logging_handler_out.setLevel(logging.DEBUG)
|
||||||
|
logging_handler_out.setFormatter(log_fmt)
|
||||||
|
logging_handler_out.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_out)
|
||||||
|
log_logger.addHandler(logging_handler_out)
|
||||||
|
|
||||||
|
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||||
|
logging_handler_err.setLevel(logging.WARNING)
|
||||||
|
logging_handler_err.setFormatter(log_fmt)
|
||||||
|
logging_handler_err.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_err)
|
||||||
|
log_logger.addHandler(logging_handler_err)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalQueueHandler(QueueHandler):
|
||||||
|
def _emit(self, record: logging.LogRecord) -> None:
|
||||||
|
# Removed the call to self.prepare(), handle task cancellation
|
||||||
|
try:
|
||||||
|
self.enqueue(record)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class WebHookHandler(logging.StreamHandler):
|
||||||
|
def __init__(self, webhook_url, prefix="", batch=True, loop=None):
|
||||||
|
super().__init__()
|
||||||
|
self.webhook_url = webhook_url
|
||||||
|
self.prefix = prefix
|
||||||
|
self.batched = ""
|
||||||
|
self.batch = batch
|
||||||
|
self.loop = loop
|
||||||
|
self.batch_delay = 10
|
||||||
|
self.batch_task = None
|
||||||
|
self.last_batched = None
|
||||||
|
self.waiting = []
|
||||||
|
|
||||||
|
self.bucket = Bucket(20, 40)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
self.session = None
|
||||||
|
self.webhook = None
|
||||||
|
|
||||||
|
def get_loop(self):
|
||||||
|
if self.loop is None:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
return self.loop
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
self.format(record)
|
||||||
|
self.get_loop().call_soon_threadsafe(self._post, record)
|
||||||
|
|
||||||
|
def _post(self, record):
|
||||||
|
if self.session is None:
|
||||||
|
self.setup()
|
||||||
|
asyncio.create_task(self.post(record))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
self.webhook = Webhook.from_url(self.webhook_url, session=self.session)
|
||||||
|
|
||||||
|
async def post(self, record):
|
||||||
|
if record.context == 'Webhook Logger':
|
||||||
|
# Don't livelog livelog errors
|
||||||
|
# Otherwise we recurse and Cloudflare hates us
|
||||||
|
return
|
||||||
|
log_context.set("Webhook Logger")
|
||||||
|
log_action_stack.set(("Logging",))
|
||||||
|
log_app.set(record.app)
|
||||||
|
|
||||||
|
try:
|
||||||
|
timestamp = utc_now().strftime("%d/%m/%Y, %H:%M:%S")
|
||||||
|
header = f"[{record.asctime}][{record.levelname}][{record.app}][{record.actionstr}] <{record.context}>"
|
||||||
|
context = f"\n# Context: {record.ctx}" if record.ctx else ""
|
||||||
|
message = f"{header}\n{record.msg}{context}"
|
||||||
|
|
||||||
|
if len(message) > 1900:
|
||||||
|
as_file = True
|
||||||
|
else:
|
||||||
|
as_file = False
|
||||||
|
message = "```md\n{}\n```".format(message)
|
||||||
|
|
||||||
|
# Post the log message(s)
|
||||||
|
if self.batch:
|
||||||
|
if len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
else:
|
||||||
|
self.batched += message
|
||||||
|
if len(self.batched) + len(message) > 1500:
|
||||||
|
await self._send_batched_now()
|
||||||
|
else:
|
||||||
|
asyncio.create_task(self._schedule_batched())
|
||||||
|
else:
|
||||||
|
await self._send(message, as_file=as_file)
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while logging to webhook: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _schedule_batched(self):
|
||||||
|
if self.batch_task is not None and not (self.batch_task.done() or self.batch_task.cancelled()):
|
||||||
|
# noop, don't reschedule if it is already scheduled
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.batch_task = asyncio.create_task(asyncio.sleep(self.batch_delay))
|
||||||
|
await self.batch_task
|
||||||
|
await self._send_batched()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Unexpected error occurred while scheduling batched webhook log: {repr(ex)}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def _send_batched_now(self):
|
||||||
|
if self.batch_task is not None and not self.batch_task.done():
|
||||||
|
self.batch_task.cancel()
|
||||||
|
self.last_batched = None
|
||||||
|
await self._send_batched()
|
||||||
|
|
||||||
|
async def _send_batched(self):
|
||||||
|
if self.batched:
|
||||||
|
batched = self.batched
|
||||||
|
self.batched = ""
|
||||||
|
await self._send(batched)
|
||||||
|
|
||||||
|
async def _send(self, message, as_file=False):
|
||||||
|
try:
|
||||||
|
self.bucket.request()
|
||||||
|
except BucketOverFull:
|
||||||
|
# Silently ignore
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
except BucketFull:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"Ignoring records on live-logger {self.webhook.id}."
|
||||||
|
)
|
||||||
|
self.ignored += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if self.ignored > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Can't keep up! "
|
||||||
|
f"{self.ignored} live logging records on webhook {self.webhook.id} skipped, continuing."
|
||||||
|
)
|
||||||
|
self.ignored = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if as_file or len(message) > 1900:
|
||||||
|
with StringIO(message) as fp:
|
||||||
|
fp.seek(0)
|
||||||
|
await self.webhook.send(
|
||||||
|
f"{self.prefix}\n`{message.splitlines()[0]}`",
|
||||||
|
file=File(fp, filename="logs.md"),
|
||||||
|
username=log_app.get()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.webhook.send(self.prefix + '\n' + message, username=log_app.get())
|
||||||
|
except discord.HTTPException:
|
||||||
|
logger.exception(
|
||||||
|
"Live logger errored. Slowing down live logger."
|
||||||
|
)
|
||||||
|
self.bucket.fill()
|
||||||
|
|
||||||
|
|
||||||
|
handlers = []
|
||||||
|
if webhook := conf.logging['general_log']:
|
||||||
|
handler = WebHookHandler(webhook, batch=True)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['warning_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['warning_prefix'], batch=True)
|
||||||
|
handler.addFilter(ExactLevelFilter(logging.WARNING))
|
||||||
|
handler.setLevel(logging.WARNING)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['error_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['error_prefix'], batch=True)
|
||||||
|
handler.setLevel(logging.ERROR)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
if webhook := conf.logging['critical_log']:
|
||||||
|
handler = WebHookHandler(webhook, prefix=conf.logging['critical_prefix'], batch=False)
|
||||||
|
handler.setLevel(logging.CRITICAL)
|
||||||
|
handlers.append(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def make_queue_handler(queue):
|
||||||
|
qhandler = QueueHandler(queue)
|
||||||
|
qhandler.setLevel(logging.INFO)
|
||||||
|
qhandler.addFilter(ContextInjection())
|
||||||
|
return qhandler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_main_logger(multiprocess=False):
|
||||||
|
q = multiprocessing.Queue() if multiprocess else queue.SimpleQueue()
|
||||||
|
if handlers:
|
||||||
|
# First create a separate loop to run the handlers on
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def run_loop(loop):
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_forever()
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop_thread = threading.Thread(target=lambda: run_loop(loop))
|
||||||
|
loop_thread.daemon = True
|
||||||
|
loop_thread.start()
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
handler.loop = loop
|
||||||
|
|
||||||
|
qhandler = make_queue_handler(q)
|
||||||
|
# qhandler.addFilter(ThreadFilter('MainThread'))
|
||||||
|
logger.addHandler(qhandler)
|
||||||
|
|
||||||
|
listener = QueueListener(
|
||||||
|
q, *handlers, respect_handler_level=True
|
||||||
|
)
|
||||||
|
listener.start()
|
||||||
|
return q
|
||||||
15
src/modules/__init__.py
Normal file
15
src/modules/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
|
||||||
|
this_package = 'modules'
|
||||||
|
|
||||||
|
active = [
|
||||||
|
'.vstudio',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(bot):
|
||||||
|
for ext in active:
|
||||||
|
bot.load_module(this_package + ext)
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
for ext in active:
|
||||||
|
await bot.load_module(this_package + ext)
|
||||||
68
src/sockets.py
Normal file
68
src/sockets.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(ABC):
|
||||||
|
"""
|
||||||
|
A channel is a stateful connection handler for a group of connected websockets.
|
||||||
|
"""
|
||||||
|
name = "Root Channel"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.connections = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty(self):
|
||||||
|
return not self.connections
|
||||||
|
|
||||||
|
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
|
||||||
|
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
|
||||||
|
self.connections.add(websocket)
|
||||||
|
|
||||||
|
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
|
||||||
|
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
|
||||||
|
self.connections.discard(websocket)
|
||||||
|
|
||||||
|
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send_event(self, event, websocket=None):
|
||||||
|
message = json.dumps(event)
|
||||||
|
if not websocket:
|
||||||
|
for ws in self.connections:
|
||||||
|
await ws.send(message)
|
||||||
|
else:
|
||||||
|
await websocket.send(message)
|
||||||
|
|
||||||
|
channels = {}
|
||||||
|
|
||||||
|
def register_channel(name, channel: Channel):
|
||||||
|
channels[name] = channel
|
||||||
|
|
||||||
|
|
||||||
|
async def root_handler(websocket: websockets.WebSocketServerProtocol):
|
||||||
|
message = await websocket.recv()
|
||||||
|
event = json.loads(message)
|
||||||
|
|
||||||
|
if event.get('type', None) != 'init':
|
||||||
|
raise ValueError("Received Websocket connection with no init.")
|
||||||
|
|
||||||
|
if (channel_name := event.get('channel', None)) not in channels:
|
||||||
|
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
|
||||||
|
channel = channels[channel_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await channel.on_connection(websocket, event)
|
||||||
|
async for message in websocket:
|
||||||
|
await channel.handle_message(websocket, message)
|
||||||
|
finally:
|
||||||
|
await channel.del_connection(websocket)
|
||||||
88
src/utils/lib.py
Normal file
88
src/utils/lib.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import re
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
|
||||||
|
def strfdelta(delta: dt.timedelta, sec=False, minutes=True, short=False) -> str:
|
||||||
|
"""
|
||||||
|
Convert a datetime.timedelta object into an easily readable duration string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
delta: datetime.timedelta
|
||||||
|
The timedelta object to convert into a readable string.
|
||||||
|
sec: bool
|
||||||
|
Whether to include the seconds from the timedelta object in the string.
|
||||||
|
minutes: bool
|
||||||
|
Whether to include the minutes from the timedelta object in the string.
|
||||||
|
short: bool
|
||||||
|
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||||
|
|
||||||
|
Returns: str
|
||||||
|
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||||
|
Time units will be abbreviated if short was set to True.
|
||||||
|
"""
|
||||||
|
output = [[delta.days, 'd' if short else ' day'],
|
||||||
|
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||||
|
if minutes:
|
||||||
|
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||||
|
if sec:
|
||||||
|
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||||
|
for i in range(len(output)):
|
||||||
|
if output[i][0] != 1 and not short:
|
||||||
|
output[i][1] += 's' # type: ignore
|
||||||
|
reply_msg = []
|
||||||
|
if output[0][0] != 0:
|
||||||
|
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||||
|
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||||
|
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||||
|
for i in range(2, len(output) - 1):
|
||||||
|
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||||
|
if not short and reply_msg:
|
||||||
|
reply_msg.append("and ")
|
||||||
|
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||||
|
return "".join(reply_msg)
|
||||||
|
|
||||||
|
def utc_now() -> dt.datetime:
|
||||||
|
"""
|
||||||
|
Return the current timezone-aware utc timestamp.
|
||||||
|
"""
|
||||||
|
return dt.datetime.utcnow().replace(tzinfo=dt.timezone.utc)
|
||||||
|
|
||||||
|
def replace_multiple(format_string, mapping):
|
||||||
|
"""
|
||||||
|
Subsistutes the keys from the format_dict with their corresponding values.
|
||||||
|
|
||||||
|
Substitution is non-chained, and done in a single pass via regex.
|
||||||
|
"""
|
||||||
|
if not mapping:
|
||||||
|
raise ValueError("Empty mapping passed.")
|
||||||
|
|
||||||
|
keys = list(mapping.keys())
|
||||||
|
pattern = '|'.join(f"({key})" for key in keys)
|
||||||
|
string = re.sub(pattern, lambda match: str(mapping[keys[match.lastindex - 1]]), format_string)
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dur(time_str):
|
||||||
|
"""
|
||||||
|
Parses a user provided time duration string into a number of seconds.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
time_str: str
|
||||||
|
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||||
|
|
||||||
|
Returns: int
|
||||||
|
The number of seconds the duration represents.
|
||||||
|
"""
|
||||||
|
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||||
|
'h': lambda x: x * 60 * 60,
|
||||||
|
'm': lambda x: x * 60,
|
||||||
|
's': lambda x: x}
|
||||||
|
time_str = time_str.strip(" ,")
|
||||||
|
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||||
|
seconds = 0
|
||||||
|
for bit in found:
|
||||||
|
if bit[1] in funcs:
|
||||||
|
seconds += funcs[bit[1]](int(bit[0]))
|
||||||
|
return seconds
|
||||||
166
src/utils/ratelimits.py
Normal file
166
src/utils/ratelimits.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BucketFull(Exception):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is already full
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BucketOverFull(BucketFull):
|
||||||
|
"""
|
||||||
|
Throw when a requested Bucket is overfull
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full', '_wait_lock')
|
||||||
|
|
||||||
|
def __init__(self, max_level, empty_time):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
self.leak_rate = max_level / empty_time
|
||||||
|
|
||||||
|
self._level = 0
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
self._last_full = False
|
||||||
|
self._wait_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether the bucket is 'full',
|
||||||
|
that is, whether an immediate request against the bucket will raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
self._leak()
|
||||||
|
return self._level + 1 > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def overfull(self):
|
||||||
|
self._leak()
|
||||||
|
return self._level > self.max_level
|
||||||
|
|
||||||
|
@property
|
||||||
|
def delay(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level + 1 > self.max_level:
|
||||||
|
delay = (self._level + 1 - self.max_level) * self.leak_rate
|
||||||
|
else:
|
||||||
|
delay = 0
|
||||||
|
return delay
|
||||||
|
|
||||||
|
def _leak(self):
|
||||||
|
if self._level:
|
||||||
|
elapsed = time.monotonic() - self._last_checked
|
||||||
|
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||||
|
|
||||||
|
self._last_checked = time.monotonic()
|
||||||
|
|
||||||
|
def request(self):
|
||||||
|
self._leak()
|
||||||
|
if self._level > self.max_level:
|
||||||
|
raise BucketOverFull
|
||||||
|
elif self._level == self.max_level:
|
||||||
|
self._level += 1
|
||||||
|
if self._last_full:
|
||||||
|
raise BucketOverFull
|
||||||
|
else:
|
||||||
|
self._last_full = True
|
||||||
|
raise BucketFull
|
||||||
|
else:
|
||||||
|
self._last_full = False
|
||||||
|
self._level += 1
|
||||||
|
|
||||||
|
def fill(self):
|
||||||
|
self._leak()
|
||||||
|
self._level = max(self._level, self.max_level + 1)
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
"""
|
||||||
|
Wait until the bucket has room.
|
||||||
|
|
||||||
|
Guarantees that a `request` directly afterwards will not raise `BucketFull`.
|
||||||
|
"""
|
||||||
|
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||||
|
# Otherwise multiple waiters will have the same delay,
|
||||||
|
# and race for the wakeup after sleep.
|
||||||
|
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||||
|
async with self._wait_lock:
|
||||||
|
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||||
|
# or a synchronous request overflows the bucket while we are waiting.
|
||||||
|
while self.full:
|
||||||
|
await asyncio.sleep(self.delay)
|
||||||
|
|
||||||
|
async def wrapped(self, coro):
|
||||||
|
await self.wait()
|
||||||
|
self.request()
|
||||||
|
await coro
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||||
|
self.max_level = max_level
|
||||||
|
self.empty_time = empty_time
|
||||||
|
|
||||||
|
self.error = error or "Too many requests, please slow down!"
|
||||||
|
self.buckets = cache
|
||||||
|
|
||||||
|
def request_for(self, key):
|
||||||
|
if not (bucket := self.buckets.get(key, None)):
|
||||||
|
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||||
|
|
||||||
|
bucket.request()
|
||||||
|
|
||||||
|
def ward(self, member=True, key=None):
|
||||||
|
"""
|
||||||
|
Command ratelimit decorator.
|
||||||
|
"""
|
||||||
|
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(ctx, *args, **kwargs):
|
||||||
|
self.request_for(key(ctx))
|
||||||
|
return await func(ctx, *args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def limit_concurrency(aws, limit):
|
||||||
|
"""
|
||||||
|
Run provided awaitables concurrently,
|
||||||
|
ensuring that no more than `limit` are running at once.
|
||||||
|
"""
|
||||||
|
aws = iter(aws)
|
||||||
|
aws_ended = False
|
||||||
|
pending = set()
|
||||||
|
count = 0
|
||||||
|
logger.debug("Starting limited concurrency executor")
|
||||||
|
|
||||||
|
while pending or not aws_ended:
|
||||||
|
while len(pending) < limit and not aws_ended:
|
||||||
|
aw = next(aws, None)
|
||||||
|
if aw is None:
|
||||||
|
aws_ended = True
|
||||||
|
else:
|
||||||
|
pending.add(asyncio.create_task(aw))
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
break
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
pending, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
while done:
|
||||||
|
yield done.pop()
|
||||||
|
logger.debug(f"Completed {count} tasks")
|
||||||
Reference in New Issue
Block a user