rewrite: Initial rewrite skeleton.

Remove modules that will no longer be required.
Move pending modules to pending-rewrite folders.
This commit is contained in:
2022-09-17 17:06:13 +10:00
parent a7f7dd6e7b
commit a5147323b5
162 changed files with 1 additions and 866 deletions

View File

@@ -0,0 +1,88 @@
import types
from cmdClient import Context
from cmdClient.logger import log
class LionContext(Context):
"""
Subclass to allow easy attachment of custom hooks and structure to contexts.
"""
__slots__ = ()
@classmethod
def util(cls, util_func):
"""
Decorator to make a utility function available as a Context instance method.
Extends the default Context method to add logging and to return the utility function.
"""
super().util(util_func)
log(f"Attached context utility function: {util_func.__name__}")
return util_func
@classmethod
def wrappable_util(cls, util_func):
"""
Decorator to add a Wrappable utility function as a Context instance method.
"""
wrappable = Wrappable(util_func)
super().util(wrappable)
log(f"Attached wrappable context utility function: {util_func.__name__}")
return wrappable
class Wrappable:
__slots__ = ('_func', 'wrappers')
def __init__(self, func):
self._func = func
self.wrappers = None
@property
def __name__(self):
return self._func.__name__
def add_wrapper(self, func, name=None):
self.wrappers = self.wrappers or {}
name = name or func.__name__
self.wrappers[name] = func
log(
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
context="Wrapping"
)
def remove_wrapper(self, name):
if not self.wrappers or name not in self.wrappers:
raise ValueError(
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
)
self.wrappers.pop(name)
log(
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
context="Wrapping"
)
def __call__(self, *args, **kwargs):
if self.wrappers:
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
else:
return self._func(*args, **kwargs)
def _wrapped(self, iter_wraps):
next_wrap = next(iter_wraps, None)
if next_wrap:
def _func(*args, **kwargs):
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
else:
_func = self._func
return _func
def __get__(self, instance, cls=None):
if instance is None:
return self
else:
return types.MethodType(self, instance)
# Override the original Context.reply with a wrappable utility
reply = LionContext.wrappable_util(Context.reply)

View File

@@ -0,0 +1,186 @@
import asyncio
import traceback
import logging
import discord
from cmdClient import Command, Module, FailedCheck
from cmdClient.lib import SafeCancellation
from meta import log
class LionCommand(Command):
"""
Subclass to allow easy attachment of custom hooks and structure to commands.
"""
allow_before_ready = False
class LionModule(Module):
"""
Custom module for Lion systems.
Adds command wrappers and various event handlers.
"""
name = "Base Lion Module"
def __init__(self, name, baseCommand=LionCommand):
super().__init__(name, baseCommand)
self.unload_tasks = []
def unload_task(self, func):
"""
Decorator adding an unload task for deactivating the module.
Should sync unsaved transactions and finalise user interaction.
If possible, should also remove attached data and handlers.
"""
self.unload_tasks.append(func)
log("Adding unload task '{}'.".format(func.__name__), context=self.name)
return func
async def unload(self, client):
"""
Run the unloading tasks.
"""
log("Unloading module.", context=self.name, post=False)
for task in self.unload_tasks:
log("Running unload task '{}'".format(task.__name__),
context=self.name, post=False)
await task(client)
async def launch(self, client):
"""
Launch hook.
Executed in `client.on_ready`.
Must set `ready` to `True`, otherwise all commands will hang.
Overrides the parent launcher to not post the log as a discord message.
"""
if not self.ready:
log("Running launch tasks.", context=self.name, post=False)
for task in self.launch_tasks:
log("Running launch task '{}'.".format(task.__name__),
context=self.name, post=False)
await task(client)
self.ready = True
else:
log("Already launched, skipping launch.", context=self.name, post=False)
async def pre_command(self, ctx):
"""
Lion pre-command hook.
"""
if not self.ready and not ctx.cmd.allow_before_ready:
try:
await ctx.embed_reply(
"I am currently restarting! Please try again in a couple of minutes."
)
except discord.HTTPException:
pass
raise SafeCancellation(details="Module '{}' is not ready.".format(self.name))
# Check global user blacklist
if ctx.author.id in ctx.client.user_blacklist():
raise SafeCancellation(details='User is blacklisted.')
if ctx.guild:
# Check that the channel and guild still exists
if not ctx.client.get_guild(ctx.guild.id) or not ctx.guild.get_channel(ctx.ch.id):
raise SafeCancellation(details='Command channel is no longer reachable.')
# Check global guild blacklist
if ctx.guild.id in ctx.client.guild_blacklist():
raise SafeCancellation(details='Guild is blacklisted.')
# Check guild's own member blacklist
if ctx.author.id in ctx.client.objects['ignored_members'][ctx.guild.id]:
raise SafeCancellation(details='User is ignored in this guild.')
# Check channel permissions are sane
if not ctx.ch.permissions_for(ctx.guild.me).send_messages:
raise SafeCancellation(details='I cannot send messages in this channel.')
if not ctx.ch.permissions_for(ctx.guild.me).embed_links:
await ctx.reply("I need permission to send embeds in this channel before I can run any commands!")
raise SafeCancellation(details='I cannot send embeds in this channel.')
# Ensure Lion exists and cached data is up to date
ctx.alion.update_saved_data(ctx.author)
# Start typing
await ctx.ch.trigger_typing()
async def on_exception(self, ctx, exception):
try:
raise exception
except (FailedCheck, SafeCancellation):
# cmdClient generated and handled exceptions
raise exception
except (asyncio.CancelledError, asyncio.TimeoutError):
# Standard command and task exceptions, cmdClient will also handle these
raise exception
except discord.Forbidden:
# Unknown uncaught Forbidden
try:
# Attempt a general error reply
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
except discord.Forbidden:
# We can't send anything at all. Exit quietly, but log.
full_traceback = traceback.format_exc()
log(("Caught an unhandled 'Forbidden' while "
"executing command '{cmdname}' from module '{module}' "
"from user '{message.author}' (uid:{message.author.id}) "
"in guild '{message.guild}' (gid:{guildid}) "
"in channel '{message.channel}' (cid:{message.channel.id}).\n"
"Message Content:\n"
"{content}\n"
"{traceback}\n\n"
"{flat_ctx}").format(
cmdname=ctx.cmd.name,
module=ctx.cmd.module.name,
message=ctx.msg,
guildid=ctx.guild.id if ctx.guild else None,
content='\n'.join('\t' + line for line in ctx.msg.content.splitlines()),
traceback=full_traceback,
flat_ctx=ctx.flatten()
),
context="mid:{}".format(ctx.msg.id),
level=logging.WARNING)
except Exception as e:
# Unknown exception!
full_traceback = traceback.format_exc()
only_error = "".join(traceback.TracebackException.from_exception(e).format_exception_only())
log(("Caught an unhandled exception while "
"executing command '{cmdname}' from module '{module}' "
"from user '{message.author}' (uid:{message.author.id}) "
"in guild '{message.guild}' (gid:{guildid}) "
"in channel '{message.channel}' (cid:{message.channel.id}).\n"
"Message Content:\n"
"{content}\n"
"{traceback}\n\n"
"{flat_ctx}").format(
cmdname=ctx.cmd.name,
module=ctx.cmd.module.name,
message=ctx.msg,
guildid=ctx.guild.id if ctx.guild else None,
content='\n'.join('\t' + line for line in ctx.msg.content.splitlines()),
traceback=full_traceback,
flat_ctx=ctx.flatten()
),
context="mid:{}".format(ctx.msg.id),
level=logging.ERROR)
error_embed = discord.Embed(title="Something went wrong!")
error_embed.description = (
"An unexpected error occurred while processing your command!\n"
"Our development team has been notified, and the issue should be fixed soon.\n"
)
if logging.getLogger().getEffectiveLevel() < logging.INFO:
error_embed.add_field(
name="Exception",
value="`{}`".format(only_error)
)
await ctx.reply(embed=error_embed)

View File

@@ -0,0 +1,5 @@
from . import data # noqa
from .module import module
from .lion import Lion
from . import blacklists

View File

@@ -0,0 +1,92 @@
"""
Guild, user, and member blacklists.
"""
from collections import defaultdict
import cachetools.func
from data import tables
from meta import client
from .module import module
@cachetools.func.ttl_cache(ttl=300)
def guild_blacklist():
"""
Get the guild blacklist
"""
rows = tables.global_guild_blacklist.select_where()
return set(row['guildid'] for row in rows)
@cachetools.func.ttl_cache(ttl=300)
def user_blacklist():
"""
Get the global user blacklist.
"""
rows = tables.global_user_blacklist.select_where()
return set(row['userid'] for row in rows)
@module.init_task
def load_ignored_members(client):
"""
Load the ignored members.
"""
ignored = defaultdict(set)
rows = tables.ignored_members.select_where()
for row in rows:
ignored[row['guildid']].add(row['userid'])
client.objects['ignored_members'] = ignored
if rows:
client.log(
"Loaded {} ignored members across {} guilds.".format(
len(rows),
len(ignored)
),
context="MEMBER_BLACKLIST"
)
@module.init_task
def attach_client_blacklists(client):
client.guild_blacklist = guild_blacklist
client.user_blacklist = user_blacklist
@module.launch_task
async def leave_blacklisted_guilds(client):
"""
Launch task to leave any blacklisted guilds we are in.
"""
to_leave = [
guild for guild in client.guilds
if guild.id in guild_blacklist()
]
for guild in to_leave:
await guild.leave()
if to_leave:
client.log(
"Left {} blacklisted guilds!".format(len(to_leave)),
context="GUILD_BLACKLIST"
)
@client.add_after_event('guild_join')
async def check_guild_blacklist(client, guild):
"""
Guild join event handler to check whether the guild is blacklisted.
If so, leaves the guild.
"""
# First refresh the blacklist cache
if guild.id in guild_blacklist():
await guild.leave()
client.log(
"Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id),
context="GUILD_BLACKLIST"
)

View File

