Core bot framework
This commit is contained in:
2
bot/constants.py
Normal file
2
bot/constants.py
Normal file
@@ -0,0 +1,2 @@
|
||||
CONFIG_FILE = "config/bot.conf"
|
||||
DATA_VERSION = 0
|
||||
3
bot/data/__init__.py
Normal file
3
bot/data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .data import *
|
||||
from . import tables
|
||||
# from . import queries
|
||||
30
bot/data/custom_cursor.py
Normal file
30
bot/data/custom_cursor.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
from psycopg2.extras import DictCursor, _ext
|
||||
|
||||
from meta import log
|
||||
|
||||
|
||||
class DictLoggingCursor(DictCursor):
|
||||
def log(self):
|
||||
msg = self.query
|
||||
if isinstance(msg, bytes):
|
||||
msg = msg.decode(_ext.encodings[self.connection.encoding], 'replace')
|
||||
|
||||
log(
|
||||
msg,
|
||||
context="DATABASE_QUERY",
|
||||
level=logging.DEBUG,
|
||||
post=False
|
||||
)
|
||||
|
||||
def execute(self, query, vars=None):
|
||||
try:
|
||||
return super().execute(query, vars)
|
||||
finally:
|
||||
self.log()
|
||||
|
||||
def callproc(self, procname, vars=None):
|
||||
try:
|
||||
return super().callproc(procname, vars)
|
||||
finally:
|
||||
self.log()
|
||||
450
bot/data/data.py
Normal file
450
bot/data/data.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import logging
|
||||
import contextlib
|
||||
from itertools import chain
|
||||
from enum import Enum
|
||||
|
||||
import psycopg2 as psy
|
||||
from cachetools import LRUCache
|
||||
|
||||
from meta import log, conf
|
||||
from constants import DATA_VERSION
|
||||
from .custom_cursor import DictLoggingCursor
|
||||
|
||||
|
||||
# Set up database connection
|
||||
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
|
||||
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
|
||||
|
||||
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
|
||||
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
|
||||
|
||||
|
||||
# Check the version matches the required version
|
||||
with conn:
|
||||
log("Checking db version.", "DB_INIT")
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get last entry in version table, compare against desired version
|
||||
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
current_version, _, _ = cursor.fetchone()
|
||||
|
||||
if current_version != DATA_VERSION:
|
||||
# Complain
|
||||
raise Exception(
|
||||
("Database version is {}, required version is {}. "
|
||||
"Please migrate database.").format(current_version, DATA_VERSION)
|
||||
)
|
||||
|
||||
cursor.close()
|
||||
|
||||
|
||||
log("Established connection.", "DB_INIT")
|
||||
|
||||
|
||||
# --------------- Data Interface Classes ---------------
|
||||
class Table:
|
||||
"""
|
||||
Transparent interface to a single table structure in the database.
|
||||
Contains standard methods to access the table.
|
||||
Intended to be subclassed to provide more derivative access for specific tables.
|
||||
"""
|
||||
conn = conn
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def select_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return select_where(self.name, *args, **kwargs)
|
||||
|
||||
def select_one_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
rows = self.select_where(*args, **kwargs)
|
||||
return rows[0] if rows else None
|
||||
|
||||
def update_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return update_where(self.name, *args, **kwargs)
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return delete_where(self.name, *args, **kwargs)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert(self.name, *args, **kwargs)
|
||||
|
||||
def insert_many(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert_many(self.name, *args, **kwargs)
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return upsert(self.name, *args, **kwargs)
|
||||
|
||||
|
||||
class Row:
|
||||
__slots__ = ('table', 'data', '_pending')
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, table, data, *args, **kwargs):
|
||||
super().__setattr__('table', table)
|
||||
self.data = data
|
||||
self._pending = None
|
||||
|
||||
@property
|
||||
def rowid(self):
|
||||
return self.data[self.table.id_col]
|
||||
|
||||
def __repr__(self):
|
||||
return "Row[{}]({})".format(
|
||||
self.table.name,
|
||||
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
|
||||
)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self.table.columns:
|
||||
if self._pending and key in self._pending:
|
||||
return self._pending[key]
|
||||
else:
|
||||
return self.data[key]
|
||||
else:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self.table.columns:
|
||||
if self._pending is None:
|
||||
self.update(**{key: value})
|
||||
else:
|
||||
self._pending[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batch_update(self):
|
||||
if self._pending:
|
||||
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
|
||||
|
||||
self._pending = {}
|
||||
try:
|
||||
yield self._pending
|
||||
finally:
|
||||
self.update(**self._pending)
|
||||
self._pending = None
|
||||
|
||||
def _refresh(self):
|
||||
row = self.table.select_one_where(**{self.table.id_col: self.rowid})
|
||||
if not row:
|
||||
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
|
||||
self.data = row
|
||||
|
||||
def update(self, **values):
|
||||
rows = self.table.update_where(values, **{self.table.id_col: self.rowid})
|
||||
self.data = rows[0]
|
||||
|
||||
@classmethod
|
||||
def _select_where(cls, _extra=None, **conditions):
|
||||
return select_where(cls._table, **conditions)
|
||||
|
||||
@classmethod
|
||||
def _insert(cls, **values):
|
||||
return insert(cls._table, **values)
|
||||
|
||||
@classmethod
|
||||
def _update_where(cls, values, **conditions):
|
||||
return update_where(cls._table, values, **conditions)
|
||||
|
||||
|
||||
class RowTable(Table):
|
||||
__slots__ = (
|
||||
'name',
|
||||
'columns',
|
||||
'id_col',
|
||||
'row_cache'
|
||||
)
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000):
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.id_col = id_col
|
||||
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
|
||||
|
||||
# Extend original Table update methods to modify the cached rows
|
||||
def update_where(self, *args, **kwargs):
|
||||
data = super().update_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
cached_row = self.row_cache.get(data_row[self.id_col], None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
return data
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
data = super().delete_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
self.row_cache.pop(data_row[self.id_col], None)
|
||||
return data
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
data = super().upsert(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
cached_row = self.row_cache.get(data[self.id_col], None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
def _make_rows(self, *data_rows):
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
if self.row_cache is not None:
|
||||
rows = []
|
||||
for data_row in data_rows:
|
||||
rowid = data_row[self.id_col]
|
||||
|
||||
cached_row = self.row_cache.get(rowid, None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
row = cached_row
|
||||
else:
|
||||
row = Row(self, data_row)
|
||||
self.row_cache[rowid] = row
|
||||
rows.append(row)
|
||||
else:
|
||||
rows = [Row(self, data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
def create_row(self, *args, **kwargs):
|
||||
data = self.insert(*args, **kwargs)
|
||||
return self._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs):
|
||||
# TODO: Handle list of rowids here?
|
||||
data = self.select_where(*args, **kwargs)
|
||||
return self._make_rows(*data)
|
||||
|
||||
def fetch(self, rowid):
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
|
||||
if row is None:
|
||||
rows = self.fetch_rows_where(**{self.id_col: rowid})
|
||||
row = rows[0] if rows else None
|
||||
return row
|
||||
|
||||
def fetch_or_create(self, rowid=None, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid is not None:
|
||||
row = self.fetch(rowid)
|
||||
else:
|
||||
data = self.select_where(**kwargs)
|
||||
row = self._make_rows(data[0])[0] if data else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid is not None:
|
||||
creation_kwargs[self.id_col] = rowid
|
||||
row = self.create_row(**creation_kwargs)
|
||||
return row
|
||||
|
||||
|
||||
# --------------- Query Builders ---------------
|
||||
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
|
||||
"""
|
||||
Select rows from the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
col_str = _format_selectkeys(select_columns)
|
||||
|
||||
if conditions:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def update_where(table, valuedict, cursor=None, **conditions):
|
||||
"""
|
||||
Update rows in the given table matching the conditions
|
||||
"""
|
||||
key_str, key_values = _format_updatestr(valuedict)
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
if conditions:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
|
||||
tuple((*key_values, *criteria_values))
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def delete_where(table, cursor=None, **conditions):
|
||||
"""
|
||||
Delete rows in the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'DELETE FROM {} WHERE {}'.format(table, criteria),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def insert(table, cursor=None, allow_replace=False, **values):
|
||||
"""
|
||||
Insert the given values into the table
|
||||
"""
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
|
||||
action = 'REPLACE' if allow_replace else 'INSERT'
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
|
||||
"""
|
||||
Insert all the given values into the table
|
||||
"""
|
||||
key_str = _format_insertkeys(insert_keys)
|
||||
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
|
||||
|
||||
value_str = ", ".join(value_strs)
|
||||
values = tuple(chain(*value_tuples))
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def upsert(table, constraint, cursor=None, **values):
|
||||
"""
|
||||
Insert or on conflict update.
|
||||
"""
|
||||
valuedict = values
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||
|
||||
if not isinstance(constraint, str):
|
||||
constraint = ", ".join(constraint)
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||
table, key_str, value_str, constraint, update_key_str
|
||||
),
|
||||
tuple((*values, *update_key_values))
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
# --------------- Query Formatting Tools ---------------
|
||||
# Replace char used by the connection for query formatting
|
||||
_replace_char: str = '%s'
|
||||
|
||||
|
||||
class fieldConstants(Enum):
|
||||
"""
|
||||
A collection of database field constants to use for selection conditions.
|
||||
"""
|
||||
NULL = "IS NULL"
|
||||
NOTNULL = "IS NOT NULL"
|
||||
|
||||
|
||||
def _format_conditions(conditions):
|
||||
"""
|
||||
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
||||
Supports `IN` type conditionals.
|
||||
"""
|
||||
if not conditions:
|
||||
return ("", tuple())
|
||||
|
||||
values = []
|
||||
conditional_strings = []
|
||||
for key, item in conditions.items():
|
||||
if isinstance(item, (list, tuple)):
|
||||
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
values.extend(item)
|
||||
elif isinstance(item, fieldConstants):
|
||||
conditional_strings.append("{} {}".format(key, item.value))
|
||||
else:
|
||||
conditional_strings.append("{}={}".format(key, _replace_char))
|
||||
values.append(item)
|
||||
|
||||
return (' AND '.join(conditional_strings), values)
|
||||
|
||||
|
||||
def _format_selectkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `SELECT`.
|
||||
"""
|
||||
if not keys:
|
||||
return "*"
|
||||
else:
|
||||
return ", ".join(keys)
|
||||
|
||||
|
||||
def _format_insertkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `INSERT`
|
||||
"""
|
||||
if not keys:
|
||||
return ""
|
||||
else:
|
||||
return "({})".format(", ".join(keys))
|
||||
|
||||
|
||||
def _format_insertvalues(values):
|
||||
"""
|
||||
Formats a list of values into a string suitable for `INSERT`
|
||||
"""
|
||||
value_str = "({})".format(", ".join(_replace_char for value in values))
|
||||
return (value_str, values)
|
||||
|
||||
|
||||
def _format_updatestr(valuedict):
|
||||
"""
|
||||
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
|
||||
"""
|
||||
if not valuedict:
|
||||
return ("", tuple())
|
||||
keys, values = zip(*valuedict.items())
|
||||
|
||||
set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys)
|
||||
|
||||
return (set_str, values)
|
||||
8
bot/data/tables.py
Normal file
8
bot/data/tables.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .data import RowTable, Table
|
||||
|
||||
raw_users = Table('Users')
|
||||
users = RowTable(
|
||||
'users',
|
||||
('userid', 'tracked_time', 'coins'),
|
||||
'userid',
|
||||
)
|
||||
7
bot/dev_main.py
Normal file
7
bot/dev_main.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import logging
|
||||
import meta
|
||||
|
||||
meta.logger.setLevel(logging.DEBUG)
|
||||
logging.getLogger("discord").setLevel(logging.INFO)
|
||||
|
||||
import main # noqa
|
||||
12
bot/main.py
12
bot/main.py
@@ -0,0 +1,12 @@
|
||||
from meta import client, conf, log
|
||||
|
||||
import data # noqa
|
||||
|
||||
import modules # noqa
|
||||
|
||||
# Initialise all modules
|
||||
client.initialise_modules()
|
||||
|
||||
# Log readyness and execute
|
||||
log("Initial setup complete, logging in", context='SETUP')
|
||||
client.run(conf.bot['TOKEN'])
|
||||
|
||||
3
bot/meta/__init__.py
Normal file
3
bot/meta/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .client import client
|
||||
from .config import conf
|
||||
from .logger import log, logger
|
||||
13
bot/meta/client.py
Normal file
13
bot/meta/client.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from cmdClient.cmdClient import cmdClient
|
||||
|
||||
from .config import Conf
|
||||
|
||||
from constants import CONFIG_FILE
|
||||
|
||||
# Initialise config
|
||||
conf = Conf(CONFIG_FILE)
|
||||
|
||||
# Initialise client
|
||||
owners = [int(owner) for owner in conf.bot.getlist('owners')]
|
||||
client = cmdClient(prefix=conf.bot['prefix'], owners=owners)
|
||||
client.conf = conf
|
||||
59
bot/meta/config.py
Normal file
59
bot/meta/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import configparser as cfgp
|
||||
|
||||
|
||||
conf = None # type: Conf
|
||||
|
||||
CONF_FILE = "bot/bot.conf"
|
||||
|
||||
|
||||
class Conf:
|
||||
def __init__(self, configfile, section_name="DEFAULT"):
|
||||
self.configfile = configfile
|
||||
|
||||
self.config = cfgp.ConfigParser(
|
||||
converters={
|
||||
"intlist": self._getintlist,
|
||||
"list": self._getlist,
|
||||
}
|
||||
)
|
||||
self.config.read(configfile)
|
||||
|
||||
self.section_name = section_name if section_name in self.config else 'DEFAULT'
|
||||
|
||||
self.default = self.config["DEFAULT"]
|
||||
self.section = self.config[self.section_name]
|
||||
self.bot = self.section
|
||||
|
||||
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
|
||||
more_to_read = self.section.getlist("ALSO_READ", [])
|
||||
read = set()
|
||||
while more_to_read:
|
||||
to_read = more_to_read.pop(0)
|
||||
read.add(to_read)
|
||||
self.config.read(to_read)
|
||||
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
|
||||
if path not in read and path not in more_to_read]
|
||||
more_to_read.extend(new_paths)
|
||||
|
||||
global conf
|
||||
conf = self
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.section[key].strip()
|
||||
|
||||
def __getattr__(self, section):
|
||||
return self.config[section]
|
||||
|
||||
def get(self, name, fallback=None):
|
||||
result = self.section.get(name, fallback)
|
||||
return result.strip() if result else result
|
||||
|
||||
def _getintlist(self, value):
|
||||
return [int(item.strip()) for item in value.split(',')]
|
||||
|
||||
def _getlist(self, value):
|
||||
return [item.strip() for item in value.split(',')]
|
||||
|
||||
def write(self):
|
||||
with open(self.configfile, 'w') as conffile:
|
||||
self.config.write(conffile)
|
||||
72
bot/meta/logger.py
Normal file
72
bot/meta/logger.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import sys
|
||||
import logging
|
||||
import asyncio
|
||||
from discord import AllowedMentions
|
||||
|
||||
from cmdClient.logger import cmd_log_handler
|
||||
|
||||
from utils.lib import mail, split_text
|
||||
|
||||
from .client import client
|
||||
from .config import conf
|
||||
|
||||
|
||||
# Setup the logger
|
||||
logger = logging.getLogger()
|
||||
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{')
|
||||
term_handler = logging.StreamHandler(sys.stdout)
|
||||
term_handler.setFormatter(log_fmt)
|
||||
logger.addHandler(term_handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# Define the context log format and attach it to the command logger as well
|
||||
@cmd_log_handler
|
||||
def log(message, context="GLOBAL", level=logging.INFO, post=True):
|
||||
# Add prefixes to lines for better parsing capability
|
||||
lines = message.splitlines()
|
||||
if len(lines) > 1:
|
||||
lines = [
|
||||
'┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line
|
||||
for i, line in enumerate(lines)
|
||||
]
|
||||
else:
|
||||
lines = ['─ ' + message]
|
||||
|
||||
for line in lines:
|
||||
logger.log(level, '\b[{}] {}'.format(
|
||||
str(context).center(22, '='),
|
||||
line
|
||||
))
|
||||
|
||||
# Fire and forget to the channel logger, if it is set up
|
||||
if post and client.is_ready():
|
||||
asyncio.ensure_future(live_log(message, context, level))
|
||||
|
||||
|
||||
# Live logger that posts to the logging channels
|
||||
async def live_log(message, context, level):
|
||||
if level >= logging.INFO:
|
||||
log_chid = conf.bot.getint('log_channel')
|
||||
|
||||
# Generate the log messages
|
||||
header = "[{}][{}]".format(logging.getLevelName(level), str(context))
|
||||
if len(message) > 1900:
|
||||
blocks = split_text(message, blocksize=1900, code=False)
|
||||
else:
|
||||
blocks = [message]
|
||||
|
||||
if len(blocks) > 1:
|
||||
blocks = [
|
||||
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
|
||||
]
|
||||
else:
|
||||
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
|
||||
|
||||
# Post the log messages
|
||||
if log_chid:
|
||||
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
|
||||
|
||||
|
||||
# Attach logger to client, for convenience
|
||||
client.log = log
|
||||
@@ -0,0 +1 @@
|
||||
from .sysadmin import *
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .exec_cmds import *
|
||||
|
||||
123
bot/modules/sysadmin/exec_cmds.py
Normal file
123
bot/modules/sysadmin/exec_cmds.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import sys
|
||||
from io import StringIO
|
||||
import traceback
|
||||
import asyncio
|
||||
|
||||
from cmdClient import cmd, checks
|
||||
|
||||
"""
|
||||
Exec level commands to manage the bot.
|
||||
|
||||
Commands provided:
|
||||
async:
|
||||
Executes provided code in an async executor
|
||||
exec:
|
||||
Executes code using standard python exec
|
||||
eval:
|
||||
Executes code and awaits it if required
|
||||
"""
|
||||
|
||||
|
||||
@cmd("reboot")
|
||||
@checks.is_owner()
|
||||
async def cmd_reboot(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
reboot
|
||||
Description:
|
||||
Update the timer status save file and reboot the client.
|
||||
"""
|
||||
ctx.client.interface.update_save("reboot")
|
||||
ctx.client.interface.shutdown()
|
||||
await ctx.reply("Saved state. Rebooting now!")
|
||||
await ctx.client.close()
|
||||
|
||||
|
||||
@cmd("async")
|
||||
@checks.is_owner()
|
||||
async def cmd_async(ctx):
|
||||
"""
|
||||
Usage:
|
||||
{prefix}async <code>
|
||||
Description:
|
||||
Runs <code> as an asynchronous coroutine and prints the output or error.
|
||||
"""
|
||||
if ctx.arg_str == "":
|
||||
await ctx.error_reply("You must give me something to run!")
|
||||
return
|
||||
output, error = await _async(ctx)
|
||||
await ctx.reply(
|
||||
"**Async input:**\
|
||||
\n```py\n{}\n```\
|
||||
\n**Output {}:** \
|
||||
\n```py\n{}\n```".format(ctx.arg_str,
|
||||
"error" if error else "",
|
||||
output))
|
||||
|
||||
|
||||
@cmd("eval")
|
||||
@checks.is_owner()
|
||||
async def cmd_eval(ctx):
|
||||
"""
|
||||
Usage:
|
||||
{prefix}eval <code>
|
||||
Description:
|
||||
Runs <code> in current environment using eval() and prints the output or error.
|
||||
"""
|
||||
if ctx.arg_str == "":
|
||||
await ctx.error_reply("You must give me something to run!")
|
||||
return
|
||||
output, error = await _eval(ctx)
|
||||
await ctx.reply(
|
||||
"**Eval input:**\
|
||||
\n```py\n{}\n```\
|
||||
\n**Output {}:** \
|
||||
\n```py\n{}\n```".format(ctx.arg_str,
|
||||
"error" if error else "",
|
||||
output)
|
||||
)
|
||||
|
||||
|
||||
async def _eval(ctx):
|
||||
output = None
|
||||
try:
|
||||
output = eval(ctx.arg_str)
|
||||
except Exception:
|
||||
return (str(traceback.format_exc()), 1)
|
||||
if asyncio.iscoroutine(output):
|
||||
output = await output
|
||||
return (output, 0)
|
||||
|
||||
|
||||
async def _async(ctx):
|
||||
env = {
|
||||
'ctx': ctx,
|
||||
'client': ctx.client,
|
||||
'message': ctx.msg,
|
||||
'arg_str': ctx.arg_str
|
||||
}
|
||||
env.update(globals())
|
||||
old_stdout = sys.stdout
|
||||
redirected_output = sys.stdout = StringIO()
|
||||
result = None
|
||||
exec_string = "async def _temp_exec():\n"
|
||||
exec_string += '\n'.join(' ' * 4 + line for line in ctx.arg_str.split('\n'))
|
||||
try:
|
||||
exec(exec_string, env)
|
||||
result = (redirected_output.getvalue(), 0)
|
||||
except Exception:
|
||||
result = (str(traceback.format_exc()), 1)
|
||||
return result
|
||||
_temp_exec = env['_temp_exec']
|
||||
try:
|
||||
returnval = await _temp_exec()
|
||||
value = redirected_output.getvalue()
|
||||
if returnval is None:
|
||||
result = (value, 0)
|
||||
else:
|
||||
result = (value + '\n' + str(returnval), 0)
|
||||
except Exception:
|
||||
result = (str(traceback.format_exc()), 1)
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
return result
|
||||
443
bot/utils/lib.py
Normal file
443
bot/utils/lib.py
Normal file
@@ -0,0 +1,443 @@
|
||||
import datetime
|
||||
import iso8601
|
||||
import re
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
def prop_tabulate(prop_list, value_list, indent=True):
|
||||
"""
|
||||
Turns a list of properties and corresponding list of values into
|
||||
a pretty string with one `prop: value` pair each line,
|
||||
padded so that the colons in each line are lined up.
|
||||
Handles empty props by using an extra couple of spaces instead of a `:`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prop_list: List[str]
|
||||
List of short names to put on the right side of the list.
|
||||
Empty props are considered to be "newlines" for the corresponding value.
|
||||
value_list: List[str]
|
||||
List of values corresponding to the properties above.
|
||||
indent: bool
|
||||
Whether to add padding so the properties are right-adjusted.
|
||||
|
||||
Returns: str
|
||||
"""
|
||||
max_len = max(len(prop) for prop in prop_list)
|
||||
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
|
||||
prop,
|
||||
":" if len(prop) else " " * 2,
|
||||
value_list[i],
|
||||
'' if str(value_list[i]).endswith("```") else '\n')
|
||||
for i, prop in enumerate(prop_list)])
|
||||
|
||||
|
||||
def paginate_list(item_list, block_length=20, style="markdown", title=None):
|
||||
"""
|
||||
Create pretty codeblock pages from a list of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item_list: List[str]
|
||||
List of strings to paginate.
|
||||
block_length: int
|
||||
Maximum number of strings per page.
|
||||
style: str
|
||||
Codeblock style to use.
|
||||
Title formatting assumes the `markdown` style, and numbered lists work well with this.
|
||||
However, `markdown` sometimes messes up formatting in the list.
|
||||
title: str
|
||||
Optional title to add to the top of each page.
|
||||
|
||||
Returns: List[str]
|
||||
List of pages, each formatted into a codeblock,
|
||||
and containing at most `block_length` of the provided strings.
|
||||
"""
|
||||
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
|
||||
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
|
||||
pages = []
|
||||
for i, block in enumerate(page_blocks):
|
||||
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
|
||||
if title:
|
||||
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
|
||||
else:
|
||||
header = pagenum
|
||||
header_line = "=" * len(header)
|
||||
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
|
||||
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
|
||||
return pages
|
||||
|
||||
|
||||
def timestamp_utcnow():
|
||||
"""
|
||||
Return the current integer UTC timestamp.
|
||||
"""
|
||||
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
|
||||
|
||||
|
||||
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
|
||||
"""
|
||||
Break the text into blocks of maximum length blocksize
|
||||
If possible, break across nearby newlines. Otherwise just break at blocksize chars
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
Text to break into blocks.
|
||||
blocksize: int
|
||||
Maximum character length for each block.
|
||||
code: bool
|
||||
Whether to wrap each block in codeblocks (these are counted in the blocksize).
|
||||
syntax: str
|
||||
The markdown formatting language to use for the codeblocks, if applicable.
|
||||
maxheight: int
|
||||
The maximum number of lines in each block
|
||||
|
||||
Returns: List[str]
|
||||
List of blocks,
|
||||
each containing at most `block_size` characters,
|
||||
of height at most `maxheight`.
|
||||
"""
|
||||
# Adjust blocksize to account for the codeblocks if required
|
||||
blocksize = blocksize - 8 - len(syntax) if code else blocksize
|
||||
|
||||
# Build the blocks
|
||||
blocks = []
|
||||
while True:
|
||||
# If the remaining text is already small enough, append it
|
||||
if len(text) <= blocksize:
|
||||
blocks.append(text)
|
||||
break
|
||||
text = text.strip('\n')
|
||||
|
||||
# Find the last newline in the prototype block
|
||||
split_on = text[0:blocksize].rfind('\n')
|
||||
split_on = blocksize if split_on < blocksize // 5 else split_on
|
||||
|
||||
# Add the block and truncate the text
|
||||
blocks.append(text[0:split_on])
|
||||
text = text[split_on:]
|
||||
|
||||
# Add the codeblock ticks and the code syntax header, if required
|
||||
if code:
|
||||
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def strfdelta(delta, sec=False, minutes=True, short=False):
|
||||
"""
|
||||
Convert a datetime.timedelta object into an easily readable duration string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delta: datetime.timedelta
|
||||
The timedelta object to convert into a readable string.
|
||||
sec: bool
|
||||
Whether to include the seconds from the timedelta object in the string.
|
||||
minutes: bool
|
||||
Whether to include the minutes from the timedelta object in the string.
|
||||
short: bool
|
||||
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||
|
||||
Returns: str
|
||||
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||
Time units will be abbreviated if short was set to True.
|
||||
"""
|
||||
|
||||
output = [[delta.days, 'd' if short else ' day'],
|
||||
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||
if minutes:
|
||||
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||
if sec:
|
||||
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||
for i in range(len(output)):
|
||||
if output[i][0] != 1 and not short:
|
||||
output[i][1] += 's'
|
||||
reply_msg = []
|
||||
if output[0][0] != 0:
|
||||
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||
for i in range(2, len(output) - 1):
|
||||
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||
if not short and reply_msg:
|
||||
reply_msg.append("and ")
|
||||
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||
return "".join(reply_msg)
|
||||
|
||||
|
||||
def parse_dur(time_str):
|
||||
"""
|
||||
Parses a user provided time duration string into a timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_str: str
|
||||
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||
|
||||
Returns: int
|
||||
The number of seconds the duration represents.
|
||||
"""
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
time_str = time_str.strip(" ,")
|
||||
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||
seconds = 0
|
||||
for bit in found:
|
||||
if bit[1] in funcs:
|
||||
seconds += funcs[bit[1]](int(bit[0]))
|
||||
return seconds
|
||||
|
||||
|
||||
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
||||
"""
|
||||
Substitutes a user provided list of numbers and ranges,
|
||||
and replaces the ranges by the corresponding list of numbers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ranges_str: str
|
||||
The string to ranges in.
|
||||
max_match: int
|
||||
The maximum number of ranges to replace.
|
||||
Any ranges exceeding this will be ignored.
|
||||
max_range: int
|
||||
The maximum length of range to replace.
|
||||
Attempting to replace a range longer than this will raise a `ValueError`.
|
||||
"""
|
||||
def _repl(match):
|
||||
n1 = int(match.group(1))
|
||||
n2 = int(match.group(2))
|
||||
if n2 - n1 > max_range:
|
||||
raise ValueError("Provided range exceeds the allowed maximum.")
|
||||
return separator.join(str(i) for i in range(n1, n2 + 1))
|
||||
|
||||
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
||||
|
||||
|
||||
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
|
||||
"""
|
||||
Format a message into a string with various information, such as:
|
||||
the timestamp of the message, author, message content, and attachments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg: Message
|
||||
The message to format.
|
||||
mask_link: bool
|
||||
Whether to mask the URLs of any attachments.
|
||||
line_break: bool
|
||||
Whether a line break should be used in the string.
|
||||
tz: Timezone
|
||||
The timezone to use in the formatted message.
|
||||
clean: bool
|
||||
Whether to use the clean content of the original message.
|
||||
|
||||
Returns: str
|
||||
A formatted string containing various information:
|
||||
User timezone, message author, message content, attachments
|
||||
"""
|
||||
timestr = "%I:%M %p, %d/%m/%Y"
|
||||
if tz:
|
||||
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
|
||||
else:
|
||||
time = msg.timestamp.strftime(timestr)
|
||||
user = str(msg.author)
|
||||
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
|
||||
if mask_link:
|
||||
attach_list = ["[Link]({})".format(url) for url in attach_list]
|
||||
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
|
||||
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
|
||||
time=time,
|
||||
user=user,
|
||||
line_break="\n" if line_break else "",
|
||||
message=msg.clean_content if clean else msg.content,
|
||||
attachments=attachments
|
||||
)
|
||||
|
||||
|
||||
def convdatestring(datestring):
|
||||
"""
|
||||
Convert a date string into a datetime.timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datestring: str
|
||||
The string to convert to a datetime.timedelta object.
|
||||
|
||||
Returns: datetime.timedelta
|
||||
A datetime.timedelta object formed from the string provided.
|
||||
"""
|
||||
datestring = datestring.strip(' ,')
|
||||
datearray = []
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
currentnumber = ''
|
||||
for char in datestring:
|
||||
if char.isdigit():
|
||||
currentnumber += char
|
||||
else:
|
||||
if currentnumber == '':
|
||||
continue
|
||||
datearray.append((int(currentnumber), char))
|
||||
currentnumber = ''
|
||||
seconds = 0
|
||||
if currentnumber:
|
||||
seconds += int(currentnumber)
|
||||
for i in datearray:
|
||||
if i[1] in funcs:
|
||||
seconds += funcs[i[1]](i[0])
|
||||
return datetime.timedelta(seconds=seconds)
|
||||
|
||||
|
||||
class _rawChannel(discord.abc.Messageable):
|
||||
"""
|
||||
Raw messageable class representing an arbitrary channel,
|
||||
not necessarially seen by the gateway.
|
||||
"""
|
||||
def __init__(self, state, id):
|
||||
self._state = state
|
||||
self.id = id
|
||||
|
||||
async def _get_channel(self):
|
||||
return discord.Object(self.id)
|
||||
|
||||
|
||||
async def mail(client: discord.Client, channelid: int, **msg_args):
|
||||
"""
|
||||
Mails a message to a channelid which may be invisible to the gateway.
|
||||
|
||||
Parameters:
|
||||
client: discord.Client
|
||||
The client to use for mailing.
|
||||
Must at least have static authentication and have a valid `_connection`.
|
||||
channelid: int
|
||||
The channel id to mail to.
|
||||
msg_args: Any
|
||||
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
|
||||
"""
|
||||
# Create the raw channel
|
||||
channel = _rawChannel(client._connection, channelid)
|
||||
return await channel.send(**msg_args)
|
||||
|
||||
|
||||
def emb_add_fields(embed, emb_fields):
|
||||
"""
|
||||
Append embed fields to an embed.
|
||||
Parameters
|
||||
----------
|
||||
embed: discord.Embed
|
||||
The embed to add the field to.
|
||||
emb_fields: tuple
|
||||
The values to add to a field.
|
||||
name: str
|
||||
The name of the field.
|
||||
value: str
|
||||
The value of the field.
|
||||
inline: bool
|
||||
Whether the embed field should be inline or not.
|
||||
"""
|
||||
for field in emb_fields:
|
||||
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
|
||||
|
||||
|
||||
def join_list(string, nfs=False):
|
||||
"""
|
||||
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
|
||||
Parameters
|
||||
----------
|
||||
string: list
|
||||
The list to join together.
|
||||
nfs: bool
|
||||
(no fullstops)
|
||||
Whether to exclude fullstops/periods from the output messages.
|
||||
If not provided, fullstops will be appended to the output.
|
||||
"""
|
||||
if len(string) > 1:
|
||||
return "{}{} and {}{}".format((", ").join(string[:-1]),
|
||||
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
|
||||
else:
|
||||
return "{}{}".format("".join(string), "" if nfs else ".")
|
||||
|
||||
|
||||
def format_activity(user):
|
||||
"""
|
||||
Format a user's activity string, depending on the type of activity.
|
||||
Currently supported types are:
|
||||
- Nothing
|
||||
- Custom status
|
||||
- Playing (with rich presence support)
|
||||
- Streaming
|
||||
- Listening (with rich presence support)
|
||||
- Watching
|
||||
- Unknown
|
||||
Parameters
|
||||
----------
|
||||
user: discord.Member
|
||||
The user to format the status of.
|
||||
If the user has no activity, "Nothing" will be returned.
|
||||
|
||||
Returns: str
|
||||
A formatted string with various information about the user's current activity like the name,
|
||||
and any extra information about the activity (such as current song artists for Spotify)
|
||||
"""
|
||||
if not user.activity:
|
||||
return "Nothing"
|
||||
|
||||
AT = user.activity.type
|
||||
a = user.activity
|
||||
if str(AT) == "ActivityType.custom":
|
||||
return "Status: {}".format(a)
|
||||
|
||||
if str(AT) == "ActivityType.playing":
|
||||
string = "Playing {}".format(a.name)
|
||||
try:
|
||||
string += " ({})".format(a.details)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return string
|
||||
|
||||
if str(AT) == "ActivityType.streaming":
|
||||
return "Streaming {}".format(a.name)
|
||||
|
||||
if str(AT) == "ActivityType.listening":
|
||||
try:
|
||||
string = "Listening to `{}`".format(a.title)
|
||||
if len(a.artists) > 1:
|
||||
string += " by {}".format(join_list(string=a.artists))
|
||||
else:
|
||||
string += " by **{}**".format(a.artist)
|
||||
except Exception:
|
||||
string = "Listening to `{}`".format(a.name)
|
||||
return string
|
||||
|
||||
if str(AT) == "ActivityType.watching":
|
||||
return "Watching `{}`".format(a.name)
|
||||
|
||||
if str(AT) == "ActivityType.unknown":
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def shard_of(shard_count: int, guildid: int):
|
||||
"""
|
||||
Calculate the shard number of a given guild.
|
||||
"""
|
||||
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
|
||||
|
||||
|
||||
def jumpto(guildid: int, channeldid: int, messageid: int):
|
||||
"""
|
||||
Build a jump link for a message given its location.
|
||||
"""
|
||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||
guildid,
|
||||
channeldid,
|
||||
messageid
|
||||
)
|
||||
Reference in New Issue
Block a user