rewrite: New bot framework.

This commit is contained in:
2022-11-02 07:24:57 +02:00
parent 069c032e02
commit b27ee447b3
8 changed files with 303 additions and 591 deletions

70
bot/main.py Normal file
View File

@@ -0,0 +1,70 @@
import asyncio
import logging
import discord
from discord.ext import commands
from meta import LionBot, conf, sharding
from meta.logger import log_context, log_action
from data import Database
from constants import DATA_VERSION
# from data import tables
# Note: This MUST be imported after core, due to table definition orders
# from settings import AppSettings
# Load and attach app specific data
if sharding.sharded:
appname = f"{conf.data['appid']}_{sharding.shard_count}_{sharding.shard_number}"
else:
appname = conf.data['appid']
log_context.set(f"APP: {appname}")
# client.appdata = core.data.meta.fetch_or_create(appname)
# client.data = tables
# client.settings = AppSettings(conf.bot['data_appid'])
# Initialise all modules
# client.initialise_modules()
for name in conf.config.options('LOGGING_LEVELS', no_defaults=True):
logging.getLogger(name).setLevel(conf.logging_levels[name])
logger = logging.getLogger(__name__)
db = Database(conf.data['args'])
async def main():
log_action.set("Initialising")
logger.info("Initialising StudyLion")
intents = discord.Intents.all()
intents.members = True
intents.message_content = True
async with await db.connect():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
logger.critical(error)
raise RuntimeError(error)
async with LionBot(
command_prefix=commands.when_mentioned,
intents=intents,
appname=appname,
db=db,
config=conf,
initial_extensions=['modules'],
web_client=None,
testing_guilds=[889875661848723456]
) as lionbot:
log_action.set("Launching")
await lionbot.start(conf.bot['TOKEN'])
asyncio.run(main())

40
bot/meta/LionBot.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import List, Optional, Dict
import discord
from discord.ext import commands
from aiohttp import ClientSession
from data import Database
from .config import Conf
class LionBot(commands.Bot):
def __init__(
self,
*args,
appname: str,
db: Database,
config: Conf,
initial_extensions: List[str],
web_client: ClientSession,
testing_guilds: List[int] = [],
**kwargs,
):
super().__init__(*args, **kwargs)
self.web_client = web_client
self.testing_guilds = testing_guilds
self.initial_extensions = initial_extensions
self.db = db
self.appname = appname
# self.appdata = appdata
self.config = config
async def setup_hook(self) -> None:
for extension in self.initial_extensions:
await self.load_extension(extension)
for guildid in self.testing_guilds:
guild = discord.Object(guildid)
self.tree.copy_global_to(guild=guild)
await self.tree.sync(guild=guild)

View File

@@ -1,9 +1,5 @@
from .logger import log, logger from .LionBot import LionBot
from . import interactions
from . import patches
from .client import client
from .config import conf from .config import conf
from .args import args from .args import args
from . import sharding from . import sharding
from . import logger

View File