@@ -0,0 +1,128 @@
from psycopg2.extras import execute_values
from cachetools import TTLCache
from data import RowTable, Table
meta = RowTable(
'AppData',
('appid', 'last_study_badge_scan'),
'appid',
attach_as='meta',
)
# TODO: Consider converting to RowTable for per-shard config caching
app_config = Table('AppConfig')
user_config = RowTable(
'user_config',
('userid', 'timezone', 'topgg_vote_reminder', 'avatar_hash', 'gems'),
'userid',
cache=TTLCache(5000, ttl=60*5)
)
guild_config = RowTable(
'guild_config',
('guildid', 'admin_role', 'mod_role', 'event_log_channel', 'mod_log_channel', 'alert_channel',
'studyban_role', 'max_study_bans',
'min_workout_length', 'workout_reward',
'max_tasks', 'task_reward', 'task_reward_limit',
'study_hourly_reward', 'study_hourly_live_bonus', 'daily_study_cap',
'renting_price', 'renting_category', 'renting_cap', 'renting_role', 'renting_sync_perms',
'accountability_category', 'accountability_lobby', 'accountability_bonus',
'accountability_reward', 'accountability_price',
'video_studyban', 'video_grace_period',
'greeting_channel', 'greeting_message', 'returning_message',
'starting_funds', 'persist_roles',
'pomodoro_channel',
'name'),
'guildid',
cache=TTLCache(2500, ttl=60*5)
)
unranked_roles = Table('unranked_roles')
donator_roles = Table('donator_roles')
lions = RowTable(
'members',
('guildid', 'userid',
'tracked_time', 'coins',
'workout_count', 'last_workout_start',
'revision_mute_count',
'last_study_badgeid',
'video_warned',
'display_name',
'_timestamp'
),
('guildid', 'userid'),
cache=TTLCache(5000, ttl=60*5),
attach_as='lions'
)
@lions.save_query
def add_pending(pending):
"""
pending:
List of tuples of the form `(guildid, userid, pending_coins)`.
"""
with lions.conn:
cursor = lions.conn.cursor()
data = execute_values(
cursor,
"""
UPDATE members
SET
coins = LEAST(coins + t.coin_diff, 2147483647)
FROM
(VALUES %s)
AS
t (guildid, userid, coin_diff)
WHERE
members.guildid = t.guildid
AND
members.userid = t.userid
RETURNING *
""",
pending,
fetch=True
)
return lions._make_rows(*data)
lion_ranks = Table('member_ranks', attach_as='lion_ranks')
@lions.save_query
def get_member_rank(guildid, userid, untracked):
"""
Get the time and coin ranking for the given member, ignoring the provided untracked members.
"""
with lions.conn as conn:
with conn.cursor() as curs:
curs.execute(
"""
SELECT
time_rank, coin_rank
FROM (
SELECT
userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals
WHERE
guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s
""",
(guildid, tuple(untracked), userid)
)
return curs.fetchone() or (None, None)
global_guild_blacklist = Table('global_guild_blacklist')
global_user_blacklist = Table('global_user_blacklist')
ignored_members = Table('ignored_members')

View File

@@ -0,0 +1,347 @@
import pytz
import discord
from functools import reduce
from datetime import datetime, timedelta
from meta import client
from data import tables as tb
from settings import UserSettings, GuildSettings
from LionContext import LionContext
class Lion:
"""
Class representing a guild Member.
Mostly acts as a transparent interface to the corresponding Row,
but also adds some transaction caching logic to `coins` and `tracked_time`.
"""
__slots__ = ('guildid', 'userid', '_pending_coins', '_member')
# Members with pending transactions
_pending = {} # userid -> User
# Lion cache. Currently lions don't expire
_lions = {} # (guildid, userid) -> Lion
# Extra methods supplying an economy bonus
_economy_bonuses = []
def __init__(self, guildid, userid):
self.guildid = guildid
self.userid = userid
self._pending_coins = 0
self._member = None
self._lions[self.key] = self
@classmethod
def fetch(cls, guildid, userid):
"""
Fetch a Lion with the given member.
If they don't exist, creates them.
If possible, retrieves the user from the user cache.
"""
key = (guildid, userid)
if key in cls._lions:
return cls._lions[key]
else:
# TODO: Debug log
lion = tb.lions.fetch(key)
if not lion:
tb.user_config.fetch_or_create(userid)
tb.guild_config.fetch_or_create(guildid)
tb.lions.create_row(
guildid=guildid,
userid=userid,
coins=GuildSettings(guildid).starting_funds.value
)
return cls(guildid, userid)
@property
def key(self):
return (self.guildid, self.userid)
@property
def guild(self) -> discord.Guild:
return client.get_guild(self.guildid)
@property
def member(self) -> discord.Member:
"""
The discord `Member` corresponding to this user.
May be `None` if the member is no longer in the guild or the caches aren't populated.
Not guaranteed to be `None` if the member is not in the guild.
"""
if self._member is None:
guild = client.get_guild(self.guildid)
if guild:
self._member = guild.get_member(self.userid)
return self._member
@property
def data(self):
"""
The Row corresponding to this member.
"""
return tb.lions.fetch(self.key)
@property
def user_data(self):
"""
The Row corresponding to this user.
"""
return tb.user_config.fetch_or_create(self.userid)
@property
def guild_data(self):
"""
The Row corresponding to this guild.
"""
return tb.guild_config.fetch_or_create(self.guildid)
@property
def settings(self):
"""
The UserSettings interface for this member.
"""
return UserSettings(self.userid)
@property
def guild_settings(self):
"""
The GuildSettings interface for this member.
"""
return GuildSettings(self.guildid)
@property
def ctx(self) -> LionContext:
"""
Manufacture a `LionContext` with the lion member as an author.
Useful for accessing member context utilities.
Be aware that `author` may be `None` if the member was not cached.
"""
return LionContext(client, guild=self.guild, author=self.member)
@property
def time(self):
"""
Amount of time the user has spent studying, accounting for a current session.
"""
# Base time from cached member data
time = self.data.tracked_time
# Add current session time if it exists
if session := self.session:
time += session.duration
return int(time)
@property
def coins(self):
"""
Number of coins the user has, accounting for the pending value and current session.
"""
# Base coin amount from cached member data
coins = self.data.coins
# Add pending coin amount
coins += self._pending_coins
# Add current session coins if applicable
if session := self.session:
coins += session.coins_earned
return int(coins)
@property
def economy_bonus(self):
"""
Economy multiplier
"""
return reduce(
lambda x, y: x * y,
[func(self) for func in self._economy_bonuses]
)
@classmethod
def register_economy_bonus(cls, func):
cls._economy_bonuses.append(func)
@classmethod
def unregister_economy_bonus(cls, func):
cls._economy_bonuses.remove(func)
@property
def session(self):
"""
The current study session the user is in, if any.
"""
if 'sessions' not in client.objects:
raise ValueError("Cannot retrieve session before Study module is initialised!")
return client.objects['sessions'][self.guildid].get(self.userid, None)
@property
def timezone(self):
"""
The user's configured timezone.
Shortcut to `Lion.settings.timezone.value`.
"""
return self.settings.timezone.value
@property
def day_start(self):
"""
A timezone aware datetime representing the start of the user's day (in their configured timezone).
NOTE: This might not be accurate over DST boundaries.
"""
now = datetime.now(tz=self.timezone)
return now.replace(hour=0, minute=0, second=0, microsecond=0)
@property
def day_timestamp(self):
"""
EPOCH timestamp representing the current day for the user.
NOTE: This is the timestamp of the start of the current UTC day with the same date as the user's day.
This is *not* the start of the current user's day, either in UTC or their own timezone.
This may also not be the start of the current day in UTC (consider 23:00 for a user in UTC-2).
"""
now = datetime.now(tz=self.timezone)
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
return int(day_start.replace(tzinfo=pytz.utc).timestamp())
@property
def week_timestamp(self):
"""
EPOCH timestamp representing the current week for the user.
"""
now = datetime.now(tz=self.timezone)
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
week_start = day_start - timedelta(days=day_start.weekday())
return int(week_start.replace(tzinfo=pytz.utc).timestamp())
@property
def month_timestamp(self):
"""
EPOCH timestamp representing the current month for the user.
"""
now = datetime.now(tz=self.timezone)
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return int(month_start.replace(tzinfo=pytz.utc).timestamp())
@property
def remaining_in_day(self):
return ((self.day_start + timedelta(days=1)) - datetime.now(self.timezone)).total_seconds()
@property
def studied_today(self):
"""
The amount of time, in seconds, that the member has studied today.
Extracted from the session history.
"""
return tb.session_history.queries.study_time_since(self.guildid, self.userid, self.day_start)
@property
def remaining_study_today(self):
"""
Maximum remaining time (in seconds) this member can study today.
May not account for DST boundaries and leap seconds.
"""
studied_today = self.studied_today
study_cap = self.guild_settings.daily_study_cap.value
remaining_in_day = self.remaining_in_day
if remaining_in_day >= (study_cap - studied_today):
remaining = study_cap - studied_today
else:
remaining = remaining_in_day + study_cap
return remaining
@property
def profile_tags(self):
"""
Returns a list of profile tags, or the default tags.
"""
tags = tb.profile_tags.queries.get_tags_for(self.guildid, self.userid)
prefix = self.ctx.best_prefix
return tags or [
f"Use {prefix}setprofile",
"and add your tags",
"to this section",
f"See {prefix}help setprofile for more"
]
@property
def name(self):
"""
Returns the best local name possible.
"""
if self.member:
name = self.member.display_name
elif self.data.display_name:
name = self.data.display_name
else:
name = str(self.userid)
return name
def update_saved_data(self, member: discord.Member):
"""
Update the stored discord data from the givem member.
Intended to be used when we get member data from events that may not be available in cache.
"""
if self.guild_data.name != member.guild.name:
self.guild_data.name = member.guild.name
if self.user_data.avatar_hash != member.avatar:
self.user_data.avatar_hash = member.avatar
if self.data.display_name != member.display_name:
self.data.display_name = member.display_name
def localize(self, naive_utc_dt):
"""
Localise the provided naive UTC datetime into the user's timezone.
"""
timezone = self.settings.timezone.value
return naive_utc_dt.replace(tzinfo=pytz.UTC).astimezone(timezone)
def addCoins(self, amount, flush=True, bonus=False):
"""
Add coins to the user, optionally store the transaction in pending.
"""
self._pending_coins += amount * (self.economy_bonus if bonus else 1)
self._pending[self.key] = self
if flush:
self.flush()
def flush(self):
"""
Flush any pending transactions to the database.
"""
self.sync(self)
@classmethod
def sync(cls, *lions):
"""
Flush pending transactions to the database.
Also refreshes the Row cache for updated lions.
"""
lions = lions or list(cls._pending.values())
if lions:
# Build userid to pending coin map
pending = [
(lion.guildid, lion.userid, int(lion._pending_coins))
for lion in lions
]
# Write to database
tb.lions.queries.add_pending(pending)
# Cleanup pending users
for lion in lions:
lion._pending_coins -= int(lion._pending_coins)
cls._pending.pop(lion.key, None)

