Core bot framework

This commit is contained in:
2021-08-25 22:56:45 +03:00
parent 87f16b6a37
commit 05cb9650ee
17 changed files with 1233 additions and 0 deletions

2
bot/constants.py Normal file
View File

@@ -0,0 +1,2 @@
CONFIG_FILE = "config/bot.conf"
DATA_VERSION = 0

3
bot/data/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .data import *
from . import tables
# from . import queries

30
bot/data/custom_cursor.py Normal file
View File

@@ -0,0 +1,30 @@
import logging
from psycopg2.extras import DictCursor, _ext
from meta import log
class DictLoggingCursor(DictCursor):
def log(self):
msg = self.query
if isinstance(msg, bytes):
msg = msg.decode(_ext.encodings[self.connection.encoding], 'replace')
log(
msg,
context="DATABASE_QUERY",
level=logging.DEBUG,
post=False
)
def execute(self, query, vars=None):
try:
return super().execute(query, vars)
finally:
self.log()
def callproc(self, procname, vars=None):
try:
return super().callproc(procname, vars)
finally:
self.log()

450
bot/data/data.py Normal file
View File

@@ -0,0 +1,450 @@
import logging
import contextlib
from itertools import chain
from enum import Enum
import psycopg2 as psy
from cachetools import LRUCache
from meta import log, conf
from constants import DATA_VERSION
from .custom_cursor import DictLoggingCursor
# Set up database connection
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
# Check the version matches the required version
with conn:
log("Checking db version.", "DB_INIT")
cursor = conn.cursor()
# Get last entry in version table, compare against desired version
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
current_version, _, _ = cursor.fetchone()
if current_version != DATA_VERSION:
# Complain
raise Exception(
("Database version is {}, required version is {}. "
"Please migrate database.").format(current_version, DATA_VERSION)
)
cursor.close()
log("Established connection.", "DB_INIT")
# --------------- Data Interface Classes ---------------
class Table:
"""
Transparent interface to a single table structure in the database.
Contains standard methods to access the table.
Intended to be subclassed to provide more derivative access for specific tables.
"""
conn = conn
def __init__(self, name):
self.name = name
def select_where(self, *args, **kwargs):
with self.conn:
return select_where(self.name, *args, **kwargs)
def select_one_where(self, *args, **kwargs):
with self.conn:
rows = self.select_where(*args, **kwargs)
return rows[0] if rows else None
def update_where(self, *args, **kwargs):
with self.conn:
return update_where(self.name, *args, **kwargs)
def delete_where(self, *args, **kwargs):
with self.conn:
return delete_where(self.name, *args, **kwargs)
def insert(self, *args, **kwargs):
with self.conn:
return insert(self.name, *args, **kwargs)
def insert_many(self, *args, **kwargs):
with self.conn:
return insert_many(self.name, *args, **kwargs)
def upsert(self, *args, **kwargs):
with self.conn:
return upsert(self.name, *args, **kwargs)
class Row:
__slots__ = ('table', 'data', '_pending')
conn = conn
def __init__(self, table, data, *args, **kwargs):
super().__setattr__('table', table)
self.data = data
self._pending = None
@property
def rowid(self):
return self.data[self.table.id_col]
def __repr__(self):
return "Row[{}]({})".format(
self.table.name,
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
)
def __getattr__(self, key):
if key in self.table.columns:
if self._pending and key in self._pending:
return self._pending[key]
else:
return self.data[key]
else:
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self.table.columns:
if self._pending is None:
self.update(**{key: value})
else:
self._pending[key] = value
else:
super().__setattr__(key, value)
@contextlib.contextmanager
def batch_update(self):
if self._pending:
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
self._pending = {}
try:
yield self._pending
finally:
self.update(**self._pending)
self._pending = None
def _refresh(self):
row = self.table.select_one_where(**{self.table.id_col: self.rowid})
if not row:
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
self.data = row
def update(self, **values):
rows = self.table.update_where(values, **{self.table.id_col: self.rowid})
self.data = rows[0]
@classmethod
def _select_where(cls, _extra=None, **conditions):
return select_where(cls._table, **conditions)
@classmethod
def _insert(cls, **values):
return insert(cls._table, **values)
@classmethod
def _update_where(cls, values, **conditions):
return update_where(cls._table, values, **conditions)
class RowTable(Table):
__slots__ = (
'name',
'columns',
'id_col',
'row_cache'
)
conn = conn
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000):
self.name = name
self.columns = columns
self.id_col = id_col
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
# Extend original Table update methods to modify the cached rows
def update_where(self, *args, **kwargs):
data = super().update_where(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
cached_row = self.row_cache.get(data_row[self.id_col], None)
if cached_row is not None:
cached_row.data = data_row
return data
def delete_where(self, *args, **kwargs):
data = super().delete_where(*args, **kwargs)
if self.row_cache is not None:
for data_row in data:
self.row_cache.pop(data_row[self.id_col], None)
return data
def upsert(self, *args, **kwargs):
data = super().upsert(*args, **kwargs)
if self.row_cache is not None:
cached_row = self.row_cache.get(data[self.id_col], None)
if cached_row is not None:
cached_row.data = data
return data
# New methods to fetch and create rows
def _make_rows(self, *data_rows):
"""
Create or retrieve Row objects for each provided data row.
If the rows already exist in cache, updates the cached row.
"""
if self.row_cache is not None:
rows = []
for data_row in data_rows:
rowid = data_row[self.id_col]
cached_row = self.row_cache.get(rowid, None)
if cached_row is not None:
cached_row.data = data_row
row = cached_row
else:
row = Row(self, data_row)
self.row_cache[rowid] = row
rows.append(row)
else:
rows = [Row(self, data_row) for data_row in data_rows]
return rows
def create_row(self, *args, **kwargs):
data = self.insert(*args, **kwargs)
return self._make_rows(data)[0]
def fetch_rows_where(self, *args, **kwargs):
# TODO: Handle list of rowids here?
data = self.select_where(*args, **kwargs)
return self._make_rows(*data)
def fetch(self, rowid):
"""
Fetch the row with the given id, retrieving from cache where possible.
"""
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
if row is None:
rows = self.fetch_rows_where(**{self.id_col: rowid})
row = rows[0] if rows else None
return row
def fetch_or_create(self, rowid=None, **kwargs):
"""
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
"""
if rowid is not None:
row = self.fetch(rowid)
else:
data = self.select_where(**kwargs)
row = self._make_rows(data[0])[0] if data else None
if row is None:
creation_kwargs = kwargs
if rowid is not None:
creation_kwargs[self.id_col] = rowid
row = self.create_row(**creation_kwargs)
return row
# --------------- Query Builders ---------------
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
"""
Select rows from the given table matching the conditions
"""
criteria, criteria_values = _format_conditions(conditions)
col_str = _format_selectkeys(select_columns)
if conditions:
where_str = "WHERE {}".format(criteria)
else:
where_str = ""
cursor = cursor or conn.cursor()
cursor.execute(
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
criteria_values
)
return cursor.fetchall()
def update_where(table, valuedict, cursor=None, **conditions):
"""
Update rows in the given table matching the conditions
"""
key_str, key_values = _format_updatestr(valuedict)
criteria, criteria_values = _format_conditions(conditions)
if conditions:
where_str = "WHERE {}".format(criteria)
else:
where_str = ""
cursor = cursor or conn.cursor()
cursor.execute(
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
tuple((*key_values, *criteria_values))
)
return cursor.fetchall()
def delete_where(table, cursor=None, **conditions):
"""
Delete rows in the given table matching the conditions
"""
criteria, criteria_values = _format_conditions(conditions)
cursor = cursor or conn.cursor()
cursor.execute(
'DELETE FROM {} WHERE {}'.format(table, criteria),
criteria_values
)
return cursor.fetchall()
def insert(table, cursor=None, allow_replace=False, **values):
"""
Insert the given values into the table
"""
keys, values = zip(*values.items())
key_str = _format_insertkeys(keys)
value_str, values = _format_insertvalues(values)
action = 'REPLACE' if allow_replace else 'INSERT'
cursor = cursor or conn.cursor()
cursor.execute(
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
values
)
return cursor.fetchone()
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
"""
Insert all the given values into the table
"""
key_str = _format_insertkeys(insert_keys)
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
value_str = ", ".join(value_strs)
values = tuple(chain(*value_tuples))
cursor = cursor or conn.cursor()
cursor.execute(
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
values
)
return cursor.fetchall()
def upsert(table, constraint, cursor=None, **values):
"""
Insert or on conflict update.
"""
valuedict = values
keys, values = zip(*values.items())
key_str = _format_insertkeys(keys)
value_str, values = _format_insertvalues(values)
update_key_str, update_key_values = _format_updatestr(valuedict)
if not isinstance(constraint, str):
constraint = ", ".join(constraint)
cursor = cursor or conn.cursor()
cursor.execute(
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
table, key_str, value_str, constraint, update_key_str
),
tuple((*values, *update_key_values))
)
return cursor.fetchone()
# --------------- Query Formatting Tools ---------------
# Replace char used by the connection for query formatting
_replace_char: str = '%s'
class fieldConstants(Enum):
"""
A collection of database field constants to use for selection conditions.
"""
NULL = "IS NULL"
NOTNULL = "IS NOT NULL"
def _format_conditions(conditions):
"""
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
Supports `IN` type conditionals.
"""
if not conditions:
return ("", tuple())
values = []
conditional_strings = []
for key, item in conditions.items():
if isinstance(item, (list, tuple)):
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
values.extend(item)
elif isinstance(item, fieldConstants):
conditional_strings.append("{} {}".format(key, item.value))
else:
conditional_strings.append("{}={}".format(key, _replace_char))
values.append(item)
return (' AND '.join(conditional_strings), values)
def _format_selectkeys(keys):
"""
Formats a list of keys into a string suitable for `SELECT`.
"""
if not keys:
return "*"
else:
return ", ".join(keys)
def _format_insertkeys(keys):
"""
Formats a list of keys into a string suitable for `INSERT`
"""
if not keys:
return ""
else:
return "({})".format(", ".join(keys))
def _format_insertvalues(values):
"""
Formats a list of values into a string suitable for `INSERT`
"""
value_str = "({})".format(", ".join(_replace_char for value in values))
return (value_str, values)
def _format_updatestr(valuedict):
"""
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
"""
if not valuedict:
return ("", tuple())
keys, values = zip(*valuedict.items())
set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys)
return (set_str, values)

