From 0183b63c5512c37574f569b9d83c853c9cf44482 Mon Sep 17 00:00:00 2001 From: Conatum Date: Sun, 12 Sep 2021 11:04:49 +0300 Subject: [PATCH] Data system refactor and core redesign for public. Redesigned data and core systems to be public-capable. --- bot/LionModule.py | 70 ++++ bot/core/__init__.py | 6 +- bot/core/data.py | 81 ++++ bot/core/lion.py | 145 +++++++ bot/core/module.py | 36 ++ bot/core/tables.py | 31 -- bot/core/user.py | 105 ----- bot/data/__init__.py | 8 +- bot/data/conditions.py | 59 +++ bot/data/connection.py | 40 ++ bot/data/{custom_cursor.py => cursor.py} | 0 bot/data/data.py | 505 ----------------------- bot/data/formatters.py | 113 +++++ bot/data/interfaces.py | 282 +++++++++++++ bot/data/queries.py | 149 +++++++ bot/main.py | 3 + bot/meta/logger.py | 32 +- bot/modules/__init__.py | 6 + bot/modules/economy/__init__.py | 3 +- bot/modules/economy/commands.py | 30 +- bot/modules/economy/module.py | 4 +- bot/modules/meta/__init__.py | 3 + bot/modules/moderation/__init__.py | 5 + bot/modules/study/__init__.py | 9 +- bot/modules/study/commands.py | 111 ----- bot/modules/study/module.py | 4 +- bot/modules/sysadmin/__init__.py | 2 + bot/modules/sysadmin/exec_cmds.py | 32 +- bot/modules/todo/__init__.py | 6 + bot/modules/workout/__init__.py | 5 + bot/utils/lib.py | 39 +- bot/wards.py | 24 ++ config/example-bot.conf | 12 + 33 files changed, 1170 insertions(+), 790 deletions(-) create mode 100644 bot/LionModule.py create mode 100644 bot/core/data.py create mode 100644 bot/core/lion.py create mode 100644 bot/core/module.py delete mode 100644 bot/core/tables.py delete mode 100644 bot/core/user.py create mode 100644 bot/data/conditions.py create mode 100644 bot/data/connection.py rename bot/data/{custom_cursor.py => cursor.py} (100%) delete mode 100644 bot/data/data.py create mode 100644 bot/data/formatters.py create mode 100644 bot/data/interfaces.py create mode 100644 bot/data/queries.py delete mode 100644 bot/modules/study/commands.py create mode 100644 bot/wards.py diff --git a/bot/LionModule.py b/bot/LionModule.py new file mode 100644 index 00000000..d56ac38b --- /dev/null +++ b/bot/LionModule.py @@ -0,0 +1,70 @@ +from cmdClient import Command, Module + +from meta import log + + +class LionCommand(Command): + """ + Subclass to allow easy attachment of custom hooks and structure to commands. + """ + ... + + +class LionModule(Module): + """ + Custom module for Lion systems. + + Adds command wrappers and various event handlers. + """ + name = "Base Lion Module" + + def __init__(self, name, baseCommand=LionCommand): + super().__init__(name, baseCommand) + + self.unload_tasks = [] + + def unload_task(self, func): + """ + Decorator adding an unload task for deactivating the module. + Should sync unsaved transactions and finalise user interaction. + If possible, should also remove attached data and handlers. + """ + self.unload_tasks.append(func) + log("Adding unload task '{}'.".format(func.__name__), context=self.name) + return func + + async def unload(self, client): + """ + Run the unloading tasks. + """ + log("Unloading module.", context=self.name, post=False) + for task in self.unload_tasks: + log("Running unload task '{}'".format(task.__name__), + context=self.name, post=False) + await task(client) + + async def launch(self, client): + """ + Launch hook. + Executed in `client.on_ready`. + Must set `ready` to `True`, otherwise all commands will hang. + Overrides the parent launcher to not post the log as a discord message. + """ + if not self.ready: + log("Running launch tasks.", context=self.name, post=False) + + for task in self.launch_tasks: + log("Running launch task '{}'.".format(task.__name__), + context=self.name, post=False) + await task(client) + + self.ready = True + else: + log("Already launched, skipping launch.", context=self.name, post=False) + + async def pre_command(self, ctx): + """ + Lion pre-command hook. + """ + # TODO: Add blacklist and auto-fetch of lion here. + ... diff --git a/bot/core/__init__.py b/bot/core/__init__.py index 4094d206..651b6553 100644 --- a/bot/core/__init__.py +++ b/bot/core/__init__.py @@ -1,2 +1,4 @@ -from . import tables -from .user import User +from . import data # noqa + +from .module import module +from .lion import Lion # noqa diff --git a/bot/core/data.py b/bot/core/data.py new file mode 100644 index 00000000..06f240e3 --- /dev/null +++ b/bot/core/data.py @@ -0,0 +1,81 @@ +from psycopg2.extras import execute_values + +from cachetools import TTLCache +from data import RowTable, Table + + +meta = RowTable( + 'AppData', + ('appid', 'last_study_badge_scan'), + 'appid', + attach_as='meta', +) + + +user_config = RowTable( + 'user_config', + ('userid', 'timezone'), + 'userid', + cache=TTLCache(5000, ttl=60*5) +) + + +@user_config.save_query +def add_pending(pending): + """ + pending: + List of tuples of the form `(userid, pending_coins, pending_time)`. + """ + with lions.conn: + cursor = lions.conn.cursor() + data = execute_values( + cursor, + """ + UPDATE members + SET + coins = coins + t.coin_diff, + tracked_time = tracked_time + t.time_diff + FROM + (VALUES %s) + AS + t (guildid, userid, coin_diff, time_diff) + WHERE + members.guildid = t.guildid + AND + members.userid = t.userid + RETURNING * + """, + pending, + fetch=True + ) + return lions._make_rows(*data) + + +guild_config = RowTable( + 'guild_config', + ('guildid', 'admin_role', 'mod_role', 'event_log_channel', + 'min_workout_length', 'workout_reward', + 'max_tasks', 'task_reward', 'task_reward_limit', + 'study_hourly_reward', 'study_hourly_live_bonus', + 'study_ban_role', 'max_study_bans'), + 'guildid', + cache=TTLCache(1000, ttl=60*5) +) + +unranked_roles = Table('unranked_roles') + +donator_roles = Table('donator_roles') + + +lions = RowTable( + 'members', + ('guildid', 'userid', + 'tracked_time', 'coins', + 'workout_count', 'last_workout_start', + 'last_study_badgeid', + 'study_ban_count', + ), + ('guildid', 'userid'), + cache=TTLCache(5000, ttl=60*5), + attach_as='lions' +) diff --git a/bot/core/lion.py b/bot/core/lion.py new file mode 100644 index 00000000..d090f176 --- /dev/null +++ b/bot/core/lion.py @@ -0,0 +1,145 @@ +import pytz + +from meta import client +from data import tables as tb +from settings import UserSettings + + +class Lion: + """ + Class representing a guild Member. + Mostly acts as a transparent interface to the corresponding Row, + but also adds some transaction caching logic to `coins` and `tracked_time`. + """ + __slots__ = ('guildid', 'userid', '_pending_coins', '_pending_time', '_member') + + # Members with pending transactions + _pending = {} # userid -> User + + # Lion cache. Currently lions don't expire + _lions = {} # (guildid, userid) -> Lion + + def __init__(self, guildid, userid): + self.guildid = guildid + self.userid = userid + + self._pending_coins = 0 + self._pending_time = 0 + + self._member = None + + self._lions[self.key] = self + + @classmethod + def fetch(cls, guildid, userid): + """ + Fetch a Lion with the given member. + If they don't exist, creates them. + If possible, retrieves the user from the user cache. + """ + key = (guildid, userid) + if key in cls._lions: + return cls._lions[key] + else: + tb.lions.fetch_or_create(key) + return cls(guildid, userid) + + @property + def key(self): + return (self.guildid, self.userid) + + @property + def member(self): + """ + The discord `Member` corresponding to this user. + May be `None` if the member is no longer in the guild or the caches aren't populated. + Not guaranteed to be `None` if the member is not in the guild. + """ + if self._member is None: + guild = client.get_guild(self.guildid) + if guild: + self._member = guild.get_member(self.userid) + return self._member + + @property + def data(self): + """ + The Row corresponding to this user. + """ + return tb.lions.fetch(self.key) + + @property + def settings(self): + """ + The UserSettings object for this user. + """ + return UserSettings(self.userid) + + @property + def time(self): + """ + Amount of time the user has spent studying, accounting for pending values. + """ + return int(self.data.tracked_time + self._pending_time) + + @property + def coins(self): + """ + Number of coins the user has, accounting for the pending value. + """ + return int(self.data.coins + self._pending_coins) + + def localize(self, naive_utc_dt): + """ + Localise the provided naive UTC datetime into the user's timezone. + """ + timezone = self.settings.timezone.value + return naive_utc_dt.replace(tzinfo=pytz.UTC).astimezone(timezone) + + def addCoins(self, amount, flush=True): + """ + Add coins to the user, optionally store the transaction in pending. + """ + self._pending_coins += amount + self._pending[self.key] = self + if flush: + self.flush() + + def addTime(self, amount, flush=True): + """ + Add time to a user (in seconds), optionally storing the transaction in pending. + """ + self._pending_time += amount + self._pending[self.key] = self + if flush: + self.flush() + + def flush(self): + """ + Flush any pending transactions to the database. + """ + self.sync(self) + + @classmethod + def sync(cls, *lions): + """ + Flush pending transactions to the database. + Also refreshes the Row cache for updated lions. + """ + lions = lions or list(cls._pending.values()) + + if lions: + # Build userid to pending coin map + pending = [ + (lion.guildid, lion.userid, int(lion._pending_coins), int(lion._pending_time)) + for lion in lions + ] + + # Write to database + tb.lions.queries.add_pending(pending) + + # Cleanup pending users + for lion in lions: + lion._pending_coins -= int(lion._pending_coins) + lion._pending_time -= int(lion._pending_time) + cls._pending.pop(lion.key, None) diff --git a/bot/core/module.py b/bot/core/module.py new file mode 100644 index 00000000..cd87f41f --- /dev/null +++ b/bot/core/module.py @@ -0,0 +1,36 @@ +import logging +import asyncio + +from meta import client, conf +from LionModule import LionModule + +from .lion import Lion + + +module = LionModule("Core") + + +async def _lion_sync_loop(): + while True: + while not client.is_ready(): + await asyncio.sleep(1) + + client.log( + "Running lion data sync.", + context="CORE", + level=logging.DEBUG, + post=False + ) + + Lion.sync() + await asyncio.sleep(conf.bot.getint("lion_sync_period")) + + +@module.launch_task +async def launch_lion_sync_loop(client): + asyncio.create_task(_lion_sync_loop()) + + +@module.unload_task +async def final_lion_sync(client): + Lion.sync() diff --git a/bot/core/tables.py b/bot/core/tables.py deleted file mode 100644 index b0f407ff..00000000 --- a/bot/core/tables.py +++ /dev/null @@ -1,31 +0,0 @@ -from psycopg2.extras import execute_values - -from cachetools import TTLCache -from data import RowTable, Table - - -users = RowTable( - 'lions', - ('userid', 'tracked_time', 'coins'), - 'userid', - cache=TTLCache(5000, ttl=60*5) -) - - -@users.save_query -def add_coins(userid_coins): - with users.conn: - cursor = users.conn.cursor() - data = execute_values( - cursor, - """ - UPDATE lions - SET coins = coins + t.diff - FROM (VALUES %s) AS t (userid, diff) - WHERE lions.userid = t.userid - RETURNING * - """, - userid_coins, - fetch=True - ) - return users._make_rows(*data) diff --git a/bot/core/user.py b/bot/core/user.py deleted file mode 100644 index 1f717474..00000000 --- a/bot/core/user.py +++ /dev/null @@ -1,105 +0,0 @@ -from . import tables as tb -from meta import conf, client - - -class User: - """ - Class representing a "Lion", i.e. a member of the managed guild. - Mostly acts as a transparent interface to the corresponding Row, - but also adds some transaction caching logic to `coins`. - """ - __slots__ = ('userid', '_pending_coins', '_member') - - # Users with pending transactions - _pending = {} # userid -> User - - # User cache. Currently users don't expire - _users = {} # userid -> User - - def __init__(self, userid): - self.userid = userid - self._pending_coins = 0 - - self._users[self.userid] = self - - @classmethod - def fetch(cls, userid): - """ - Fetch a User with the given userid. - If they don't exist, creates them. - If possible, retrieves the user from the user cache. - """ - if userid in cls._users: - return cls._users[userid] - else: - tb.users.fetch_or_create(userid) - return cls(userid) - - @property - def member(self): - """ - The discord `Member` corresponding to this user. - May be `None` if the member is no longer in the guild or the caches aren't populated. - Not guaranteed to be `None` if the member is not in the guild. - """ - if self._member is None: - self._member = client.get_guild(conf.meta.getint('managed_guild_id')).get_member(self.userid) - - @property - def data(self): - """ - The Row corresponding to this user. - """ - return tb.users.fetch(self.userid) - - @property - def time(self): - """ - Amount of time the user has spent.. studying? - """ - return self.data.tracked_time - - @property - def coins(self): - """ - Number of coins the user has, accounting for the pending value. - """ - return self.data.coins + self._pending_coins - - def addCoins(self, amount, flush=True): - """ - Add coins to the user, optionally store the transaction in pending. - """ - self._pending_coins += amount - if self._pending_coins != 0: - self._pending[self.userid] = self - else: - self._pending.pop(self.userid, None) - if flush: - self.flush() - - def flush(self): - """ - Flush any pending transactions to the database. - """ - self.sync(self) - - @classmethod - def sync(cls, *users): - """ - Flush pending transactions to the database. - Also refreshes the Row cache for updated users. - """ - users = users or list(cls._pending.values()) - - if users: - # Build userid to pending coin map - userid_coins = [(user.userid, user._pending_coins) for user in users] - - # Write to database - tb.users.queries.add_coins(userid_coins) - - # Cleanup pending users - for user in users: - user._pending_coins = 0 - cls._pending.pop(user.userid, None) diff --git a/bot/data/__init__.py b/bot/data/__init__.py index ad8ddd02..f048ce37 100644 --- a/bot/data/__init__.py +++ b/bot/data/__init__.py @@ -1,3 +1,5 @@ -from .data import * -# from . import tables -# from . import queries +from .connection import conn # noqa +from .formatters import UpdateValue, UpdateValueAdd # noqa +from .interfaces import Table, RowTable, Row, tables # noqa +from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa +from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa diff --git a/bot/data/conditions.py b/bot/data/conditions.py new file mode 100644 index 00000000..ca01ea5d --- /dev/null +++ b/bot/data/conditions.py @@ -0,0 +1,59 @@ +from .connection import _replace_char + + +class Condition: + """ + ABC representing a selection condition. + """ + __slots__ = () + + def apply(self, key, values, conditions): + raise NotImplementedError + + +class NOT(Condition): + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def apply(self, key, values, conditions): + item = self.value + if isinstance(item, (list, tuple)): + if item: + conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item)))) + values.extend(item) + else: + raise ValueError("Cannot check an empty iterable!") + else: + conditions.append("{}!={}".format(key, _replace_char)) + values.append(item) + + +class GEQ(Condition): + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def apply(self, key, values, conditions): + item = self.value + if isinstance(item, (list, tuple)): + raise ValueError("Cannot apply GEQ condition to a list!") + else: + conditions.append("{} >= {}".format(key, _replace_char)) + values.append(item) + + +class Constant(Condition): + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def apply(self, key, values, conditions): + conditions.append("{} {}".format(key, self.value)) + + +NULL = Constant('IS NULL') +NOTNULL = Constant('IS NOT NULL') diff --git a/bot/data/connection.py b/bot/data/connection.py new file mode 100644 index 00000000..1f35eda2 --- /dev/null +++ b/bot/data/connection.py @@ -0,0 +1,40 @@ +import logging + +import psycopg2 as psy + +from meta import log, conf +from constants import DATA_VERSION +from .cursor import DictLoggingCursor + + +# Set up database connection +log("Establishing connection.", "DB_INIT", level=logging.DEBUG) +conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor) + +# Replace char used by the connection for query formatting +_replace_char: str = '%s' + +# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG)) +# sq.register_adapter(datetime, lambda dt: dt.timestamp()) + + +# Check the version matches the required version +with conn: + log("Checking db version.", "DB_INIT") + cursor = conn.cursor() + + # Get last entry in version table, compare against desired version + cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + current_version, _, _ = cursor.fetchone() + + if current_version != DATA_VERSION: + # Complain + raise Exception( + ("Database version is {}, required version is {}. " + "Please migrate database.").format(current_version, DATA_VERSION) + ) + + cursor.close() + + +log("Established connection.", "DB_INIT") diff --git a/bot/data/custom_cursor.py b/bot/data/cursor.py similarity index 100% rename from bot/data/custom_cursor.py rename to bot/data/cursor.py diff --git a/bot/data/data.py b/bot/data/data.py deleted file mode 100644 index b498f7c6..00000000 --- a/bot/data/data.py +++ /dev/null @@ -1,505 +0,0 @@ -import logging -import contextlib -from itertools import chain -from enum import Enum - -import psycopg2 as psy -from cachetools import LRUCache - -from utils.lib import DotDict -from meta import log, conf -from constants import DATA_VERSION -from .custom_cursor import DictLoggingCursor - - -# Set up database connection -log("Establishing connection.", "DB_INIT", level=logging.DEBUG) -conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor) - -# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG)) -# sq.register_adapter(datetime, lambda dt: dt.timestamp()) - - -# Check the version matches the required version -with conn: - log("Checking db version.", "DB_INIT") - cursor = conn.cursor() - - # Get last entry in version table, compare against desired version - cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") - current_version, _, _ = cursor.fetchone() - - if current_version != DATA_VERSION: - # Complain - raise Exception( - ("Database version is {}, required version is {}. " - "Please migrate database.").format(current_version, DATA_VERSION) - ) - - cursor.close() - - -log("Established connection.", "DB_INIT") - - -# --------------- Data Interface Classes --------------- -class Table: - """ - Transparent interface to a single table structure in the database. - Contains standard methods to access the table. - Intended to be subclassed to provide more derivative access for specific tables. - """ - conn = conn - queries = DotDict() - - def __init__(self, name): - self.name = name - - def select_where(self, *args, **kwargs): - with self.conn: - return select_where(self.name, *args, **kwargs) - - def select_one_where(self, *args, **kwargs): - with self.conn: - rows = self.select_where(*args, **kwargs) - return rows[0] if rows else None - - def update_where(self, *args, **kwargs): - with self.conn: - return update_where(self.name, *args, **kwargs) - - def delete_where(self, *args, **kwargs): - with self.conn: - return delete_where(self.name, *args, **kwargs) - - def insert(self, *args, **kwargs): - with self.conn: - return insert(self.name, *args, **kwargs) - - def insert_many(self, *args, **kwargs): - with self.conn: - return insert_many(self.name, *args, **kwargs) - - def upsert(self, *args, **kwargs): - with self.conn: - return upsert(self.name, *args, **kwargs) - - def save_query(self, func): - """ - Decorator to add a saved query to the table. - """ - self.queries[func.__name__] = func - - -class Row: - __slots__ = ('table', 'data', '_pending') - - conn = conn - - def __init__(self, table, data, *args, **kwargs): - super().__setattr__('table', table) - self.data = data - self._pending = None - - @property - def rowid(self): - return self.data[self.table.id_col] - - def __repr__(self): - return "Row[{}]({})".format( - self.table.name, - ', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns) - ) - - def __getattr__(self, key): - if key in self.table.columns: - if self._pending and key in self._pending: - return self._pending[key] - else: - return self.data[key] - else: - raise AttributeError(key) - - def __setattr__(self, key, value): - if key in self.table.columns: - if self._pending is None: - self.update(**{key: value}) - else: - self._pending[key] = value - else: - super().__setattr__(key, value) - - @contextlib.contextmanager - def batch_update(self): - if self._pending: - raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__)) - - self._pending = {} - try: - yield self._pending - finally: - self.update(**self._pending) - self._pending = None - - def _refresh(self): - row = self.table.select_one_where(**{self.table.id_col: self.rowid}) - if not row: - raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__)) - self.data = row - - def update(self, **values): - rows = self.table.update_where(values, **{self.table.id_col: self.rowid}) - self.data = rows[0] - - @classmethod - def _select_where(cls, _extra=None, **conditions): - return select_where(cls._table, **conditions) - - @classmethod - def _insert(cls, **values): - return insert(cls._table, **values) - - @classmethod - def _update_where(cls, values, **conditions): - return update_where(cls._table, values, **conditions) - - -class RowTable(Table): - __slots__ = ( - 'name', - 'columns', - 'id_col', - 'row_cache' - ) - - conn = conn - - def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000): - self.name = name - self.columns = columns - self.id_col = id_col - self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None - - # Extend original Table update methods to modify the cached rows - def update_where(self, *args, **kwargs): - data = super().update_where(*args, **kwargs) - if self.row_cache is not None: - for data_row in data: - cached_row = self.row_cache.get(data_row[self.id_col], None) - if cached_row is not None: - cached_row.data = data_row - return data - - def delete_where(self, *args, **kwargs): - data = super().delete_where(*args, **kwargs) - if self.row_cache is not None: - for data_row in data: - self.row_cache.pop(data_row[self.id_col], None) - return data - - def upsert(self, *args, **kwargs): - data = super().upsert(*args, **kwargs) - if self.row_cache is not None: - cached_row = self.row_cache.get(data[self.id_col], None) - if cached_row is not None: - cached_row.data = data - return data - - # New methods to fetch and create rows - def _make_rows(self, *data_rows): - """ - Create or retrieve Row objects for each provided data row. - If the rows already exist in cache, updates the cached row. - """ - if self.row_cache is not None: - rows = [] - for data_row in data_rows: - rowid = data_row[self.id_col] - - cached_row = self.row_cache.get(rowid, None) - if cached_row is not None: - cached_row.data = data_row - row = cached_row - else: - row = Row(self, data_row) - self.row_cache[rowid] = row - rows.append(row) - else: - rows = [Row(self, data_row) for data_row in data_rows] - return rows - - def create_row(self, *args, **kwargs): - data = self.insert(*args, **kwargs) - return self._make_rows(data)[0] - - def fetch_rows_where(self, *args, **kwargs): - # TODO: Handle list of rowids here? - data = self.select_where(*args, **kwargs) - return self._make_rows(*data) - - def fetch(self, rowid): - """ - Fetch the row with the given id, retrieving from cache where possible. - """ - row = self.row_cache.get(rowid, None) if self.row_cache is not None else None - if row is None: - rows = self.fetch_rows_where(**{self.id_col: rowid}) - row = rows[0] if rows else None - return row - - def fetch_or_create(self, rowid=None, **kwargs): - """ - Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. - """ - if rowid is not None: - row = self.fetch(rowid) - else: - data = self.select_where(**kwargs) - row = self._make_rows(data[0])[0] if data else None - - if row is None: - creation_kwargs = kwargs - if rowid is not None: - creation_kwargs[self.id_col] = rowid - row = self.create_row(**creation_kwargs) - return row - - -# --------------- Query Builders --------------- -def select_where(table, select_columns=None, cursor=None, _extra='', **conditions): - """ - Select rows from the given table matching the conditions - """ - criteria, criteria_values = _format_conditions(conditions) - col_str = _format_selectkeys(select_columns) - - if conditions: - where_str = "WHERE {}".format(criteria) - else: - where_str = "" - - cursor = cursor or conn.cursor() - cursor.execute( - 'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra), - criteria_values - ) - return cursor.fetchall() - - -def update_where(table, valuedict, cursor=None, **conditions): - """ - Update rows in the given table matching the conditions - """ - key_str, key_values = _format_updatestr(valuedict) - criteria, criteria_values = _format_conditions(conditions) - - if conditions: - where_str = "WHERE {}".format(criteria) - else: - where_str = "" - - cursor = cursor or conn.cursor() - cursor.execute( - 'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str), - tuple((*key_values, *criteria_values)) - ) - return cursor.fetchall() - - -def delete_where(table, cursor=None, **conditions): - """ - Delete rows in the given table matching the conditions - """ - criteria, criteria_values = _format_conditions(conditions) - - cursor = cursor or conn.cursor() - cursor.execute( - 'DELETE FROM {} WHERE {}'.format(table, criteria), - criteria_values - ) - return cursor.fetchall() - - -def insert(table, cursor=None, allow_replace=False, **values): - """ - Insert the given values into the table - """ - keys, values = zip(*values.items()) - - key_str = _format_insertkeys(keys) - value_str, values = _format_insertvalues(values) - - action = 'REPLACE' if allow_replace else 'INSERT' - - cursor = cursor or conn.cursor() - cursor.execute( - '{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str), - values - ) - return cursor.fetchone() - - -def insert_many(table, *value_tuples, insert_keys=None, cursor=None): - """ - Insert all the given values into the table - """ - key_str = _format_insertkeys(insert_keys) - value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples)) - - value_str = ", ".join(value_strs) - values = tuple(chain(*value_tuples)) - - cursor = cursor or conn.cursor() - cursor.execute( - 'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str), - values - ) - return cursor.fetchall() - - -def upsert(table, constraint, cursor=None, **values): - """ - Insert or on conflict update. - """ - valuedict = values - keys, values = zip(*values.items()) - - key_str = _format_insertkeys(keys) - value_str, values = _format_insertvalues(values) - update_key_str, update_key_values = _format_updatestr(valuedict) - - if not isinstance(constraint, str): - constraint = ", ".join(constraint) - - cursor = cursor or conn.cursor() - cursor.execute( - 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( - table, key_str, value_str, constraint, update_key_str - ), - tuple((*values, *update_key_values)) - ) - return cursor.fetchone() - - -# --------------- Query Formatting Tools --------------- -# Replace char used by the connection for query formatting -_replace_char: str = '%s' - - -class fieldConstants(Enum): - """ - A collection of database field constants to use for selection conditions. - """ - NULL = "IS NULL" - NOTNULL = "IS NOT NULL" - - -class _updateField: - __slots__ = () - _EMPTY = object() # Return value for `value` indicating no value should be added - - def key_field(self, key): - raise NotImplementedError - - def value_field(self, key): - raise NotImplementedError - - -class UpdateValue(_updateField): - __slots__ = ('key_str', 'value') - - def __init__(self, key_str, value=_updateField._EMPTY): - self.key_str = key_str - self.value = value - - def key_field(self, key): - return self.key_str.format(key=key, value=_replace_char, replace=_replace_char) - - def value_field(self, key): - return self.value - - -class UpdateValueAdd(_updateField): - __slots__ = ('value',) - - def __init__(self, value): - self.value = value - - def key_field(self, key): - return "{key} = {key} + {replace}".format(key=key, replace=_replace_char) - - def value_field(self, key): - return self.value - - -def _format_conditions(conditions): - """ - Formats a dictionary of conditions into a string suitable for 'WHERE' clauses. - Supports `IN` type conditionals. - """ - if not conditions: - return ("", tuple()) - - values = [] - conditional_strings = [] - for key, item in conditions.items(): - if isinstance(item, (list, tuple)): - conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item)))) - values.extend(item) - elif isinstance(item, fieldConstants): - conditional_strings.append("{} {}".format(key, item.value)) - else: - conditional_strings.append("{}={}".format(key, _replace_char)) - values.append(item) - - return (' AND '.join(conditional_strings), values) - - -def _format_selectkeys(keys): - """ - Formats a list of keys into a string suitable for `SELECT`. - """ - if not keys: - return "*" - else: - return ", ".join(keys) - - -def _format_insertkeys(keys): - """ - Formats a list of keys into a string suitable for `INSERT` - """ - if not keys: - return "" - else: - return "({})".format(", ".join(keys)) - - -def _format_insertvalues(values): - """ - Formats a list of values into a string suitable for `INSERT` - """ - value_str = "({})".format(", ".join(_replace_char for value in values)) - return (value_str, values) - - -def _format_updatestr(valuedict): - """ - Formats a dictionary of keys and values into a string suitable for 'SET' clauses. - """ - if not valuedict: - return ("", tuple()) - - key_fields = [] - values = [] - for key, value in valuedict.items(): - if isinstance(value, _updateField): - key_fields.append(value.key_field(key)) - v = value.value_field(key) - if v is not _updateField._EMPTY: - values.append(value.value_field(key)) - else: - key_fields.append("{} = {}".format(key, _replace_char)) - values.append(value) - - return (', '.join(key_fields), values) diff --git a/bot/data/formatters.py b/bot/data/formatters.py new file mode 100644 index 00000000..4bdccbc3 --- /dev/null +++ b/bot/data/formatters.py @@ -0,0 +1,113 @@ +from .connection import _replace_char +from .conditions import Condition + + +class _updateField: + __slots__ = () + _EMPTY = object() # Return value for `value` indicating no value should be added + + def key_field(self, key): + raise NotImplementedError + + def value_field(self, key): + raise NotImplementedError + + +class UpdateValue(_updateField): + __slots__ = ('key_str', 'value') + + def __init__(self, key_str, value=_updateField._EMPTY): + self.key_str = key_str + self.value = value + + def key_field(self, key): + return self.key_str.format(key=key, value=_replace_char, replace=_replace_char) + + def value_field(self, key): + return self.value + + +class UpdateValueAdd(_updateField): + __slots__ = ('value',) + + def __init__(self, value): + self.value = value + + def key_field(self, key): + return "{key} = {key} + {replace}".format(key=key, replace=_replace_char) + + def value_field(self, key): + return self.value + + +def _format_conditions(conditions): + """ + Formats a dictionary of conditions into a string suitable for 'WHERE' clauses. + Supports `IN` type conditionals. + """ + if not conditions: + return ("", tuple()) + + values = [] + conditional_strings = [] + for key, item in conditions.items(): + if isinstance(item, (list, tuple)): + conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item)))) + values.extend(item) + elif isinstance(item, Condition): + item.apply(key, values, conditional_strings) + else: + conditional_strings.append("{}={}".format(key, _replace_char)) + values.append(item) + + return (' AND '.join(conditional_strings), values) + + +def _format_selectkeys(keys): + """ + Formats a list of keys into a string suitable for `SELECT`. + """ + if not keys: + return "*" + else: + return ", ".join(keys) + + +def _format_insertkeys(keys): + """ + Formats a list of keys into a string suitable for `INSERT` + """ + if not keys: + return "" + else: + return "({})".format(", ".join(keys)) + + +def _format_insertvalues(values): + """ + Formats a list of values into a string suitable for `INSERT` + """ + value_str = "({})".format(", ".join(_replace_char for value in values)) + return (value_str, values) + + +def _format_updatestr(valuedict): + """ + Formats a dictionary of keys and values into a string suitable for 'SET' clauses. + """ + if not valuedict: + return ("", tuple()) + + key_fields = [] + values = [] + for key, value in valuedict.items(): + if isinstance(value, _updateField): + key_fields.append(value.key_field(key)) + v = value.value_field(key) + if v is not _updateField._EMPTY: + values.append(value.value_field(key)) + else: + key_fields.append("{} = {}".format(key, _replace_char)) + values.append(value) + + return (', '.join(key_fields), values) diff --git a/bot/data/interfaces.py b/bot/data/interfaces.py new file mode 100644 index 00000000..88e60a15 --- /dev/null +++ b/bot/data/interfaces.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import contextlib +from cachetools import LRUCache +from typing import Mapping + +from utils.lib import DotDict + +from .connection import conn +from .queries import insert, insert_many, select_where, update_where, upsert, delete_where, update_many + + +# Global cache of interfaces +tables: Mapping[str, Table] = DotDict() + + +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + Intended to be subclassed to provide more derivative access for specific tables. + """ + conn = conn + queries = DotDict() + + def __init__(self, name, attach_as=None): + self.name = name + tables[attach_as or name] = self + + def select_where(self, *args, **kwargs): + with self.conn: + return select_where(self.name, *args, **kwargs) + + def select_one_where(self, *args, **kwargs): + rows = self.select_where(*args, **kwargs) + return rows[0] if rows else None + + def update_where(self, *args, **kwargs): + with self.conn: + return update_where(self.name, *args, **kwargs) + + def delete_where(self, *args, **kwargs): + with self.conn: + return delete_where(self.name, *args, **kwargs) + + def insert(self, *args, **kwargs): + with self.conn: + return insert(self.name, *args, **kwargs) + + def insert_many(self, *args, **kwargs): + with self.conn: + return insert_many(self.name, *args, **kwargs) + + def update_many(self, *args, **kwargs): + with self.conn: + return update_many(self.name, *args, **kwargs) + + def upsert(self, *args, **kwargs): + with self.conn: + return upsert(self.name, *args, **kwargs) + + def save_query(self, func): + """ + Decorator to add a saved query to the table. + """ + self.queries[func.__name__] = func + + +class Row: + __slots__ = ('table', 'data', '_pending') + + conn = conn + + def __init__(self, table, data, *args, **kwargs): + super().__setattr__('table', table) + self.data = data + self._pending = None + + @property + def rowid(self): + return self.table.id_from_row(self.data) + + def __repr__(self): + return "Row[{}]({})".format( + self.table.name, + ', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns) + ) + + def __getattr__(self, key): + if key in self.table.columns: + if self._pending and key in self._pending: + return self._pending[key] + else: + return self.data[key] + else: + raise AttributeError(key) + + def __setattr__(self, key, value): + if key in self.table.columns: + if self._pending is None: + self.update(**{key: value}) + else: + self._pending[key] = value + else: + super().__setattr__(key, value) + + @contextlib.contextmanager + def batch_update(self): + if self._pending: + raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__)) + + self._pending = {} + try: + yield self._pending + finally: + self.update(**self._pending) + self._pending = None + + def _refresh(self): + row = self.table.select_one_where(self.table.dict_from_id(self.rowid)) + if not row: + raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__)) + self.data = row + + def update(self, **values): + rows = self.table.update_where(values, **self.table.dict_from_id(self.rowid)) + self.data = rows[0] + + @classmethod + def _select_where(cls, _extra=None, **conditions): + return select_where(cls._table, **conditions) + + @classmethod + def _insert(cls, **values): + return insert(cls._table, **values) + + @classmethod + def _update_where(cls, values, **conditions): + return update_where(cls._table, values, **conditions) + + +class RowTable(Table): + __slots__ = ( + 'name', + 'columns', + 'id_col', + 'multi_key', + 'row_cache' + ) + + conn = conn + + def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000, **kwargs): + super().__init__(name, **kwargs) + self.name = name + self.columns = columns + self.id_col = id_col + self.multi_key = isinstance(id_col, tuple) + self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None + + def id_from_row(self, row): + if self.multi_key: + return tuple(row[key] for key in self.id_col) + else: + return row[self.id_col] + + def dict_from_id(self, rowid): + if self.multi_key: + return dict(zip(self.id_col, rowid)) + else: + return {self.id_col: rowid} + + # Extend original Table update methods to modify the cached rows + def insert(self, *args, **kwargs): + data = super().insert(*args, **kwargs) + if self.row_cache is not None: + self.row_cache[self.id_from_row(data)] = Row(self, data) + return data + + def insert_many(self, *args, **kwargs): + data = super().insert_many(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + cached_row = self.row_cache.get(self.id_from_row(data_row), None) + if cached_row is not None: + cached_row.data = data_row + return data + + def update_where(self, *args, **kwargs): + data = super().update_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + cached_row = self.row_cache.get(self.id_from_row(data_row), None) + if cached_row is not None: + cached_row.data = data_row + return data + + def update_many(self, *args, **kwargs): + data = super().update_many(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + cached_row = self.row_cache.get(self.id_from_row(data_row), None) + if cached_row is not None: + cached_row.data = data_row + return data + + def delete_where(self, *args, **kwargs): + data = super().delete_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + self.row_cache.pop(self.id_from_row(data_row), None) + return data + + def upsert(self, *args, **kwargs): + data = super().upsert(*args, **kwargs) + if self.row_cache is not None: + rowid = self.id_from_row(data) + cached_row = self.row_cache.get(rowid, None) + if cached_row is not None: + cached_row.data = data + else: + self.row_cache[rowid] = Row(self, data) + return data + + # New methods to fetch and create rows + def _make_rows(self, *data_rows): + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + if self.row_cache is not None: + rows = [] + for data_row in data_rows: + rowid = self.id_from_row(data_row) + + cached_row = self.row_cache.get(rowid, None) + if cached_row is not None: + cached_row.data = data_row + row = cached_row + else: + row = Row(self, data_row) + self.row_cache[rowid] = row + rows.append(row) + else: + rows = [Row(self, data_row) for data_row in data_rows] + return rows + + def create_row(self, *args, **kwargs): + data = self.insert(*args, **kwargs) + return self._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs): + # TODO: Handle list of rowids here? + data = self.select_where(*args, **kwargs) + return self._make_rows(*data) + + def fetch(self, rowid): + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = self.row_cache.get(rowid, None) if self.row_cache is not None else None + if row is None: + rows = self.fetch_rows_where(**self.dict_from_id(rowid)) + row = rows[0] if rows else None + return row + + def fetch_or_create(self, rowid=None, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid is not None: + row = self.fetch(rowid) + else: + data = self.select_where(**kwargs) + row = self._make_rows(data[0])[0] if data else None + + if row is None: + creation_kwargs = kwargs + if rowid is not None: + creation_kwargs.update(self.dict_from_id(rowid)) + row = self.create_row(**creation_kwargs) + return row diff --git a/bot/data/queries.py b/bot/data/queries.py new file mode 100644 index 00000000..e4cce52d --- /dev/null +++ b/bot/data/queries.py @@ -0,0 +1,149 @@ +from itertools import chain +from psycopg2.extras import execute_values + +from .connection import conn +from .formatters import (_format_updatestr, _format_conditions, _format_insertkeys, + _format_selectkeys, _format_insertvalues) + + +def select_where(table, select_columns=None, cursor=None, _extra='', **conditions): + """ + Select rows from the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + col_str = _format_selectkeys(select_columns) + + if criteria: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra), + criteria_values + ) + return cursor.fetchall() + + +def update_where(table, valuedict, cursor=None, **conditions): + """ + Update rows in the given table matching the conditions + """ + key_str, key_values = _format_updatestr(valuedict) + criteria, criteria_values = _format_conditions(conditions) + + if criteria: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str), + tuple((*key_values, *criteria_values)) + ) + return cursor.fetchall() + + +def delete_where(table, cursor=None, **conditions): + """ + Delete rows in the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + + if criteria: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'DELETE FROM {} {} RETURNING *'.format(table, where_str), + criteria_values + ) + return cursor.fetchall() + + +def insert(table, cursor=None, allow_replace=False, **values): + """ + Insert the given values into the table + """ + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + + action = 'REPLACE' if allow_replace else 'INSERT' + + cursor = cursor or conn.cursor() + cursor.execute( + '{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str), + values + ) + return cursor.fetchone() + + +def insert_many(table, *value_tuples, insert_keys=None, cursor=None): + """ + Insert all the given values into the table + """ + key_str = _format_insertkeys(insert_keys) + value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples)) + + value_str = ", ".join(value_strs) + values = tuple(chain(*value_tuples)) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str), + values + ) + return cursor.fetchall() + + +def upsert(table, constraint, cursor=None, **values): + """ + Insert or on conflict update. + """ + valuedict = values + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + update_key_str, update_key_values = _format_updatestr(valuedict) + + if not isinstance(constraint, str): + constraint = ", ".join(constraint) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( + table, key_str, value_str, constraint, update_key_str + ), + tuple((*values, *update_key_values)) + ) + return cursor.fetchone() + + +def update_many(table, *values, set_keys=None, where_keys=None, cursor=None): + cursor = cursor or conn.cursor() + + return execute_values( + cursor, + """ + UPDATE {table} + SET {set_clause} + FROM (VALUES %s) + AS {temp_table} + WHERE {where_clause} + RETURNING * + """.format( + table=table, + set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys), + where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys), + temp_table="_t ({})".format(', '.join(set_keys + where_keys)) + ), + values, + fetch=True + ) diff --git a/bot/main.py b/bot/main.py index 4e0b6663..60cffa5e 100644 --- a/bot/main.py +++ b/bot/main.py @@ -6,6 +6,9 @@ import core # noqa import modules # noqa +# Load and attach app specific data +client.appdata = core.data.meta.fetch_or_create(conf.bot['data_appid']) + # Initialise all modules client.initialise_modules() diff --git a/bot/meta/logger.py b/bot/meta/logger.py index dfb618ad..c6d916d9 100644 --- a/bot/meta/logger.py +++ b/bot/meta/logger.py @@ -14,10 +14,34 @@ from .config import conf # Setup the logger logger = logging.getLogger() log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{') -term_handler = logging.StreamHandler(sys.stdout) -term_handler.setFormatter(log_fmt) -logger.addHandler(term_handler) -logger.setLevel(logging.INFO) +# term_handler = logging.StreamHandler(sys.stdout) +# term_handler.setFormatter(log_fmt) +# logger.addHandler(term_handler) +# logger.setLevel(logging.INFO) + + +class LessThanFilter(logging.Filter): + def __init__(self, exclusive_maximum, name=""): + super(LessThanFilter, self).__init__(name) + self.max_level = exclusive_maximum + + def filter(self, record): + # non-zero return means we log this message + return 1 if record.levelno < self.max_level else 0 + + +logger.setLevel(logging.NOTSET) + +logging_handler_out = logging.StreamHandler(sys.stdout) +logging_handler_out.setLevel(logging.DEBUG) +logging_handler_out.setFormatter(log_fmt) +logging_handler_out.addFilter(LessThanFilter(logging.WARNING)) +logger.addHandler(logging_handler_out) + +logging_handler_err = logging.StreamHandler(sys.stderr) +logging_handler_err.setLevel(logging.WARNING) +logging_handler_err.setFormatter(log_fmt) +logger.addHandler(logging_handler_err) # Define the context log format and attach it to the command logger as well diff --git a/bot/modules/__init__.py b/bot/modules/__init__.py index 427ec9f6..fcedb483 100644 --- a/bot/modules/__init__.py +++ b/bot/modules/__init__.py @@ -1,3 +1,9 @@ from .sysadmin import * +from .guild_admin import * +from .meta import * from .economy import * from .study import * +from .user_config import * +from .workout import * +from .todo import * +# from .moderation import * diff --git a/bot/modules/economy/__init__.py b/bot/modules/economy/__init__.py index 22c7b131..963fc7ff 100644 --- a/bot/modules/economy/__init__.py +++ b/bot/modules/economy/__init__.py @@ -1,2 +1,3 @@ -from . import module +from .module import module + from . import commands diff --git a/bot/modules/economy/commands.py b/bot/modules/economy/commands.py index 4b34f99a..a7d220c4 100644 --- a/bot/modules/economy/commands.py +++ b/bot/modules/economy/commands.py @@ -1,6 +1,8 @@ -from core import User -from core.tables import users +from cmdClient.checks import in_guild +import data +from data import tables +from core import Lion from utils import interactive # noqa from .module import module @@ -11,17 +13,19 @@ second_emoji = "🥈" third_emoji = "🥉" -# TODO: in_guild ward @module.cmd( - "topcoin", - short_help="View the LionCoin leaderboard.", - aliases=('topc', 'ctop') + "cointop", + group="Statistics", + desc="View the LionCoin leaderboard.", + aliases=('topc', 'ctop', 'topcoins', 'topcoin', 'cointop100'), + help_aliases={'cointop100': "View the LionCoin top 100."} ) +@in_guild() async def cmd_topcoin(ctx): """ Usage``: - {prefix}topcoin - {prefix}topcoin 100 + {prefix}cointop + {prefix}cointop 100 Description: Display the LionCoin leaderboard, or top 100. @@ -30,15 +34,17 @@ async def cmd_topcoin(ctx): # Handle args if ctx.args and not ctx.args == "100": return await ctx.error_reply( - "**Usage:**`{prefix}topcoin` or `{prefix}topcoin100`.".format(prefix=ctx.client.prefix) + "**Usage:**`{prefix}topcoin` or `{prefix}topcoin100`.".format(prefix=ctx.best_prefix) ) - top100 = ctx.args == "100" + top100 = (ctx.args == "100" or ctx.alias == "contop100") # Flush any pending coin transactions - User.sync() + Lion.sync() # Fetch the leaderboard - user_data = users.select_where( + user_data = tables.lions.select_where( + guildid=ctx.guild.id, + userid=data.NOT([m.id for m in ctx.guild_settings.unranked_roles.members]), select_columns=('userid', 'coins'), _extra="ORDER BY coins DESC " + ("LIMIT 100" if top100 else "") ) diff --git a/bot/modules/economy/module.py b/bot/modules/economy/module.py index e5192e7f..880fce2f 100644 --- a/bot/modules/economy/module.py +++ b/bot/modules/economy/module.py @@ -1,4 +1,4 @@ -from cmdClient import Module +from LionModule import LionModule -module = Module("Economy") +module = LionModule("Economy") diff --git a/bot/modules/meta/__init__.py b/bot/modules/meta/__init__.py index e69de29b..d1888a17 100644 --- a/bot/modules/meta/__init__.py +++ b/bot/modules/meta/__init__.py @@ -0,0 +1,3 @@ +from .module import module + +from . import help diff --git a/bot/modules/moderation/__init__.py b/bot/modules/moderation/__init__.py index e69de29b..52a384ed 100644 --- a/bot/modules/moderation/__init__.py +++ b/bot/modules/moderation/__init__.py @@ -0,0 +1,5 @@ +from .module import module + +from . import admin +# from . import video_channels +from . import Ticket diff --git a/bot/modules/study/__init__.py b/bot/modules/study/__init__.py index 7ef2765b..eec16e1d 100644 --- a/bot/modules/study/__init__.py +++ b/bot/modules/study/__init__.py @@ -1,2 +1,9 @@ from .module import module -from . import commands + +from . import data +from . import admin +from . import badge_tracker +from . import time_tracker +from . import top_cmd +from . import studybadge_cmd +from . import stats_cmd diff --git a/bot/modules/study/commands.py b/bot/modules/study/commands.py deleted file mode 100644 index 1e78a903..00000000 --- a/bot/modules/study/commands.py +++ /dev/null @@ -1,111 +0,0 @@ -import datetime as dt - -from core import User -from core.tables import users - -from utils import interactive # noqa - -from .module import module - - -first_emoji = "🥇" -second_emoji = "🥈" -third_emoji = "🥉" - - -# TODO: in_guild ward -@module.cmd( - "top", - short_help="View the Study Time leaderboard.", - aliases=('ttop', 'toptime') -) -async def cmd_top(ctx): - """ - Usage``: - {prefix}top - {prefix}top 100 - Description: - Display the study time leaderboard, or the top 100. - - Use the paging reactions or send `p` to switch pages (e.g. `p11` to switch to page 11). - """ - # Handle args - if ctx.args and not ctx.args == "100": - return await ctx.error_reply( - "**Usage:**`{prefix}top` or `{prefix}top100`.".format(prefix=ctx.client.prefix) - ) - top100 = ctx.args == "100" - - # Flush any pending coin transactions - User.sync() - - # Fetch the leaderboard - user_data = users.select_where( - select_columns=('userid', 'tracked_time'), - _extra="ORDER BY tracked_time DESC " + ("LIMIT 100" if top100 else "") - ) - - # Quit early if the leaderboard is empty - if not user_data: - return await ctx.reply("No leaderboard entries yet!") - - # Extract entries - author_index = None - entries = [] - for i, (userid, time) in enumerate(user_data): - member = ctx.guild.get_member(userid) - name = member.display_name if member else str(userid) - name = name.replace('*', ' ').replace('_', ' ') - - num_str = "{}.".format(i+1) - - hours = time // 3600 - minutes = time // 60 % 60 - seconds = time % 60 - - time_str = "{}:{:02}:{:02}".format( - hours, - minutes, - seconds - ) - - if ctx.author.id == userid: - author_index = i - - entries.append((num_str, name, time_str)) - - # Extract blocks - blocks = [entries[i:i+20] for i in range(0, len(entries), 20)] - block_count = len(blocks) - - # Build strings - header = "Study Time Top 100" if top100 else "Study Time Leaderboard" - if block_count > 1: - header += " (Page {{page}}/{})".format(block_count) - - # Build pages - pages = [] - for i, block in enumerate(blocks): - max_num_l, max_name_l, max_time_l = [max(len(e[i]) for e in block) for i in (0, 1, 2)] - body = '\n'.join( - "{:>{}} {:<{}} \t {:>{}} {} {}".format( - entry[0], max_num_l, - entry[1], max_name_l + 2, - entry[2], max_time_l + 1, - first_emoji if i == 0 and j == 0 else ( - second_emoji if i == 0 and j == 1 else ( - third_emoji if i == 0 and j == 2 else '' - ) - ), - "⮜" if author_index is not None and author_index == i * 20 + j else "" - ) - for j, entry in enumerate(block) - ) - title = header.format(page=i+1) - line = '='*len(title) - pages.append( - "```md\n{}\n{}\n{}```".format(title, line, body) - ) - - # Finally, page the results - await ctx.pager(pages, start_at=(author_index or 0)//20 if not top100 else 0) diff --git a/bot/modules/study/module.py b/bot/modules/study/module.py index 613746ca..ae88f7dd 100644 --- a/bot/modules/study/module.py +++ b/bot/modules/study/module.py @@ -1,4 +1,4 @@ -from cmdClient import Module +from LionModule import LionModule -module = Module("Study_Stats") +module = LionModule("Study_Stats") diff --git a/bot/modules/sysadmin/__init__.py b/bot/modules/sysadmin/__init__.py index 5bccdb52..5401a965 100644 --- a/bot/modules/sysadmin/__init__.py +++ b/bot/modules/sysadmin/__init__.py @@ -1 +1,3 @@ +from .module import module + from .exec_cmds import * diff --git a/bot/modules/sysadmin/exec_cmds.py b/bot/modules/sysadmin/exec_cmds.py index d3af7f24..40260f5b 100644 --- a/bot/modules/sysadmin/exec_cmds.py +++ b/bot/modules/sysadmin/exec_cmds.py @@ -5,35 +5,45 @@ import asyncio from cmdClient import cmd, checks +from core import Lion +from LionModule import LionModule + """ Exec level commands to manage the bot. Commands provided: async: Executes provided code in an async executor - exec: - Executes code using standard python exec eval: Executes code and awaits it if required """ -@cmd("reboot") +@cmd("shutdown", + desc="Sync data and shutdown.", + group="Bot Admin", + aliases=('restart', 'reboot')) @checks.is_owner() -async def cmd_reboot(ctx): +async def cmd_shutdown(ctx): """ Usage``: reboot Description: - Update the timer status save file and reboot the client. + Run unload tasks and shutdown/reboot. """ - ctx.client.interface.update_save("reboot") - ctx.client.interface.shutdown() - await ctx.reply("Saved state. Rebooting now!") + # Run module logout tasks + for module in ctx.client.modules: + if isinstance(module, LionModule): + await module.unload(ctx.client) + + # Reply and logout + await ctx.reply("All modules synced. Shutting down!") await ctx.client.close() -@cmd("async") +@cmd("async", + desc="Execute arbitrary code with `async`.", + group="Bot Admin") @checks.is_owner() async def cmd_async(ctx): """ @@ -55,7 +65,9 @@ async def cmd_async(ctx): output)) -@cmd("eval") +@cmd("eval", + desc="Execute arbitrary code with `eval`.", + group="Bot Admin") @checks.is_owner() async def cmd_eval(ctx): """ diff --git a/bot/modules/todo/__init__.py b/bot/modules/todo/__init__.py index e69de29b..df71f3a2 100644 --- a/bot/modules/todo/__init__.py +++ b/bot/modules/todo/__init__.py @@ -0,0 +1,6 @@ +from .module import module + +from . import Tasklist +from . import admin +from . import data +from . import commands diff --git a/bot/modules/workout/__init__.py b/bot/modules/workout/__init__.py index e69de29b..c209e42e 100644 --- a/bot/modules/workout/__init__.py +++ b/bot/modules/workout/__init__.py @@ -0,0 +1,5 @@ +from .module import module + +from . import admin +from . import data +from . import tracker diff --git a/bot/utils/lib.py b/bot/utils/lib.py index 82c48116..eb2416fe 100644 --- a/bot/utils/lib.py +++ b/bot/utils/lib.py @@ -4,6 +4,8 @@ import re import discord +from cmdClient.lib import SafeCancellation + def prop_tabulate(prop_list, value_list, indent=True): """ @@ -193,6 +195,25 @@ def parse_dur(time_str): return seconds +def strfdur(duration): + """ + Convert a duration given in seconds to a number of hours, minutes, and seconds. + """ + hours = duration // 3600 + minutes = duration // 60 % 60 + seconds = duration % 60 + + parts = [] + if hours: + parts.append('{}h'.format(hours)) + if minutes: + parts.append('{}m'.format(minutes)) + if seconds or duration == 0: + parts.append('{}s'.format(seconds)) + + return ' '.join(parts) + + def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','): """ Substitutes a user provided list of numbers and ranges, @@ -213,12 +234,28 @@ def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','): n1 = int(match.group(1)) n2 = int(match.group(2)) if n2 - n1 > max_range: - raise ValueError("Provided range exceeds the allowed maximum.") + raise SafeCancellation("Provided range is too large!") return separator.join(str(i) for i in range(n1, n2 + 1)) return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) +def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs): + """ + Parses a user provided range string into a list of numbers. + Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`. + """ + substituted = substitute_ranges(ranges_str, separator=separator, **kwargs) + numbers = (item.strip() for item in substituted.split(',')) + numbers = [item for item in numbers if item] + integers = [int(item) for item in numbers if item.isdigit()] + + if not ignore_errors and len(integers) != len(numbers): + raise SafeCancellation("Couldn't parse the provided selection!") + + return integers + + def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True): """ Format a message into a string with various information, such as: diff --git a/bot/wards.py b/bot/wards.py new file mode 100644 index 00000000..d867ccbe --- /dev/null +++ b/bot/wards.py @@ -0,0 +1,24 @@ +from cmdClient import check +from cmdClient.checks import in_guild + +from data import tables + + +def is_guild_admin(member): + # First check guild admin permissions + admin = member.guild_permissions.administrator + + # Then check the admin role, if it is set + if not admin: + admin_role_id = tables.guild_config.fetch_or_create(member.guild.id).admin_role + admin = admin_role_id and (admin_role_id in (r.id for r in member.roles)) + return admin + + +@check( + name="ADMIN", + msg=("You need to be a server admin to do this!"), + requires=[in_guild] +) +async def guild_admin(ctx, *args, **kwargs): + return is_guild_admin(ctx.author) diff --git a/config/example-bot.conf b/config/example-bot.conf index e69de29b..52226759 100644 --- a/config/example-bot.conf +++ b/config/example-bot.conf @@ -0,0 +1,12 @@ +[DEFAULT] +log_file = bot.log +log_channel = + +prefix = ! +token = +owners = 413668234269818890, 389399222400712714 + +database = dbname=lionbot +data_appid = LionBot + +lion_sync_period = 60