View File

@@ -0,0 +1,80 @@
import logging
import asyncio
from meta import client, conf
from settings import GuildSettings, UserSettings
from LionModule import LionModule
from .lion import Lion
module = LionModule("Core")
async def _lion_sync_loop():
while True:
while not client.is_ready():
await asyncio.sleep(1)
client.log(
"Running lion data sync.",
context="CORE",
level=logging.DEBUG,
post=False
)
Lion.sync()
await asyncio.sleep(conf.bot.getint("lion_sync_period"))
@module.init_task
def setting_initialisation(client):
"""
Execute all Setting initialisation tasks from GuildSettings and UserSettings.
"""
for setting in GuildSettings.settings.values():
setting.init_task(client)
for setting in UserSettings.settings.values():
setting.init_task(client)
@module.launch_task
async def preload_guild_configuration(client):
"""
Loads the plain guild configuration for all guilds the client is part of into data.
"""
guildids = [guild.id for guild in client.guilds]
if guildids:
rows = client.data.guild_config.fetch_rows_where(guildid=guildids)
client.log(
"Preloaded guild configuration for {} guilds.".format(len(rows)),
context="CORE_LOADING"
)
@module.launch_task
async def preload_studying_members(client):
"""
Loads the member data for all members who are currently in voice channels.
"""
userids = list(set(member.id for guild in client.guilds for ch in guild.voice_channels for member in ch.members))
if userids:
users = client.data.user_config.fetch_rows_where(userid=userids)
members = client.data.lions.fetch_rows_where(userid=userids)
client.log(
"Preloaded data for {} user with {} members.".format(len(users), len(members)),
context="CORE_LOADING"
)
# Removing the sync loop in favour of the studybadge sync.
# @module.launch_task
# async def launch_lion_sync_loop(client):
# asyncio.create_task(_lion_sync_loop())
@module.unload_task
async def final_lion_sync(client):
Lion.sync()

View File

@@ -0,0 +1,9 @@
import logging
import meta
meta.logger.setLevel(logging.DEBUG)
logging.getLogger("discord").setLevel(logging.INFO)
from utils import interactive # noqa
import main # noqa

View File

@@ -0,0 +1,28 @@
from meta import client, conf, log, sharding
from data import tables
import core # noqa
# Note: This MUST be imported after core, due to table definition orders
from settings import AppSettings
import modules # noqa
# Load and attach app specific data
if sharding.sharded:
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
else:
appname = conf.bot['data_appid']
client.appdata = core.data.meta.fetch_or_create(appname)
client.data = tables
client.settings = AppSettings(conf.bot['data_appid'])
# Initialise all modules
client.initialise_modules()
# Log readyness and execute
log("Initial setup complete, logging in", context='SETUP')
client.run(conf.bot['TOKEN'])

View File

@@ -0,0 +1,6 @@
from .base import * # noqa
from .setting_types import * # noqa
from .user_settings import UserSettings, UserSetting # noqa
from .guild_settings import GuildSettings, GuildSetting # noqa
from .app_settings import AppSettings

View File

@@ -0,0 +1,5 @@
import settings
from utils.lib import DotDict
class AppSettings(settings.ObjectSettings):
settings = DotDict()

View File

@@ -0,0 +1,514 @@
import json
import discord
from cmdClient.cmdClient import cmdClient
from cmdClient.lib import SafeCancellation
from cmdClient.Check import Check
from utils.lib import prop_tabulate, DotDict
from LionContext import LionContext as Context
from meta import client
from data import Table, RowTable
class Setting:
"""
Abstract base class describing a stored configuration setting.
A setting consists of logic to load the setting from storage,
present it in a readable form, understand user entered values,
and write it again in storage.
Additionally, the setting has attributes attached describing
the setting in a user-friendly manner for display purposes.
"""
attr_name: str = None # Internal attribute name for the setting
_default: ... = None # Default data value for the setting.. this may be None if the setting overrides 'default'.
write_ward: Check = None # Check that must be passed to write the setting. Not implemented internally.
# Configuration interface descriptions
display_name: str = None # User readable name of the setting
desc: str = None # User readable brief description of the setting
long_desc: str = None # User readable long description of the setting
accepts: str = None # User readable description of the acceptable values
def __init__(self, id, data: ..., **kwargs):
self.client: cmdClient = client
self.id = id
self._data = data
# Configuration embeds
@property
def embed(self):
"""
Discord Embed showing an information summary about the setting.
"""
embed = discord.Embed(
title="Configuration options for `{}`".format(self.display_name),
)
fields = ("Current value", "Default value", "Accepted input")
values = (self.formatted or "Not Set",
self._format_data(self.id, self.default) or "None",
self.accepts)
table = prop_tabulate(fields, values)
embed.description = "{}\n{}".format(self.long_desc.format(self=self, client=self.client), table)
return embed
async def widget(self, ctx: Context, **kwargs):
"""
Show the setting widget for this setting.
By default this displays the setting embed.
Settings may override this if they need more complex widget context or logic.
"""
return await ctx.reply(embed=self.embed)
@property
def summary(self):
"""
Formatted summary of the data.
May be implemented in `_format_data(..., summary=True, ...)` or overidden.
"""
return self._format_data(self.id, self.data, summary=True)
@property
def success_response(self):
"""
Response message sent when the setting has successfully been updated.
"""
return "Setting Updated!"
# Instance generation
@classmethod
def get(cls, id: int, **kwargs):
"""
Return a setting instance initialised from the stored value.
"""
data = cls._reader(id, **kwargs)
return cls(id, data, **kwargs)
@classmethod
async def parse(cls, id: int, ctx: Context, userstr: str, **kwargs):
"""
Return a setting instance initialised from a parsed user string.
"""
data = await cls._parse_userstr(ctx, id, userstr, **kwargs)
return cls(id, data, **kwargs)
# Main interface
@property
def data(self):
"""
Retrieves the current internal setting data if it is set, otherwise the default data
"""
return self._data if self._data is not None else self.default
@data.setter
def data(self, new_data):
"""
Sets the internal setting data and writes the changes.
"""
self._data = new_data
self.write()
@property
def default(self):
"""
Retrieves the default value for this setting.
Settings should override this if the default depends on the object id.
"""
return self._default
@property
def value(self):
"""
Discord-aware object or objects associated with the setting.
"""
return self._data_to_value(self.id, self.data)
@value.setter
def value(self, new_value):
"""
Setter which reads the discord-aware object, converts it to data, and writes it.
"""
self._data = self._data_from_value(self.id, new_value)
self.write()
@property
def formatted(self):
"""
User-readable form of the setting.
"""
return self._format_data(self.id, self.data)
def write(self, **kwargs):
"""
Write value to the database.
For settings which override this,
ensure you handle deletion of values when internal data is None.
"""
self._writer(self.id, self._data, **kwargs)
# Raw converters
@classmethod
def _data_from_value(cls, id: int, value, **kwargs):
"""
Convert a high-level setting value to internal data.
Must be overriden by the setting.
Be aware of None values, these should always pass through as None
to provide an unsetting interface.
"""
raise NotImplementedError
@classmethod
def _data_to_value(cls, id: int, data: ..., **kwargs):
"""
Convert internal data to high-level setting value.
Must be overriden by the setting.
"""
raise NotImplementedError
@classmethod
async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs):
"""
Parse user provided input into internal data.
Must be overriden by the setting if the setting is user-configurable.
"""
raise NotImplementedError
@classmethod
def _format_data(cls, id: int, data: ..., **kwargs):
"""
Convert internal data into a formatted user-readable string.
Must be overriden by the setting if the setting is user-viewable.
"""
raise NotImplementedError
# Database access classmethods
@classmethod
def _reader(cls, id: int, **kwargs):
"""
Read a setting from storage and return setting data or None.
Must be overriden by the setting.
"""
raise NotImplementedError
@classmethod
def _writer(cls, id: int, data: ..., **kwargs):
"""
Write provided setting data to storage.
Must be overriden by the setting unless the `write` method is overidden.
If the data is None, the setting is empty and should be unset.
"""
raise NotImplementedError
@classmethod
async def command(cls, ctx, id, flags=()):
"""
Standardised command viewing/setting interface for the setting.
"""
if not ctx.args and not ctx.msg.attachments:
# View config embed for provided cls
await cls.get(id).widget(ctx, flags=flags)
else:
# Check the write ward
if cls.write_ward and not await cls.write_ward.run(ctx):
await ctx.error_reply(cls.write_ward.msg)
else:
# Attempt to set config cls
try:
cls = await cls.parse(id, ctx, ctx.args)
except UserInputError as e:
await ctx.reply(embed=discord.Embed(
description="{} {}".format('', e.msg),
Colour=discord.Colour.red()
))
else:
cls.write()
await ctx.reply(embed=discord.Embed(
description="{} {}".format('', cls.success_response),
Colour=discord.Colour.green()
))
@classmethod
def init_task(self, client):
"""
Initialisation task to be excuted during client initialisation.
May be used for e.g. populating a cache or required client setup.
Main application must execute the initialisation task before the setting is used.
Further, the task must always be executable, if the setting is loaded.
Conditional initalisation should go in the relevant module's init tasks.
"""
return None
class ObjectSettings:
"""
Abstract class representing a linked collection of settings for a single object.
Initialised settings are provided as instance attributes in the form of properties.
"""
__slots__ = ('id', 'params')
settings: DotDict = None
def __init__(self, id, **kwargs):
self.id = id
self.params = tuple(kwargs.items())
@classmethod
def _setting_property(cls, setting):
def wrapped_setting(self):
return setting.get(self.id, **dict(self.params))
return wrapped_setting
@classmethod
def attach_setting(cls, setting: Setting):
name = setting.attr_name or setting.__name__
setattr(cls, name, property(cls._setting_property(setting)))
cls.settings[name] = setting
return setting
def tabulated(self):
"""
Convenience method to provide a complete setting property-table.
"""
formatted = {
setting.display_name: setting.get(self.id, **dict(self.params)).formatted
for name, setting in self.settings.items()
}
return prop_tabulate(*zip(*formatted.items()))
class ColumnData:
"""
Mixin for settings stored in a single row and column of a Table.
Intended to be used with tables where the only primary key is the object id.
"""
# Table storing the desired data
_table_interface: Table = None
# Name of the column storing the setting object id
_id_column: str = None
# Name of the column with the desired data
_data_column: str = None
# Whether to use create a row if not found (only applies to TableRow)
_create_row = False
# Whether to upsert or update for updates
_upsert: bool = True
# High level data cache to use, set to None to disable cache.
_cache = None # Map[id -> value]
@classmethod
def _reader(cls, id: int, use_cache=True, **kwargs):
"""
Read in the requested entry associated to the id.
Supports reading cached values from a `RowTable`.
"""
if cls._cache is not None and id in cls._cache and use_cache:
return cls._cache[id]
table = cls._table_interface
if isinstance(table, RowTable) and cls._id_column == table.id_col:
if cls._create_row:
row = table.fetch_or_create(id)
else:
row = table.fetch(id)
data = row.data[cls._data_column] if row else None
else:
params = {
"select_columns": (cls._data_column,),
cls._id_column: id
}
row = table.select_one_where(**params)
data = row[cls._data_column] if row else None
if cls._cache is not None:
cls._cache[id] = data
return data
@classmethod
def _writer(cls, id: int, data: ..., **kwargs):
"""
Write the provided entry to the table, allowing replacements.
"""
table = cls._table_interface
params = {
cls._id_column: id
}
values = {
cls._data_column: data
}
# Update data
if cls._upsert:
# Upsert data
table.upsert(
constraint=cls._id_column,
**params,
**values
)
else:
# Update data
table.update_where(values, **params)
if cls._cache is not None:
cls._cache[id] = data
class ListData:
"""
Mixin for list types implemented on a Table.
Implements a reader and writer.
This assumes the list is the only data stored in the table,
and removes list entries by deleting rows.
"""
# Table storing the setting data
_table_interface: Table = None
# Name of the column storing the id
_id_column: str = None
# Name of the column storing the data to read
_data_column: str = None
# Name of column storing the order index to use, if any. Assumed to be Serial on writing.
_order_column: str = None
_order_type: str = "ASC"
# High level data cache to use, set to None to disable cache.
_cache = None # Map[id -> value]
@classmethod
def _reader(cls, id: int, use_cache=True, **kwargs):
"""
Read in all entries associated to the given id.
"""
if cls._cache is not None and id in cls._cache and use_cache:
return cls._cache[id]
table = cls._table_interface # type: Table
params = {
"select_columns": [cls._data_column],
cls._id_column: id
}
if cls._order_column:
params['_extra'] = "ORDER BY {} {}".format(cls._order_column, cls._order_type)
rows = table.select_where(**params)
data = [row[cls._data_column] for row in rows]
if cls._cache is not None:
cls._cache[id] = data
return data
@classmethod
def _writer(cls, id: int, data: ..., add_only=False, remove_only=False, **kwargs):
"""
Write the provided list to storage.
"""
# TODO: Transaction lock on the table so this is atomic
# Or just use the connection context manager
table = cls._table_interface # type: Table
# Handle None input as an empty list
if data is None:
data = []
current = cls._reader(id, **kwargs)
if not cls._order_column and (add_only or remove_only):
to_insert = [item for item in data if item not in current] if not remove_only else []
to_remove = data if remove_only else (
[item for item in current if item not in data] if not add_only else []
)
# Handle required deletions
if to_remove:
params = {
cls._id_column: id,
cls._data_column: to_remove
}
table.delete_where(**params)
# Handle required insertions
if to_insert:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in to_insert]
table.insert_many(*values, insert_keys=columns)
if cls._cache is not None:
new_current = [item for item in current + to_insert if item not in to_remove]
cls._cache[id] = new_current
else:
# Remove all and add all to preserve order
# TODO: This really really should be atomic if anything else reads this
delete_params = {cls._id_column: id}
table.delete_where(**delete_params)
if data:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in data]
table.insert_many(*values, insert_keys=columns)
if cls._cache is not None:
cls._cache[id] = data
class KeyValueData:
"""
Mixin for settings implemented in a Key-Value table.
The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair.
"""
_table_interface: Table = None
_id_column: str = None
_key_column: str = None
_value_column: str = None
_key: str = None
@classmethod
def _reader(cls, id: ..., **kwargs):
params = {
"select_columns": (cls._value_column, ),
cls._id_column: id,
cls._key_column: cls._key
}
row = cls._table_interface.select_one_where(**params)
data = row[cls._value_column] if row else None
if data is not None:
data = json.loads(data)
return data
@classmethod
def _writer(cls, id: ..., data: ..., **kwargs):
params = {
cls._id_column: id,
cls._key_column: cls._key
}
if data is not None:
values = {
cls._value_column: json.dumps(data)
}
cls._table_interface.upsert(
constraint=f"{cls._id_column}, {cls._key_column}",
**params,
**values
)
else:
cls._table_interface.delete_where(**params)
class UserInputError(SafeCancellation):
pass

