rewrite: New bot framework.
This commit is contained in:
70
bot/main.py
Normal file
70
bot/main.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from meta import LionBot, conf, sharding
|
||||||
|
from meta.logger import log_context, log_action
|
||||||
|
|
||||||
|
from data import Database
|
||||||
|
|
||||||
|
from constants import DATA_VERSION
|
||||||
|
|
||||||
|
# from data import tables
|
||||||
|
|
||||||
|
# Note: This MUST be imported after core, due to table definition orders
|
||||||
|
# from settings import AppSettings
|
||||||
|
|
||||||
|
# Load and attach app specific data
|
||||||
|
if sharding.sharded:
|
||||||
|
appname = f"{conf.data['appid']}_{sharding.shard_count}_{sharding.shard_number}"
|
||||||
|
else:
|
||||||
|
appname = conf.data['appid']
|
||||||
|
log_context.set(f"APP: {appname}")
|
||||||
|
|
||||||
|
# client.appdata = core.data.meta.fetch_or_create(appname)
|
||||||
|
|
||||||
|
# client.data = tables
|
||||||
|
|
||||||
|
# client.settings = AppSettings(conf.bot['data_appid'])
|
||||||
|
|
||||||
|
# Initialise all modules
|
||||||
|
# client.initialise_modules()
|
||||||
|
|
||||||
|
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
|
||||||
|
logging.getLogger(name).setLevel(conf.logging_levels[name])
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db = Database(conf.data['args'])
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
log_action.set("Initialising")
|
||||||
|
logger.info("Initialising StudyLion")
|
||||||
|
|
||||||
|
intents = discord.Intents.all()
|
||||||
|
intents.members = True
|
||||||
|
intents.message_content = True
|
||||||
|
|
||||||
|
async with await db.connect():
|
||||||
|
version = await db.version()
|
||||||
|
if version.version != DATA_VERSION:
|
||||||
|
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||||
|
logger.critical(error)
|
||||||
|
raise RuntimeError(error)
|
||||||
|
async with LionBot(
|
||||||
|
command_prefix=commands.when_mentioned,
|
||||||
|
intents=intents,
|
||||||
|
appname=appname,
|
||||||
|
db=db,
|
||||||
|
config=conf,
|
||||||
|
initial_extensions=['modules'],
|
||||||
|
web_client=None,
|
||||||
|
testing_guilds=[889875661848723456]
|
||||||
|
) as lionbot:
|
||||||
|
log_action.set("Launching")
|
||||||
|
await lionbot.start(conf.bot['TOKEN'])
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
40
bot/meta/LionBot.py
Normal file
40
bot/meta/LionBot.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
from data import Database
|
||||||
|
|
||||||
|
from .config import Conf
|
||||||
|
|
||||||
|
|
||||||
|
class LionBot(commands.Bot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
appname: str,
|
||||||
|
db: Database,
|
||||||
|
config: Conf,
|
||||||
|
initial_extensions: List[str],
|
||||||
|
web_client: ClientSession,
|
||||||
|
testing_guilds: List[int] = [],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.web_client = web_client
|
||||||
|
self.testing_guilds = testing_guilds
|
||||||
|
self.initial_extensions = initial_extensions
|
||||||
|
self.db = db
|
||||||
|
self.appname = appname
|
||||||
|
# self.appdata = appdata
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def setup_hook(self) -> None:
|
||||||
|
for extension in self.initial_extensions:
|
||||||
|
await self.load_extension(extension)
|
||||||
|
|
||||||
|
for guildid in self.testing_guilds:
|
||||||
|
guild = discord.Object(guildid)
|
||||||
|
self.tree.copy_global_to(guild=guild)
|
||||||
|
await self.tree.sync(guild=guild)
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
from .logger import log, logger
|
from .LionBot import LionBot
|
||||||
|
|
||||||
from . import interactions
|
|
||||||
from . import patches
|
|
||||||
|
|
||||||
from .client import client
|
|
||||||
from .config import conf
|
from .config import conf
|
||||||
from .args import args
|
from .args import args
|
||||||
from . import sharding
|
from . import sharding
|
||||||
|
from . import logger
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class configEmoji(PartialEmoji):
|
|||||||
animated, name, id = emojistr.split(':')
|
animated, name, id = emojistr.split(':')
|
||||||
return cls(
|
return cls(
|
||||||
name=name,
|
name=name,
|
||||||
fallback=PartialEmoji(name=fallback),
|
fallback=PartialEmoji(name=fallback) if fallback is not None else None,
|
||||||
animated=bool(animated),
|
animated=bool(animated),
|
||||||
id=int(id)
|
id=int(id)
|
||||||
)
|
)
|
||||||
@@ -60,11 +60,26 @@ class MapDotProxy:
|
|||||||
return self._map.__getitem__(key)
|
return self._map.__getitem__(key)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigParser(cfgp.ConfigParser):
|
||||||
|
"""
|
||||||
|
Extension of base ConfigParser allowing optional
|
||||||
|
section option retrieval without defaults.
|
||||||
|
"""
|
||||||
|
def options(self, section, no_defaults=False, **kwargs):
|
||||||
|
if no_defaults:
|
||||||
|
try:
|
||||||
|
return list(self._sections[section].keys())
|
||||||
|
except KeyError:
|
||||||
|
raise cfgp.NoSectionError(section)
|
||||||
|
else:
|
||||||
|
return super().options(section, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Conf:
|
class Conf:
|
||||||
def __init__(self, configfile, section_name="DEFAULT"):
|
def __init__(self, configfile, section_name="DEFAULT"):
|
||||||
self.configfile = configfile
|
self.configfile = configfile
|
||||||
|
|
||||||
self.config = cfgp.ConfigParser(
|
self.config = ConfigParser(
|
||||||
converters={
|
converters={
|
||||||
"intlist": self._getintlist,
|
"intlist": self._getintlist,
|
||||||
"list": self._getlist,
|
"list": self._getlist,
|
||||||
@@ -102,7 +117,7 @@ class Conf:
|
|||||||
return self.section[key].strip()
|
return self.section[key].strip()
|
||||||
|
|
||||||
def __getattr__(self, section):
|
def __getattr__(self, section):
|
||||||
return self.config[section]
|
return self.config[section.upper()]
|
||||||
|
|
||||||
def get(self, name, fallback=None):
|
def get(self, name, fallback=None):
|
||||||
result = self.section.get(name, fallback)
|
result = self.section.get(name, fallback)
|
||||||
@@ -119,4 +134,4 @@ class Conf:
|
|||||||
self.config.write(conffile)
|
self.config.write(conffile)
|
||||||
|
|
||||||
|
|
||||||
conf = Conf(args.config)
|
conf = Conf(args.config, 'STUDYLION')
|
||||||
146
bot/meta/logger.py
Normal file
146
bot/meta/logger.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from discord import AllowedMentions
|
||||||
|
|
||||||
|
from .config import conf
|
||||||
|
from . import sharding
|
||||||
|
|
||||||
|
|
||||||
|
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
|
||||||
|
log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION')
|
||||||
|
|
||||||
|
|
||||||
|
RESET_SEQ = "\033[0m"
|
||||||
|
COLOR_SEQ = "\033[3%dm"
|
||||||
|
BOLD_SEQ = "\033[1m"
|
||||||
|
"]]]"
|
||||||
|
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
|
||||||
|
|
||||||
|
def colour_escape(fmt: str) -> str:
|
||||||
|
cmap = {
|
||||||
|
'%(black)': COLOR_SEQ % BLACK,
|
||||||
|
'%(red)': COLOR_SEQ % RED,
|
||||||
|
'%(green)': COLOR_SEQ % GREEN,
|
||||||
|
'%(yellow)': COLOR_SEQ % YELLOW,
|
||||||
|
'%(blue)': COLOR_SEQ % BLUE,
|
||||||
|
'%(magenta)': COLOR_SEQ % MAGENTA,
|
||||||
|
'%(cyan)': COLOR_SEQ % CYAN,
|
||||||
|
'%(white)': COLOR_SEQ % WHITE,
|
||||||
|
'%(reset)': RESET_SEQ,
|
||||||
|
'%(bold)': BOLD_SEQ,
|
||||||
|
}
|
||||||
|
for key, value in cmap.items():
|
||||||
|
fmt = fmt.replace(key, value)
|
||||||
|
return fmt
|
||||||
|
|
||||||
|
|
||||||
|
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
|
||||||
|
'[%(cyan)SHARD {:02}%(reset)]'.format(sharding.shard_number) +
|
||||||
|
'[%(cyan)%(context)-22s%(reset)]' +
|
||||||
|
'[%(cyan)%(action)-22s%(reset)]' +
|
||||||
|
' %(bold)%(cyan)%(name)s:%(reset)' +
|
||||||
|
' %(white)%(message)s%(reset)')
|
||||||
|
log_format = colour_escape(log_format)
|
||||||
|
|
||||||
|
|
||||||
|
# Setup the logger
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_fmt = logging.Formatter(
|
||||||
|
fmt=log_format,
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
|
||||||
|
class LessThanFilter(logging.Filter):
|
||||||
|
def __init__(self, exclusive_maximum, name=""):
|
||||||
|
super(LessThanFilter, self).__init__(name)
|
||||||
|
self.max_level = exclusive_maximum
|
||||||
|
|
||||||
|
def filter(self, record):
|
||||||
|
# non-zero return means we log this message
|
||||||
|
return 1 if record.levelno < self.max_level else 0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextInjection(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
if not hasattr(record, 'context'):
|
||||||
|
record.context = log_context.get()
|
||||||
|
if not hasattr(record, 'action'):
|
||||||
|
record.action = log_action.get()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
logging_handler_out = logging.StreamHandler(sys.stdout)
|
||||||
|
logging_handler_out.setLevel(logging.DEBUG)
|
||||||
|
logging_handler_out.setFormatter(log_fmt)
|
||||||
|
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
|
||||||
|
logging_handler_out.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_out)
|
||||||
|
|
||||||
|
logging_handler_err = logging.StreamHandler(sys.stderr)
|
||||||
|
logging_handler_err.setLevel(logging.WARNING)
|
||||||
|
logging_handler_err.setFormatter(log_fmt)
|
||||||
|
logging_handler_err.addFilter(ContextInjection())
|
||||||
|
logger.addHandler(logging_handler_err)
|
||||||
|
|
||||||
|
# TODO: Add an async handler for posting
|
||||||
|
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
|
||||||
|
# Then we can handle error channels etc differently
|
||||||
|
# The formatting can be handled with a custom handler as well
|
||||||
|
|
||||||
|
|
||||||
|
# Define the context log format and attach it to the command logger as well
|
||||||
|
def log(message, context="GLOBAL", level=logging.INFO, post=True):
|
||||||
|
# Add prefixes to lines for better parsing capability
|
||||||
|
lines = message.splitlines()
|
||||||
|
if len(lines) > 1:
|
||||||
|
lines = [
|
||||||
|
'┌ ' * (i == 0) + '│ ' * (0 < i < len(lines) - 1) + '└ ' * (i == len(lines) - 1) + line
|
||||||
|
for i, line in enumerate(lines)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
lines = ['─ ' + message]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
logger.log(level, '\b[{}] {}'.format(
|
||||||
|
str(context).center(22, '='),
|
||||||
|
line
|
||||||
|
))
|
||||||
|
|
||||||
|
# Fire and forget to the channel logger, if it is set up
|
||||||
|
if post and client.is_ready():
|
||||||
|
asyncio.ensure_future(live_log(message, context, level))
|
||||||
|
|
||||||
|
|
||||||
|
# Live logger that posts to the logging channels
|
||||||
|
async def live_log(message, context, level):
|
||||||
|
if level >= logging.INFO:
|
||||||
|
if level >= logging.WARNING:
|
||||||
|
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
|
||||||
|
else:
|
||||||
|
log_chid = conf.bot.getint('log_channel')
|
||||||
|
|
||||||
|
# Generate the log messages
|
||||||
|
if sharding.sharded:
|
||||||
|
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
|
||||||
|
else:
|
||||||
|
header = f"[{logging.getLevelName(level)}][{context}]"
|
||||||
|
|
||||||
|
if len(message) > 1900:
|
||||||
|
blocks = split_text(message, blocksize=1900, code=False)
|
||||||
|
else:
|
||||||
|
blocks = [message]
|
||||||
|
|
||||||
|
if len(blocks) > 1:
|
||||||
|
blocks = [
|
||||||
|
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
|
||||||
|
|
||||||
|
# Post the log messages
|
||||||
|
if log_chid:
|
||||||
|
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]
|
||||||
@@ -1,9 +1,35 @@
|
|||||||
from .args import args
|
from .args import args
|
||||||
from .config import conf
|
from .config import conf
|
||||||
|
|
||||||
|
from psycopg import sql
|
||||||
|
from data.conditions import Condition, Joiner
|
||||||
|
|
||||||
|
|
||||||
shard_number = args.shard or 0
|
shard_number = args.shard or 0
|
||||||
|
|
||||||
shard_count = conf.bot.getint('shard_count', 1)
|
shard_count = conf.bot.getint('shard_count', 1)
|
||||||
|
|
||||||
sharded = (shard_count > 0)
|
sharded = (shard_count > 0)
|
||||||
|
|
||||||
|
|
||||||
|
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
|
||||||
|
"""
|
||||||
|
Condition constructor for filtering by shard id.
|
||||||
|
|
||||||
|
Example Usage
|
||||||
|
-------------
|
||||||
|
Query.where(_shard_condition('guildid', 10, 1))
|
||||||
|
"""
|
||||||
|
return Condition(
|
||||||
|
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
|
||||||
|
guildid=sql.Identifier(guild_column),
|
||||||
|
shard_count=sql.Literal(shard_count)
|
||||||
|
),
|
||||||
|
Joiner.EQUALS,
|
||||||
|
sql.Placeholder(),
|
||||||
|
(shard_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-built Condition for filtering by current shard.
|
||||||
|
THIS_SHARD = SHARDID(shard_number)
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
from meta import client, conf, log, sharding
|
|
||||||
|
|
||||||
from data import tables
|
|
||||||
|
|
||||||
import core # noqa
|
|
||||||
|
|
||||||
# Note: This MUST be imported after core, due to table definition orders
|
|
||||||
from settings import AppSettings
|
|
||||||
|
|
||||||
import modules # noqa
|
|
||||||
|
|
||||||
# Load and attach app specific data
|
|
||||||
if sharding.sharded:
|
|
||||||
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
|
|
||||||
else:
|
|
||||||
appname = conf.bot['data_appid']
|
|
||||||
client.appdata = core.data.meta.fetch_or_create(appname)
|
|
||||||
|
|
||||||
client.data = tables
|
|
||||||
|
|
||||||
client.settings = AppSettings(conf.bot['data_appid'])
|
|
||||||
|
|
||||||
# Initialise all modules
|
|
||||||
client.initialise_modules()
|
|
||||||
|
|
||||||
# Log readyness and execute
|
|
||||||
log("Initial setup complete, logging in", context='SETUP')
|
|
||||||
client.run(conf.bot['TOKEN'])
|
|
||||||
@@ -1,553 +0,0 @@
|
|||||||
import datetime
|
|
||||||
import iso8601
|
|
||||||
import re
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from psycopg2.extensions import QuotedString
|
|
||||||
|
|
||||||
from cmdClient.lib import SafeCancellation
|
|
||||||
|
|
||||||
|
|
||||||
multiselect_regex = re.compile(
|
|
||||||
r"^([0-9, -]+)$",
|
|
||||||
re.DOTALL | re.IGNORECASE | re.VERBOSE
|
|
||||||
)
|
|
||||||
tick = '✅'
|
|
||||||
cross = '❌'
|
|
||||||
|
|
||||||
|
|
||||||
def prop_tabulate(prop_list, value_list, indent=True, colon=True):
|
|
||||||
"""
|
|
||||||
Turns a list of properties and corresponding list of values into
|
|
||||||
a pretty string with one `prop: value` pair each line,
|
|
||||||
padded so that the colons in each line are lined up.
|
|
||||||
Handles empty props by using an extra couple of spaces instead of a `:`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
prop_list: List[str]
|
|
||||||
List of short names to put on the right side of the list.
|
|
||||||
Empty props are considered to be "newlines" for the corresponding value.
|
|
||||||
value_list: List[str]
|
|
||||||
List of values corresponding to the properties above.
|
|
||||||
indent: bool
|
|
||||||
Whether to add padding so the properties are right-adjusted.
|
|
||||||
|
|
||||||
Returns: str
|
|
||||||
"""
|
|
||||||
max_len = max(len(prop) for prop in prop_list)
|
|
||||||
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
|
|
||||||
prop,
|
|
||||||
(":" if len(prop) else " " * 2) if colon else '',
|
|
||||||
value_list[i],
|
|
||||||
'' if str(value_list[i]).endswith("```") else '\n')
|
|
||||||
for i, prop in enumerate(prop_list)])
|
|
||||||
|
|
||||||
|
|
||||||
def paginate_list(item_list, block_length=20, style="markdown", title=None):
|
|
||||||
"""
|
|
||||||
Create pretty codeblock pages from a list of strings.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
item_list: List[str]
|
|
||||||
List of strings to paginate.
|
|
||||||
block_length: int
|
|
||||||
Maximum number of strings per page.
|
|
||||||
style: str
|
|
||||||
Codeblock style to use.
|
|
||||||
Title formatting assumes the `markdown` style, and numbered lists work well with this.
|
|
||||||
However, `markdown` sometimes messes up formatting in the list.
|
|
||||||
title: str
|
|
||||||
Optional title to add to the top of each page.
|
|
||||||
|
|
||||||
Returns: List[str]
|
|
||||||
List of pages, each formatted into a codeblock,
|
|
||||||
and containing at most `block_length` of the provided strings.
|
|
||||||
"""
|
|
||||||
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
|
|
||||||
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
|
|
||||||
pages = []
|
|
||||||
for i, block in enumerate(page_blocks):
|
|
||||||
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
|
|
||||||
if title:
|
|
||||||
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
|
|
||||||
else:
|
|
||||||
header = pagenum
|
|
||||||
header_line = "=" * len(header)
|
|
||||||
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
|
|
||||||
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
|
|
||||||
return pages
|
|
||||||
|
|
||||||
|
|
||||||
def timestamp_utcnow():
|
|
||||||
"""
|
|
||||||
Return the current integer UTC timestamp.
|
|
||||||
"""
|
|
||||||
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
|
|
||||||
|
|
||||||
|
|
||||||
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
|
|
||||||
"""
|
|
||||||
Break the text into blocks of maximum length blocksize
|
|
||||||
If possible, break across nearby newlines. Otherwise just break at blocksize chars
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text: str
|
|
||||||
Text to break into blocks.
|
|
||||||
blocksize: int
|
|
||||||
Maximum character length for each block.
|
|
||||||
code: bool
|
|
||||||
Whether to wrap each block in codeblocks (these are counted in the blocksize).
|
|
||||||
syntax: str
|
|
||||||
The markdown formatting language to use for the codeblocks, if applicable.
|
|
||||||
maxheight: int
|
|
||||||
The maximum number of lines in each block
|
|
||||||
|
|
||||||
Returns: List[str]
|
|
||||||
List of blocks,
|
|
||||||
each containing at most `block_size` characters,
|
|
||||||
of height at most `maxheight`.
|
|
||||||
"""
|
|
||||||
# Adjust blocksize to account for the codeblocks if required
|
|
||||||
blocksize = blocksize - 8 - len(syntax) if code else blocksize
|
|
||||||
|
|
||||||
# Build the blocks
|
|
||||||
blocks = []
|
|
||||||
while True:
|
|
||||||
# If the remaining text is already small enough, append it
|
|
||||||
if len(text) <= blocksize:
|
|
||||||
blocks.append(text)
|
|
||||||
break
|
|
||||||
text = text.strip('\n')
|
|
||||||
|
|
||||||
# Find the last newline in the prototype block
|
|
||||||
split_on = text[0:blocksize].rfind('\n')
|
|
||||||
split_on = blocksize if split_on < blocksize // 5 else split_on
|
|
||||||
|
|
||||||
# Add the block and truncate the text
|
|
||||||
blocks.append(text[0:split_on])
|
|
||||||
text = text[split_on:]
|
|
||||||
|
|
||||||
# Add the codeblock ticks and the code syntax header, if required
|
|
||||||
if code:
|
|
||||||
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
|
|
||||||
|
|
||||||
return blocks
|
|
||||||
|
|
||||||
|
|
||||||
def strfdelta(delta, sec=False, minutes=True, short=False):
|
|
||||||
"""
|
|
||||||
Convert a datetime.timedelta object into an easily readable duration string.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
delta: datetime.timedelta
|
|
||||||
The timedelta object to convert into a readable string.
|
|
||||||
sec: bool
|
|
||||||
Whether to include the seconds from the timedelta object in the string.
|
|
||||||
minutes: bool
|
|
||||||
Whether to include the minutes from the timedelta object in the string.
|
|
||||||
short: bool
|
|
||||||
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
|
||||||
|
|
||||||
Returns: str
|
|
||||||
A string containing a time from the datetime.timedelta object, in a readable format.
|
|
||||||
Time units will be abbreviated if short was set to True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
output = [[delta.days, 'd' if short else ' day'],
|
|
||||||
[delta.seconds // 3600, 'h' if short else ' hour']]
|
|
||||||
if minutes:
|
|
||||||
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
|
||||||
if sec:
|
|
||||||
output.append([delta.seconds % 60, 's' if short else ' second'])
|
|
||||||
for i in range(len(output)):
|
|
||||||
if output[i][0] != 1 and not short:
|
|
||||||
output[i][1] += 's'
|
|
||||||
reply_msg = []
|
|
||||||
if output[0][0] != 0:
|
|
||||||
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
|
||||||
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
|
||||||
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
|
||||||
for i in range(2, len(output) - 1):
|
|
||||||
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
|
||||||
if not short and reply_msg:
|
|
||||||
reply_msg.append("and ")
|
|
||||||
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
|
||||||
return "".join(reply_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_dur(time_str):
|
|
||||||
"""
|
|
||||||
Parses a user provided time duration string into a timedelta object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
time_str: str
|
|
||||||
The time string to parse. String can include days, hours, minutes, and seconds.
|
|
||||||
|
|
||||||
Returns: int
|
|
||||||
The number of seconds the duration represents.
|
|
||||||
"""
|
|
||||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
|
||||||
'h': lambda x: x * 60 * 60,
|
|
||||||
'm': lambda x: x * 60,
|
|
||||||
's': lambda x: x}
|
|
||||||
time_str = time_str.strip(" ,")
|
|
||||||
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
|
||||||
seconds = 0
|
|
||||||
for bit in found:
|
|
||||||
if bit[1] in funcs:
|
|
||||||
seconds += funcs[bit[1]](int(bit[0]))
|
|
||||||
return seconds
|
|
||||||
|
|
||||||
|
|
||||||
def strfdur(duration, short=True, show_days=False):
|
|
||||||
"""
|
|
||||||
Convert a duration given in seconds to a number of hours, minutes, and seconds.
|
|
||||||
"""
|
|
||||||
days = duration // (3600 * 24) if show_days else 0
|
|
||||||
hours = duration // 3600
|
|
||||||
if days:
|
|
||||||
hours %= 24
|
|
||||||
minutes = duration // 60 % 60
|
|
||||||
seconds = duration % 60
|
|
||||||
|
|
||||||
parts = []
|
|
||||||
if days:
|
|
||||||
unit = 'd' if short else (' days' if days != 1 else ' day')
|
|
||||||
parts.append('{}{}'.format(days, unit))
|
|
||||||
if hours:
|
|
||||||
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
|
|
||||||
parts.append('{}{}'.format(hours, unit))
|
|
||||||
if minutes:
|
|
||||||
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
|
|
||||||
parts.append('{}{}'.format(minutes, unit))
|
|
||||||
if seconds or duration == 0:
|
|
||||||
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
|
|
||||||
parts.append('{}{}'.format(seconds, unit))
|
|
||||||
|
|
||||||
if short:
|
|
||||||
return ' '.join(parts)
|
|
||||||
else:
|
|
||||||
return ', '.join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
|
||||||
"""
|
|
||||||
Substitutes a user provided list of numbers and ranges,
|
|
||||||
and replaces the ranges by the corresponding list of numbers.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
ranges_str: str
|
|
||||||
The string to ranges in.
|
|
||||||
max_match: int
|
|
||||||
The maximum number of ranges to replace.
|
|
||||||
Any ranges exceeding this will be ignored.
|
|
||||||
max_range: int
|
|
||||||
The maximum length of range to replace.
|
|
||||||
Attempting to replace a range longer than this will raise a `ValueError`.
|
|
||||||
"""
|
|
||||||
def _repl(match):
|
|
||||||
n1 = int(match.group(1))
|
|
||||||
n2 = int(match.group(2))
|
|
||||||
if n2 - n1 > max_range:
|
|
||||||
raise SafeCancellation("Provided range is too large!")
|
|
||||||
return separator.join(str(i) for i in range(n1, n2 + 1))
|
|
||||||
|
|
||||||
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs):
|
|
||||||
"""
|
|
||||||
Parses a user provided range string into a list of numbers.
|
|
||||||
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
|
|
||||||
"""
|
|
||||||
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
|
|
||||||
numbers = (item.strip() for item in substituted.split(','))
|
|
||||||
numbers = [item for item in numbers if item]
|
|
||||||
integers = [int(item) for item in numbers if item.isdigit()]
|
|
||||||
|
|
||||||
if not ignore_errors and len(integers) != len(numbers):
|
|
||||||
raise SafeCancellation(
|
|
||||||
"Couldn't parse the provided selection!\n"
|
|
||||||
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
|
|
||||||
)
|
|
||||||
|
|
||||||
return integers
|
|
||||||
|
|
||||||
|
|
||||||
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
|
|
||||||
"""
|
|
||||||
Format a message into a string with various information, such as:
|
|
||||||
the timestamp of the message, author, message content, and attachments.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
msg: Message
|
|
||||||
The message to format.
|
|
||||||
mask_link: bool
|
|
||||||
Whether to mask the URLs of any attachments.
|
|
||||||
line_break: bool
|
|
||||||
Whether a line break should be used in the string.
|
|
||||||
tz: Timezone
|
|
||||||
The timezone to use in the formatted message.
|
|
||||||
clean: bool
|
|
||||||
Whether to use the clean content of the original message.
|
|
||||||
|
|
||||||
Returns: str
|
|
||||||
A formatted string containing various information:
|
|
||||||
User timezone, message author, message content, attachments
|
|
||||||
"""
|
|
||||||
timestr = "%I:%M %p, %d/%m/%Y"
|
|
||||||
if tz:
|
|
||||||
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
|
|
||||||
else:
|
|
||||||
time = msg.timestamp.strftime(timestr)
|
|
||||||
user = str(msg.author)
|
|
||||||
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
|
|
||||||
if mask_link:
|
|
||||||
attach_list = ["[Link]({})".format(url) for url in attach_list]
|
|
||||||
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
|
|
||||||
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
|
|
||||||
time=time,
|
|
||||||
user=user,
|
|
||||||
line_break="\n" if line_break else "",
|
|
||||||
message=msg.clean_content if clean else msg.content,
|
|
||||||
attachments=attachments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convdatestring(datestring):
|
|
||||||
"""
|
|
||||||
Convert a date string into a datetime.timedelta object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
datestring: str
|
|
||||||
The string to convert to a datetime.timedelta object.
|
|
||||||
|
|
||||||
Returns: datetime.timedelta
|
|
||||||
A datetime.timedelta object formed from the string provided.
|
|
||||||
"""
|
|
||||||
datestring = datestring.strip(' ,')
|
|
||||||
datearray = []
|
|
||||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
|
||||||
'h': lambda x: x * 60 * 60,
|
|
||||||
'm': lambda x: x * 60,
|
|
||||||
's': lambda x: x}
|
|
||||||
currentnumber = ''
|
|
||||||
for char in datestring:
|
|
||||||
if char.isdigit():
|
|
||||||
currentnumber += char
|
|
||||||
else:
|
|
||||||
if currentnumber == '':
|
|
||||||
continue
|
|
||||||
datearray.append((int(currentnumber), char))
|
|
||||||
currentnumber = ''
|
|
||||||
seconds = 0
|
|
||||||
if currentnumber:
|
|
||||||
seconds += int(currentnumber)
|
|
||||||
for i in datearray:
|
|
||||||
if i[1] in funcs:
|
|
||||||
seconds += funcs[i[1]](i[0])
|
|
||||||
return datetime.timedelta(seconds=seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class _rawChannel(discord.abc.Messageable):
|
|
||||||
"""
|
|
||||||
Raw messageable class representing an arbitrary channel,
|
|
||||||
not necessarially seen by the gateway.
|
|
||||||
"""
|
|
||||||
def __init__(self, state, id):
|
|
||||||
self._state = state
|
|
||||||
self.id = id
|
|
||||||
|
|
||||||
async def _get_channel(self):
|
|
||||||
return discord.Object(self.id)
|
|
||||||
|
|
||||||
|
|
||||||
async def mail(client: discord.Client, channelid: int, **msg_args):
|
|
||||||
"""
|
|
||||||
Mails a message to a channelid which may be invisible to the gateway.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
client: discord.Client
|
|
||||||
The client to use for mailing.
|
|
||||||
Must at least have static authentication and have a valid `_connection`.
|
|
||||||
channelid: int
|
|
||||||
The channel id to mail to.
|
|
||||||
msg_args: Any
|
|
||||||
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
|
|
||||||
"""
|
|
||||||
# Create the raw channel
|
|
||||||
channel = _rawChannel(client._connection, channelid)
|
|
||||||
return await channel.send(**msg_args)
|
|
||||||
|
|
||||||
|
|
||||||
def emb_add_fields(embed, emb_fields):
|
|
||||||
"""
|
|
||||||
Append embed fields to an embed.
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
embed: discord.Embed
|
|
||||||
The embed to add the field to.
|
|
||||||
emb_fields: tuple
|
|
||||||
The values to add to a field.
|
|
||||||
name: str
|
|
||||||
The name of the field.
|
|
||||||
value: str
|
|
||||||
The value of the field.
|
|
||||||
inline: bool
|
|
||||||
Whether the embed field should be inline or not.
|
|
||||||
"""
|
|
||||||
for field in emb_fields:
|
|
||||||
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
|
|
||||||
|
|
||||||
|
|
||||||
def join_list(string, nfs=False):
|
|
||||||
"""
|
|
||||||
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
string: list
|
|
||||||
The list to join together.
|
|
||||||
nfs: bool
|
|
||||||
(no fullstops)
|
|
||||||
Whether to exclude fullstops/periods from the output messages.
|
|
||||||
If not provided, fullstops will be appended to the output.
|
|
||||||
"""
|
|
||||||
if len(string) > 1:
|
|
||||||
return "{}{} and {}{}".format((", ").join(string[:-1]),
|
|
||||||
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
|
|
||||||
else:
|
|
||||||
return "{}{}".format("".join(string), "" if nfs else ".")
|
|
||||||
|
|
||||||
|
|
||||||
def format_activity(user):
|
|
||||||
"""
|
|
||||||
Format a user's activity string, depending on the type of activity.
|
|
||||||
Currently supported types are:
|
|
||||||
- Nothing
|
|
||||||
- Custom status
|
|
||||||
- Playing (with rich presence support)
|
|
||||||
- Streaming
|
|
||||||
- Listening (with rich presence support)
|
|
||||||
- Watching
|
|
||||||
- Unknown
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
user: discord.Member
|
|
||||||
The user to format the status of.
|
|
||||||
If the user has no activity, "Nothing" will be returned.
|
|
||||||
|
|
||||||
Returns: str
|
|
||||||
A formatted string with various information about the user's current activity like the name,
|
|
||||||
and any extra information about the activity (such as current song artists for Spotify)
|
|
||||||
"""
|
|
||||||
if not user.activity:
|
|
||||||
return "Nothing"
|
|
||||||
|
|
||||||
AT = user.activity.type
|
|
||||||
a = user.activity
|
|
||||||
if str(AT) == "ActivityType.custom":
|
|
||||||
return "Status: {}".format(a)
|
|
||||||
|
|
||||||
if str(AT) == "ActivityType.playing":
|
|
||||||
string = "Playing {}".format(a.name)
|
|
||||||
try:
|
|
||||||
string += " ({})".format(a.details)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
if str(AT) == "ActivityType.streaming":
|
|
||||||
return "Streaming {}".format(a.name)
|
|
||||||
|
|
||||||
if str(AT) == "ActivityType.listening":
|
|
||||||
try:
|
|
||||||
string = "Listening to `{}`".format(a.title)
|
|
||||||
if len(a.artists) > 1:
|
|
||||||
string += " by {}".format(join_list(string=a.artists))
|
|
||||||
else:
|
|
||||||
string += " by **{}**".format(a.artist)
|
|
||||||
except Exception:
|
|
||||||
string = "Listening to `{}`".format(a.name)
|
|
||||||
return string
|
|
||||||
|
|
||||||
if str(AT) == "ActivityType.watching":
|
|
||||||
return "Watching `{}`".format(a.name)
|
|
||||||
|
|
||||||
if str(AT) == "ActivityType.unknown":
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def shard_of(shard_count: int, guildid: int):
|
|
||||||
"""
|
|
||||||
Calculate the shard number of a given guild.
|
|
||||||
"""
|
|
||||||
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
|
|
||||||
|
|
||||||
|
|
||||||
def jumpto(guildid: int, channeldid: int, messageid: int):
|
|
||||||
"""
|
|
||||||
Build a jump link for a message given its location.
|
|
||||||
"""
|
|
||||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
|
||||||
guildid,
|
|
||||||
channeldid,
|
|
||||||
messageid
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DotDict(dict):
|
|
||||||
"""
|
|
||||||
Dict-type allowing dot access to keys.
|
|
||||||
"""
|
|
||||||
__getattr__ = dict.get
|
|
||||||
__setattr__ = dict.__setitem__
|
|
||||||
__delattr__ = dict.__delitem__
|
|
||||||
|
|
||||||
|
|
||||||
class FieldEnum(str, Enum):
|
|
||||||
"""
|
|
||||||
String enum with description conforming to the ISQLQuote protocol.
|
|
||||||
Allows processing by psycog
|
|
||||||
"""
|
|
||||||
def __new__(cls, value, desc):
|
|
||||||
obj = str.__new__(cls, value)
|
|
||||||
obj._value_ = value
|
|
||||||
obj.desc = desc
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '<%s.%s>' % (self.__class__.__name__, self.name)
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __conform__(self, proto):
|
|
||||||
return QuotedString(self.value)
|
|
||||||
|
|
||||||
|
|
||||||
def utc_now():
|
|
||||||
"""
|
|
||||||
Return the current timezone-aware utc timestamp.
|
|
||||||
"""
|
|
||||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
def multiple_replace(string, rep_dict):
|
|
||||||
if rep_dict:
|
|
||||||
pattern = re.compile(
|
|
||||||
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
|
|
||||||
flags=re.DOTALL
|
|
||||||
)
|
|
||||||
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
|
|
||||||
else:
|
|
||||||
return string
|
|
||||||
Reference in New Issue
Block a user