8
bot/data/tables.py Normal file
View File

@@ -0,0 +1,8 @@
from .data import RowTable, Table
raw_users = Table('Users')
users = RowTable(
'users',
('userid', 'tracked_time', 'coins'),
'userid',
)

View File

7
bot/dev_main.py Normal file
View File

@@ -0,0 +1,7 @@
import logging
import meta
meta.logger.setLevel(logging.DEBUG)
logging.getLogger("discord").setLevel(logging.INFO)
import main # noqa

View File

@@ -0,0 +1,12 @@
from meta import client, conf, log
import data # noqa
import modules # noqa
# Initialise all modules
client.initialise_modules()
# Log readyness and execute
log("Initial setup complete, logging in", context='SETUP')
client.run(conf.bot['TOKEN'])

3
bot/meta/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .client import client
from .config import conf
from .logger import log, logger

13
bot/meta/client.py Normal file
View File

@@ -0,0 +1,13 @@
from cmdClient.cmdClient import cmdClient
from .config import Conf
from constants import CONFIG_FILE
# Initialise config
conf = Conf(CONFIG_FILE)
# Initialise client
owners = [int(owner) for owner in conf.bot.getlist('owners')]
client = cmdClient(prefix=conf.bot['prefix'], owners=owners)
client.conf = conf