View File

@@ -0,0 +1,197 @@
import datetime
import asyncio
import discord
import settings
from utils.lib import DotDict
from utils import seekers # noqa
from wards import guild_admin
from data import tables as tb
class GuildSettings(settings.ObjectSettings):
settings = DotDict()
class GuildSetting(settings.ColumnData, settings.Setting):
_table_interface = tb.guild_config
_id_column = 'guildid'
_create_row = True
category = None
write_ward = guild_admin
@GuildSettings.attach_setting
class event_log(settings.Channel, GuildSetting):
category = "Meta"
attr_name = 'event_log'
_data_column = 'event_log_channel'
display_name = "event_log"
desc = "Bot event logging channel."
long_desc = (
"Channel to post 'events', such as workouts completing or members renting a room."
)
_chan_type = discord.ChannelType.text
@property
def success_response(self):
if self.value:
return "The event log is now {}.".format(self.formatted)
else:
return "The event log has been unset."
def log(self, description="", colour=discord.Color.orange(), **kwargs):
channel = self.value
if channel:
embed = discord.Embed(
description=description,
colour=colour,
timestamp=datetime.datetime.utcnow(),
**kwargs
)
asyncio.create_task(channel.send(embed=embed))
@GuildSettings.attach_setting
class admin_role(settings.Role, GuildSetting):
category = "Guild Roles"
attr_name = 'admin_role'
_data_column = 'admin_role'
display_name = "admin_role"
desc = "Server administrator role."
long_desc = (
"Server administrator role.\n"
"Allows usage of the administrative commands, such as `config`.\n"
"These commands may also be used by anyone with the discord adminitrator permission."
)
# TODO Expand on what these are.
@property
def success_response(self):
if self.value:
return "The administrator role is now {}.".format(self.formatted)
else:
return "The administrator role has been unset."
@GuildSettings.attach_setting
class mod_role(settings.Role, GuildSetting):
category = "Guild Roles"
attr_name = 'mod_role'
_data_column = 'mod_role'
display_name = "mod_role"
desc = "Server moderator role."
long_desc = (
"Server moderator role.\n"
"Allows usage of the modistrative commands."
)
# TODO Expand on what these are.
@property
def success_response(self):
if self.value:
return "The moderator role is now {}.".format(self.formatted)
else:
return "The moderator role has been unset."
@GuildSettings.attach_setting
class unranked_roles(settings.RoleList, settings.ListData, settings.Setting):
category = "Guild Roles"
attr_name = 'unranked_roles'
_table_interface = tb.unranked_roles
_id_column = 'guildid'
_data_column = 'roleid'
write_ward = guild_admin
display_name = "unranked_roles"
desc = "Roles to exclude from the leaderboards."
_force_unique = True
long_desc = (
"Roles to be excluded from the `top` and `topcoins` leaderboards."
)
# Flat cache, no need to expire objects
_cache = {}
@property
def success_response(self):
if self.value:
return "The following roles will be excluded from the leaderboard:\n{}".format(self.formatted)
else:
return "The excluded roles have been removed."
@GuildSettings.attach_setting
class donator_roles(settings.RoleList, settings.ListData, settings.Setting):
category = "Hidden"
attr_name = 'donator_roles'
_table_interface = tb.donator_roles
_id_column = 'guildid'
_data_column = 'roleid'
write_ward = guild_admin
display_name = "donator_roles"
desc = "Donator badge roles."
_force_unique = True
long_desc = (
"Members with these roles will be considered donators and have access to premium features."
)
# Flat cache, no need to expire objects
_cache = {}
@property
def success_response(self):
if self.value:
return "The donator badges are now:\n{}".format(self.formatted)
else:
return "The donator badges have been removed."
@GuildSettings.attach_setting
class alert_channel(settings.Channel, GuildSetting):
category = "Meta"
attr_name = 'alert_channel'
_data_column = 'alert_channel'
display_name = "alert_channel"
desc = "Channel to display global user alerts."
long_desc = (
"This channel will be used for group notifications, "
"for example group timers and anti-cheat messages, "
"as well as for critical alerts to users that have their direct messages disapbled.\n"
"It should be visible to all members."
)
_chan_type = discord.ChannelType.text
@property
def success_response(self):
if self.value:
return "The alert channel is now {}.".format(self.formatted)
else:
return "The alert channel has been unset."

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,42 @@
import datetime
import settings
from utils.lib import DotDict
from data import tables as tb
class UserSettings(settings.ObjectSettings):
settings = DotDict()
class UserSetting(settings.ColumnData, settings.Setting):
_table_interface = tb.user_config
_id_column = 'userid'
_create_row = True
write_ward = None
@UserSettings.attach_setting
class timezone(settings.Timezone, UserSetting):
attr_name = 'timezone'
_data_column = 'timezone'
_default = 'UTC'
display_name = 'timezone'
desc = "Timezone to display prompts in."
long_desc = (
"Timezone used for displaying certain prompts (e.g. selecting an accountability room)."
)
@property
def success_response(self):
if self.value:
return (
"Your personal timezone is now {}.\n"
"Your current time is **{}**."
).format(self.formatted, datetime.datetime.now(tz=self.value).strftime("%H:%M"))
else:
return "Your personal timezone has been unset."

View File

