Data system refactor and core redesign for public.
Redesigned data and core systems to be public-capable.
This commit is contained in:
70
bot/LionModule.py
Normal file
70
bot/LionModule.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from cmdClient import Command, Module
|
||||||
|
|
||||||
|
from meta import log
|
||||||
|
|
||||||
|
|
||||||
|
class LionCommand(Command):
|
||||||
|
"""
|
||||||
|
Subclass to allow easy attachment of custom hooks and structure to commands.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LionModule(Module):
|
||||||
|
"""
|
||||||
|
Custom module for Lion systems.
|
||||||
|
|
||||||
|
Adds command wrappers and various event handlers.
|
||||||
|
"""
|
||||||
|
name = "Base Lion Module"
|
||||||
|
|
||||||
|
def __init__(self, name, baseCommand=LionCommand):
|
||||||
|
super().__init__(name, baseCommand)
|
||||||
|
|
||||||
|
self.unload_tasks = []
|
||||||
|
|
||||||
|
def unload_task(self, func):
|
||||||
|
"""
|
||||||
|
Decorator adding an unload task for deactivating the module.
|
||||||
|
Should sync unsaved transactions and finalise user interaction.
|
||||||
|
If possible, should also remove attached data and handlers.
|
||||||
|
"""
|
||||||
|
self.unload_tasks.append(func)
|
||||||
|
log("Adding unload task '{}'.".format(func.__name__), context=self.name)
|
||||||
|
return func
|
||||||
|
|
||||||
|
async def unload(self, client):
|
||||||
|
"""
|
||||||
|
Run the unloading tasks.
|
||||||
|
"""
|
||||||
|
log("Unloading module.", context=self.name, post=False)
|
||||||
|
for task in self.unload_tasks:
|
||||||
|
log("Running unload task '{}'".format(task.__name__),
|
||||||
|
context=self.name, post=False)
|
||||||
|
await task(client)
|
||||||
|
|
||||||
|
async def launch(self, client):
|
||||||
|
"""
|
||||||
|
Launch hook.
|
||||||
|
Executed in `client.on_ready`.
|
||||||
|
Must set `ready` to `True`, otherwise all commands will hang.
|
||||||
|
Overrides the parent launcher to not post the log as a discord message.
|
||||||
|
"""
|
||||||
|
if not self.ready:
|
||||||
|
log("Running launch tasks.", context=self.name, post=False)
|
||||||
|
|
||||||
|
for task in self.launch_tasks:
|
||||||
|
log("Running launch task '{}'.".format(task.__name__),
|
||||||
|
context=self.name, post=False)
|
||||||
|
await task(client)
|
||||||
|
|
||||||
|
self.ready = True
|
||||||
|
else:
|
||||||
|
log("Already launched, skipping launch.", context=self.name, post=False)
|
||||||
|
|
||||||
|
async def pre_command(self, ctx):
|
||||||
|
"""
|
||||||
|
Lion pre-command hook.
|
||||||
|
"""
|
||||||
|
# TODO: Add blacklist and auto-fetch of lion here.
|
||||||
|
...
|
||||||
@@ -1,2 +1,4 @@
|
|||||||
from . import tables
|
from . import data # noqa
|
||||||
from .user import User
|
|
||||||
|
from .module import module
|
||||||
|
from .lion import Lion # noqa
|
||||||
|
|||||||
81
bot/core/data.py
Normal file
81
bot/core/data.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from psycopg2.extras import execute_values
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from data import RowTable, Table
|
||||||
|
|
||||||
|
|
||||||
|
meta = RowTable(
|
||||||
|
'AppData',
|
||||||
|
('appid', 'last_study_badge_scan'),
|
||||||
|
'appid',
|
||||||
|
attach_as='meta',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
user_config = RowTable(
|
||||||
|
'user_config',
|
||||||
|
('userid', 'timezone'),
|
||||||
|
'userid',
|
||||||
|
cache=TTLCache(5000, ttl=60*5)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@user_config.save_query
|
||||||
|
def add_pending(pending):
|
||||||
|
"""
|
||||||
|
pending:
|
||||||
|
List of tuples of the form `(userid, pending_coins, pending_time)`.
|
||||||
|
"""
|
||||||
|
with lions.conn:
|
||||||
|
cursor = lions.conn.cursor()
|
||||||
|
data = execute_values(
|
||||||
|
cursor,
|
||||||
|
"""
|
||||||
|
UPDATE members
|
||||||
|
SET
|
||||||
|
coins = coins + t.coin_diff,
|
||||||
|
tracked_time = tracked_time + t.time_diff
|
||||||
|
FROM
|
||||||
|
(VALUES %s)
|
||||||
|
AS
|
||||||
|
t (guildid, userid, coin_diff, time_diff)
|
||||||
|
WHERE
|
||||||
|
members.guildid = t.guildid
|
||||||
|
AND
|
||||||
|
members.userid = t.userid
|
||||||
|
RETURNING *
|
||||||
|
""",
|
||||||
|
pending,
|
||||||
|
fetch=True
|
||||||
|
)
|
||||||
|
return lions._make_rows(*data)
|
||||||
|
|
||||||
|
|
||||||
|
guild_config = RowTable(
|
||||||
|
'guild_config',
|
||||||
|
('guildid', 'admin_role', 'mod_role', 'event_log_channel',
|
||||||
|
'min_workout_length', 'workout_reward',
|
||||||
|
'max_tasks', 'task_reward', 'task_reward_limit',
|
||||||
|
'study_hourly_reward', 'study_hourly_live_bonus',
|
||||||
|
'study_ban_role', 'max_study_bans'),
|
||||||
|
'guildid',
|
||||||
|
cache=TTLCache(1000, ttl=60*5)
|
||||||
|
)
|
||||||
|
|
||||||
|
unranked_roles = Table('unranked_roles')
|
||||||
|
|
||||||
|
donator_roles = Table('donator_roles')
|
||||||
|
|
||||||
|
|
||||||
|
lions = RowTable(
|
||||||
|
'members',
|
||||||
|
('guildid', 'userid',
|
||||||
|
'tracked_time', 'coins',
|
||||||
|
'workout_count', 'last_workout_start',
|
||||||
|
'last_study_badgeid',
|
||||||
|
'study_ban_count',
|
||||||
|
),
|
||||||
|
('guildid', 'userid'),
|
||||||
|
cache=TTLCache(5000, ttl=60*5),
|
||||||
|
attach_as='lions'
|
||||||
|
)
|
||||||
145
bot/core/lion.py
Normal file
145
bot/core/lion.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import pytz
|
||||||
|
|
||||||
|
from meta import client
|
||||||
|
from data import tables as tb
|
||||||
|
from settings import UserSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Lion:
|
||||||
|
"""
|
||||||
|
Class representing a guild Member.
|
||||||
|
Mostly acts as a transparent interface to the corresponding Row,
|
||||||
|
but also adds some transaction caching logic to `coins` and `tracked_time`.
|
||||||
|
"""
|
||||||
|
__slots__ = ('guildid', 'userid', '_pending_coins', '_pending_time', '_member')
|
||||||
|
|
||||||
|
# Members with pending transactions
|
||||||
|
_pending = {} # userid -> User
|
||||||
|
|
||||||
|
# Lion cache. Currently lions don't expire
|
||||||
|
_lions = {} # (guildid, userid) -> Lion
|
||||||
|
|
||||||
|
def __init__(self, guildid, userid):
|
||||||
|
self.guildid = guildid
|
||||||
|
self.userid = userid
|
||||||
|
|
||||||
|
self._pending_coins = 0
|
||||||
|
self._pending_time = 0
|
||||||
|
|
||||||
|
self._member = None
|
||||||
|
|
||||||
|
self._lions[self.key] = self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fetch(cls, guildid, userid):
|
||||||
|
"""
|
||||||
|
Fetch a Lion with the given member.
|
||||||
|
If they don't exist, creates them.
|
||||||
|
If possible, retrieves the user from the user cache.
|
||||||
|
"""
|
||||||
|
key = (guildid, userid)
|
||||||
|
if key in cls._lions:
|
||||||
|
return cls._lions[key]
|
||||||
|
else:
|
||||||
|
tb.lions.fetch_or_create(key)
|
||||||
|
return cls(guildid, userid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self):
|
||||||
|
return (self.guildid, self.userid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def member(self):
|
||||||
|
"""
|
||||||
|
The discord `Member` corresponding to this user.
|
||||||
|
May be `None` if the member is no longer in the guild or the caches aren't populated.
|
||||||
|
Not guaranteed to be `None` if the member is not in the guild.
|
||||||
|
"""
|
||||||
|
if self._member is None:
|
||||||
|
guild = client.get_guild(self.guildid)
|
||||||
|
if guild:
|
||||||
|
self._member = guild.get_member(self.userid)
|
||||||
|
return self._member
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
"""
|
||||||
|
The Row corresponding to this user.
|
||||||
|
"""
|
||||||
|
return tb.lions.fetch(self.key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def settings(self):
|
||||||
|
"""
|
||||||
|
The UserSettings object for this user.
|
||||||
|
"""
|
||||||
|
return UserSettings(self.userid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time(self):
|
||||||
|
"""
|
||||||
|
Amount of time the user has spent studying, accounting for pending values.
|
||||||
|
"""
|
||||||
|
return int(self.data.tracked_time + self._pending_time)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coins(self):
|
||||||
|
"""
|
||||||
|
Number of coins the user has, accounting for the pending value.
|
||||||
|
"""
|
||||||
|
return int(self.data.coins + self._pending_coins)
|
||||||
|
|
||||||
|
def localize(self, naive_utc_dt):
|
||||||
|
"""
|
||||||
|
Localise the provided naive UTC datetime into the user's timezone.
|
||||||
|
"""
|
||||||
|
timezone = self.settings.timezone.value
|
||||||
|
return naive_utc_dt.replace(tzinfo=pytz.UTC).astimezone(timezone)
|
||||||
|
|
||||||
|
def addCoins(self, amount, flush=True):
|
||||||
|
"""
|
||||||
|
Add coins to the user, optionally store the transaction in pending.
|
||||||
|
"""
|
||||||
|
self._pending_coins += amount
|
||||||
|
self._pending[self.key] = self
|
||||||
|
if flush:
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def addTime(self, amount, flush=True):
|
||||||
|
"""
|
||||||
|
Add time to a user (in seconds), optionally storing the transaction in pending.
|
||||||
|
"""
|
||||||
|
self._pending_time += amount
|
||||||
|
self._pending[self.key] = self
|
||||||
|
if flush:
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
Flush any pending transactions to the database.
|
||||||
|
"""
|
||||||
|
self.sync(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sync(cls, *lions):
|
||||||
|
"""
|
||||||
|
Flush pending transactions to the database.
|
||||||
|
Also refreshes the Row cache for updated lions.
|
||||||
|
"""
|
||||||
|
lions = lions or list(cls._pending.values())
|
||||||
|
|
||||||
|
if lions:
|
||||||
|
# Build userid to pending coin map
|
||||||
|
pending = [
|
||||||
|
(lion.guildid, lion.userid, int(lion._pending_coins), int(lion._pending_time))
|
||||||
|
for lion in lions
|
||||||
|
]
|
||||||
|
|
||||||
|
# Write to database
|
||||||
|
tb.lions.queries.add_pending(pending)
|
||||||
|
|
||||||
|
# Cleanup pending users
|
||||||
|
for lion in lions:
|
||||||
|
lion._pending_coins -= int(lion._pending_coins)
|
||||||
|
lion._pending_time -= int(lion._pending_time)
|
||||||
|
cls._pending.pop(lion.key, None)
|
||||||
36
bot/core/module.py
Normal file
36
bot/core/module.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from meta import client, conf
|
||||||
|
from LionModule import LionModule
|
||||||
|
|
||||||
|
from .lion import Lion
|
||||||
|
|
||||||
|
|
||||||
|
module = LionModule("Core")
|
||||||
|
|
||||||
|
|
||||||
|
async def _lion_sync_loop():
|
||||||
|
while True:
|
||||||
|
while not client.is_ready():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
client.log(
|
||||||
|
"Running lion data sync.",
|
||||||
|
context="CORE",
|
||||||
|
level=logging.DEBUG,
|
||||||
|
post=False
|
||||||
|
)
|
||||||
|
|
||||||
|
Lion.sync()
|
||||||
|
await asyncio.sleep(conf.bot.getint("lion_sync_period"))
|
||||||
|
|
||||||
|
|
||||||
|
@module.launch_task
|
||||||
|
async def launch_lion_sync_loop(client):
|
||||||
|
asyncio.create_task(_lion_sync_loop())
|
||||||
|
|
||||||
|
|
||||||
|
@module.unload_task
|
||||||
|
async def final_lion_sync(client):
|
||||||
|
Lion.sync()
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from psycopg2.extras import execute_values
|
|
||||||
|
|
||||||
from cachetools import TTLCache
|
|
||||||
from data import RowTable, Table
|
|
||||||
|
|
||||||
|
|
||||||
users = RowTable(
|
|
||||||
'lions',
|
|
||||||
('userid', 'tracked_time', 'coins'),
|
|
||||||
'userid',
|
|
||||||
cache=TTLCache(5000, ttl=60*5)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@users.save_query
|
|
||||||
def add_coins(userid_coins):
|
|
||||||
with users.conn:
|
|
||||||
cursor = users.conn.cursor()
|
|
||||||
data = execute_values(
|
|
||||||
cursor,
|
|
||||||
"""
|
|
||||||
UPDATE lions
|
|
||||||
SET coins = coins + t.diff
|
|
||||||
FROM (VALUES %s) AS t (userid, diff)
|
|
||||||
WHERE lions.userid = t.userid
|
|
||||||
RETURNING *
|
|
||||||
""",
|
|
||||||
userid_coins,
|
|
||||||
fetch=True
|
|
||||||
)
|
|
||||||
return users._make_rows(*data)
|
|
||||||
105
bot/core/user.py
105
bot/core/user.py
@@ -1,105 +0,0 @@
|
|||||||
from . import tables as tb
|
|
||||||
from meta import conf, client
|
|
||||||
|
|
||||||
|
|
||||||
class User:
|
|
||||||
"""
|
|
||||||
Class representing a "Lion", i.e. a member of the managed guild.
|
|
||||||
Mostly acts as a transparent interface to the corresponding Row,
|
|
||||||
but also adds some transaction caching logic to `coins`.
|
|
||||||
"""
|
|
||||||
__slots__ = ('userid', '_pending_coins', '_member')
|
|
||||||
|
|
||||||
# Users with pending transactions
|
|
||||||
_pending = {} # userid -> User
|
|
||||||
|
|
||||||
# User cache. Currently users don't expire
|
|
||||||
_users = {} # userid -> User
|
|
||||||
|
|
||||||
def __init__(self, userid):
|
|
||||||
self.userid = userid
|
|
||||||
self._pending_coins = 0
|
|
||||||
|
|
||||||
self._users[self.userid] = self
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fetch(cls, userid):
|
|
||||||
"""
|
|
||||||
Fetch a User with the given userid.
|
|
||||||
If they don't exist, creates them.
|
|
||||||
If possible, retrieves the user from the user cache.
|
|
||||||
"""
|
|
||||||
if userid in cls._users:
|
|
||||||
return cls._users[userid]
|
|
||||||
else:
|
|
||||||
tb.users.fetch_or_create(userid)
|
|
||||||
return cls(userid)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def member(self):
|
|
||||||
"""
|
|
||||||
The discord `Member` corresponding to this user.
|
|
||||||
May be `None` if the member is no longer in the guild or the caches aren't populated.
|
|
||||||
Not guaranteed to be `None` if the member is not in the guild.
|
|
||||||
"""
|
|
||||||
if self._member is None:
|
|
||||||
self._member = client.get_guild(conf.meta.getint('managed_guild_id')).get_member(self.userid)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self):
|
|
||||||
"""
|
|
||||||
The Row corresponding to this user.
|
|
||||||
"""
|
|
||||||
return tb.users.fetch(self.userid)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def time(self):
|
|
||||||
"""
|
|
||||||
Amount of time the user has spent.. studying?
|
|
||||||
"""
|
|
||||||
return self.data.tracked_time
|
|
||||||
|
|
||||||
@property
|
|
||||||
def coins(self):
|
|
||||||
"""
|
|
||||||
Number of coins the user has, accounting for the pending value.
|
|
||||||
"""
|
|
||||||
return self.data.coins + self._pending_coins
|
|
||||||
|
|
||||||
def addCoins(self, amount, flush=True):
|
|
||||||
"""
|
|
||||||
Add coins to the user, optionally store the transaction in pending.
|
|
||||||
"""
|
|
||||||
self._pending_coins += amount
|
|
||||||
if self._pending_coins != 0:
|
|
||||||
self._pending[self.userid] = self
|
|
||||||
else:
|
|
||||||
self._pending.pop(self.userid, None)
|
|
||||||
if flush:
|
|
||||||
self.flush()
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""
|
|
||||||
Flush any pending transactions to the database.
|
|
||||||
"""
|
|
||||||
self.sync(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sync(cls, *users):
|
|
||||||
"""
|
|
||||||
Flush pending transactions to the database.
|
|
||||||
Also refreshes the Row cache for updated users.
|
|
||||||
"""
|
|
||||||
users = users or list(cls._pending.values())
|
|
||||||
|
|
||||||
if users:
|
|
||||||
# Build userid to pending coin map
|
|
||||||
userid_coins = [(user.userid, user._pending_coins) for user in users]
|
|
||||||
|
|
||||||
# Write to database
|
|
||||||
tb.users.queries.add_coins(userid_coins)
|
|
||||||
|
|
||||||
# Cleanup pending users
|
|
||||||
for user in users:
|
|
||||||
user._pending_coins = 0
|
|
||||||
cls._pending.pop(user.userid, None)
|
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
from .data import *
|
from .connection import conn # noqa
|
||||||
# from . import tables
|
from .formatters import UpdateValue, UpdateValueAdd # noqa
|
||||||
# from . import queries
|
from .interfaces import Table, RowTable, Row, tables # noqa
|
||||||
|
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa
|
||||||
|
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa
|
||||||
|
|||||||
59
bot/data/conditions.py
Normal file
59
bot/data/conditions.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from .connection import _replace_char
|
||||||
|
|
||||||
|
|
||||||
|
class Condition:
|
||||||
|
"""
|
||||||
|
ABC representing a selection condition.
|
||||||
|
"""
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def apply(self, key, values, conditions):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class NOT(Condition):
|
||||||
|
__slots__ = ('value',)
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def apply(self, key, values, conditions):
|
||||||
|
item = self.value
|
||||||
|
if isinstance(item, (list, tuple)):
|
||||||
|
if item:
|
||||||
|
conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||||
|
values.extend(item)
|
||||||
|
else:
|
||||||
|
raise ValueError("Cannot check an empty iterable!")
|
||||||
|
else:
|
||||||
|
conditions.append("{}!={}".format(key, _replace_char))
|
||||||
|
values.append(item)
|
||||||
|
|
||||||
|
|
||||||
|
class GEQ(Condition):
|
||||||
|
__slots__ = ('value',)
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def apply(self, key, values, conditions):
|
||||||
|
item = self.value
|
||||||
|
if isinstance(item, (list, tuple)):
|
||||||
|
raise ValueError("Cannot apply GEQ condition to a list!")
|
||||||
|
else:
|
||||||
|
conditions.append("{} >= {}".format(key, _replace_char))
|
||||||
|
values.append(item)
|
||||||
|
|
||||||
|
|
||||||
|
class Constant(Condition):
|
||||||
|
__slots__ = ('value',)
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def apply(self, key, values, conditions):
|
||||||
|
conditions.append("{} {}".format(key, self.value))
|
||||||
|
|
||||||
|
|
||||||
|
NULL = Constant('IS NULL')
|
||||||
|
NOTNULL = Constant('IS NOT NULL')
|
||||||
40
bot/data/connection.py
Normal file
40
bot/data/connection.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import psycopg2 as psy
|
||||||
|
|
||||||
|
from meta import log, conf
|
||||||
|
from constants import DATA_VERSION
|
||||||
|
from .cursor import DictLoggingCursor
|
||||||
|
|
||||||
|
|
||||||
|
# Set up database connection
|
||||||
|
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
|
||||||
|
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
|
||||||
|
|
||||||
|
# Replace char used by the connection for query formatting
|
||||||
|
_replace_char: str = '%s'
|
||||||
|
|
||||||
|
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
|
||||||
|
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
# Check the version matches the required version
|
||||||
|
with conn:
|
||||||
|
log("Checking db version.", "DB_INIT")
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Get last entry in version table, compare against desired version
|
||||||
|
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||||
|
current_version, _, _ = cursor.fetchone()
|
||||||
|
|
||||||
|
if current_version != DATA_VERSION:
|
||||||
|
# Complain
|
||||||
|
raise Exception(
|
||||||
|
("Database version is {}, required version is {}. "
|
||||||
|
"Please migrate database.").format(current_version, DATA_VERSION)
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
log("Established connection.", "DB_INIT")
|
||||||
505
bot/data/data.py
505
bot/data/data.py
@@ -1,505 +0,0 @@
|
|||||||
import logging
|
|
||||||
import contextlib
|
|
||||||
from itertools import chain
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import psycopg2 as psy
|
|
||||||
from cachetools import LRUCache
|
|
||||||
|
|
||||||
from utils.lib import DotDict
|
|
||||||
from meta import log, conf
|
|
||||||
from constants import DATA_VERSION
|
|
||||||
from .custom_cursor import DictLoggingCursor
|
|
||||||
|
|
||||||
|
|
||||||
# Set up database connection
|
|
||||||
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
|
|
||||||
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
|
|
||||||
|
|
||||||
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
|
|
||||||
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
|
|
||||||
|
|
||||||
|
|
||||||
# Check the version matches the required version
|
|
||||||
with conn:
|
|
||||||
log("Checking db version.", "DB_INIT")
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Get last entry in version table, compare against desired version
|
|
||||||
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
|
||||||
current_version, _, _ = cursor.fetchone()
|
|
||||||
|
|
||||||
if current_version != DATA_VERSION:
|
|
||||||
# Complain
|
|
||||||
raise Exception(
|
|
||||||
("Database version is {}, required version is {}. "
|
|
||||||
"Please migrate database.").format(current_version, DATA_VERSION)
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
log("Established connection.", "DB_INIT")
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- Data Interface Classes ---------------
|
|
||||||
class Table:
|
|
||||||
"""
|
|
||||||
Transparent interface to a single table structure in the database.
|
|
||||||
Contains standard methods to access the table.
|
|
||||||
Intended to be subclassed to provide more derivative access for specific tables.
|
|
||||||
"""
|
|
||||||
conn = conn
|
|
||||||
queries = DotDict()
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
def select_where(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return select_where(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def select_one_where(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
rows = self.select_where(*args, **kwargs)
|
|
||||||
return rows[0] if rows else None
|
|
||||||
|
|
||||||
def update_where(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return update_where(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def delete_where(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return delete_where(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def insert(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return insert(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def insert_many(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return insert_many(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def upsert(self, *args, **kwargs):
|
|
||||||
with self.conn:
|
|
||||||
return upsert(self.name, *args, **kwargs)
|
|
||||||
|
|
||||||
def save_query(self, func):
|
|
||||||
"""
|
|
||||||
Decorator to add a saved query to the table.
|
|
||||||
"""
|
|
||||||
self.queries[func.__name__] = func
|
|
||||||
|
|
||||||
|
|
||||||
class Row:
|
|
||||||
__slots__ = ('table', 'data', '_pending')
|
|
||||||
|
|
||||||
conn = conn
|
|
||||||
|
|
||||||
def __init__(self, table, data, *args, **kwargs):
|
|
||||||
super().__setattr__('table', table)
|
|
||||||
self.data = data
|
|
||||||
self._pending = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rowid(self):
|
|
||||||
return self.data[self.table.id_col]
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "Row[{}]({})".format(
|
|
||||||
self.table.name,
|
|
||||||
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
if key in self.table.columns:
|
|
||||||
if self._pending and key in self._pending:
|
|
||||||
return self._pending[key]
|
|
||||||
else:
|
|
||||||
return self.data[key]
|
|
||||||
else:
|
|
||||||
raise AttributeError(key)
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if key in self.table.columns:
|
|
||||||
if self._pending is None:
|
|
||||||
self.update(**{key: value})
|
|
||||||
else:
|
|
||||||
self._pending[key] = value
|
|
||||||
else:
|
|
||||||
super().__setattr__(key, value)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def batch_update(self):
|
|
||||||
if self._pending:
|
|
||||||
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
|
|
||||||
|
|
||||||
self._pending = {}
|
|
||||||
try:
|
|
||||||
yield self._pending
|
|
||||||
finally:
|
|
||||||
self.update(**self._pending)
|
|
||||||
self._pending = None
|
|
||||||
|
|
||||||
def _refresh(self):
|
|
||||||
row = self.table.select_one_where(**{self.table.id_col: self.rowid})
|
|
||||||
if not row:
|
|
||||||
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
|
|
||||||
self.data = row
|
|
||||||
|
|
||||||
def update(self, **values):
|
|
||||||
rows = self.table.update_where(values, **{self.table.id_col: self.rowid})
|
|
||||||
self.data = rows[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _select_where(cls, _extra=None, **conditions):
|
|
||||||
return select_where(cls._table, **conditions)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _insert(cls, **values):
|
|
||||||
return insert(cls._table, **values)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _update_where(cls, values, **conditions):
|
|
||||||
return update_where(cls._table, values, **conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class RowTable(Table):
|
|
||||||
__slots__ = (
|
|
||||||
'name',
|
|
||||||
'columns',
|
|
||||||
'id_col',
|
|
||||||
'row_cache'
|
|
||||||
)
|
|
||||||
|
|
||||||
conn = conn
|
|
||||||
|
|
||||||
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000):
|
|
||||||
self.name = name
|
|
||||||
self.columns = columns
|
|
||||||
self.id_col = id_col
|
|
||||||
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
|
|
||||||
|
|
||||||
# Extend original Table update methods to modify the cached rows
|
|
||||||
def update_where(self, *args, **kwargs):
|
|
||||||
data = super().update_where(*args, **kwargs)
|
|
||||||
if self.row_cache is not None:
|
|
||||||
for data_row in data:
|
|
||||||
cached_row = self.row_cache.get(data_row[self.id_col], None)
|
|
||||||
if cached_row is not None:
|
|
||||||
cached_row.data = data_row
|
|
||||||
return data
|
|
||||||
|
|
||||||
def delete_where(self, *args, **kwargs):
|
|
||||||
data = super().delete_where(*args, **kwargs)
|
|
||||||
if self.row_cache is not None:
|
|
||||||
for data_row in data:
|
|
||||||
self.row_cache.pop(data_row[self.id_col], None)
|
|
||||||
return data
|
|
||||||
|
|
||||||
def upsert(self, *args, **kwargs):
|
|
||||||
data = super().upsert(*args, **kwargs)
|
|
||||||
if self.row_cache is not None:
|
|
||||||
cached_row = self.row_cache.get(data[self.id_col], None)
|
|
||||||
if cached_row is not None:
|
|
||||||
cached_row.data = data
|
|
||||||
return data
|
|
||||||
|
|
||||||
# New methods to fetch and create rows
|
|
||||||
def _make_rows(self, *data_rows):
|
|
||||||
"""
|
|
||||||
Create or retrieve Row objects for each provided data row.
|
|
||||||
If the rows already exist in cache, updates the cached row.
|
|
||||||
"""
|
|
||||||
if self.row_cache is not None:
|
|
||||||
rows = []
|
|
||||||
for data_row in data_rows:
|
|
||||||
rowid = data_row[self.id_col]
|
|
||||||
|
|
||||||
cached_row = self.row_cache.get(rowid, None)
|
|
||||||
if cached_row is not None:
|
|
||||||
cached_row.data = data_row
|
|
||||||
row = cached_row
|
|
||||||
else:
|
|
||||||
row = Row(self, data_row)
|
|
||||||
self.row_cache[rowid] = row
|
|
||||||
rows.append(row)
|
|
||||||
else:
|
|
||||||
rows = [Row(self, data_row) for data_row in data_rows]
|
|
||||||
return rows
|
|
||||||
|
|
||||||
def create_row(self, *args, **kwargs):
|
|
||||||
data = self.insert(*args, **kwargs)
|
|
||||||
return self._make_rows(data)[0]
|
|
||||||
|
|
||||||
def fetch_rows_where(self, *args, **kwargs):
|
|
||||||
# TODO: Handle list of rowids here?
|
|
||||||
data = self.select_where(*args, **kwargs)
|
|
||||||
return self._make_rows(*data)
|
|
||||||
|
|
||||||
def fetch(self, rowid):
|
|
||||||
"""
|
|
||||||
Fetch the row with the given id, retrieving from cache where possible.
|
|
||||||
"""
|
|
||||||
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
|
|
||||||
if row is None:
|
|
||||||
rows = self.fetch_rows_where(**{self.id_col: rowid})
|
|
||||||
row = rows[0] if rows else None
|
|
||||||
return row
|
|
||||||
|
|
||||||
def fetch_or_create(self, rowid=None, **kwargs):
|
|
||||||
"""
|
|
||||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
|
||||||
"""
|
|
||||||
if rowid is not None:
|
|
||||||
row = self.fetch(rowid)
|
|
||||||
else:
|
|
||||||
data = self.select_where(**kwargs)
|
|
||||||
row = self._make_rows(data[0])[0] if data else None
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
creation_kwargs = kwargs
|
|
||||||
if rowid is not None:
|
|
||||||
creation_kwargs[self.id_col] = rowid
|
|
||||||
row = self.create_row(**creation_kwargs)
|
|
||||||
return row
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- Query Builders ---------------
|
|
||||||
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
|
|
||||||
"""
|
|
||||||
Select rows from the given table matching the conditions
|
|
||||||
"""
|
|
||||||
criteria, criteria_values = _format_conditions(conditions)
|
|
||||||
col_str = _format_selectkeys(select_columns)
|
|
||||||
|
|
||||||
if conditions:
|
|
||||||
where_str = "WHERE {}".format(criteria)
|
|
||||||
else:
|
|
||||||
where_str = ""
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
|
|
||||||
criteria_values
|
|
||||||
)
|
|
||||||
return cursor.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def update_where(table, valuedict, cursor=None, **conditions):
|
|
||||||
"""
|
|
||||||
Update rows in the given table matching the conditions
|
|
||||||
"""
|
|
||||||
key_str, key_values = _format_updatestr(valuedict)
|
|
||||||
criteria, criteria_values = _format_conditions(conditions)
|
|
||||||
|
|
||||||
if conditions:
|
|
||||||
where_str = "WHERE {}".format(criteria)
|
|
||||||
else:
|
|
||||||
where_str = ""
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
|
|
||||||
tuple((*key_values, *criteria_values))
|
|
||||||
)
|
|
||||||
return cursor.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def delete_where(table, cursor=None, **conditions):
|
|
||||||
"""
|
|
||||||
Delete rows in the given table matching the conditions
|
|
||||||
"""
|
|
||||||
criteria, criteria_values = _format_conditions(conditions)
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'DELETE FROM {} WHERE {}'.format(table, criteria),
|
|
||||||
criteria_values
|
|
||||||
)
|
|
||||||
return cursor.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def insert(table, cursor=None, allow_replace=False, **values):
|
|
||||||
"""
|
|
||||||
Insert the given values into the table
|
|
||||||
"""
|
|
||||||
keys, values = zip(*values.items())
|
|
||||||
|
|
||||||
key_str = _format_insertkeys(keys)
|
|
||||||
value_str, values = _format_insertvalues(values)
|
|
||||||
|
|
||||||
action = 'REPLACE' if allow_replace else 'INSERT'
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
|
|
||||||
values
|
|
||||||
)
|
|
||||||
return cursor.fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
|
|
||||||
"""
|
|
||||||
Insert all the given values into the table
|
|
||||||
"""
|
|
||||||
key_str = _format_insertkeys(insert_keys)
|
|
||||||
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
|
|
||||||
|
|
||||||
value_str = ", ".join(value_strs)
|
|
||||||
values = tuple(chain(*value_tuples))
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
|
|
||||||
values
|
|
||||||
)
|
|
||||||
return cursor.fetchall()
|
|
||||||
|
|
||||||
|
|
||||||
def upsert(table, constraint, cursor=None, **values):
|
|
||||||
"""
|
|
||||||
Insert or on conflict update.
|
|
||||||
"""
|
|
||||||
valuedict = values
|
|
||||||
keys, values = zip(*values.items())
|
|
||||||
|
|
||||||
key_str = _format_insertkeys(keys)
|
|
||||||
value_str, values = _format_insertvalues(values)
|
|
||||||
update_key_str, update_key_values = _format_updatestr(valuedict)
|
|
||||||
|
|
||||||
if not isinstance(constraint, str):
|
|
||||||
constraint = ", ".join(constraint)
|
|
||||||
|
|
||||||
cursor = cursor or conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
|
||||||
table, key_str, value_str, constraint, update_key_str
|
|
||||||
),
|
|
||||||
tuple((*values, *update_key_values))
|
|
||||||
)
|
|
||||||
return cursor.fetchone()
|
|
||||||
|
|
||||||
|
|
||||||
# --------------- Query Formatting Tools ---------------
|
|
||||||
# Replace char used by the connection for query formatting
|
|
||||||
_replace_char: str = '%s'
|
|
||||||
|
|
||||||
|
|
||||||
class fieldConstants(Enum):
|
|
||||||
"""
|
|
||||||
A collection of database field constants to use for selection conditions.
|
|
||||||
"""
|
|
||||||
NULL = "IS NULL"
|
|
||||||
NOTNULL = "IS NOT NULL"
|
|
||||||
|
|
||||||
|
|
||||||
class _updateField:
|
|
||||||
__slots__ = ()
|
|
||||||
_EMPTY = object() # Return value for `value` indicating no value should be added
|
|
||||||
|
|
||||||
def key_field(self, key):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def value_field(self, key):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateValue(_updateField):
|
|
||||||
__slots__ = ('key_str', 'value')
|
|
||||||
|
|
||||||
def __init__(self, key_str, value=_updateField._EMPTY):
|
|
||||||
self.key_str = key_str
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def key_field(self, key):
|
|
||||||
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
|
|
||||||
|
|
||||||
def value_field(self, key):
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateValueAdd(_updateField):
|
|
||||||
__slots__ = ('value',)
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def key_field(self, key):
|
|
||||||
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
|
|
||||||
|
|
||||||
def value_field(self, key):
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
|
|
||||||
def _format_conditions(conditions):
|
|
||||||
"""
|
|
||||||
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
|
||||||
Supports `IN` type conditionals.
|
|
||||||
"""
|
|
||||||
if not conditions:
|
|
||||||
return ("", tuple())
|
|
||||||
|
|
||||||
values = []
|
|
||||||
conditional_strings = []
|
|
||||||
for key, item in conditions.items():
|
|
||||||
if isinstance(item, (list, tuple)):
|
|
||||||
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
|
||||||
values.extend(item)
|
|
||||||
elif isinstance(item, fieldConstants):
|
|
||||||
conditional_strings.append("{} {}".format(key, item.value))
|
|
||||||
else:
|
|
||||||
conditional_strings.append("{}={}".format(key, _replace_char))
|
|
||||||
values.append(item)
|
|
||||||
|
|
||||||
return (' AND '.join(conditional_strings), values)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_selectkeys(keys):
|
|
||||||
"""
|
|
||||||
Formats a list of keys into a string suitable for `SELECT`.
|
|
||||||
"""
|
|
||||||
if not keys:
|
|
||||||
return "*"
|
|
||||||
else:
|
|
||||||
return ", ".join(keys)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_insertkeys(keys):
|
|
||||||
"""
|
|
||||||
Formats a list of keys into a string suitable for `INSERT`
|
|
||||||
"""
|
|
||||||
if not keys:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
return "({})".format(", ".join(keys))
|
|
||||||
|
|
||||||
|
|
||||||
def _format_insertvalues(values):
|
|
||||||
"""
|
|
||||||
Formats a list of values into a string suitable for `INSERT`
|
|
||||||
"""
|
|
||||||
value_str = "({})".format(", ".join(_replace_char for value in values))
|
|
||||||
return (value_str, values)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_updatestr(valuedict):
|
|
||||||
"""
|
|
||||||
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
|
|
||||||
"""
|
|
||||||
if not valuedict:
|
|
||||||
return ("", tuple())
|
|
||||||
|
|
||||||
key_fields = []
|
|
||||||
values = []
|
|
||||||
for key, value in valuedict.items():
|
|
||||||
if isinstance(value, _updateField):
|
|
||||||
key_fields.append(value.key_field(key))
|
|
||||||
v = value.value_field(key)
|
|
||||||
if v is not _updateField._EMPTY:
|
|
||||||
values.append(value.value_field(key))
|
|
||||||
else:
|
|
||||||
key_fields.append("{} = {}".format(key, _replace_char))
|
|
||||||
values.append(value)
|
|
||||||
|
|
||||||
return (', '.join(key_fields), values)
|
|
||||||
113
bot/data/formatters.py
Normal file
113
bot/data/formatters.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
from .connection import _replace_char
|
||||||
|
from .conditions import Condition
|
||||||
|
|
||||||
|
|
||||||
|
class _updateField:
|
||||||
|
__slots__ = ()
|
||||||
|
_EMPTY = object() # Return value for `value` indicating no value should be added
|
||||||
|
|
||||||
|
def key_field(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def value_field(self, key):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateValue(_updateField):
|
||||||
|
__slots__ = ('key_str', 'value')
|
||||||
|
|
||||||
|
def __init__(self, key_str, value=_updateField._EMPTY):
|
||||||
|
self.key_str = key_str
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def key_field(self, key):
|
||||||
|
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
|
||||||
|
|
||||||
|
def value_field(self, key):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateValueAdd(_updateField):
|
||||||
|
__slots__ = ('value',)
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def key_field(self, key):
|
||||||
|
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
|
||||||
|
|
||||||
|
def value_field(self, key):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
def _format_conditions(conditions):
|
||||||
|
"""
|
||||||
|
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
||||||
|
Supports `IN` type conditionals.
|
||||||
|
"""
|
||||||
|
if not conditions:
|
||||||
|
return ("", tuple())
|
||||||
|
|
||||||
|
values = []
|
||||||
|
conditional_strings = []
|
||||||
|
for key, item in conditions.items():
|
||||||
|
if isinstance(item, (list, tuple)):
|
||||||
|
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||||
|
values.extend(item)
|
||||||
|
elif isinstance(item, Condition):
|
||||||
|
item.apply(key, values, conditional_strings)
|
||||||
|
else:
|
||||||
|
conditional_strings.append("{}={}".format(key, _replace_char))
|
||||||
|
values.append(item)
|
||||||
|
|
||||||
|
return (' AND '.join(conditional_strings), values)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_selectkeys(keys):
|
||||||
|
"""
|
||||||
|
Formats a list of keys into a string suitable for `SELECT`.
|
||||||
|
"""
|
||||||
|
if not keys:
|
||||||
|
return "*"
|
||||||
|
else:
|
||||||
|
return ", ".join(keys)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_insertkeys(keys):
|
||||||
|
"""
|
||||||
|
Formats a list of keys into a string suitable for `INSERT`
|
||||||
|
"""
|
||||||
|
if not keys:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "({})".format(", ".join(keys))
|
||||||
|
|
||||||
|
|
||||||
|
def _format_insertvalues(values):
|
||||||
|
"""
|
||||||
|
Formats a list of values into a string suitable for `INSERT`
|
||||||
|
"""
|
||||||
|
value_str = "({})".format(", ".join(_replace_char for value in values))
|
||||||
|
return (value_str, values)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_updatestr(valuedict):
|
||||||
|
"""
|
||||||
|
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
|
||||||
|
"""
|
||||||
|
if not valuedict:
|
||||||
|
return ("", tuple())
|
||||||
|
|
||||||
|
key_fields = []
|
||||||
|
values = []
|
||||||
|
for key, value in valuedict.items():
|
||||||
|
if isinstance(value, _updateField):
|
||||||
|
key_fields.append(value.key_field(key))
|
||||||
|
v = value.value_field(key)
|
||||||
|
if v is not _updateField._EMPTY:
|
||||||
|
values.append(value.value_field(key))
|
||||||
|
else:
|
||||||
|
key_fields.append("{} = {}".format(key, _replace_char))
|
||||||
|
values.append(value)
|
||||||
|
|
||||||
|
return (', '.join(key_fields), values)
|
||||||
282
bot/data/interfaces.py
Normal file
282
bot/data/interfaces.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from cachetools import LRUCache
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from utils.lib import DotDict
|
||||||
|
|
||||||
|
from .connection import conn
|
||||||
|
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where, update_many
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache of interfaces
|
||||||
|
tables: Mapping[str, Table] = DotDict()
|
||||||
|
|
||||||
|
|
||||||
|
class Table:
|
||||||
|
"""
|
||||||
|
Transparent interface to a single table structure in the database.
|
||||||
|
Contains standard methods to access the table.
|
||||||
|
Intended to be subclassed to provide more derivative access for specific tables.
|
||||||
|
"""
|
||||||
|
conn = conn
|
||||||
|
queries = DotDict()
|
||||||
|
|
||||||
|
def __init__(self, name, attach_as=None):
|
||||||
|
self.name = name
|
||||||
|
tables[attach_as or name] = self
|
||||||
|
|
||||||
|
def select_where(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return select_where(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def select_one_where(self, *args, **kwargs):
|
||||||
|
rows = self.select_where(*args, **kwargs)
|
||||||
|
return rows[0] if rows else None
|
||||||
|
|
||||||
|
def update_where(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return update_where(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def delete_where(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return delete_where(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def insert(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return insert(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def insert_many(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return insert_many(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def update_many(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return update_many(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def upsert(self, *args, **kwargs):
|
||||||
|
with self.conn:
|
||||||
|
return upsert(self.name, *args, **kwargs)
|
||||||
|
|
||||||
|
def save_query(self, func):
|
||||||
|
"""
|
||||||
|
Decorator to add a saved query to the table.
|
||||||
|
"""
|
||||||
|
self.queries[func.__name__] = func
|
||||||
|
|
||||||
|
|
||||||
|
class Row:
|
||||||
|
__slots__ = ('table', 'data', '_pending')
|
||||||
|
|
||||||
|
conn = conn
|
||||||
|
|
||||||
|
def __init__(self, table, data, *args, **kwargs):
|
||||||
|
super().__setattr__('table', table)
|
||||||
|
self.data = data
|
||||||
|
self._pending = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rowid(self):
|
||||||
|
return self.table.id_from_row(self.data)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Row[{}]({})".format(
|
||||||
|
self.table.name,
|
||||||
|
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if key in self.table.columns:
|
||||||
|
if self._pending and key in self._pending:
|
||||||
|
return self._pending[key]
|
||||||
|
else:
|
||||||
|
return self.data[key]
|
||||||
|
else:
|
||||||
|
raise AttributeError(key)
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in self.table.columns:
|
||||||
|
if self._pending is None:
|
||||||
|
self.update(**{key: value})
|
||||||
|
else:
|
||||||
|
self._pending[key] = value
|
||||||
|
else:
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def batch_update(self):
|
||||||
|
if self._pending:
|
||||||
|
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
|
||||||
|
|
||||||
|
self._pending = {}
|
||||||
|
try:
|
||||||
|
yield self._pending
|
||||||
|
finally:
|
||||||
|
self.update(**self._pending)
|
||||||
|
self._pending = None
|
||||||
|
|
||||||
|
def _refresh(self):
|
||||||
|
row = self.table.select_one_where(self.table.dict_from_id(self.rowid))
|
||||||
|
if not row:
|
||||||
|
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
|
||||||
|
self.data = row
|
||||||
|
|
||||||
|
def update(self, **values):
|
||||||
|
rows = self.table.update_where(values, **self.table.dict_from_id(self.rowid))
|
||||||
|
self.data = rows[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _select_where(cls, _extra=None, **conditions):
|
||||||
|
return select_where(cls._table, **conditions)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _insert(cls, **values):
|
||||||
|
return insert(cls._table, **values)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _update_where(cls, values, **conditions):
|
||||||
|
return update_where(cls._table, values, **conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class RowTable(Table):
|
||||||
|
__slots__ = (
|
||||||
|
'name',
|
||||||
|
'columns',
|
||||||
|
'id_col',
|
||||||
|
'multi_key',
|
||||||
|
'row_cache'
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = conn
|
||||||
|
|
||||||
|
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000, **kwargs):
|
||||||
|
super().__init__(name, **kwargs)
|
||||||
|
self.name = name
|
||||||
|
self.columns = columns
|
||||||
|
self.id_col = id_col
|
||||||
|
self.multi_key = isinstance(id_col, tuple)
|
||||||
|
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
|
||||||
|
|
||||||
|
def id_from_row(self, row):
|
||||||
|
if self.multi_key:
|
||||||
|
return tuple(row[key] for key in self.id_col)
|
||||||
|
else:
|
||||||
|
return row[self.id_col]
|
||||||
|
|
||||||
|
def dict_from_id(self, rowid):
|
||||||
|
if self.multi_key:
|
||||||
|
return dict(zip(self.id_col, rowid))
|
||||||
|
else:
|
||||||
|
return {self.id_col: rowid}
|
||||||
|
|
||||||
|
# Extend original Table update methods to modify the cached rows
|
||||||
|
def insert(self, *args, **kwargs):
|
||||||
|
data = super().insert(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
self.row_cache[self.id_from_row(data)] = Row(self, data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def insert_many(self, *args, **kwargs):
|
||||||
|
data = super().insert_many(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
for data_row in data:
|
||||||
|
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||||
|
if cached_row is not None:
|
||||||
|
cached_row.data = data_row
|
||||||
|
return data
|
||||||
|
|
||||||
|
def update_where(self, *args, **kwargs):
|
||||||
|
data = super().update_where(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
for data_row in data:
|
||||||
|
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||||
|
if cached_row is not None:
|
||||||
|
cached_row.data = data_row
|
||||||
|
return data
|
||||||
|
|
||||||
|
def update_many(self, *args, **kwargs):
|
||||||
|
data = super().update_many(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
for data_row in data:
|
||||||
|
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||||
|
if cached_row is not None:
|
||||||
|
cached_row.data = data_row
|
||||||
|
return data
|
||||||
|
|
||||||
|
def delete_where(self, *args, **kwargs):
|
||||||
|
data = super().delete_where(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
for data_row in data:
|
||||||
|
self.row_cache.pop(self.id_from_row(data_row), None)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def upsert(self, *args, **kwargs):
|
||||||
|
data = super().upsert(*args, **kwargs)
|
||||||
|
if self.row_cache is not None:
|
||||||
|
rowid = self.id_from_row(data)
|
||||||
|
cached_row = self.row_cache.get(rowid, None)
|
||||||
|
if cached_row is not None:
|
||||||
|
cached_row.data = data
|
||||||
|
else:
|
||||||
|
self.row_cache[rowid] = Row(self, data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
# New methods to fetch and create rows
|
||||||
|
def _make_rows(self, *data_rows):
|
||||||
|
"""
|
||||||
|
Create or retrieve Row objects for each provided data row.
|
||||||
|
If the rows already exist in cache, updates the cached row.
|
||||||
|
"""
|
||||||
|
if self.row_cache is not None:
|
||||||
|
rows = []
|
||||||
|
for data_row in data_rows:
|
||||||
|
rowid = self.id_from_row(data_row)
|
||||||
|
|
||||||
|
cached_row = self.row_cache.get(rowid, None)
|
||||||
|
if cached_row is not None:
|
||||||
|
cached_row.data = data_row
|
||||||
|
row = cached_row
|
||||||
|
else:
|
||||||
|
row = Row(self, data_row)
|
||||||
|
self.row_cache[rowid] = row
|
||||||
|
rows.append(row)
|
||||||
|
else:
|
||||||
|
rows = [Row(self, data_row) for data_row in data_rows]
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def create_row(self, *args, **kwargs):
|
||||||
|
data = self.insert(*args, **kwargs)
|
||||||
|
return self._make_rows(data)[0]
|
||||||
|
|
||||||
|
def fetch_rows_where(self, *args, **kwargs):
|
||||||
|
# TODO: Handle list of rowids here?
|
||||||
|
data = self.select_where(*args, **kwargs)
|
||||||
|
return self._make_rows(*data)
|
||||||
|
|
||||||
|
def fetch(self, rowid):
|
||||||
|
"""
|
||||||
|
Fetch the row with the given id, retrieving from cache where possible.
|
||||||
|
"""
|
||||||
|
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
|
||||||
|
if row is None:
|
||||||
|
rows = self.fetch_rows_where(**self.dict_from_id(rowid))
|
||||||
|
row = rows[0] if rows else None
|
||||||
|
return row
|
||||||
|
|
||||||
|
def fetch_or_create(self, rowid=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if rowid is not None:
|
||||||
|
row = self.fetch(rowid)
|
||||||
|
else:
|
||||||
|
data = self.select_where(**kwargs)
|
||||||
|
row = self._make_rows(data[0])[0] if data else None
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
creation_kwargs = kwargs
|
||||||
|
if rowid is not None:
|
||||||
|
creation_kwargs.update(self.dict_from_id(rowid))
|
||||||
|
row = self.create_row(**creation_kwargs)
|
||||||
|
return row
|
||||||
149
bot/data/queries.py
Normal file
149
bot/data/queries.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from psycopg2.extras import execute_values
|
||||||
|
|
||||||
|
from .connection import conn
|
||||||
|
from .formatters import (_format_updatestr, _format_conditions, _format_insertkeys,
|
||||||
|
_format_selectkeys, _format_insertvalues)
|
||||||
|
|
||||||
|
|
||||||
|
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
|
||||||
|
"""
|
||||||
|
Select rows from the given table matching the conditions
|
||||||
|
"""
|
||||||
|
criteria, criteria_values = _format_conditions(conditions)
|
||||||
|
col_str = _format_selectkeys(select_columns)
|
||||||
|
|
||||||
|
if criteria:
|
||||||
|
where_str = "WHERE {}".format(criteria)
|
||||||
|
else:
|
||||||
|
where_str = ""
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
|
||||||
|
criteria_values
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def update_where(table, valuedict, cursor=None, **conditions):
|
||||||
|
"""
|
||||||
|
Update rows in the given table matching the conditions
|
||||||
|
"""
|
||||||
|
key_str, key_values = _format_updatestr(valuedict)
|
||||||
|
criteria, criteria_values = _format_conditions(conditions)
|
||||||
|
|
||||||
|
if criteria:
|
||||||
|
where_str = "WHERE {}".format(criteria)
|
||||||
|
else:
|
||||||
|
where_str = ""
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
|
||||||
|
tuple((*key_values, *criteria_values))
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_where(table, cursor=None, **conditions):
|
||||||
|
"""
|
||||||
|
Delete rows in the given table matching the conditions
|
||||||
|
"""
|
||||||
|
criteria, criteria_values = _format_conditions(conditions)
|
||||||
|
|
||||||
|
if criteria:
|
||||||
|
where_str = "WHERE {}".format(criteria)
|
||||||
|
else:
|
||||||
|
where_str = ""
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'DELETE FROM {} {} RETURNING *'.format(table, where_str),
|
||||||
|
criteria_values
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def insert(table, cursor=None, allow_replace=False, **values):
|
||||||
|
"""
|
||||||
|
Insert the given values into the table
|
||||||
|
"""
|
||||||
|
keys, values = zip(*values.items())
|
||||||
|
|
||||||
|
key_str = _format_insertkeys(keys)
|
||||||
|
value_str, values = _format_insertvalues(values)
|
||||||
|
|
||||||
|
action = 'REPLACE' if allow_replace else 'INSERT'
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
|
||||||
|
values
|
||||||
|
)
|
||||||
|
return cursor.fetchone()
|
||||||
|
|
||||||
|
|
||||||
|
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
|
||||||
|
"""
|
||||||
|
Insert all the given values into the table
|
||||||
|
"""
|
||||||
|
key_str = _format_insertkeys(insert_keys)
|
||||||
|
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
|
||||||
|
|
||||||
|
value_str = ", ".join(value_strs)
|
||||||
|
values = tuple(chain(*value_tuples))
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
|
||||||
|
values
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def upsert(table, constraint, cursor=None, **values):
|
||||||
|
"""
|
||||||
|
Insert or on conflict update.
|
||||||
|
"""
|
||||||
|
valuedict = values
|
||||||
|
keys, values = zip(*values.items())
|
||||||
|
|
||||||
|
key_str = _format_insertkeys(keys)
|
||||||
|
value_str, values = _format_insertvalues(values)
|
||||||
|
update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||||
|
|
||||||
|
if not isinstance(constraint, str):
|
||||||
|
constraint = ", ".join(constraint)
|
||||||
|
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||||
|
table, key_str, value_str, constraint, update_key_str
|
||||||
|
),
|
||||||
|
tuple((*values, *update_key_values))
|
||||||
|
)
|
||||||
|
return cursor.fetchone()
|
||||||
|
|
||||||
|
|
||||||
|
def update_many(table, *values, set_keys=None, where_keys=None, cursor=None):
|
||||||
|
cursor = cursor or conn.cursor()
|
||||||
|
|
||||||
|
return execute_values(
|
||||||
|
cursor,
|
||||||
|
"""
|
||||||
|
UPDATE {table}
|
||||||
|
SET {set_clause}
|
||||||
|
FROM (VALUES %s)
|
||||||
|
AS {temp_table}
|
||||||
|
WHERE {where_clause}
|
||||||
|
RETURNING *
|
||||||
|
""".format(
|
||||||
|
table=table,
|
||||||
|
set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
||||||
|
where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
||||||
|
temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
||||||
|
),
|
||||||
|
values,
|
||||||
|
fetch=True
|
||||||
|
)
|
||||||
@@ -6,6 +6,9 @@ import core # noqa
|
|||||||
|
|
||||||
import modules # noqa
|
import modules # noqa
|
||||||
|
|
||||||
|
# Load and attach app specific data
|
||||||
|
client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid'])
|
||||||
|
|
||||||
# Initialise all modules
|
# Initialise all modules
|
||||||
client.initialise_modules()
|
client.initialise_modules()
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,34 @@ from .config import conf
|
|||||||
# Setup the logger
|
# Setup the logger
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{')
|
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{')
|
||||||
term_handler = logging.StreamHandler(sys.stdout)
|
# term_handler = logging.StreamHandler(sys.stdout)
|
||||||
term_handler.setFormatter(log_fmt)
|
# term_handler.setFormatter(log_fmt)
|
||||||
logger.addHandler(term_handler)
|
# logger.addHandler(term_handler)
|
||||||
logger.setLevel(logging.INFO)
|
# logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanFilter(logging.Filter):
|
||||||
|
def __init__(self, exclusive_maximum, name=""):
|
||||||
|
super(LessThanFilter, self).__init__(name)
|
||||||
|
self.max_level = exclusive_maximum
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.levelno < self.max_level else 0
|
||||||
|
|
||||||
|
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||||
|
logging_handler_out.setLevel(logging.DEBUG)
|
||||||
|
logging_handler_out.setFormatter(log_fmt)
|
||||||
|
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||||
|
logger.addHandler(logging_handler_out)
|
||||||
|
|
||||||
|
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||||
|
logging_handler_err.setLevel(logging.WARNING)
|
||||||
|
logging_handler_err.setFormatter(log_fmt)
|
||||||
|
logger.addHandler(logging_handler_err)
|
||||||
|
|
||||||
|
|
||||||
# Define the context log format and attach it to the command logger as well
|
# Define the context log format and attach it to the command logger as well
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
from .sysadmin import *
|
from .sysadmin import *
|
||||||
|
from .guild_admin import *
|
||||||
|
from .meta import *
|
||||||
from .economy import *
|
from .economy import *
|
||||||
from .study import *
|
from .study import *
|
||||||
|
from .user_config import *
|
||||||
|
from .workout import *
|
||||||
|
from .todo import *
|
||||||
|
# from .moderation import *
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from . import module
|
from .module import module
|
||||||
|
|
||||||
from . import commands
|
from . import commands
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from core import User
|
from cmdClient.checks import in_guild
|
||||||
from core.tables import users
|
|
||||||
|
|
||||||
|
import data
|
||||||
|
from data import tables
|
||||||
|
from core import Lion
|
||||||
from utils import interactive # noqa
|
from utils import interactive # noqa
|
||||||
|
|
||||||
from .module import module
|
from .module import module
|
||||||
@@ -11,17 +13,19 @@ second_emoji = "🥈"
|
|||||||
third_emoji = "🥉"
|
third_emoji = "🥉"
|
||||||
|
|
||||||
|
|
||||||
# TODO: in_guild ward
|
|
||||||
@module.cmd(
|
@module.cmd(
|
||||||
"topcoin",
|
"cointop",
|
||||||
short_help="View the LionCoin leaderboard.",
|
group="Statistics",
|
||||||
aliases=('topc', 'ctop')
|
desc="View the LionCoin leaderboard.",
|
||||||
|
aliases=('topc', 'ctop', 'topcoins', 'topcoin', 'cointop100'),
|
||||||
|
help_aliases={'cointop100': "View the LionCoin top 100."}
|
||||||
)
|
)
|
||||||
|
@in_guild()
|
||||||
async def cmd_topcoin(ctx):
|
async def cmd_topcoin(ctx):
|
||||||
"""
|
"""
|
||||||
Usage``:
|
Usage``:
|
||||||
{prefix}topcoin
|
{prefix}cointop
|
||||||
{prefix}topcoin 100
|
{prefix}cointop 100
|
||||||
Description:
|
Description:
|
||||||
Display the LionCoin leaderboard, or top 100.
|
Display the LionCoin leaderboard, or top 100.
|
||||||
|
|
||||||
@@ -30,15 +34,17 @@ async def cmd_topcoin(ctx):
|
|||||||
# Handle args
|
# Handle args
|
||||||
if ctx.args and not ctx.args == "100":
|
if ctx.args and not ctx.args == "100":
|
||||||
return await ctx.error_reply(
|
return await ctx.error_reply(
|
||||||
"**Usage:**`{prefix}topcoin` or `{prefix}topcoin100`.".format(prefix=ctx.client.prefix)
|
"**Usage:**`{prefix}topcoin` or `{prefix}topcoin100`.".format(prefix=ctx.best_prefix)
|
||||||
)
|
)
|
||||||
top100 = ctx.args == "100"
|
top100 = (ctx.args == "100" or ctx.alias == "contop100")
|
||||||
|
|
||||||
# Flush any pending coin transactions
|
# Flush any pending coin transactions
|
||||||
User.sync()
|
Lion.sync()
|
||||||
|
|
||||||
# Fetch the leaderboard
|
# Fetch the leaderboard
|
||||||
user_data = users.select_where(
|
user_data = tables.lions.select_where(
|
||||||
|
guildid=ctx.guild.id,
|
||||||
|
userid=data.NOT([m.id for m in ctx.guild_settings.unranked_roles.members]),
|
||||||
select_columns=('userid', 'coins'),
|
select_columns=('userid', 'coins'),
|
||||||
_extra="ORDER BY coins DESC " + ("LIMIT 100" if top100 else "")
|
_extra="ORDER BY coins DESC " + ("LIMIT 100" if top100 else "")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from cmdClient import Module
|
from LionModule import LionModule
|
||||||
|
|
||||||
|
|
||||||
module = Module("Economy")
|
module = LionModule("Economy")
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .module import module
|
||||||
|
|
||||||
|
from . import help
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from .module import module
|
||||||
|
|
||||||
|
from . import admin
|
||||||
|
# from . import video_channels
|
||||||
|
from . import Ticket
|
||||||
|
|||||||
@@ -1,2 +1,9 @@
|
|||||||
from .module import module
|
from .module import module
|
||||||
from . import commands
|
|
||||||
|
from . import data
|
||||||
|
from . import admin
|
||||||
|
from . import badge_tracker
|
||||||
|
from . import time_tracker
|
||||||
|
from . import top_cmd
|
||||||
|
from . import studybadge_cmd
|
||||||
|
from . import stats_cmd
|
||||||
|
|||||||
@@ -1,111 +0,0 @@
|
|||||||
import datetime as dt
|
|
||||||
|
|
||||||
from core import User
|
|
||||||
from core.tables import users
|
|
||||||
|
|
||||||
from utils import interactive # noqa
|
|
||||||
|
|
||||||
from .module import module
|
|
||||||
|
|
||||||
|
|
||||||
first_emoji = "🥇"
|
|
||||||
second_emoji = "🥈"
|
|
||||||
third_emoji = "🥉"
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: in_guild ward
|
|
||||||
@module.cmd(
|
|
||||||
"top",
|
|
||||||
short_help="View the Study Time leaderboard.",
|
|
||||||
aliases=('ttop', 'toptime')
|
|
||||||
)
|
|
||||||
async def cmd_top(ctx):
|
|
||||||
"""
|
|
||||||
Usage``:
|
|
||||||
{prefix}top
|
|
||||||
{prefix}top 100
|
|
||||||
Description:
|
|
||||||
Display the study time leaderboard, or the top 100.
|
|
||||||
|
|
||||||
Use the paging reactions or send `p<n>` to switch pages (e.g. `p11` to switch to page 11).
|
|
||||||
"""
|
|
||||||
# Handle args
|
|
||||||
if ctx.args and not ctx.args == "100":
|
|
||||||
return await ctx.error_reply(
|
|
||||||
"**Usage:**`{prefix}top` or `{prefix}top100`.".format(prefix=ctx.client.prefix)
|
|
||||||
)
|
|
||||||
top100 = ctx.args == "100"
|
|
||||||
|
|
||||||
# Flush any pending coin transactions
|
|
||||||
User.sync()
|
|
||||||
|
|
||||||
# Fetch the leaderboard
|
|
||||||
user_data = users.select_where(
|
|
||||||
select_columns=('userid', 'tracked_time'),
|
|
||||||
_extra="ORDER BY tracked_time DESC " + ("LIMIT 100" if top100 else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Quit early if the leaderboard is empty
|
|
||||||
if not user_data:
|
|
||||||
return await ctx.reply("No leaderboard entries yet!")
|
|
||||||
|
|
||||||
# Extract entries
|
|
||||||
author_index = None
|
|
||||||
entries = []
|
|
||||||
for i, (userid, time) in enumerate(user_data):
|
|
||||||
member = ctx.guild.get_member(userid)
|
|
||||||
name = member.display_name if member else str(userid)
|
|
||||||
name = name.replace('*', ' ').replace('_', ' ')
|
|
||||||
|
|
||||||
num_str = "{}.".format(i+1)
|
|
||||||
|
|
||||||
hours = time // 3600
|
|
||||||
minutes = time // 60 % 60
|
|
||||||
seconds = time % 60
|
|
||||||
|
|
||||||
time_str = "{}:{:02}:{:02}".format(
|
|
||||||
hours,
|
|
||||||
minutes,
|
|
||||||
seconds
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx.author.id == userid:
|
|
||||||
author_index = i
|
|
||||||
|
|
||||||
entries.append((num_str, name, time_str))
|
|
||||||
|
|
||||||
# Extract blocks
|
|
||||||
blocks = [entries[i:i+20] for i in range(0, len(entries), 20)]
|
|
||||||
block_count = len(blocks)
|
|
||||||
|
|
||||||
# Build strings
|
|
||||||
header = "Study Time Top 100" if top100 else "Study Time Leaderboard"
|
|
||||||
if block_count > 1:
|
|
||||||
header += " (Page {{page}}/{})".format(block_count)
|
|
||||||
|
|
||||||
# Build pages
|
|
||||||
pages = []
|
|
||||||
for i, block in enumerate(blocks):
|
|
||||||
max_num_l, max_name_l, max_time_l = [max(len(e[i]) for e in block) for i in (0, 1, 2)]
|
|
||||||
body = '\n'.join(
|
|
||||||
"{:>{}} {:<{}} \t {:>{}} {} {}".format(
|
|
||||||
entry[0], max_num_l,
|
|
||||||
entry[1], max_name_l + 2,
|
|
||||||
entry[2], max_time_l + 1,
|
|
||||||
first_emoji if i == 0 and j == 0 else (
|
|
||||||
second_emoji if i == 0 and j == 1 else (
|
|
||||||
third_emoji if i == 0 and j == 2 else ''
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"⮜" if author_index is not None and author_index == i * 20 + j else ""
|
|
||||||
)
|
|
||||||
for j, entry in enumerate(block)
|
|
||||||
)
|
|
||||||
title = header.format(page=i+1)
|
|
||||||
line = '='*len(title)
|
|
||||||
pages.append(
|
|
||||||
"```md\n{}\n{}\n{}```".format(title, line, body)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Finally, page the results
|
|
||||||
await ctx.pager(pages, start_at=(author_index or 0)//20 if not top100 else 0)
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from cmdClient import Module
|
from LionModule import LionModule
|
||||||
|
|
||||||
|
|
||||||
module = Module("Study_Stats")
|
module = LionModule("Study_Stats")
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
|
from .module import module
|
||||||
|
|
||||||
from .exec_cmds import *
|
from .exec_cmds import *
|
||||||
|
|||||||
@@ -5,35 +5,45 @@ import asyncio
|
|||||||
|
|
||||||
from cmdClient import cmd, checks
|
from cmdClient import cmd, checks
|
||||||
|
|
||||||
|
from core import Lion
|
||||||
|
from LionModule import LionModule
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Exec level commands to manage the bot.
|
Exec level commands to manage the bot.
|
||||||
|
|
||||||
Commands provided:
|
Commands provided:
|
||||||
async:
|
async:
|
||||||
Executes provided code in an async executor
|
Executes provided code in an async executor
|
||||||
exec:
|
|
||||||
Executes code using standard python exec
|
|
||||||
eval:
|
eval:
|
||||||
Executes code and awaits it if required
|
Executes code and awaits it if required
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@cmd("reboot")
|
@cmd("shutdown",
|
||||||
|
desc="Sync data and shutdown.",
|
||||||
|
group="Bot Admin",
|
||||||
|
aliases=('restart', 'reboot'))
|
||||||
@checks.is_owner()
|
@checks.is_owner()
|
||||||
async def cmd_reboot(ctx):
|
async def cmd_shutdown(ctx):
|
||||||
"""
|
"""
|
||||||
Usage``:
|
Usage``:
|
||||||
reboot
|
reboot
|
||||||
Description:
|
Description:
|
||||||
Update the timer status save file and reboot the client.
|
Run unload tasks and shutdown/reboot.
|
||||||
"""
|
"""
|
||||||
ctx.client.interface.update_save("reboot")
|
# Run module logout tasks
|
||||||
ctx.client.interface.shutdown()
|
for module in ctx.client.modules:
|
||||||
await ctx.reply("Saved state. Rebooting now!")
|
if isinstance(module, LionModule):
|
||||||
|
await module.unload(ctx.client)
|
||||||
|
|
||||||
|
# Reply and logout
|
||||||
|
await ctx.reply("All modules synced. Shutting down!")
|
||||||
await ctx.client.close()
|
await ctx.client.close()
|
||||||
|
|
||||||
|
|
||||||
@cmd("async")
|
@cmd("async",
|
||||||
|
desc="Execute arbitrary code with `async`.",
|
||||||
|
group="Bot Admin")
|
||||||
@checks.is_owner()
|
@checks.is_owner()
|
||||||
async def cmd_async(ctx):
|
async def cmd_async(ctx):
|
||||||
"""
|
"""
|
||||||
@@ -55,7 +65,9 @@ async def cmd_async(ctx):
|
|||||||
output))
|
output))
|
||||||
|
|
||||||
|
|
||||||
@cmd("eval")
|
@cmd("eval",
|
||||||
|
desc="Execute arbitrary code with `eval`.",
|
||||||
|
group="Bot Admin")
|
||||||
@checks.is_owner()
|
@checks.is_owner()
|
||||||
async def cmd_eval(ctx):
|
async def cmd_eval(ctx):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from .module import module
|
||||||
|
|
||||||
|
from . import Tasklist
|
||||||
|
from . import admin
|
||||||
|
from . import data
|
||||||
|
from . import commands
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from .module import module
|
||||||
|
|
||||||
|
from . import admin
|
||||||
|
from . import data
|
||||||
|
from . import tracker
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import re
|
|||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
|
from cmdClient.lib import SafeCancellation
|
||||||
|
|
||||||
|
|
||||||
def prop_tabulate(prop_list, value_list, indent=True):
|
def prop_tabulate(prop_list, value_list, indent=True):
|
||||||
"""
|
"""
|
||||||
@@ -193,6 +195,25 @@ def parse_dur(time_str):
|
|||||||
return seconds
|
return seconds
|
||||||
|
|
||||||
|
|
||||||
|
def strfdur(duration):
|
||||||
|
"""
|
||||||
|
Convert a duration given in seconds to a number of hours, minutes, and seconds.
|
||||||
|
"""
|
||||||
|
hours = duration // 3600
|
||||||
|
minutes = duration // 60 % 60
|
||||||
|
seconds = duration % 60
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if hours:
|
||||||
|
parts.append('{}h'.format(hours))
|
||||||
|
if minutes:
|
||||||
|
parts.append('{}m'.format(minutes))
|
||||||
|
if seconds or duration == 0:
|
||||||
|
parts.append('{}s'.format(seconds))
|
||||||
|
|
||||||
|
return ' '.join(parts)
|
||||||
|
|
||||||
|
|
||||||
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
||||||
"""
|
"""
|
||||||
Substitutes a user provided list of numbers and ranges,
|
Substitutes a user provided list of numbers and ranges,
|
||||||
@@ -213,12 +234,28 @@ def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
|||||||
n1 = int(match.group(1))
|
n1 = int(match.group(1))
|
||||||
n2 = int(match.group(2))
|
n2 = int(match.group(2))
|
||||||
if n2 - n1 > max_range:
|
if n2 - n1 > max_range:
|
||||||
raise ValueError("Provided range exceeds the allowed maximum.")
|
raise SafeCancellation("Provided range is too large!")
|
||||||
return separator.join(str(i) for i in range(n1, n2 + 1))
|
return separator.join(str(i) for i in range(n1, n2 + 1))
|
||||||
|
|
||||||
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs):
|
||||||
|
"""
|
||||||
|
Parses a user provided range string into a list of numbers.
|
||||||
|
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
|
||||||
|
"""
|
||||||
|
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
|
||||||
|
numbers = (item.strip() for item in substituted.split(','))
|
||||||
|
numbers = [item for item in numbers if item]
|
||||||
|
integers = [int(item) for item in numbers if item.isdigit()]
|
||||||
|
|
||||||
|
if not ignore_errors and len(integers) != len(numbers):
|
||||||
|
raise SafeCancellation("Couldn't parse the provided selection!")
|
||||||
|
|
||||||
|
return integers
|
||||||
|
|
||||||
|
|
||||||
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
|
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
|
||||||
"""
|
"""
|
||||||
Format a message into a string with various information, such as:
|
Format a message into a string with various information, such as:
|
||||||
|
|||||||
24
bot/wards.py
Normal file
24
bot/wards.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from cmdClient import check
|
||||||
|
from cmdClient.checks import in_guild
|
||||||
|
|
||||||
|
from data import tables
|
||||||
|
|
||||||
|
|
||||||
|
def is_guild_admin(member):
|
||||||
|
# First check guild admin permissions
|
||||||
|
admin = member.guild_permissions.administrator
|
||||||
|
|
||||||
|
# Then check the admin role, if it is set
|
||||||
|
if not admin:
|
||||||
|
admin_role_id = tables.guild_config.fetch_or_create(member.guild.id).admin_role
|
||||||
|
admin = admin_role_id and (admin_role_id in (r.id for r in member.roles))
|
||||||
|
return admin
|
||||||
|
|
||||||
|
|
||||||
|
@check(
|
||||||
|
name="ADMIN",
|
||||||
|
msg=("You need to be a server admin to do this!"),
|
||||||
|
requires=[in_guild]
|
||||||
|
)
|
||||||
|
async def guild_admin(ctx, *args, **kwargs):
|
||||||
|
return is_guild_admin(ctx.author)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
[DEFAULT]
|
||||||
|
log_file = bot.log
|
||||||
|
log_channel =
|
||||||
|
|
||||||
|
prefix = !
|
||||||
|
token =
|
||||||
|
owners = 413668234269818890, 389399222400712714
|
||||||
|
|
||||||
|
database = dbname=lionbot
|
||||||
|
data_appid = LionBot
|
||||||
|
|
||||||
|
lion_sync_period = 60
|
||||||
|
|||||||
Reference in New Issue
Block a user