@@ -27,7 +27,7 @@ class configEmoji(PartialEmoji):
animated, name, id = emojistr.split(':') animated, name, id = emojistr.split(':')
return cls( return cls(
name=name, name=name,
fallback=PartialEmoji(name=fallback), fallback=PartialEmoji(name=fallback) if fallback is not None else None,
animated=bool(animated), animated=bool(animated),
id=int(id) id=int(id)
) )
@@ -60,11 +60,26 @@ class MapDotProxy:
return self._map.__getitem__(key) return self._map.__getitem__(key)
class ConfigParser(cfgp.ConfigParser):
"""
Extension of base ConfigParser allowing optional
section option retrieval without defaults.
"""
def options(self, section, no_defaults=False, **kwargs):
if no_defaults:
try:
return list(self._sections[section].keys())
except KeyError:
raise cfgp.NoSectionError(section)
else:
return super().options(section, **kwargs)
class Conf: class Conf:
def __init__(self, configfile, section_name="DEFAULT"): def __init__(self, configfile, section_name="DEFAULT"):
self.configfile = configfile self.configfile = configfile
self.config = cfgp.ConfigParser( self.config = ConfigParser(
converters={ converters={
"intlist": self._getintlist, "intlist": self._getintlist,
"list": self._getlist, "list": self._getlist,
@@ -102,7 +117,7 @@ class Conf:
return self.section[key].strip() return self.section[key].strip()
def __getattr__(self, section): def __getattr__(self, section):
return self.config[section] return self.config[section.upper()]
def get(self, name, fallback=None): def get(self, name, fallback=None):
result = self.section.get(name, fallback) result = self.section.get(name, fallback)
@@ -119,4 +134,4 @@ class Conf:
self.config.write(conffile) self.config.write(conffile)
conf = Conf(args.config) conf = Conf(args.config, 'STUDYLION')

146
bot/meta/logger.py Normal file
View File

@@ -0,0 +1,146 @@
import sys
import logging
import asyncio
from contextvars import ContextVar
from discord import AllowedMentions
from .config import conf
from . import sharding
log_context: ContextVar[str] = ContextVar('logging_context', default='CTX: ROOT CONTEXT')
log_action: ContextVar[str] = ContextVar('logging_action', default='UNKNOWN ACTION')
RESET_SEQ = "\033[0m"
COLOR_SEQ = "\033[3%dm"
BOLD_SEQ = "\033[1m"
"]]]"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
def colour_escape(fmt: str) -> str:
cmap = {
'%(black)': COLOR_SEQ % BLACK,
'%(red)': COLOR_SEQ % RED,
'%(green)': COLOR_SEQ % GREEN,
'%(yellow)': COLOR_SEQ % YELLOW,
'%(blue)': COLOR_SEQ % BLUE,
'%(magenta)': COLOR_SEQ % MAGENTA,
'%(cyan)': COLOR_SEQ % CYAN,
'%(white)': COLOR_SEQ % WHITE,
'%(reset)': RESET_SEQ,
'%(bold)': BOLD_SEQ,
}
for key, value in cmap.items():
fmt = fmt.replace(key, value)
return fmt
log_format = ('[%(green)%(asctime)-19s%(reset)][%(red)%(levelname)-8s%(reset)]' +
'[%(cyan)SHARD {:02}%(reset)]'.format(sharding.shard_number) +
'[%(cyan)%(context)-22s%(reset)]' +
'[%(cyan)%(action)-22s%(reset)]' +
' %(bold)%(cyan)%(name)s:%(reset)' +
' %(white)%(message)s%(reset)')
log_format = colour_escape(log_format)
# Setup the logger
logger = logging.getLogger()
log_fmt = logging.Formatter(
fmt=log_format,
datefmt='%Y-%m-%d %H:%M:%S'
)
logger.setLevel(logging.NOTSET)
class LessThanFilter(logging.Filter):
def __init__(self, exclusive_maximum, name=""):
super(LessThanFilter, self).__init__(name)
self.max_level = exclusive_maximum
def filter(self, record):
# non-zero return means we log this message
return 1 if record.levelno < self.max_level else 0
class ContextInjection(logging.Filter):
def filter(self, record):
if not hasattr(record, 'context'):
record.context = log_context.get()
if not hasattr(record, 'action'):
record.action = log_action.get()
return True
logging_handler_out = logging.StreamHandler(sys.stdout)
logging_handler_out.setLevel(logging.DEBUG)
logging_handler_out.setFormatter(log_fmt)
logging_handler_out.addFilter(LessThanFilter(logging.WARNING))
logging_handler_out.addFilter(ContextInjection())
logger.addHandler(logging_handler_out)
logging_handler_err = logging.StreamHandler(sys.stderr)
logging_handler_err.setLevel(logging.WARNING)
logging_handler_err.setFormatter(log_fmt)
logging_handler_err.addFilter(ContextInjection())
logger.addHandler(logging_handler_err)
# TODO: Add an async handler for posting
# Subclass this, create a DiscordChannelHandler, taking a Client and a channel as an argument
# Then we can handle error channels etc differently
# The formatting can be handled with a custom handler as well
# Define the context log format and attach it to the command logger as well
def log(message, context="GLOBAL", level=logging.INFO, post=True):
# Add prefixes to lines for better parsing capability
lines = message.splitlines()
if len(lines) > 1:
lines = [
'' * (i == 0) + '' * (0 < i < len(lines) - 1) + '' * (i == len(lines) - 1) + line
for i, line in enumerate(lines)
]
else:
lines = ['' + message]
for line in lines:
logger.log(level, '\b[{}] {}'.format(
str(context).center(22, '='),
line
))
# Fire and forget to the channel logger, if it is set up
if post and client.is_ready():
asyncio.ensure_future(live_log(message, context, level))
# Live logger that posts to the logging channels
async def live_log(message, context, level):
if level >= logging.INFO:
if level >= logging.WARNING:
log_chid = conf.bot.getint('error_channel') or conf.bot.getint('log_channel')
else:
log_chid = conf.bot.getint('log_channel')
# Generate the log messages
if sharding.sharded:
header = f"[{logging.getLevelName(level)}][SHARD {sharding.shard_number}][{context}]"
else:
header = f"[{logging.getLevelName(level)}][{context}]"
if len(message) > 1900:
blocks = split_text(message, blocksize=1900, code=False)
else:
blocks = [message]
if len(blocks) > 1:
blocks = [
"```md\n{}[{}/{}]\n{}\n```".format(header, i+1, len(blocks), block) for i, block in enumerate(blocks)
]
else:
blocks = ["```md\n{}\n{}\n```".format(header, blocks[0])]
# Post the log messages
if log_chid:
[await mail(client, log_chid, content=block, allowed_mentions=AllowedMentions.none()) for block in blocks]

View File

@@ -1,9 +1,35 @@
from .args import args from .args import args
from .config import conf from .config import conf
from psycopg import sql
from data.conditions import Condition, Joiner
shard_number = args.shard or 0 shard_number = args.shard or 0
shard_count = conf.bot.getint('shard_count', 1) shard_count = conf.bot.getint('shard_count', 1)
sharded = (shard_count > 0) sharded = (shard_count > 0)
def SHARDID(shard_id: int, guild_column: str = 'guildid', shard_count: int = shard_count) -> Condition:
"""
Condition constructor for filtering by shard id.
Example Usage
-------------
Query.where(_shard_condition('guildid', 10, 1))
"""
return Condition(
sql.SQL("({guildid} >> 22) %% {shard_count}").format(
guildid=sql.Identifier(guild_column),
shard_count=sql.Literal(shard_count)
),
Joiner.EQUALS,
sql.Placeholder(),
(shard_id,)
)
# Pre-built Condition for filtering by current shard.
THIS_SHARD = SHARDID(shard_number)

View File

@@ -1,28 +0,0 @@
from meta import client, conf, log, sharding
from data import tables
import core # noqa
# Note: This MUST be imported after core, due to table definition orders
from settings import AppSettings
import modules # noqa
# Load and attach app specific data
if sharding.sharded:
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
else:
appname = conf.bot['data_appid']
client.appdata = core.data.meta.fetch_or_create(appname)
client.data = tables
client.settings = AppSettings(conf.bot['data_appid'])
# Initialise all modules
client.initialise_modules()
# Log readyness and execute
log("Initial setup complete, logging in", context='SETUP')
client.run(conf.bot['TOKEN'])

View File

@@ -1,553 +0,0 @@
import datetime
import iso8601
import re
from enum import Enum
import discord
from psycopg2.extensions import QuotedString
from cmdClient.lib import SafeCancellation
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
def prop_tabulate(prop_list, value_list, indent=True, colon=True):
"""
Turns a list of properties and corresponding list of values into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Handles empty props by using an extra couple of spaces instead of a `:`.
Parameters
----------
prop_list: List[str]
List of short names to put on the right side of the list.
Empty props are considered to be "newlines" for the corresponding value.
value_list: List[str]
List of values corresponding to the properties above.
indent: bool
Whether to add padding so the properties are right-adjusted.
Returns: str
"""
max_len = max(len(prop) for prop in prop_list)
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
prop,
(":" if len(prop) else " " * 2) if colon else '',
value_list[i],
'' if str(value_list[i]).endswith("```") else '\n')
for i, prop in enumerate(prop_list)])
def paginate_list(item_list, block_length=20, style="markdown", title=None):
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def timestamp_utcnow():
"""
Return the current integer UTC timestamp.
"""
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta, sec=False, minutes=True, short=False):
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's'
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def parse_dur(time_str):
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration, short=True, show_days=False):
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
raise SafeCancellation("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs):
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
raise SafeCancellation(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.timestamp.strftime(timestr)
user = str(msg.author)
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring):
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args):
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
def emb_add_fields(embed, emb_fields):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string, nfs=False):
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def format_activity(user):
"""
Format a user's activity string, depending on the type of activity.
Currently supported types are:
- Nothing
- Custom status
- Playing (with rich presence support)
- Streaming
- Listening (with rich presence support)
- Watching
- Unknown
Parameters
----------
user: discord.Member
The user to format the status of.
If the user has no activity, "Nothing" will be returned.
Returns: str
A formatted string with various information about the user's current activity like the name,
and any extra information about the activity (such as current song artists for Spotify)
"""
if not user.activity:
return "Nothing"
AT = user.activity.type
a = user.activity
if str(AT) == "ActivityType.custom":
return "Status: {}".format(a)
if str(AT) == "ActivityType.playing":
string = "Playing {}".format(a.name)
try:
string += " ({})".format(a.details)
except Exception:
pass
return string
if str(AT) == "ActivityType.streaming":
return "Streaming {}".format(a.name)
if str(AT) == "ActivityType.listening":
try:
string = "Listening to `{}`".format(a.title)
if len(a.artists) > 1:
string += " by {}".format(join_list(string=a.artists))
else:
string += " by **{}**".format(a.artist)
except Exception:
string = "Listening to `{}`".format(a.name)
return string
if str(AT) == "ActivityType.watching":
return "Watching `{}`".format(a.name)
if str(AT) == "ActivityType.unknown":
return "Unknown"
def shard_of(shard_count: int, guildid: int):
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int):
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
class DotDict(dict):
"""
Dict-type allowing dot access to keys.
"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class FieldEnum(str, Enum):
"""
String enum with description conforming to the ISQLQuote protocol.
Allows processing by psycog
"""
def __new__(cls, value, desc):
obj = str.__new__(cls, value)
obj._value_ = value
obj.desc = desc
return obj
def __repr__(self):
return '<%s.%s>' % (self.__class__.__name__, self.name)
def __bool__(self):
return True
def __conform__(self, proto):
return QuotedString(self.value)
def utc_now():
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string, rep_dict):
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string