@@ -0,0 +1,157 @@
import asyncio
import discord
from LionContext import LionContext as Context
from cmdClient.lib import SafeCancellation
from data import tables
from core import Lion
from . import lib
from settings import GuildSettings, UserSettings
@Context.util
async def embed_reply(ctx, desc, colour=discord.Colour.orange(), **kwargs):
"""
Simple helper to embed replies.
All arguments are passed to the embed constructor.
`desc` is passed as the `description` kwarg.
"""
embed = discord.Embed(description=desc, colour=colour, **kwargs)
try:
return await ctx.reply(embed=embed, reference=ctx.msg.to_reference(fail_if_not_exists=False))
except discord.Forbidden:
if not ctx.guild or ctx.ch.permissions_for(ctx.guild.me).send_messages:
await ctx.reply("Command failed, I don't have permission to send embeds in this channel!")
raise SafeCancellation
@Context.util
async def error_reply(ctx, error_str, send_args={}, **kwargs):
"""
Notify the user of a user level error.
Typically, this will occur in a red embed, posted in the command channel.
"""
embed = discord.Embed(
colour=discord.Colour.red(),
description=error_str,
**kwargs
)
message = None
try:
message = await ctx.ch.send(
embed=embed,
reference=ctx.msg.to_reference(fail_if_not_exists=False),
**send_args
)
ctx.sent_messages.append(message)
return message
except discord.Forbidden:
if not ctx.guild or ctx.ch.permissions_for(ctx.guild.me).send_messages:
await ctx.reply("Command failed, I don't have permission to send embeds in this channel!")
raise SafeCancellation
@Context.util
async def offer_delete(ctx: Context, *to_delete, timeout=300):
"""
Offers to delete the provided messages via a reaction on the last message.
Removes the reaction if the offer times out.
If any exceptions occur, handles them silently and returns.
Parameters
----------
to_delete: List[Message]
The messages to delete.
timeout: int
Time in seconds after which to remove the delete offer reaction.
"""
# Get the delete emoji from the config
emoji = lib.cross
# Return if there are no messages to delete
if not to_delete:
return
# The message to add the reaction to
react_msg = to_delete[-1]
# Build the reaction check function
if ctx.guild:
modrole = ctx.guild_settings.mod_role.value if ctx.guild else None
def check(reaction, user):
if not (reaction.message.id == react_msg.id and reaction.emoji == emoji):
return False
if user == ctx.guild.me:
return False
return ((user == ctx.author)
or (user.permissions_in(ctx.ch).manage_messages)
or (modrole and modrole in user.roles))
else:
def check(reaction, user):
return user == ctx.author and reaction.message.id == react_msg.id and reaction.emoji == emoji
try:
# Add the reaction to the message
await react_msg.add_reaction(emoji)
# Wait for the user to press the reaction
reaction, user = await ctx.client.wait_for("reaction_add", check=check, timeout=timeout)
# Since the check was satisfied, the reaction is correct. Delete the messages, ignoring any exceptions
deleted = False
# First try to bulk delete if we have the permissions
if ctx.guild and ctx.ch.permissions_for(ctx.guild.me).manage_messages:
try:
await ctx.ch.delete_messages(to_delete)
deleted = True
except Exception:
deleted = False
# If we couldn't bulk delete, delete them one by one
if not deleted:
try:
asyncio.gather(*[message.delete() for message in to_delete], return_exceptions=True)
except Exception:
pass
except (asyncio.TimeoutError, asyncio.CancelledError):
# Timed out waiting for the reaction, attempt to remove the delete reaction
try:
await react_msg.remove_reaction(emoji, ctx.client.user)
except Exception:
pass
except discord.Forbidden:
pass
except discord.NotFound:
pass
except discord.HTTPException:
pass
def context_property(func):
setattr(Context, func.__name__, property(func))
return func
@context_property
def best_prefix(ctx):
return ctx.client.prefix
@context_property
def guild_settings(ctx):
if ctx.guild:
tables.guild_config.fetch_or_create(ctx.guild.id)
return GuildSettings(ctx.guild.id if ctx.guild else 0)
@context_property
def author_settings(ctx):
return UserSettings(ctx.author.id)
@context_property
def alion(ctx):
return Lion.fetch(ctx.guild.id if ctx.guild else 0, ctx.author.id)

View File

@@ -0,0 +1,461 @@
import asyncio
import discord
from LionContext import LionContext as Context
from cmdClient.lib import UserCancelled, ResponseTimedOut
from .lib import paginate_list
# TODO: Interactive locks
cancel_emoji = ''
number_emojis = (
'1', '2', '3', '4', '5', '6', '7', '8', '9'
)
async def discord_shield(coro):
try:
await coro
except discord.HTTPException:
pass
@Context.util
async def cancellable(ctx, msg, add_reaction=True, cancel_message=None, timeout=300):
"""
Add a cancellation reaction to the given message.
Pressing the reaction triggers cancellation of the original context, and a UserCancelled-style error response.
"""
# TODO: Not consistent with the exception driven flow, make a decision here?
# Add reaction
if add_reaction and cancel_emoji not in (str(r.emoji) for r in msg.reactions):
try:
await msg.add_reaction(cancel_emoji)
except discord.HTTPException:
return
# Define cancellation function
async def _cancel():
try:
await ctx.client.wait_for(
'reaction_add',
timeout=timeout,
check=lambda r, u: (u == ctx.author
and r.message == msg
and str(r.emoji) == cancel_emoji)
)
except asyncio.TimeoutError:
pass
else:
await ctx.client.active_command_response_cleaner(ctx)
if cancel_message:
await ctx.error_reply(cancel_message)
else:
try:
await ctx.msg.add_reaction(cancel_emoji)
except discord.HTTPException:
pass
[task.cancel() for task in ctx.tasks]
# Launch cancellation task
task = asyncio.create_task(_cancel())
ctx.tasks.append(task)
return task
@Context.util
async def listen_for(ctx, allowed_input=None, timeout=120, lower=True, check=None):
"""
Listen for a one of a particular set of input strings,
sent in the current channel by `ctx.author`.
When found, return the message containing them.
Parameters
----------
allowed_input: Union(List(str), None)
List of strings to listen for.
Allowed to be `None` precisely when a `check` function is also supplied.
timeout: int
Number of seconds to wait before timing out.
lower: bool
Whether to shift the allowed and message strings to lowercase before checking.
check: Function(message) -> bool
Alternative custom check function.
Returns: discord.Message
The message that was matched.
Raises
------
cmdClient.lib.ResponseTimedOut:
Raised when no messages matching the given criteria are detected in `timeout` seconds.
"""
# Generate the check if it hasn't been provided
if not check:
# Quick check the arguments are sane
if not allowed_input:
raise ValueError("allowed_input and check cannot both be None")
# Force a lower on the allowed inputs
allowed_input = [s.lower() for s in allowed_input]
# Create the check function
def check(message):
result = (message.author == ctx.author)
result = result and (message.channel == ctx.ch)
result = result and ((message.content.lower() if lower else message.content) in allowed_input)
return result
# Wait for a matching message, catch and transform the timeout
try:
message = await ctx.client.wait_for('message', check=check, timeout=timeout)
except asyncio.TimeoutError:
raise ResponseTimedOut("Session timed out waiting for user response.") from None
return message
@Context.util
async def selector(ctx, header, select_from, timeout=120, max_len=20):
"""
Interactive routine to prompt the `ctx.author` to select an item from a list.
Returns the list index that was selected.
Parameters
----------
header: str
String to put at the top of each page of selection options.
Intended to be information about the list the user is selecting from.
select_from: List(str)
The list of strings to select from.
timeout: int
The number of seconds to wait before throwing `ResponseTimedOut`.
max_len: int
The maximum number of items to display on each page.
Decrease this if the items are long, to avoid going over the char limit.
Returns
-------
int:
The index of the list entry selected by the user.
Raises
------
cmdClient.lib.UserCancelled:
Raised if the user manually cancels the selection.
cmdClient.lib.ResponseTimedOut:
Raised if the user fails to respond to the selector within `timeout` seconds.
"""
# Handle improper arguments
if len(select_from) == 0:
raise ValueError("Selection list passed to `selector` cannot be empty.")
# Generate the selector pages
footer = "Please reply with the number of your selection, or press {} to cancel.".format(cancel_emoji)
list_pages = paginate_list(select_from, block_length=max_len)
pages = ["\n".join([header, page, footer]) for page in list_pages]
# Post the pages in a paged message
out_msg = await ctx.pager(pages, add_cancel=True)
cancel_task = await ctx.cancellable(out_msg, add_reaction=False, timeout=None)
if len(select_from) <= 5:
for i, _ in enumerate(select_from):
asyncio.create_task(discord_shield(out_msg.add_reaction(number_emojis[i])))
# Build response tasks
valid_input = [str(i+1) for i in range(0, len(select_from))] + ['c', 'C']
listen_task = asyncio.create_task(ctx.listen_for(valid_input, timeout=None))
emoji_task = asyncio.create_task(ctx.client.wait_for(
'reaction_add',
check=lambda r, u: (u == ctx.author
and r.message == out_msg
and str(r.emoji) in number_emojis)
))
# Wait for the response tasks
done, pending = await asyncio.wait(
(listen_task, emoji_task),
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED
)
# Cleanup
try:
await out_msg.delete()
except discord.HTTPException:
pass
# Handle different return cases
if listen_task in done:
emoji_task.cancel()
result_msg = listen_task.result()
try:
await result_msg.delete()
except discord.HTTPException:
pass
if result_msg.content.lower() == 'c':
raise UserCancelled("Selection cancelled!")
result = int(result_msg.content) - 1
elif emoji_task in done:
listen_task.cancel()
reaction, _ = emoji_task.result()
result = number_emojis.index(str(reaction.emoji))
elif cancel_task in done:
# Manually cancelled case.. the current task should have been cancelled
# Raise UserCancelled in case the task wasn't cancelled for some reason
raise UserCancelled("Selection cancelled!")
elif not done:
# Timeout case
raise ResponseTimedOut("Selector timed out waiting for a response.")
# Finally cancel the canceller and return the provided index
cancel_task.cancel()
return result
@Context.util
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
"""
Shows the user each page from the provided list `pages` one at a time,
providing reactions to page back and forth between pages.
This is done asynchronously, and returns after displaying the first page.
Parameters
----------
pages: List(Union(str, discord.Embed))
A list of either strings or embeds to display as the pages.
locked: bool
Whether only the `ctx.author` should be able to use the paging reactions.
kwargs: ...
Remaining keyword arguments are transparently passed to the reply context method.
Returns: discord.Message
This is the output message, returned for easy deletion.
"""
# Handle broken input
if len(pages) == 0:
raise ValueError("Pager cannot page with no pages!")
# Post first page. Method depends on whether the page is an embed or not.
if isinstance(pages[start_at], discord.Embed):
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
else:
out_msg = await ctx.reply(pages[start_at], **kwargs)
# Run the paging loop if required
if len(pages) > 1:
task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs))
ctx.tasks.append(task)
elif add_cancel:
await out_msg.add_reaction(cancel_emoji)
# Return the output message
return out_msg
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
"""
Asynchronous initialiser and loop for the `pager` utility above.
"""
# Page number
page = start_at
# Add reactions to the output message
next_emoji = ""
prev_emoji = ""
try:
await out_msg.add_reaction(prev_emoji)
if add_cancel:
await out_msg.add_reaction(cancel_emoji)
await out_msg.add_reaction(next_emoji)
except discord.Forbidden:
# We don't have permission to add paging emojis
# Die as gracefully as we can
if ctx.guild:
perms = ctx.ch.permissions_for(ctx.guild.me)
if not perms.add_reactions:
await ctx.error_reply(
"Cannot page results because I do not have the `add_reactions` permission!"
)
elif not perms.read_message_history:
await ctx.error_reply(
"Cannot page results because I do not have the `read_message_history` permission!"
)
else:
await ctx.error_reply(
"Cannot page results due to insufficient permissions!"
)
else:
await ctx.error_reply(
"Cannot page results!"
)
return
# Check function to determine whether a reaction is valid
def reaction_check(reaction, user):
result = reaction.message.id == out_msg.id
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
result = result and not (user.id == ctx.client.user.id)
result = result and not (locked and user != ctx.author)
return result
# Check function to determine if message has a page number
def message_check(message):
result = message.channel.id == ctx.ch.id
result = result and not (locked and message.author != ctx.author)
result = result and message.content.lower().startswith('p')
result = result and message.content[1:].isdigit()
result = result and 1 <= int(message.content[1:]) <= len(pages)
return result
# Begin loop
while True:
# Wait for a valid reaction or message, break if we time out
reaction_task = asyncio.create_task(
ctx.client.wait_for('reaction_add', check=reaction_check)
)
message_task = asyncio.create_task(
ctx.client.wait_for('message', check=message_check)
)
done, pending = await asyncio.wait(
(reaction_task, message_task),
timeout=300,
return_when=asyncio.FIRST_COMPLETED
)
if done:
if reaction_task in done:
# Cancel the message task and collect the reaction result
message_task.cancel()
reaction, user = reaction_task.result()
# Attempt to remove the user's reaction, silently ignore errors
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
# Change the page number
page += 1 if reaction.emoji == next_emoji else -1
page %= len(pages)
elif message_task in done:
# Cancel the reaction task and collect the message result
reaction_task.cancel()
message = message_task.result()
# Attempt to delete the user's message, silently ignore errors
asyncio.ensure_future(message.delete())
# Move to the correct page
page = int(message.content[1:]) - 1
# Edit the message with the new page
active_page = pages[page]
if isinstance(active_page, discord.Embed):
await out_msg.edit(embed=active_page, **kwargs)
else:
await out_msg.edit(content=active_page, **kwargs)
else:
# No tasks finished, so we must have timed out, or had an exception.
# Break the loop and clean up
break
# Clean up by removing the reactions
try:
await out_msg.clear_reactions()
except discord.Forbidden:
try:
await out_msg.remove_reaction(next_emoji, ctx.client.user)
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
except discord.NotFound:
pass
except discord.NotFound:
pass
@Context.util
async def input(ctx, msg="", timeout=120):
"""
Listen for a response in the current channel, from ctx.author.
Returns the response from ctx.author, if it is provided.
Parameters
----------
msg: string
Allows a custom input message to be provided.
Will use default message if not provided.
timeout: int
Number of seconds to wait before timing out.
Raises
------
cmdClient.lib.ResponseTimedOut:
Raised when ctx.author does not provide a response before the function times out.
"""
# Deliver prompt
offer_msg = await ctx.reply(msg or "Please enter your input.")
# Criteria for the input message
def checks(m):
return m.author == ctx.author and m.channel == ctx.ch
# Listen for the reply
try:
result_msg = await ctx.client.wait_for("message", check=checks, timeout=timeout)
except asyncio.TimeoutError:
raise ResponseTimedOut("Session timed out waiting for user response.") from None
result = result_msg.content
# Attempt to delete the prompt and reply messages
try:
await offer_msg.delete()
await result_msg.delete()
except Exception:
pass
return result
@Context.util
async def ask(ctx, msg, timeout=30, use_msg=None, del_on_timeout=False):
"""
Ask ctx.author a yes/no question.
Returns 0 if ctx.author answers no
Returns 1 if ctx.author answers yes
Parameters
----------
msg: string
Adds the question to the message string.
Requires an input.
timeout: int
Number of seconds to wait before timing out.
use_msg: string
A completely custom string to use instead of the default string.
del_on_timeout: bool
Whether to delete the question if it times out.
Raises
------
Nothing
"""
out = "{} {}".format(msg, "`y(es)`/`n(o)`")
offer_msg = use_msg or await ctx.reply(out)
if use_msg and msg:
await use_msg.edit(content=msg)
result_msg = await ctx.listen_for(["y", "yes", "n", "no"], timeout=timeout)
if result_msg is None:
if del_on_timeout:
try:
await offer_msg.delete()
except Exception:
pass
return None
result = result_msg.content.lower()
try:
if not use_msg:
await offer_msg.delete()
await result_msg.delete()
except Exception:
pass
if result in ["n", "no"]:
return 0
return 1