59
bot/meta/config.py Normal file
View File

@@ -0,0 +1,59 @@
import configparser as cfgp
conf = None # type: Conf
CONF_FILE = "bot/bot.conf"
class Conf:
def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile
self.config = cfgp.ConfigParser(
converters={
"intlist": self._getintlist,
"list": self._getlist,
}
)
self.config.read(configfile)
self.section_name = section_name if section_name in self.config else 'DEFAULT'
self.default = self.config["DEFAULT"]
self.section = self.config[self.section_name]
self.bot = self.section
# Config file recursion, read in configuration files specified in every "ALSO_READ" key.
more_to_read = self.section.getlist("ALSO_READ", [])
read = set()
while more_to_read:
to_read = more_to_read.pop(0)
read.add(to_read)
self.config.read(to_read)
new_paths = [path for path in self.section.getlist("ALSO_READ", [])
if path not in read and path not in more_to_read]
more_to_read.extend(new_paths)
global conf
conf = self
def __getitem__(self, key):
return self.section[key].strip()
def __getattr__(self, section):
return self.config[section]
def get(self, name, fallback=None):
result = self.section.get(name, fallback)
return result.strip() if result else result
def _getintlist(self, value):
return [int(item.strip()) for item in value.split(',')]
def _getlist(self, value):
return [item.strip() for item in value.split(',')]
def write(self):
with open(self.configfile, 'w') as conffile:
self.config.write(conffile)

