diff --git a/bot/constants.py b/bot/constants.py new file mode 100644 index 00000000..f50111af --- /dev/null +++ b/bot/constants.py @@ -0,0 +1,2 @@ +CONFIG_FILE = "config/bot.conf" +DATA_VERSION = 0 diff --git a/bot/data/__init__.py b/bot/data/__init__.py new file mode 100644 index 00000000..b2670d99 --- /dev/null +++ b/bot/data/__init__.py @@ -0,0 +1,3 @@ +from .data import * +from . import tables +# from . import queries diff --git a/bot/data/custom_cursor.py b/bot/data/custom_cursor.py new file mode 100644 index 00000000..a63f08c6 --- /dev/null +++ b/bot/data/custom_cursor.py @@ -0,0 +1,30 @@ +import logging +from psycopg2.extras import DictCursor, _ext + +from meta import log + + +class DictLoggingCursor(DictCursor): + def log(self): + msg = self.query + if isinstance(msg, bytes): + msg = msg.decode(_ext.encodings[self.connection.encoding], 'replace') + + log( + msg, + context="DATABASE_QUERY", + level=logging.DEBUG, + post=False + ) + + def execute(self, query, vars=None): + try: + return super().execute(query, vars) + finally: + self.log() + + def callproc(self, procname, vars=None): + try: + return super().callproc(procname, vars) + finally: + self.log() diff --git a/bot/data/data.py b/bot/data/data.py new file mode 100644 index 00000000..74f85e5c --- /dev/null +++ b/bot/data/data.py @@ -0,0 +1,450 @@ +import logging +import contextlib +from itertools import chain +from enum import Enum + +import psycopg2 as psy +from cachetools import LRUCache + +from meta import log, conf +from constants import DATA_VERSION +from .custom_cursor import DictLoggingCursor + + +# Set up database connection +log("Establishing connection.", "DB_INIT", level=logging.DEBUG) +conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor) + +# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG)) +# sq.register_adapter(datetime, lambda dt: dt.timestamp()) + + +# Check the version matches the required version +with conn: + log("Checking db version.", "DB_INIT") + cursor = conn.cursor() + + # Get last entry in version table, compare against desired version + cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + current_version, _, _ = cursor.fetchone() + + if current_version != DATA_VERSION: + # Complain + raise Exception( + ("Database version is {}, required version is {}. " + "Please migrate database.").format(current_version, DATA_VERSION) + ) + + cursor.close() + + +log("Established connection.", "DB_INIT") + + +# --------------- Data Interface Classes --------------- +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + Intended to be subclassed to provide more derivative access for specific tables. + """ + conn = conn + + def __init__(self, name): + self.name = name + + def select_where(self, *args, **kwargs): + with self.conn: + return select_where(self.name, *args, **kwargs) + + def select_one_where(self, *args, **kwargs): + with self.conn: + rows = self.select_where(*args, **kwargs) + return rows[0] if rows else None + + def update_where(self, *args, **kwargs): + with self.conn: + return update_where(self.name, *args, **kwargs) + + def delete_where(self, *args, **kwargs): + with self.conn: + return delete_where(self.name, *args, **kwargs) + + def insert(self, *args, **kwargs): + with self.conn: + return insert(self.name, *args, **kwargs) + + def insert_many(self, *args, **kwargs): + with self.conn: + return insert_many(self.name, *args, **kwargs) + + def upsert(self, *args, **kwargs): + with self.conn: + return upsert(self.name, *args, **kwargs) + + +class Row: + __slots__ = ('table', 'data', '_pending') + + conn = conn + + def __init__(self, table, data, *args, **kwargs): + super().__setattr__('table', table) + self.data = data + self._pending = None + + @property + def rowid(self): + return self.data[self.table.id_col] + + def __repr__(self): + return "Row[{}]({})".format( + self.table.name, + ', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns) + ) + + def __getattr__(self, key): + if key in self.table.columns: + if self._pending and key in self._pending: + return self._pending[key] + else: + return self.data[key] + else: + raise AttributeError(key) + + def __setattr__(self, key, value): + if key in self.table.columns: + if self._pending is None: + self.update(**{key: value}) + else: + self._pending[key] = value + else: + super().__setattr__(key, value) + + @contextlib.contextmanager + def batch_update(self): + if self._pending: + raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__)) + + self._pending = {} + try: + yield self._pending + finally: + self.update(**self._pending) + self._pending = None + + def _refresh(self): + row = self.table.select_one_where(**{self.table.id_col: self.rowid}) + if not row: + raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__)) + self.data = row + + def update(self, **values): + rows = self.table.update_where(values, **{self.table.id_col: self.rowid}) + self.data = rows[0] + + @classmethod + def _select_where(cls, _extra=None, **conditions): + return select_where(cls._table, **conditions) + + @classmethod + def _insert(cls, **values): + return insert(cls._table, **values) + + @classmethod + def _update_where(cls, values, **conditions): + return update_where(cls._table, values, **conditions) + + +class RowTable(Table): + __slots__ = ( + 'name', + 'columns', + 'id_col', + 'row_cache' + ) + + conn = conn + + def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000): + self.name = name + self.columns = columns + self.id_col = id_col + self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None + + # Extend original Table update methods to modify the cached rows + def update_where(self, *args, **kwargs): + data = super().update_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + cached_row = self.row_cache.get(data_row[self.id_col], None) + if cached_row is not None: + cached_row.data = data_row + return data + + def delete_where(self, *args, **kwargs): + data = super().delete_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + self.row_cache.pop(data_row[self.id_col], None) + return data + + def upsert(self, *args, **kwargs): + data = super().upsert(*args, **kwargs) + if self.row_cache is not None: + cached_row = self.row_cache.get(data[self.id_col], None) + if cached_row is not None: + cached_row.data = data + return data + + # New methods to fetch and create rows + def _make_rows(self, *data_rows): + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + if self.row_cache is not None: + rows = [] + for data_row in data_rows: + rowid = data_row[self.id_col] + + cached_row = self.row_cache.get(rowid, None) + if cached_row is not None: + cached_row.data = data_row + row = cached_row + else: + row = Row(self, data_row) + self.row_cache[rowid] = row + rows.append(row) + else: + rows = [Row(self, data_row) for data_row in data_rows] + return rows + + def create_row(self, *args, **kwargs): + data = self.insert(*args, **kwargs) + return self._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs): + # TODO: Handle list of rowids here? + data = self.select_where(*args, **kwargs) + return self._make_rows(*data) + + def fetch(self, rowid): + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = self.row_cache.get(rowid, None) if self.row_cache is not None else None + if row is None: + rows = self.fetch_rows_where(**{self.id_col: rowid}) + row = rows[0] if rows else None + return row + + def fetch_or_create(self, rowid=None, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid is not None: + row = self.fetch(rowid) + else: + data = self.select_where(**kwargs) + row = self._make_rows(data[0])[0] if data else None + + if row is None: + creation_kwargs = kwargs + if rowid is not None: + creation_kwargs[self.id_col] = rowid + row = self.create_row(**creation_kwargs) + return row + + +# --------------- Query Builders --------------- +def select_where(table, select_columns=None, cursor=None, _extra='', **conditions): + """ + Select rows from the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + col_str = _format_selectkeys(select_columns) + + if conditions: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra), + criteria_values + ) + return cursor.fetchall() + + +def update_where(table, valuedict, cursor=None, **conditions): + """ + Update rows in the given table matching the conditions + """ + key_str, key_values = _format_updatestr(valuedict) + criteria, criteria_values = _format_conditions(conditions) + + if conditions: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str), + tuple((*key_values, *criteria_values)) + ) + return cursor.fetchall() + + +def delete_where(table, cursor=None, **conditions): + """ + Delete rows in the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + + cursor = cursor or conn.cursor() + cursor.execute( + 'DELETE FROM {} WHERE {}'.format(table, criteria), + criteria_values + ) + return cursor.fetchall() + + +def insert(table, cursor=None, allow_replace=False, **values): + """ + Insert the given values into the table + """ + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + + action = 'REPLACE' if allow_replace else 'INSERT' + + cursor = cursor or conn.cursor() + cursor.execute( + '{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str), + values + ) + return cursor.fetchone() + + +def insert_many(table, *value_tuples, insert_keys=None, cursor=None): + """ + Insert all the given values into the table + """ + key_str = _format_insertkeys(insert_keys) + value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples)) + + value_str = ", ".join(value_strs) + values = tuple(chain(*value_tuples)) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str), + values + ) + return cursor.fetchall() + + +def upsert(table, constraint, cursor=None, **values): + """ + Insert or on conflict update. + """ + valuedict = values + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + update_key_str, update_key_values = _format_updatestr(valuedict) + + if not isinstance(constraint, str): + constraint = ", ".join(constraint) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( + table, key_str, value_str, constraint, update_key_str + ), + tuple((*values, *update_key_values)) + ) + return cursor.fetchone() + + +# --------------- Query Formatting Tools --------------- +# Replace char used by the connection for query formatting +_replace_char: str = '%s' + + +class fieldConstants(Enum): + """ + A collection of database field constants to use for selection conditions. + """ + NULL = "IS NULL" + NOTNULL = "IS NOT NULL" + + +def _format_conditions(conditions): + """ + Formats a dictionary of conditions into a string suitable for 'WHERE' clauses. + Supports `IN` type conditionals. + """ + if not conditions: + return ("", tuple()) + + values = [] + conditional_strings = [] + for key, item in conditions.items(): + if isinstance(item, (list, tuple)): + conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item)))) + values.extend(item) + elif isinstance(item, fieldConstants): + conditional_strings.append("{} {}".format(key, item.value)) + else: + conditional_strings.append("{}={}".format(key, _replace_char)) + values.append(item) + + return (' AND '.join(conditional_strings), values) + + +def _format_selectkeys(keys): + """ + Formats a list of keys into a string suitable for `SELECT`. + """ + if not keys: + return "*" + else: + return ", ".join(keys) + + +def _format_insertkeys(keys): + """ + Formats a list of keys into a string suitable for `INSERT` + """ + if not keys: + return "" + else: + return "({})".format(", ".join(keys)) + + +def _format_insertvalues(values): + """ + Formats a list of values into a string suitable for `INSERT` + """ + value_str = "({})".format(", ".join(_replace_char for value in values)) + return (value_str, values) + + +def _format_updatestr(valuedict): + """ + Formats a dictionary of keys and values into a string suitable for 'SET' clauses. + """ + if not valuedict: + return ("", tuple()) + keys, values = zip(*valuedict.items()) + + set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys) + + return (set_str, values) diff --git a/bot/data/tables.py b/bot/data/tables.py new file mode 100644 index 00000000..8fd7786f --- /dev/null +++ b/bot/data/tables.py @@ -0,0 +1,8 @@ +from .data import RowTable, Table + +raw_users = Table('Users') +users = RowTable( + 'users', + ('userid', 'tracked_time', 'coins'), + 'userid', +) diff --git a/bot/dev-main.py b/bot/dev-main.py deleted file mode 100644 index e69de29b..00000000 diff --git a/bot/dev_main.py b/bot/dev_main.py new file mode 100644 index 00000000..92a215af --- /dev/null +++ b/bot/dev_main.py @@ -0,0 +1,7 @@ +import logging +import meta + +meta.logger.setLevel(logging.DEBUG) +logging.getLogger("discord").setLevel(logging.INFO) + +import main # noqa diff --git a/bot/main.py b/bot/main.py index e69de29b..c601218f 100644 --- a/bot/main.py +++ b/bot/main.py @@ -0,0 +1,12 @@ +from meta import client, conf, log + +import data # noqa + +import modules # noqa + +# Initialise all modules +client.initialise_modules() + +# Log readyness and execute +log("Initial setup complete, logging in", context='SETUP') +client.run(conf.bot['TOKEN']) diff --git a/bot/meta/__init__.py b/bot/meta/__init__.py new file mode 100644 index 00000000..dd852d4f --- /dev/null +++ b/bot/meta/__init__.py @@ -0,0 +1,3 @@ +from .client import client +from .config import conf +from .logger import log, logger diff --git a/bot/meta/client.py b/bot/meta/client.py new file mode 100644 index 00000000..683ef7f7 --- /dev/null +++ b/bot/meta/client.py @@ -0,0 +1,13 @@ +from cmdClient.cmdClient import cmdClient + +from .config import Conf + +from constants import CONFIG_FILE + +# Initialise config +conf = Conf(CONFIG_FILE) + +# Initialise client +owners = [int(owner) for owner in conf.bot.getlist('owners')] +client = cmdClient(prefix=conf.bot['prefix'], owners=owners) +client.conf = conf diff --git a/bot/meta/config.py b/bot/meta/config.py new file mode 100644 index 00000000..a94d2b1a --- /dev/null +++ b/bot/meta/config.py @@ -0,0 +1,59 @@ +import configparser as cfgp + + +conf = None # type: Conf + +CONF_FILE = "bot/bot.conf" + + +class Conf: + def __init__(self, configfile, section_name="DEFAULT"): + self.configfile = configfile + + self.config = cfgp.ConfigParser( + converters={ + "intlist": self._getintlist, + "list": self._getlist, + } + ) + self.config.read(configfile) + + self.section_name = section_name if section_name in self.config else 'DEFAULT' + + self.default = self.config["DEFAULT"] + self.section = self.config[self.section_name] + self.bot = self.section + + # Config file recursion, read in configuration files specified in every "ALSO_READ" key. + more_to_read = self.section.getlist("ALSO_READ", []) + read = set() + while more_to_read: + to_read = more_to_read.pop(0) + read.add(to_read) + self.config.read(to_read) + new_paths = [path for path in self.section.getlist("ALSO_READ", []) + if path not in read and path not in more_to_read] + more_to_read.extend(new_paths) + + global conf + conf = self + + def __getitem__(self, key): + return self.section[key].strip() + + def __getattr__(self, section): + return self.config[section] + + def get(self, name, fallback=None): + result = self.section.get(name, fallback) + return result.strip() if result else result + + def _getintlist(self, value): + return [int(item.strip()) for item in value.split(',')] + + def _getlist(self, value): + return [item.strip() for item in value.split(',')] + + def write(self): + with open(self.configfile, 'w') as conffile: + self.config.write(conffile) diff --git a/bot/meta/logger.py b/bot/meta/logger.py new file mode 100644 index 00000000..dfb618ad --- /dev/null +++ b/bot/meta/logger.py @@ -0,0 +1,72 @@ +import sys +import logging +import asyncio +from discord import AllowedMentions + +from cmdClient.logger import cmd_log_handler + +from utils.lib import mail, split_text + +from .client import client +from .config import conf + + +# Setup the logger +logger = logging.getLogger() +log_fmt = logging.Formatter(fmt='[{asctime}][{levelname:^8}] {message}', datefmt='%d/%m | %H:%M:%S', style='{') +term_handler = logging.StreamHandler(sys.stdout) +term_handler.setFormatter(log_fmt) +logger.addHandler(term_handler) +logger.setLevel(logging.INFO) + + +# Define the context log format and attach it to the command logger as well +@cmd_log_handler +def log(message, context="GLOBAL", level=logging.INFO, post=True): + # Add prefixes to lines for better parsing capability + lines = message.splitlines() + if len(lines) > 1: + lines = [ + '┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line + for i, line in enumerate(lines) + ] + else: + lines = ['─ ' + message] + + for line in lines: + logger.log(level, '\b[{}] {}'.format( + str(context).center(22, '='), + line + )) + + # Fire and forget to the channel logger, if it is set up + if post and client.is_ready(): + asyncio.ensure_future(live_log(message, context, level)) + + +# Live logger that posts to the logging channels +async def live_log(message, context, level): + if level >= logging.INFO: + log_chid = conf.bot.getint('log_channel') + + # Generate the log messages + header = "[{}][{}]".format(logging.getLevelName(level), str(context)) + if len(message) > 1900: + blocks = split_text(message, blocksize=1900, code=False) + else: + blocks = [message] + + if len(blocks) > 1: + blocks = [ + "```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks) + ] + else: + blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])] + + # Post the log messages + if log_chid: + [await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks] + + +# Attach logger to client, for convenience +client.log = log diff --git a/bot/modules/__init__.py b/bot/modules/__init__.py index e69de29b..903d03c5 100644 --- a/bot/modules/__init__.py +++ b/bot/modules/__init__.py @@ -0,0 +1 @@ +from .sysadmin import * diff --git a/bot/modules/sysadmin/__init__.py b/bot/modules/sysadmin/__init__.py index e69de29b..5bccdb52 100644 --- a/bot/modules/sysadmin/__init__.py +++ b/bot/modules/sysadmin/__init__.py @@ -0,0 +1 @@ +from .exec_cmds import * diff --git a/bot/modules/sysadmin/exec_cmds.py b/bot/modules/sysadmin/exec_cmds.py new file mode 100644 index 00000000..d3af7f24 --- /dev/null +++ b/bot/modules/sysadmin/exec_cmds.py @@ -0,0 +1,123 @@ +import sys +from io import StringIO +import traceback +import asyncio + +from cmdClient import cmd, checks + +""" +Exec level commands to manage the bot. + +Commands provided: + async: + Executes provided code in an async executor + exec: + Executes code using standard python exec + eval: + Executes code and awaits it if required +""" + + +@cmd("reboot") +@checks.is_owner() +async def cmd_reboot(ctx): + """ + Usage``: + reboot + Description: + Update the timer status save file and reboot the client. + """ + ctx.client.interface.update_save("reboot") + ctx.client.interface.shutdown() + await ctx.reply("Saved state. Rebooting now!") + await ctx.client.close() + + +@cmd("async") +@checks.is_owner() +async def cmd_async(ctx): + """ + Usage: + {prefix}async + Description: + Runs as an asynchronous coroutine and prints the output or error. + """ + if ctx.arg_str == "": + await ctx.error_reply("You must give me something to run!") + return + output, error = await _async(ctx) + await ctx.reply( + "**Async input:**\ + \n```py\n{}\n```\ + \n**Output {}:** \ + \n```py\n{}\n```".format(ctx.arg_str, + "error" if error else "", + output)) + + +@cmd("eval") +@checks.is_owner() +async def cmd_eval(ctx): + """ + Usage: + {prefix}eval + Description: + Runs in current environment using eval() and prints the output or error. + """ + if ctx.arg_str == "": + await ctx.error_reply("You must give me something to run!") + return + output, error = await _eval(ctx) + await ctx.reply( + "**Eval input:**\ + \n```py\n{}\n```\ + \n**Output {}:** \ + \n```py\n{}\n```".format(ctx.arg_str, + "error" if error else "", + output) + ) + + +async def _eval(ctx): + output = None + try: + output = eval(ctx.arg_str) + except Exception: + return (str(traceback.format_exc()), 1) + if asyncio.iscoroutine(output): + output = await output + return (output, 0) + + +async def _async(ctx): + env = { + 'ctx': ctx, + 'client': ctx.client, + 'message': ctx.msg, + 'arg_str': ctx.arg_str + } + env.update(globals()) + old_stdout = sys.stdout + redirected_output = sys.stdout = StringIO() + result = None + exec_string = "async def _temp_exec():\n" + exec_string += '\n'.join(' ' * 4 + line for line in ctx.arg_str.split('\n')) + try: + exec(exec_string, env) + result = (redirected_output.getvalue(), 0) + except Exception: + result = (str(traceback.format_exc()), 1) + return result + _temp_exec = env['_temp_exec'] + try: + returnval = await _temp_exec() + value = redirected_output.getvalue() + if returnval is None: + result = (value, 0) + else: + result = (value + '\n' + str(returnval), 0) + except Exception: + result = (str(traceback.format_exc()), 1) + finally: + sys.stdout = old_stdout + return result diff --git a/bot/utils/lib.py b/bot/utils/lib.py new file mode 100644 index 00000000..e44123e3 --- /dev/null +++ b/bot/utils/lib.py @@ -0,0 +1,443 @@ +import datetime +import iso8601 +import re + +import discord + + +def prop_tabulate(prop_list, value_list, indent=True): + """ + Turns a list of properties and corresponding list of values into + a pretty string with one `prop: value` pair each line, + padded so that the colons in each line are lined up. + Handles empty props by using an extra couple of spaces instead of a `:`. + + Parameters + ---------- + prop_list: List[str] + List of short names to put on the right side of the list. + Empty props are considered to be "newlines" for the corresponding value. + value_list: List[str] + List of values corresponding to the properties above. + indent: bool + Whether to add padding so the properties are right-adjusted. + + Returns: str + """ + max_len = max(len(prop) for prop in prop_list) + return "".join(["`{}{}{}`\t{}{}".format("​ " * (max_len - len(prop)) if indent else "", + prop, + ":" if len(prop) else "​ " * 2, + value_list[i], + '' if str(value_list[i]).endswith("```") else '\n') + for i, prop in enumerate(prop_list)]) + + +def paginate_list(item_list, block_length=20, style="markdown", title=None): + """ + Create pretty codeblock pages from a list of strings. + + Parameters + ---------- + item_list: List[str] + List of strings to paginate. + block_length: int + Maximum number of strings per page. + style: str + Codeblock style to use. + Title formatting assumes the `markdown` style, and numbered lists work well with this. + However, `markdown` sometimes messes up formatting in the list. + title: str + Optional title to add to the top of each page. + + Returns: List[str] + List of pages, each formatted into a codeblock, + and containing at most `block_length` of the provided strings. + """ + lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] + page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] + pages = [] + for i, block in enumerate(page_blocks): + pagenum = "Page {}/{}".format(i + 1, len(page_blocks)) + if title: + header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title + else: + header = pagenum + header_line = "=" * len(header) + full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else "" + pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block))) + return pages + + +def timestamp_utcnow(): + """ + Return the current integer UTC timestamp. + """ + return int(datetime.datetime.timestamp(datetime.datetime.utcnow())) + + +def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50): + """ + Break the text into blocks of maximum length blocksize + If possible, break across nearby newlines. Otherwise just break at blocksize chars + + Parameters + ---------- + text: str + Text to break into blocks. + blocksize: int + Maximum character length for each block. + code: bool + Whether to wrap each block in codeblocks (these are counted in the blocksize). + syntax: str + The markdown formatting language to use for the codeblocks, if applicable. + maxheight: int + The maximum number of lines in each block + + Returns: List[str] + List of blocks, + each containing at most `block_size` characters, + of height at most `maxheight`. + """ + # Adjust blocksize to account for the codeblocks if required + blocksize = blocksize - 8 - len(syntax) if code else blocksize + + # Build the blocks + blocks = [] + while True: + # If the remaining text is already small enough, append it + if len(text) <= blocksize: + blocks.append(text) + break + text = text.strip('\n') + + # Find the last newline in the prototype block + split_on = text[0:blocksize].rfind('\n') + split_on = blocksize if split_on < blocksize // 5 else split_on + + # Add the block and truncate the text + blocks.append(text[0:split_on]) + text = text[split_on:] + + # Add the codeblock ticks and the code syntax header, if required + if code: + blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks] + + return blocks + + +def strfdelta(delta, sec=False, minutes=True, short=False): + """ + Convert a datetime.timedelta object into an easily readable duration string. + + Parameters + ---------- + delta: datetime.timedelta + The timedelta object to convert into a readable string. + sec: bool + Whether to include the seconds from the timedelta object in the string. + minutes: bool + Whether to include the minutes from the timedelta object in the string. + short: bool + Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s"). + + Returns: str + A string containing a time from the datetime.timedelta object, in a readable format. + Time units will be abbreviated if short was set to True. + """ + + output = [[delta.days, 'd' if short else ' day'], + [delta.seconds // 3600, 'h' if short else ' hour']] + if minutes: + output.append([delta.seconds // 60 % 60, 'm' if short else ' minute']) + if sec: + output.append([delta.seconds % 60, 's' if short else ' second']) + for i in range(len(output)): + if output[i][0] != 1 and not short: + output[i][1] += 's' + reply_msg = [] + if output[0][0] != 0: + reply_msg.append("{}{} ".format(output[0][0], output[0][1])) + if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2: + reply_msg.append("{}{} ".format(output[1][0], output[1][1])) + for i in range(2, len(output) - 1): + reply_msg.append("{}{} ".format(output[i][0], output[i][1])) + if not short and reply_msg: + reply_msg.append("and ") + reply_msg.append("{}{}".format(output[-1][0], output[-1][1])) + return "".join(reply_msg) + + +def parse_dur(time_str): + """ + Parses a user provided time duration string into a timedelta object. + + Parameters + ---------- + time_str: str + The time string to parse. String can include days, hours, minutes, and seconds. + + Returns: int + The number of seconds the duration represents. + """ + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + time_str = time_str.strip(" ,") + found = re.findall(r'(\d+)\s?(\w+?)', time_str) + seconds = 0 + for bit in found: + if bit[1] in funcs: + seconds += funcs[bit[1]](int(bit[0])) + return seconds + + +def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','): + """ + Substitutes a user provided list of numbers and ranges, + and replaces the ranges by the corresponding list of numbers. + + Parameters + ---------- + ranges_str: str + The string to ranges in. + max_match: int + The maximum number of ranges to replace. + Any ranges exceeding this will be ignored. + max_range: int + The maximum length of range to replace. + Attempting to replace a range longer than this will raise a `ValueError`. + """ + def _repl(match): + n1 = int(match.group(1)) + n2 = int(match.group(2)) + if n2 - n1 > max_range: + raise ValueError("Provided range exceeds the allowed maximum.") + return separator.join(str(i) for i in range(n1, n2 + 1)) + + return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match) + + +def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True): + """ + Format a message into a string with various information, such as: + the timestamp of the message, author, message content, and attachments. + + Parameters + ---------- + msg: Message + The message to format. + mask_link: bool + Whether to mask the URLs of any attachments. + line_break: bool + Whether a line break should be used in the string. + tz: Timezone + The timezone to use in the formatted message. + clean: bool + Whether to use the clean content of the original message. + + Returns: str + A formatted string containing various information: + User timezone, message author, message content, attachments + """ + timestr = "%I:%M %p, %d/%m/%Y" + if tz: + time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr) + else: + time = msg.timestamp.strftime(timestr) + user = str(msg.author) + attach_list = [attach["url"] for attach in msg.attachments if "url" in attach] + if mask_link: + attach_list = ["[Link]({})".format(url) for url in attach_list] + attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else "" + return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format( + time=time, + user=user, + line_break="\n" if line_break else "", + message=msg.clean_content if clean else msg.content, + attachments=attachments + ) + + +def convdatestring(datestring): + """ + Convert a date string into a datetime.timedelta object. + + Parameters + ---------- + datestring: str + The string to convert to a datetime.timedelta object. + + Returns: datetime.timedelta + A datetime.timedelta object formed from the string provided. + """ + datestring = datestring.strip(' ,') + datearray = [] + funcs = {'d': lambda x: x * 24 * 60 * 60, + 'h': lambda x: x * 60 * 60, + 'm': lambda x: x * 60, + 's': lambda x: x} + currentnumber = '' + for char in datestring: + if char.isdigit(): + currentnumber += char + else: + if currentnumber == '': + continue + datearray.append((int(currentnumber), char)) + currentnumber = '' + seconds = 0 + if currentnumber: + seconds += int(currentnumber) + for i in datearray: + if i[1] in funcs: + seconds += funcs[i[1]](i[0]) + return datetime.timedelta(seconds=seconds) + + +class _rawChannel(discord.abc.Messageable): + """ + Raw messageable class representing an arbitrary channel, + not necessarially seen by the gateway. + """ + def __init__(self, state, id): + self._state = state + self.id = id + + async def _get_channel(self): + return discord.Object(self.id) + + +async def mail(client: discord.Client, channelid: int, **msg_args): + """ + Mails a message to a channelid which may be invisible to the gateway. + + Parameters: + client: discord.Client + The client to use for mailing. + Must at least have static authentication and have a valid `_connection`. + channelid: int + The channel id to mail to. + msg_args: Any + Message keyword arguments which are passed transparently to `_rawChannel.send(...)`. + """ + # Create the raw channel + channel = _rawChannel(client._connection, channelid) + return await channel.send(**msg_args) + + +def emb_add_fields(embed, emb_fields): + """ + Append embed fields to an embed. + Parameters + ---------- + embed: discord.Embed + The embed to add the field to. + emb_fields: tuple + The values to add to a field. + name: str + The name of the field. + value: str + The value of the field. + inline: bool + Whether the embed field should be inline or not. + """ + for field in emb_fields: + embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2])) + + +def join_list(string, nfs=False): + """ + Join a list together, separated with commas, plus add "and" to the beginning of the last value. + Parameters + ---------- + string: list + The list to join together. + nfs: bool + (no fullstops) + Whether to exclude fullstops/periods from the output messages. + If not provided, fullstops will be appended to the output. + """ + if len(string) > 1: + return "{}{} and {}{}".format((", ").join(string[:-1]), + "," if len(string) > 2 else "", string[-1], "" if nfs else ".") + else: + return "{}{}".format("".join(string), "" if nfs else ".") + + +def format_activity(user): + """ + Format a user's activity string, depending on the type of activity. + Currently supported types are: + - Nothing + - Custom status + - Playing (with rich presence support) + - Streaming + - Listening (with rich presence support) + - Watching + - Unknown + Parameters + ---------- + user: discord.Member + The user to format the status of. + If the user has no activity, "Nothing" will be returned. + + Returns: str + A formatted string with various information about the user's current activity like the name, + and any extra information about the activity (such as current song artists for Spotify) + """ + if not user.activity: + return "Nothing" + + AT = user.activity.type + a = user.activity + if str(AT) == "ActivityType.custom": + return "Status: {}".format(a) + + if str(AT) == "ActivityType.playing": + string = "Playing {}".format(a.name) + try: + string += " ({})".format(a.details) + except Exception: + pass + + return string + + if str(AT) == "ActivityType.streaming": + return "Streaming {}".format(a.name) + + if str(AT) == "ActivityType.listening": + try: + string = "Listening to `{}`".format(a.title) + if len(a.artists) > 1: + string += " by {}".format(join_list(string=a.artists)) + else: + string += " by **{}**".format(a.artist) + except Exception: + string = "Listening to `{}`".format(a.name) + return string + + if str(AT) == "ActivityType.watching": + return "Watching `{}`".format(a.name) + + if str(AT) == "ActivityType.unknown": + return "Unknown" + + +def shard_of(shard_count: int, guildid: int): + """ + Calculate the shard number of a given guild. + """ + return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0 + + +def jumpto(guildid: int, channeldid: int, messageid: int): + """ + Build a jump link for a message given its location. + """ + return 'https://discord.com/channels/{}/{}/{}'.format( + guildid, + channeldid, + messageid + ) diff --git a/run.py b/run.py index e69de29b..56a190ed 100644 --- a/run.py +++ b/run.py @@ -0,0 +1,6 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.getcwd(), "bot")) + +from bot import dev_main