View File

@@ -0,0 +1,553 @@
import datetime
import iso8601
import re
from enum import Enum
import discord
from psycopg2.extensions import QuotedString
from cmdClient.lib import SafeCancellation
multiselect_regex = re.compile(
r"^([0-9, -]+)$",
re.DOTALL | re.IGNORECASE | re.VERBOSE
)
tick = ''
cross = ''
def prop_tabulate(prop_list, value_list, indent=True, colon=True):
"""
Turns a list of properties and corresponding list of values into
a pretty string with one `prop: value` pair each line,
padded so that the colons in each line are lined up.
Handles empty props by using an extra couple of spaces instead of a `:`.
Parameters
----------
prop_list: List[str]
List of short names to put on the right side of the list.
Empty props are considered to be "newlines" for the corresponding value.
value_list: List[str]
List of values corresponding to the properties above.
indent: bool
Whether to add padding so the properties are right-adjusted.
Returns: str
"""
max_len = max(len(prop) for prop in prop_list)
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
prop,
(":" if len(prop) else " " * 2) if colon else '',
value_list[i],
'' if str(value_list[i]).endswith("```") else '\n')
for i, prop in enumerate(prop_list)])
def paginate_list(item_list, block_length=20, style="markdown", title=None):
"""
Create pretty codeblock pages from a list of strings.
Parameters
----------
item_list: List[str]
List of strings to paginate.
block_length: int
Maximum number of strings per page.
style: str
Codeblock style to use.
Title formatting assumes the `markdown` style, and numbered lists work well with this.
However, `markdown` sometimes messes up formatting in the list.
title: str
Optional title to add to the top of each page.
Returns: List[str]
List of pages, each formatted into a codeblock,
and containing at most `block_length` of the provided strings.
"""
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
pages = []
for i, block in enumerate(page_blocks):
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
if title:
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
else:
header = pagenum
header_line = "=" * len(header)
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
return pages
def timestamp_utcnow():
"""
Return the current integer UTC timestamp.
"""
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
"""
Break the text into blocks of maximum length blocksize
If possible, break across nearby newlines. Otherwise just break at blocksize chars
Parameters
----------
text: str
Text to break into blocks.
blocksize: int
Maximum character length for each block.
code: bool
Whether to wrap each block in codeblocks (these are counted in the blocksize).
syntax: str
The markdown formatting language to use for the codeblocks, if applicable.
maxheight: int
The maximum number of lines in each block
Returns: List[str]
List of blocks,
each containing at most `block_size` characters,
of height at most `maxheight`.
"""
# Adjust blocksize to account for the codeblocks if required
blocksize = blocksize - 8 - len(syntax) if code else blocksize
# Build the blocks
blocks = []
while True:
# If the remaining text is already small enough, append it
if len(text) <= blocksize:
blocks.append(text)
break
text = text.strip('\n')
# Find the last newline in the prototype block
split_on = text[0:blocksize].rfind('\n')
split_on = blocksize if split_on < blocksize // 5 else split_on
# Add the block and truncate the text
blocks.append(text[0:split_on])
text = text[split_on:]
# Add the codeblock ticks and the code syntax header, if required
if code:
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
return blocks
def strfdelta(delta, sec=False, minutes=True, short=False):
"""
Convert a datetime.timedelta object into an easily readable duration string.
Parameters
----------
delta: datetime.timedelta
The timedelta object to convert into a readable string.
sec: bool
Whether to include the seconds from the timedelta object in the string.
minutes: bool
Whether to include the minutes from the timedelta object in the string.
short: bool
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
Returns: str
A string containing a time from the datetime.timedelta object, in a readable format.
Time units will be abbreviated if short was set to True.
"""
output = [[delta.days, 'd' if short else ' day'],
[delta.seconds // 3600, 'h' if short else ' hour']]
if minutes:
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
if sec:
output.append([delta.seconds % 60, 's' if short else ' second'])
for i in range(len(output)):
if output[i][0] != 1 and not short:
output[i][1] += 's'
reply_msg = []
if output[0][0] != 0:
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
for i in range(2, len(output) - 1):
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
if not short and reply_msg:
reply_msg.append("and ")
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
return "".join(reply_msg)
def parse_dur(time_str):
"""
Parses a user provided time duration string into a timedelta object.
Parameters
----------
time_str: str
The time string to parse. String can include days, hours, minutes, and seconds.
Returns: int
The number of seconds the duration represents.
"""
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
time_str = time_str.strip(" ,")
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
seconds = 0
for bit in found:
if bit[1] in funcs:
seconds += funcs[bit[1]](int(bit[0]))
return seconds
def strfdur(duration, short=True, show_days=False):
"""
Convert a duration given in seconds to a number of hours, minutes, and seconds.
"""
days = duration // (3600 * 24) if show_days else 0
hours = duration // 3600
if days:
hours %= 24
minutes = duration // 60 % 60
seconds = duration % 60
parts = []
if days:
unit = 'd' if short else (' days' if days != 1 else ' day')
parts.append('{}{}'.format(days, unit))
if hours:
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
parts.append('{}{}'.format(hours, unit))
if minutes:
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
parts.append('{}{}'.format(minutes, unit))
if seconds or duration == 0:
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
parts.append('{}{}'.format(seconds, unit))
if short:
return ' '.join(parts)
else:
return ', '.join(parts)
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
"""
Substitutes a user provided list of numbers and ranges,
and replaces the ranges by the corresponding list of numbers.
Parameters
----------
ranges_str: str
The string to ranges in.
max_match: int
The maximum number of ranges to replace.
Any ranges exceeding this will be ignored.
max_range: int
The maximum length of range to replace.
Attempting to replace a range longer than this will raise a `ValueError`.
"""
def _repl(match):
n1 = int(match.group(1))
n2 = int(match.group(2))
if n2 - n1 > max_range:
raise SafeCancellation("Provided range is too large!")
return separator.join(str(i) for i in range(n1, n2 + 1))
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs):
"""
Parses a user provided range string into a list of numbers.
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
"""
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
numbers = (item.strip() for item in substituted.split(','))
numbers = [item for item in numbers if item]
integers = [int(item) for item in numbers if item.isdigit()]
if not ignore_errors and len(integers) != len(numbers):
raise SafeCancellation(
"Couldn't parse the provided selection!\n"
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
)
return integers
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
"""
Format a message into a string with various information, such as:
the timestamp of the message, author, message content, and attachments.
Parameters
----------
msg: Message
The message to format.
mask_link: bool
Whether to mask the URLs of any attachments.
line_break: bool
Whether a line break should be used in the string.
tz: Timezone
The timezone to use in the formatted message.
clean: bool
Whether to use the clean content of the original message.
Returns: str
A formatted string containing various information:
User timezone, message author, message content, attachments
"""
timestr = "%I:%M %p, %d/%m/%Y"
if tz:
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
else:
time = msg.timestamp.strftime(timestr)
user = str(msg.author)
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
if mask_link:
attach_list = ["[Link]({})".format(url) for url in attach_list]
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
time=time,
user=user,
line_break="\n" if line_break else "",
message=msg.clean_content if clean else msg.content,
attachments=attachments
)
def convdatestring(datestring):
"""
Convert a date string into a datetime.timedelta object.
Parameters
----------
datestring: str
The string to convert to a datetime.timedelta object.
Returns: datetime.timedelta
A datetime.timedelta object formed from the string provided.
"""
datestring = datestring.strip(' ,')
datearray = []
funcs = {'d': lambda x: x * 24 * 60 * 60,
'h': lambda x: x * 60 * 60,
'm': lambda x: x * 60,
's': lambda x: x}
currentnumber = ''
for char in datestring:
if char.isdigit():
currentnumber += char
else:
if currentnumber == '':
continue
datearray.append((int(currentnumber), char))
currentnumber = ''
seconds = 0
if currentnumber:
seconds += int(currentnumber)
for i in datearray:
if i[1] in funcs:
seconds += funcs[i[1]](i[0])
return datetime.timedelta(seconds=seconds)
class _rawChannel(discord.abc.Messageable):
"""
Raw messageable class representing an arbitrary channel,
not necessarially seen by the gateway.
"""
def __init__(self, state, id):
self._state = state
self.id = id
async def _get_channel(self):
return discord.Object(self.id)
async def mail(client: discord.Client, channelid: int, **msg_args):
"""
Mails a message to a channelid which may be invisible to the gateway.
Parameters:
client: discord.Client
The client to use for mailing.
Must at least have static authentication and have a valid `_connection`.
channelid: int
The channel id to mail to.
msg_args: Any
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
"""
# Create the raw channel
channel = _rawChannel(client._connection, channelid)
return await channel.send(**msg_args)
def emb_add_fields(embed, emb_fields):
"""
Append embed fields to an embed.
Parameters
----------
embed: discord.Embed
The embed to add the field to.
emb_fields: tuple
The values to add to a field.
name: str
The name of the field.
value: str
The value of the field.
inline: bool
Whether the embed field should be inline or not.
"""
for field in emb_fields:
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
def join_list(string, nfs=False):
"""
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
Parameters
----------
string: list
The list to join together.
nfs: bool
(no fullstops)
Whether to exclude fullstops/periods from the output messages.
If not provided, fullstops will be appended to the output.
"""
if len(string) > 1:
return "{}{} and {}{}".format((", ").join(string[:-1]),
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
else:
return "{}{}".format("".join(string), "" if nfs else ".")
def format_activity(user):
"""
Format a user's activity string, depending on the type of activity.
Currently supported types are:
- Nothing
- Custom status
- Playing (with rich presence support)
- Streaming
- Listening (with rich presence support)
- Watching
- Unknown
Parameters
----------
user: discord.Member
The user to format the status of.
If the user has no activity, "Nothing" will be returned.
Returns: str
A formatted string with various information about the user's current activity like the name,
and any extra information about the activity (such as current song artists for Spotify)
"""
if not user.activity:
return "Nothing"
AT = user.activity.type
a = user.activity
if str(AT) == "ActivityType.custom":
return "Status: {}".format(a)
if str(AT) == "ActivityType.playing":
string = "Playing {}".format(a.name)
try:
string += " ({})".format(a.details)
except Exception:
pass
return string
if str(AT) == "ActivityType.streaming":
return "Streaming {}".format(a.name)
if str(AT) == "ActivityType.listening":
try:
string = "Listening to `{}`".format(a.title)
if len(a.artists) > 1:
string += " by {}".format(join_list(string=a.artists))
else:
string += " by **{}**".format(a.artist)
except Exception:
string = "Listening to `{}`".format(a.name)
return string
if str(AT) == "ActivityType.watching":
return "Watching `{}`".format(a.name)
if str(AT) == "ActivityType.unknown":
return "Unknown"
def shard_of(shard_count: int, guildid: int):
"""
Calculate the shard number of a given guild.
"""
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
def jumpto(guildid: int, channeldid: int, messageid: int):
"""
Build a jump link for a message given its location.
"""
return 'https://discord.com/channels/{}/{}/{}'.format(
guildid,
channeldid,
messageid
)
class DotDict(dict):
"""
Dict-type allowing dot access to keys.
"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class FieldEnum(str, Enum):
"""
String enum with description conforming to the ISQLQuote protocol.
Allows processing by psycog
"""
def __new__(cls, value, desc):
obj = str.__new__(cls, value)
obj._value_ = value
obj.desc = desc
return obj
def __repr__(self):
return '<%s.%s>' % (self.__class__.__name__, self.name)
def __bool__(self):
return True
def __conform__(self, proto):
return QuotedString(self.value)
def utc_now():
"""
Return the current timezone-aware utc timestamp.
"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
def multiple_replace(string, rep_dict):
if rep_dict:
pattern = re.compile(
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
flags=re.DOTALL
)
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
else:
return string

View File

@@ -0,0 +1,92 @@
import time
from cmdClient.lib import SafeCancellation
from cachetools import TTLCache
class BucketFull(Exception):
"""
Throw when a requested Bucket is already full
"""
pass
class BucketOverFull(BucketFull):
"""
Throw when a requested Bucket is overfull
"""
pass
class Bucket:
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full')
def __init__(self, max_level, empty_time):
self.max_level = max_level
self.empty_time = empty_time
self.leak_rate = max_level / empty_time
self._level = 0
self._last_checked = time.time()
self._last_full = False
@property
def overfull(self):
self._leak()
return self._level > self.max_level
def _leak(self):
if self._level:
elapsed = time.time() - self._last_checked
self._level = max(0, self._level - (elapsed * self.leak_rate))
self._last_checked = time.time()
def request(self):
self._leak()
if self._level + 1 > self.max_level + 1:
raise BucketOverFull
elif self._level + 1 > self.max_level:
self._level += 1
if self._last_full:
raise BucketOverFull
else:
self._last_full = True
raise BucketFull
else:
self._last_full = False
self._level += 1
class RateLimit:
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
self.max_level = max_level
self.empty_time = empty_time
self.error = error or "Too many requests, please slow down!"
self.buckets = cache
def request_for(self, key):
if not (bucket := self.buckets.get(key, None)):
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
try:
bucket.request()
except BucketOverFull:
raise SafeCancellation(details="Bucket overflow")
except BucketFull:
raise SafeCancellation(self.error, details="Bucket full")
def ward(self, member=True, key=None):
"""
Command ratelimit decorator.
"""
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
def decorator(func):
async def wrapper(ctx, *args, **kwargs):
self.request_for(key(ctx))
return await func(ctx, *args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,427 @@
import asyncio
import discord
from LionContext import LionContext as Context
from cmdClient.lib import InvalidContext, UserCancelled, ResponseTimedOut, SafeCancellation
from . import interactive as _interactive
@Context.util
async def find_role(ctx, userstr, create=False, interactive=False, collection=None, allow_notfound=True):
"""
Find a guild role given a partial matching string,
allowing custom role collections and several behavioural switches.
Parameters
----------
userstr: str
String obtained from a user, expected to partially match a role in the collection.
The string will be tested against both the id and the name of the role.
create: bool
Whether to offer to create the role if it does not exist.
The bot will only offer to create the role if it has the `manage_channels` permission.
interactive: bool
Whether to offer the user a list of roles to choose from,
or pick the first matching role.
collection: List[Union[discord.Role, discord.Object]]
Collection of roles to search amongst.
If none, uses the guild role list.
allow_notfound: bool
Whether to return `None` when there are no matches, instead of raising `SafeCancellation`.
Overriden by `create`, if it is set.
Returns
-------
discord.Role:
If a valid role is found.
None:
If no valid role has been found.
Raises
------
cmdClient.lib.UserCancelled:
If the user cancels interactive role selection.
cmdClient.lib.ResponseTimedOut:
If the user fails to respond to interactive role selection within `60` seconds`
cmdClient.lib.SafeCancellation:
If `allow_notfound` is `False`, and the search returned no matches.
"""
# Handle invalid situations and input
if not ctx.guild:
raise InvalidContext("Attempt to use find_role outside of a guild.")
if userstr == "":
raise ValueError("User string passed to find_role was empty.")
# Create the collection to search from args or guild roles
collection = collection if collection is not None else ctx.guild.roles
# If the unser input was a number or possible role mention, get it out
userstr = userstr.strip()
roleid = userstr.strip('<#@&!> ')
roleid = int(roleid) if roleid.isdigit() else None
searchstr = userstr.lower()
# Find the role
role = None
# Check method to determine whether a role matches
def check(role):
return (role.id == roleid) or (searchstr in role.name.lower())
# Get list of matching roles
roles = list(filter(check, collection))
if len(roles) == 0:
# Nope
role = None
elif len(roles) == 1:
# Select our lucky winner
role = roles[0]
else:
# We have multiple matching roles!
if interactive:
# Interactive prompt with the list of roles, handle `Object`s
role_names = [
role.name if isinstance(role, discord.Role) else str(role.id) for role in roles
]
try:
selected = await ctx.selector(
"`{}` roles found matching `{}`!".format(len(roles), userstr),
role_names,
timeout=60
)
except UserCancelled:
raise UserCancelled("User cancelled role selection.") from None
except ResponseTimedOut:
raise ResponseTimedOut("Role selection timed out.") from None
role = roles[selected]
else:
# Just select the first one
role = roles[0]
# Handle non-existence of the role
if role is None:
msgstr = "Couldn't find a role matching `{}`!".format(userstr)
if create:
# Inform the user
msg = await ctx.error_reply(msgstr)
if ctx.guild.me.guild_permissions.manage_roles:
# Offer to create it
resp = await ctx.ask("Would you like to create this role?", timeout=30)
if resp:
# They accepted, create the role
# Before creation, check if the role name is too long
if len(userstr) > 100:
await ctx.error_reply("Could not create a role with a name over 100 characters long!")
else:
role = await ctx.guild.create_role(
name=userstr,
reason="Interactive role creation for {} (uid:{})".format(ctx.author, ctx.author.id)
)
await msg.delete()
await ctx.reply("You have created the role `{}`!".format(userstr))
# If we still don't have a role, cancel unless allow_notfound is set
if role is None and not allow_notfound:
raise SafeCancellation
elif not allow_notfound:
raise SafeCancellation(msgstr)
else:
await ctx.error_reply(msgstr)
return role
@Context.util
async def find_channel(ctx, userstr, interactive=False, collection=None, chan_type=None, type_name=None):
"""
Find a guild channel given a partial matching string,
allowing custom channel collections and several behavioural switches.
Parameters
----------
userstr: str
String obtained from a user, expected to partially match a channel in the collection.
The string will be tested against both the id and the name of the channel.
interactive: bool
Whether to offer the user a list of channels to choose from,
or pick the first matching channel.
collection: List(discord.Channel)
Collection of channels to search amongst.
If none, uses the full guild channel list.
chan_type: discord.ChannelType
Type of channel to restrict the collection to.
type_name: str
Optional name to use for the channel type if it is not found.
Used particularly with custom collections.
Returns
-------
discord.Channel:
If a valid channel is found.
None:
If no valid channel has been found.
Raises
------
cmdClient.lib.UserCancelled:
If the user cancels interactive channel selection.
cmdClient.lib.ResponseTimedOut:
If the user fails to respond to interactive channel selection within `60` seconds`
"""
# Handle invalid situations and input
if not ctx.guild:
raise InvalidContext("Attempt to use find_channel outside of a guild.")
if userstr == "":
raise ValueError("User string passed to find_channel was empty.")
# Create the collection to search from args or guild channels
collection = collection if collection else ctx.guild.channels
if chan_type is not None:
if chan_type == discord.ChannelType.text:
# Hack to support news channels as text channels
collection = [chan for chan in collection if isinstance(chan, discord.TextChannel)]
else:
collection = [chan for chan in collection if chan.type == chan_type]
# If the user input was a number or possible channel mention, extract it
chanid = userstr.strip('<#@&!>')
chanid = int(chanid) if chanid.isdigit() else None
searchstr = userstr.lower()
# Find the channel
chan = None
# Check method to determine whether a channel matches
def check(chan):
return (chan.id == chanid) or (searchstr in chan.name.lower())
# Get list of matching roles
channels = list(filter(check, collection))
if len(channels) == 0:
# Nope
chan = None
elif len(channels) == 1:
# Select our lucky winner
chan = channels[0]
else:
# We have multiple matching channels!
if interactive:
# Interactive prompt with the list of channels
chan_names = [chan.name for chan in channels]
try:
selected = await ctx.selector(
"`{}` channels found matching `{}`!".format(len(channels), userstr),
chan_names,
timeout=60
)
except UserCancelled:
raise UserCancelled("User cancelled channel selection.") from None
except ResponseTimedOut:
raise ResponseTimedOut("Channel selection timed out.") from None
chan = channels[selected]
else:
# Just select the first one
chan = channels[0]
if chan is None:
typestr = type_name
addendum = ""
if chan_type and not type_name:
chan_type_strings = {
discord.ChannelType.category: "category",
discord.ChannelType.text: "text channel",
discord.ChannelType.voice: "voice channel",
discord.ChannelType.stage_voice: "stage channel",
}
typestr = chan_type_strings.get(chan_type, None)
if typestr and chanid:
actual = ctx.guild.get_channel(chanid)
if actual and actual.type in chan_type_strings:
addendum = "\n{} appears to be a {} instead.".format(
actual.mention,
chan_type_strings[actual.type]
)
typestr = typestr or "channel"
await ctx.error_reply("Couldn't find a {} matching `{}`!{}".format(typestr, userstr, addendum))
return chan
@Context.util
async def find_member(ctx, userstr, interactive=False, collection=None, silent=False):
"""
Find a guild member given a partial matching string,
allowing custom member collections.
Parameters
----------
userstr: str
String obtained from a user, expected to partially match a member in the collection.
The string will be tested against both the userid, full user name and user nickname.
interactive: bool
Whether to offer the user a list of members to choose from,
or pick the first matching channel.
collection: List(discord.Member)
Collection of members to search amongst.
If none, uses the full guild member list.
silent: bool
Whether to reply with an error when there are no matches.
Returns
-------
discord.Member:
If a valid member is found.
None:
If no valid member has been found.
Raises
------
cmdClient.lib.UserCancelled:
If the user cancels interactive member selection.
cmdClient.lib.ResponseTimedOut:
If the user fails to respond to interactive member selection within `60` seconds`
"""
# Handle invalid situations and input
if not ctx.guild:
raise InvalidContext("Attempt to use find_member outside of a guild.")
if userstr == "":
raise ValueError("User string passed to find_member was empty.")
# Create the collection to search from args or guild members
collection = collection if collection else ctx.guild.members
# If the user input was a number or possible member mention, extract it
userid = userstr.strip('<#@&!>')
userid = int(userid) if userid.isdigit() else None
searchstr = userstr.lower()
# Find the member
member = None
# Check method to determine whether a member matches
def check(member):
return (
member.id == userid
or searchstr in member.display_name.lower()
or searchstr in str(member).lower()
)
# Get list of matching roles
members = list(filter(check, collection))
if len(members) == 0:
# Nope
member = None
elif len(members) == 1:
# Select our lucky winner
member = members[0]
else:
# We have multiple matching members!
if interactive:
# Interactive prompt with the list of members
member_names = [
"{} {}".format(
member.nick if member.nick else (member if members.count(member) > 1
else member.name),
("<{}>".format(member)) if member.nick else ""
) for member in members
]
try:
selected = await ctx.selector(
"`{}` members found matching `{}`!".format(len(members), userstr),
member_names,
timeout=60
)
except UserCancelled:
raise UserCancelled("User cancelled member selection.") from None
except ResponseTimedOut:
raise ResponseTimedOut("Member selection timed out.") from None
member = members[selected]
else:
# Just select the first one
member = members[0]
if member is None and not silent:
await ctx.error_reply("Couldn't find a member matching `{}`!".format(userstr))
return member
@Context.util
async def find_message(ctx, msgid, chlist=None, ignore=[]):
"""
Searches for the given message id in the guild channels.
Parameters
-------
msgid: int
The `id` of the message to search for.
chlist: Optional[List[discord.TextChannel]]
List of channels to search in.
If `None`, searches all the text channels that the `ctx.author` can read.
ignore: list
A list of channelids to explicitly ignore in the search.
Returns
-------
Optional[discord.Message]:
If a message is found, returns the message.
Otherwise, returns `None`.
"""
if not ctx.guild:
raise InvalidContext("Cannot use this seeker outside of a guild!")
msgid = int(msgid)
# Build the channel list to search
if chlist is None:
chlist = [ch for ch in ctx.guild.text_channels if ch.permissions_for(ctx.author).read_messages]
# Remove any channels we are ignoring
chlist = [ch for ch in chlist if ch.id not in ignore]
tasks = set()
i = 0
while True:
done = set((task for task in tasks if task.done()))
tasks = tasks.difference(done)
results = [task.result() for task in done]
result = next((result for result in results if result is not None), None)
if result:
[task.cancel() for task in tasks]
return result
if i < len(chlist):
task = asyncio.create_task(_search_in_channel(chlist[i], msgid))
tasks.add(task)
i += 1
elif len(tasks) == 0:
return None
await asyncio.sleep(0.1)
async def _search_in_channel(channel: discord.TextChannel, msgid: int):
if not isinstance(channel, discord.TextChannel):
return
try:
message = await channel.fetch_message(msgid)
except Exception:
return None
else:
return message

View File

@@ -0,0 +1,40 @@
from cmdClient import check
from cmdClient.checks import in_guild
from meta import client
from data import tables
def is_guild_admin(member):
if member.id in client.owners:
return True
# First check guild admin permissions
admin = member.guild_permissions.administrator
# Then check the admin role, if it is set
if not admin:
admin_role_id = tables.guild_config.fetch_or_create(member.guild.id).admin_role
admin = admin_role_id and (admin_role_id in (r.id for r in member.roles))
return admin
@check(
name="ADMIN",
msg=("You need to be a server admin to do this!"),
requires=[in_guild]
)
async def guild_admin(ctx, *args, **kwargs):
return is_guild_admin(ctx.author)
@check(
name="MODERATOR",
msg=("You need to be a server moderator to do this!"),
requires=[in_guild],
parents=(guild_admin,)
)
async def guild_moderator(ctx, *args, **kwargs):
modrole = ctx.guild_settings.mod_role.value
return (modrole and (modrole in ctx.author.roles))