72
bot/meta/logger.py Normal file
View File

@@ -0,0 +1,72 @@
import sys
import logging
import asyncio
from discord import AllowedMentions
from cmdClient.logger import cmd_log_handler
from utils.lib import mail, split_text
from .client import client
from .config import conf
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{')
term_handler = logging.StreamHandler(sys.stdout)
term_handler.setFormatter(log_fmt)
logger.addHandler(term_handler)
logger.setLevel(logging.INFO)
# Define the context log format and attach it to the command logger as well
@cmd_log_handler
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
header = "[{}][{}]".format(logging.getLevelName(level), str(context))
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
# Attach logger to client, for convenience
client.log = log

View File

@@ -0,0 +1 @@
from .sysadmin import *

View File

@@ -0,0 +1 @@
from .exec_cmds import *

View File

@@ -0,0 +1,123 @@
import sys
from io import StringIO
import traceback
import asyncio
from cmdClient import cmd, checks
"""
Exec level commands to manage the bot.
Commands provided:
async:
Executes provided code in an async executor
exec:
Executes code using standard python exec
eval:
Executes code and awaits it if required
"""
@cmd("reboot")
@checks.is_owner()
async def cmd_reboot(ctx):
"""
Usage``:
reboot
Description:
Update the timer status save file and reboot the client.
"""
ctx.client.interface.update_save("reboot")
ctx.client.interface.shutdown()
await ctx.reply("Saved state. Rebooting now!")
await ctx.client.close()
@cmd("async")
@checks.is_owner()
async def cmd_async(ctx):
"""
Usage:
{prefix}async <code>
Description:
Runs <code> as an asynchronous coroutine and prints the output or error.
"""
if ctx.arg_str == "":
await ctx.error_reply("You must give me something to run!")
return
output, error = await _async(ctx)
await ctx.reply(
"**Async input:**\
\n```py\n{}\n```\
\n**Output {}:** \
\n```py\n{}\n```".format(ctx.arg_str,
"error" if error else "",
output))
@cmd("eval")
@checks.is_owner()
async def cmd_eval(ctx):
"""
Usage:
{prefix}eval <code>
Description:
Runs <code> in current environment using eval() and prints the output or error.
"""
if ctx.arg_str == "":
await ctx.error_reply("You must give me something to run!")
return
output, error = await _eval(ctx)
await ctx.reply(
"**Eval input:**\
\n```py\n{}\n```\
\n**Output {}:** \
\n```py\n{}\n```".format(ctx.arg_str,
"error" if error else "",
output)
)
async def _eval(ctx):
output = None
try:
output = eval(ctx.arg_str)
except Exception:
return (str(traceback.format_exc()), 1)
if asyncio.iscoroutine(output):
output = await output
return (output, 0)
async def _async(ctx):
env = {
'ctx': ctx,
'client': ctx.client,
'message': ctx.msg,
'arg_str': ctx.arg_str
}
env.update(globals())
old_stdout = sys.stdout
redirected_output = sys.stdout = StringIO()
result = None
exec_string = "async def _temp_exec():\n"
exec_string += '\n'.join(' ' * 4 + line for line in ctx.arg_str.split('\n'))
try:
exec(exec_string, env)
result = (redirected_output.getvalue(), 0)
except Exception:
result = (str(traceback.format_exc()), 1)
return result
_temp_exec = env['_temp_exec']
try:
returnval = await _temp_exec()
value = redirected_output.getvalue()
if returnval is None:
result = (value, 0)
else:
result = (value + '\n' + str(returnval), 0)
except Exception:
result = (str(traceback.format_exc()), 1)
finally:
sys.stdout = old_stdout
return result

443
bot/utils/lib.py Normal file
View File

@@ -0,0 +1,443 @@
import datetime
import iso8601
import re
import discord
def prop_tabulate(prop_list, value_list, indent=True):
"""
Turns a list of properties and corresponding list of values into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Handles empty props by using an extra couple of spaces instead of a `:`.
Parameters
----------
prop_list: List[str]
List of short names to put on the right side of the list.
Empty props are considered to be "newlines" for the corresponding value.
value_list: List[str]
List of values corresponding to the properties above.
indent: bool
Whether to add padding so the properties are right-adjusted.
Returns: str
"""
max_len = max(len(prop) for prop in prop_list)
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
prop,
":" if len(prop) else " " * 2,
value_list[i],
'' if str(value_list[i]).endswith("```") else '\n')
for i, prop in enumerate(prop_list)])
def paginate_list(item_list, block_length=20, style="markdown", title=None):
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def timestamp_utcnow():
"""
Return the current integer UTC timestamp.
"""
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta, sec=False, minutes=True, short=False):
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's'
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def parse_dur(time_str):
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
raise ValueError("Provided range exceeds the allowed maximum.")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.timestamp.strftime(timestr)
user = str(msg.author)
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring):
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args):
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
def emb_add_fields(embed, emb_fields):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string, nfs=False):
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def format_activity(user):
"""
Format a user's activity string, depending on the type of activity.
Currently supported types are:
- Nothing
- Custom status
- Playing (with rich presence support)
- Streaming
- Listening (with rich presence support)
- Watching
- Unknown
Parameters
----------
user: discord.Member
The user to format the status of.
If the user has no activity, "Nothing" will be returned.
Returns: str
A formatted string with various information about the user's current activity like the name,
and any extra information about the activity (such as current song artists for Spotify)
"""
if not user.activity:
return "Nothing"
AT = user.activity.type
a = user.activity
if str(AT) == "ActivityType.custom":
return "Status: {}".format(a)
if str(AT) == "ActivityType.playing":
string = "Playing {}".format(a.name)
try:
string += " ({})".format(a.details)
except Exception:
pass
return string
if str(AT) == "ActivityType.streaming":
return "Streaming {}".format(a.name)
if str(AT) == "ActivityType.listening":
try:
string = "Listening to `{}`".format(a.title)
if len(a.artists) > 1:
string += " by {}".format(join_list(string=a.artists))
else:
string += " by **{}**".format(a.artist)
except Exception:
string = "Listening to `{}`".format(a.name)
return string
if str(AT) == "ActivityType.watching":
return "Watching `{}`".format(a.name)
if str(AT) == "ActivityType.unknown":
return "Unknown"
def shard_of(shard_count: int, guildid: int):
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int):
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)

6
run.py
View File

@@ -0,0 +1,6 @@
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), "bot"))
from bot import dev_main