rewrite: Initial rewrite skeleton.
Remove modules that will no longer be required. Move pending modules to pending-rewrite folders.
This commit is contained in:
88
bot/pending-rewrite/LionContext.py
Normal file
88
bot/pending-rewrite/LionContext.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import types
|
||||
|
||||
from cmdClient import Context
|
||||
from cmdClient.logger import log
|
||||
|
||||
|
||||
class LionContext(Context):
|
||||
"""
|
||||
Subclass to allow easy attachment of custom hooks and structure to contexts.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def util(cls, util_func):
|
||||
"""
|
||||
Decorator to make a utility function available as a Context instance method.
|
||||
Extends the default Context method to add logging and to return the utility function.
|
||||
"""
|
||||
super().util(util_func)
|
||||
log(f"Attached context utility function: {util_func.__name__}")
|
||||
return util_func
|
||||
|
||||
@classmethod
|
||||
def wrappable_util(cls, util_func):
|
||||
"""
|
||||
Decorator to add a Wrappable utility function as a Context instance method.
|
||||
"""
|
||||
wrappable = Wrappable(util_func)
|
||||
super().util(wrappable)
|
||||
log(f"Attached wrappable context utility function: {util_func.__name__}")
|
||||
return wrappable
|
||||
|
||||
|
||||
class Wrappable:
|
||||
__slots__ = ('_func', 'wrappers')
|
||||
|
||||
def __init__(self, func):
|
||||
self._func = func
|
||||
self.wrappers = None
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return self._func.__name__
|
||||
|
||||
def add_wrapper(self, func, name=None):
|
||||
self.wrappers = self.wrappers or {}
|
||||
name = name or func.__name__
|
||||
self.wrappers[name] = func
|
||||
log(
|
||||
f"Added wrapper '{name}' to Wrappable '{self._func.__name__}'.",
|
||||
context="Wrapping"
|
||||
)
|
||||
|
||||
def remove_wrapper(self, name):
|
||||
if not self.wrappers or name not in self.wrappers:
|
||||
raise ValueError(
|
||||
f"Cannot remove non-existent wrapper '{name}' from Wrappable '{self._func.__name__}'"
|
||||
)
|
||||
self.wrappers.pop(name)
|
||||
log(
|
||||
f"Removed wrapper '{name}' from Wrappable '{self._func.__name__}'.",
|
||||
context="Wrapping"
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.wrappers:
|
||||
return self._wrapped(iter(self.wrappers.values()))(*args, **kwargs)
|
||||
else:
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
def _wrapped(self, iter_wraps):
|
||||
next_wrap = next(iter_wraps, None)
|
||||
if next_wrap:
|
||||
def _func(*args, **kwargs):
|
||||
return next_wrap(self._wrapped(iter_wraps), *args, **kwargs)
|
||||
else:
|
||||
_func = self._func
|
||||
return _func
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return types.MethodType(self, instance)
|
||||
|
||||
|
||||
# Override the original Context.reply with a wrappable utility
|
||||
reply = LionContext.wrappable_util(Context.reply)
|
||||
186
bot/pending-rewrite/LionModule.py
Normal file
186
bot/pending-rewrite/LionModule.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import logging
|
||||
import discord
|
||||
|
||||
from cmdClient import Command, Module, FailedCheck
|
||||
from cmdClient.lib import SafeCancellation
|
||||
|
||||
from meta import log
|
||||
|
||||
|
||||
class LionCommand(Command):
|
||||
"""
|
||||
Subclass to allow easy attachment of custom hooks and structure to commands.
|
||||
"""
|
||||
allow_before_ready = False
|
||||
|
||||
|
||||
class LionModule(Module):
|
||||
"""
|
||||
Custom module for Lion systems.
|
||||
|
||||
Adds command wrappers and various event handlers.
|
||||
"""
|
||||
name = "Base Lion Module"
|
||||
|
||||
def __init__(self, name, baseCommand=LionCommand):
|
||||
super().__init__(name, baseCommand)
|
||||
|
||||
self.unload_tasks = []
|
||||
|
||||
def unload_task(self, func):
|
||||
"""
|
||||
Decorator adding an unload task for deactivating the module.
|
||||
Should sync unsaved transactions and finalise user interaction.
|
||||
If possible, should also remove attached data and handlers.
|
||||
"""
|
||||
self.unload_tasks.append(func)
|
||||
log("Adding unload task '{}'.".format(func.__name__), context=self.name)
|
||||
return func
|
||||
|
||||
async def unload(self, client):
|
||||
"""
|
||||
Run the unloading tasks.
|
||||
"""
|
||||
log("Unloading module.", context=self.name, post=False)
|
||||
for task in self.unload_tasks:
|
||||
log("Running unload task '{}'".format(task.__name__),
|
||||
context=self.name, post=False)
|
||||
await task(client)
|
||||
|
||||
async def launch(self, client):
|
||||
"""
|
||||
Launch hook.
|
||||
Executed in `client.on_ready`.
|
||||
Must set `ready` to `True`, otherwise all commands will hang.
|
||||
Overrides the parent launcher to not post the log as a discord message.
|
||||
"""
|
||||
if not self.ready:
|
||||
log("Running launch tasks.", context=self.name, post=False)
|
||||
|
||||
for task in self.launch_tasks:
|
||||
log("Running launch task '{}'.".format(task.__name__),
|
||||
context=self.name, post=False)
|
||||
await task(client)
|
||||
|
||||
self.ready = True
|
||||
else:
|
||||
log("Already launched, skipping launch.", context=self.name, post=False)
|
||||
|
||||
async def pre_command(self, ctx):
|
||||
"""
|
||||
Lion pre-command hook.
|
||||
"""
|
||||
if not self.ready and not ctx.cmd.allow_before_ready:
|
||||
try:
|
||||
await ctx.embed_reply(
|
||||
"I am currently restarting! Please try again in a couple of minutes."
|
||||
)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
raise SafeCancellation(details="Module '{}' is not ready.".format(self.name))
|
||||
|
||||
# Check global user blacklist
|
||||
if ctx.author.id in ctx.client.user_blacklist():
|
||||
raise SafeCancellation(details='User is blacklisted.')
|
||||
|
||||
if ctx.guild:
|
||||
# Check that the channel and guild still exists
|
||||
if not ctx.client.get_guild(ctx.guild.id) or not ctx.guild.get_channel(ctx.ch.id):
|
||||
raise SafeCancellation(details='Command channel is no longer reachable.')
|
||||
|
||||
# Check global guild blacklist
|
||||
if ctx.guild.id in ctx.client.guild_blacklist():
|
||||
raise SafeCancellation(details='Guild is blacklisted.')
|
||||
|
||||
# Check guild's own member blacklist
|
||||
if ctx.author.id in ctx.client.objects['ignored_members'][ctx.guild.id]:
|
||||
raise SafeCancellation(details='User is ignored in this guild.')
|
||||
|
||||
# Check channel permissions are sane
|
||||
if not ctx.ch.permissions_for(ctx.guild.me).send_messages:
|
||||
raise SafeCancellation(details='I cannot send messages in this channel.')
|
||||
if not ctx.ch.permissions_for(ctx.guild.me).embed_links:
|
||||
await ctx.reply("I need permission to send embeds in this channel before I can run any commands!")
|
||||
raise SafeCancellation(details='I cannot send embeds in this channel.')
|
||||
|
||||
# Ensure Lion exists and cached data is up to date
|
||||
ctx.alion.update_saved_data(ctx.author)
|
||||
|
||||
# Start typing
|
||||
await ctx.ch.trigger_typing()
|
||||
|
||||
async def on_exception(self, ctx, exception):
|
||||
try:
|
||||
raise exception
|
||||
except (FailedCheck, SafeCancellation):
|
||||
# cmdClient generated and handled exceptions
|
||||
raise exception
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
# Standard command and task exceptions, cmdClient will also handle these
|
||||
raise exception
|
||||
except discord.Forbidden:
|
||||
# Unknown uncaught Forbidden
|
||||
try:
|
||||
# Attempt a general error reply
|
||||
await ctx.reply("I don't have enough channel or server permissions to complete that command here!")
|
||||
except discord.Forbidden:
|
||||
# We can't send anything at all. Exit quietly, but log.
|
||||
full_traceback = traceback.format_exc()
|
||||
log(("Caught an unhandled 'Forbidden' while "
|
||||
"executing command '{cmdname}' from module '{module}' "
|
||||
"from user '{message.author}' (uid:{message.author.id}) "
|
||||
"in guild '{message.guild}' (gid:{guildid}) "
|
||||
"in channel '{message.channel}' (cid:{message.channel.id}).\n"
|
||||
"Message Content:\n"
|
||||
"{content}\n"
|
||||
"{traceback}\n\n"
|
||||
"{flat_ctx}").format(
|
||||
cmdname=ctx.cmd.name,
|
||||
module=ctx.cmd.module.name,
|
||||
message=ctx.msg,
|
||||
guildid=ctx.guild.id if ctx.guild else None,
|
||||
content='\n'.join('\t' + line for line in ctx.msg.content.splitlines()),
|
||||
traceback=full_traceback,
|
||||
flat_ctx=ctx.flatten()
|
||||
),
|
||||
context="mid:{}".format(ctx.msg.id),
|
||||
level=logging.WARNING)
|
||||
except Exception as e:
|
||||
# Unknown exception!
|
||||
full_traceback = traceback.format_exc()
|
||||
only_error = "".join(traceback.TracebackException.from_exception(e).format_exception_only())
|
||||
|
||||
log(("Caught an unhandled exception while "
|
||||
"executing command '{cmdname}' from module '{module}' "
|
||||
"from user '{message.author}' (uid:{message.author.id}) "
|
||||
"in guild '{message.guild}' (gid:{guildid}) "
|
||||
"in channel '{message.channel}' (cid:{message.channel.id}).\n"
|
||||
"Message Content:\n"
|
||||
"{content}\n"
|
||||
"{traceback}\n\n"
|
||||
"{flat_ctx}").format(
|
||||
cmdname=ctx.cmd.name,
|
||||
module=ctx.cmd.module.name,
|
||||
message=ctx.msg,
|
||||
guildid=ctx.guild.id if ctx.guild else None,
|
||||
content='\n'.join('\t' + line for line in ctx.msg.content.splitlines()),
|
||||
traceback=full_traceback,
|
||||
flat_ctx=ctx.flatten()
|
||||
),
|
||||
context="mid:{}".format(ctx.msg.id),
|
||||
level=logging.ERROR)
|
||||
|
||||
error_embed = discord.Embed(title="Something went wrong!")
|
||||
error_embed.description = (
|
||||
"An unexpected error occurred while processing your command!\n"
|
||||
"Our development team has been notified, and the issue should be fixed soon.\n"
|
||||
)
|
||||
if logging.getLogger().getEffectiveLevel() < logging.INFO:
|
||||
error_embed.add_field(
|
||||
name="Exception",
|
||||
value="`{}`".format(only_error)
|
||||
)
|
||||
|
||||
await ctx.reply(embed=error_embed)
|
||||
5
bot/pending-rewrite/core/__init__.py
Normal file
5
bot/pending-rewrite/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from . import data # noqa
|
||||
|
||||
from .module import module
|
||||
from .lion import Lion
|
||||
from . import blacklists
|
||||
92
bot/pending-rewrite/core/blacklists.py
Normal file
92
bot/pending-rewrite/core/blacklists.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Guild, user, and member blacklists.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
import cachetools.func
|
||||
|
||||
from data import tables
|
||||
from meta import client
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
@cachetools.func.ttl_cache(ttl=300)
|
||||
def guild_blacklist():
|
||||
"""
|
||||
Get the guild blacklist
|
||||
"""
|
||||
rows = tables.global_guild_blacklist.select_where()
|
||||
return set(row['guildid'] for row in rows)
|
||||
|
||||
|
||||
@cachetools.func.ttl_cache(ttl=300)
|
||||
def user_blacklist():
|
||||
"""
|
||||
Get the global user blacklist.
|
||||
"""
|
||||
rows = tables.global_user_blacklist.select_where()
|
||||
return set(row['userid'] for row in rows)
|
||||
|
||||
|
||||
@module.init_task
|
||||
def load_ignored_members(client):
|
||||
"""
|
||||
Load the ignored members.
|
||||
"""
|
||||
ignored = defaultdict(set)
|
||||
rows = tables.ignored_members.select_where()
|
||||
|
||||
for row in rows:
|
||||
ignored[row['guildid']].add(row['userid'])
|
||||
|
||||
client.objects['ignored_members'] = ignored
|
||||
|
||||
if rows:
|
||||
client.log(
|
||||
"Loaded {} ignored members across {} guilds.".format(
|
||||
len(rows),
|
||||
len(ignored)
|
||||
),
|
||||
context="MEMBER_BLACKLIST"
|
||||
)
|
||||
|
||||
|
||||
@module.init_task
|
||||
def attach_client_blacklists(client):
|
||||
client.guild_blacklist = guild_blacklist
|
||||
client.user_blacklist = user_blacklist
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def leave_blacklisted_guilds(client):
|
||||
"""
|
||||
Launch task to leave any blacklisted guilds we are in.
|
||||
"""
|
||||
to_leave = [
|
||||
guild for guild in client.guilds
|
||||
if guild.id in guild_blacklist()
|
||||
]
|
||||
|
||||
for guild in to_leave:
|
||||
await guild.leave()
|
||||
|
||||
if to_leave:
|
||||
client.log(
|
||||
"Left {} blacklisted guilds!".format(len(to_leave)),
|
||||
context="GUILD_BLACKLIST"
|
||||
)
|
||||
|
||||
|
||||
@client.add_after_event('guild_join')
|
||||
async def check_guild_blacklist(client, guild):
|
||||
"""
|
||||
Guild join event handler to check whether the guild is blacklisted.
|
||||
If so, leaves the guild.
|
||||
"""
|
||||
# First refresh the blacklist cache
|
||||
if guild.id in guild_blacklist():
|
||||
await guild.leave()
|
||||
client.log(
|
||||
"Automatically left blacklisted guild '{}' (gid:{}) upon join.".format(guild.name, guild.id),
|
||||
context="GUILD_BLACKLIST"
|
||||
)
|
||||
128
bot/pending-rewrite/core/data.py
Normal file
128
bot/pending-rewrite/core/data.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from cachetools import TTLCache
|
||||
from data import RowTable, Table
|
||||
|
||||
|
||||
meta = RowTable(
|
||||
'AppData',
|
||||
('appid', 'last_study_badge_scan'),
|
||||
'appid',
|
||||
attach_as='meta',
|
||||
)
|
||||
|
||||
# TODO: Consider converting to RowTable for per-shard config caching
|
||||
app_config = Table('AppConfig')
|
||||
|
||||
|
||||
user_config = RowTable(
|
||||
'user_config',
|
||||
('userid', 'timezone', 'topgg_vote_reminder', 'avatar_hash', 'gems'),
|
||||
'userid',
|
||||
cache=TTLCache(5000, ttl=60*5)
|
||||
)
|
||||
|
||||
|
||||
guild_config = RowTable(
|
||||
'guild_config',
|
||||
('guildid', 'admin_role', 'mod_role', 'event_log_channel', 'mod_log_channel', 'alert_channel',
|
||||
'studyban_role', 'max_study_bans',
|
||||
'min_workout_length', 'workout_reward',
|
||||
'max_tasks', 'task_reward', 'task_reward_limit',
|
||||
'study_hourly_reward', 'study_hourly_live_bonus', 'daily_study_cap',
|
||||
'renting_price', 'renting_category', 'renting_cap', 'renting_role', 'renting_sync_perms',
|
||||
'accountability_category', 'accountability_lobby', 'accountability_bonus',
|
||||
'accountability_reward', 'accountability_price',
|
||||
'video_studyban', 'video_grace_period',
|
||||
'greeting_channel', 'greeting_message', 'returning_message',
|
||||
'starting_funds', 'persist_roles',
|
||||
'pomodoro_channel',
|
||||
'name'),
|
||||
'guildid',
|
||||
cache=TTLCache(2500, ttl=60*5)
|
||||
)
|
||||
|
||||
unranked_roles = Table('unranked_roles')
|
||||
|
||||
donator_roles = Table('donator_roles')
|
||||
|
||||
|
||||
lions = RowTable(
|
||||
'members',
|
||||
('guildid', 'userid',
|
||||
'tracked_time', 'coins',
|
||||
'workout_count', 'last_workout_start',
|
||||
'revision_mute_count',
|
||||
'last_study_badgeid',
|
||||
'video_warned',
|
||||
'display_name',
|
||||
'_timestamp'
|
||||
),
|
||||
('guildid', 'userid'),
|
||||
cache=TTLCache(5000, ttl=60*5),
|
||||
attach_as='lions'
|
||||
)
|
||||
|
||||
|
||||
@lions.save_query
|
||||
def add_pending(pending):
|
||||
"""
|
||||
pending:
|
||||
List of tuples of the form `(guildid, userid, pending_coins)`.
|
||||
"""
|
||||
with lions.conn:
|
||||
cursor = lions.conn.cursor()
|
||||
data = execute_values(
|
||||
cursor,
|
||||
"""
|
||||
UPDATE members
|
||||
SET
|
||||
coins = LEAST(coins + t.coin_diff, 2147483647)
|
||||
FROM
|
||||
(VALUES %s)
|
||||
AS
|
||||
t (guildid, userid, coin_diff)
|
||||
WHERE
|
||||
members.guildid = t.guildid
|
||||
AND
|
||||
members.userid = t.userid
|
||||
RETURNING *
|
||||
""",
|
||||
pending,
|
||||
fetch=True
|
||||
)
|
||||
return lions._make_rows(*data)
|
||||
|
||||
|
||||
lion_ranks = Table('member_ranks', attach_as='lion_ranks')
|
||||
|
||||
|
||||
@lions.save_query
|
||||
def get_member_rank(guildid, userid, untracked):
|
||||
"""
|
||||
Get the time and coin ranking for the given member, ignoring the provided untracked members.
|
||||
"""
|
||||
with lions.conn as conn:
|
||||
with conn.cursor() as curs:
|
||||
curs.execute(
|
||||
"""
|
||||
SELECT
|
||||
time_rank, coin_rank
|
||||
FROM (
|
||||
SELECT
|
||||
userid,
|
||||
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
|
||||
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
|
||||
FROM members_totals
|
||||
WHERE
|
||||
guildid=%s AND userid NOT IN %s
|
||||
) AS guild_ranks WHERE userid=%s
|
||||
""",
|
||||
(guildid, tuple(untracked), userid)
|
||||
)
|
||||
return curs.fetchone() or (None, None)
|
||||
|
||||
|
||||
global_guild_blacklist = Table('global_guild_blacklist')
|
||||
global_user_blacklist = Table('global_user_blacklist')
|
||||
ignored_members = Table('ignored_members')
|
||||
347
bot/pending-rewrite/core/lion.py
Normal file
347
bot/pending-rewrite/core/lion.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import pytz
|
||||
import discord
|
||||
from functools import reduce
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from meta import client
|
||||
from data import tables as tb
|
||||
from settings import UserSettings, GuildSettings
|
||||
|
||||
from LionContext import LionContext
|
||||
|
||||
|
||||
class Lion:
|
||||
"""
|
||||
Class representing a guild Member.
|
||||
Mostly acts as a transparent interface to the corresponding Row,
|
||||
but also adds some transaction caching logic to `coins` and `tracked_time`.
|
||||
"""
|
||||
__slots__ = ('guildid', 'userid', '_pending_coins', '_member')
|
||||
|
||||
# Members with pending transactions
|
||||
_pending = {} # userid -> User
|
||||
|
||||
# Lion cache. Currently lions don't expire
|
||||
_lions = {} # (guildid, userid) -> Lion
|
||||
|
||||
# Extra methods supplying an economy bonus
|
||||
_economy_bonuses = []
|
||||
|
||||
def __init__(self, guildid, userid):
|
||||
self.guildid = guildid
|
||||
self.userid = userid
|
||||
|
||||
self._pending_coins = 0
|
||||
|
||||
self._member = None
|
||||
|
||||
self._lions[self.key] = self
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, guildid, userid):
|
||||
"""
|
||||
Fetch a Lion with the given member.
|
||||
If they don't exist, creates them.
|
||||
If possible, retrieves the user from the user cache.
|
||||
"""
|
||||
key = (guildid, userid)
|
||||
if key in cls._lions:
|
||||
return cls._lions[key]
|
||||
else:
|
||||
# TODO: Debug log
|
||||
lion = tb.lions.fetch(key)
|
||||
if not lion:
|
||||
tb.user_config.fetch_or_create(userid)
|
||||
tb.guild_config.fetch_or_create(guildid)
|
||||
tb.lions.create_row(
|
||||
guildid=guildid,
|
||||
userid=userid,
|
||||
coins=GuildSettings(guildid).starting_funds.value
|
||||
)
|
||||
return cls(guildid, userid)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return (self.guildid, self.userid)
|
||||
|
||||
@property
|
||||
def guild(self) -> discord.Guild:
|
||||
return client.get_guild(self.guildid)
|
||||
|
||||
@property
|
||||
def member(self) -> discord.Member:
|
||||
"""
|
||||
The discord `Member` corresponding to this user.
|
||||
May be `None` if the member is no longer in the guild or the caches aren't populated.
|
||||
Not guaranteed to be `None` if the member is not in the guild.
|
||||
"""
|
||||
if self._member is None:
|
||||
guild = client.get_guild(self.guildid)
|
||||
if guild:
|
||||
self._member = guild.get_member(self.userid)
|
||||
return self._member
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""
|
||||
The Row corresponding to this member.
|
||||
"""
|
||||
return tb.lions.fetch(self.key)
|
||||
|
||||
@property
|
||||
def user_data(self):
|
||||
"""
|
||||
The Row corresponding to this user.
|
||||
"""
|
||||
return tb.user_config.fetch_or_create(self.userid)
|
||||
|
||||
@property
|
||||
def guild_data(self):
|
||||
"""
|
||||
The Row corresponding to this guild.
|
||||
"""
|
||||
return tb.guild_config.fetch_or_create(self.guildid)
|
||||
|
||||
@property
|
||||
def settings(self):
|
||||
"""
|
||||
The UserSettings interface for this member.
|
||||
"""
|
||||
return UserSettings(self.userid)
|
||||
|
||||
@property
|
||||
def guild_settings(self):
|
||||
"""
|
||||
The GuildSettings interface for this member.
|
||||
"""
|
||||
return GuildSettings(self.guildid)
|
||||
|
||||
@property
|
||||
def ctx(self) -> LionContext:
|
||||
"""
|
||||
Manufacture a `LionContext` with the lion member as an author.
|
||||
Useful for accessing member context utilities.
|
||||
Be aware that `author` may be `None` if the member was not cached.
|
||||
"""
|
||||
return LionContext(client, guild=self.guild, author=self.member)
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
"""
|
||||
Amount of time the user has spent studying, accounting for a current session.
|
||||
"""
|
||||
# Base time from cached member data
|
||||
time = self.data.tracked_time
|
||||
|
||||
# Add current session time if it exists
|
||||
if session := self.session:
|
||||
time += session.duration
|
||||
|
||||
return int(time)
|
||||
|
||||
@property
|
||||
def coins(self):
|
||||
"""
|
||||
Number of coins the user has, accounting for the pending value and current session.
|
||||
"""
|
||||
# Base coin amount from cached member data
|
||||
coins = self.data.coins
|
||||
|
||||
# Add pending coin amount
|
||||
coins += self._pending_coins
|
||||
|
||||
# Add current session coins if applicable
|
||||
if session := self.session:
|
||||
coins += session.coins_earned
|
||||
|
||||
return int(coins)
|
||||
|
||||
@property
|
||||
def economy_bonus(self):
|
||||
"""
|
||||
Economy multiplier
|
||||
"""
|
||||
return reduce(
|
||||
lambda x, y: x * y,
|
||||
[func(self) for func in self._economy_bonuses]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_economy_bonus(cls, func):
|
||||
cls._economy_bonuses.append(func)
|
||||
|
||||
@classmethod
|
||||
def unregister_economy_bonus(cls, func):
|
||||
cls._economy_bonuses.remove(func)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""
|
||||
The current study session the user is in, if any.
|
||||
"""
|
||||
if 'sessions' not in client.objects:
|
||||
raise ValueError("Cannot retrieve session before Study module is initialised!")
|
||||
return client.objects['sessions'][self.guildid].get(self.userid, None)
|
||||
|
||||
@property
|
||||
def timezone(self):
|
||||
"""
|
||||
The user's configured timezone.
|
||||
Shortcut to `Lion.settings.timezone.value`.
|
||||
"""
|
||||
return self.settings.timezone.value
|
||||
|
||||
@property
|
||||
def day_start(self):
|
||||
"""
|
||||
A timezone aware datetime representing the start of the user's day (in their configured timezone).
|
||||
NOTE: This might not be accurate over DST boundaries.
|
||||
"""
|
||||
now = datetime.now(tz=self.timezone)
|
||||
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@property
|
||||
def day_timestamp(self):
|
||||
"""
|
||||
EPOCH timestamp representing the current day for the user.
|
||||
NOTE: This is the timestamp of the start of the current UTC day with the same date as the user's day.
|
||||
This is *not* the start of the current user's day, either in UTC or their own timezone.
|
||||
This may also not be the start of the current day in UTC (consider 23:00 for a user in UTC-2).
|
||||
"""
|
||||
now = datetime.now(tz=self.timezone)
|
||||
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
return int(day_start.replace(tzinfo=pytz.utc).timestamp())
|
||||
|
||||
@property
|
||||
def week_timestamp(self):
|
||||
"""
|
||||
EPOCH timestamp representing the current week for the user.
|
||||
"""
|
||||
now = datetime.now(tz=self.timezone)
|
||||
day_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
week_start = day_start - timedelta(days=day_start.weekday())
|
||||
return int(week_start.replace(tzinfo=pytz.utc).timestamp())
|
||||
|
||||
@property
|
||||
def month_timestamp(self):
|
||||
"""
|
||||
EPOCH timestamp representing the current month for the user.
|
||||
"""
|
||||
now = datetime.now(tz=self.timezone)
|
||||
month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
return int(month_start.replace(tzinfo=pytz.utc).timestamp())
|
||||
|
||||
@property
|
||||
def remaining_in_day(self):
|
||||
return ((self.day_start + timedelta(days=1)) - datetime.now(self.timezone)).total_seconds()
|
||||
|
||||
@property
|
||||
def studied_today(self):
|
||||
"""
|
||||
The amount of time, in seconds, that the member has studied today.
|
||||
Extracted from the session history.
|
||||
"""
|
||||
return tb.session_history.queries.study_time_since(self.guildid, self.userid, self.day_start)
|
||||
|
||||
@property
|
||||
def remaining_study_today(self):
|
||||
"""
|
||||
Maximum remaining time (in seconds) this member can study today.
|
||||
|
||||
May not account for DST boundaries and leap seconds.
|
||||
"""
|
||||
studied_today = self.studied_today
|
||||
study_cap = self.guild_settings.daily_study_cap.value
|
||||
|
||||
remaining_in_day = self.remaining_in_day
|
||||
if remaining_in_day >= (study_cap - studied_today):
|
||||
remaining = study_cap - studied_today
|
||||
else:
|
||||
remaining = remaining_in_day + study_cap
|
||||
|
||||
return remaining
|
||||
|
||||
@property
|
||||
def profile_tags(self):
|
||||
"""
|
||||
Returns a list of profile tags, or the default tags.
|
||||
"""
|
||||
tags = tb.profile_tags.queries.get_tags_for(self.guildid, self.userid)
|
||||
prefix = self.ctx.best_prefix
|
||||
return tags or [
|
||||
f"Use {prefix}setprofile",
|
||||
"and add your tags",
|
||||
"to this section",
|
||||
f"See {prefix}help setprofile for more"
|
||||
]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
Returns the best local name possible.
|
||||
"""
|
||||
if self.member:
|
||||
name = self.member.display_name
|
||||
elif self.data.display_name:
|
||||
name = self.data.display_name
|
||||
else:
|
||||
name = str(self.userid)
|
||||
|
||||
return name
|
||||
|
||||
def update_saved_data(self, member: discord.Member):
|
||||
"""
|
||||
Update the stored discord data from the givem member.
|
||||
Intended to be used when we get member data from events that may not be available in cache.
|
||||
"""
|
||||
if self.guild_data.name != member.guild.name:
|
||||
self.guild_data.name = member.guild.name
|
||||
if self.user_data.avatar_hash != member.avatar:
|
||||
self.user_data.avatar_hash = member.avatar
|
||||
if self.data.display_name != member.display_name:
|
||||
self.data.display_name = member.display_name
|
||||
|
||||
def localize(self, naive_utc_dt):
|
||||
"""
|
||||
Localise the provided naive UTC datetime into the user's timezone.
|
||||
"""
|
||||
timezone = self.settings.timezone.value
|
||||
return naive_utc_dt.replace(tzinfo=pytz.UTC).astimezone(timezone)
|
||||
|
||||
def addCoins(self, amount, flush=True, bonus=False):
|
||||
"""
|
||||
Add coins to the user, optionally store the transaction in pending.
|
||||
"""
|
||||
self._pending_coins += amount * (self.economy_bonus if bonus else 1)
|
||||
self._pending[self.key] = self
|
||||
if flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush any pending transactions to the database.
|
||||
"""
|
||||
self.sync(self)
|
||||
|
||||
@classmethod
|
||||
def sync(cls, *lions):
|
||||
"""
|
||||
Flush pending transactions to the database.
|
||||
Also refreshes the Row cache for updated lions.
|
||||
"""
|
||||
lions = lions or list(cls._pending.values())
|
||||
|
||||
if lions:
|
||||
# Build userid to pending coin map
|
||||
pending = [
|
||||
(lion.guildid, lion.userid, int(lion._pending_coins))
|
||||
for lion in lions
|
||||
]
|
||||
|
||||
# Write to database
|
||||
tb.lions.queries.add_pending(pending)
|
||||
|
||||
# Cleanup pending users
|
||||
for lion in lions:
|
||||
lion._pending_coins -= int(lion._pending_coins)
|
||||
cls._pending.pop(lion.key, None)
|
||||
80
bot/pending-rewrite/core/module.py
Normal file
80
bot/pending-rewrite/core/module.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from meta import client, conf
|
||||
from settings import GuildSettings, UserSettings
|
||||
|
||||
from LionModule import LionModule
|
||||
|
||||
from .lion import Lion
|
||||
|
||||
|
||||
module = LionModule("Core")
|
||||
|
||||
|
||||
async def _lion_sync_loop():
|
||||
while True:
|
||||
while not client.is_ready():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
client.log(
|
||||
"Running lion data sync.",
|
||||
context="CORE",
|
||||
level=logging.DEBUG,
|
||||
post=False
|
||||
)
|
||||
|
||||
Lion.sync()
|
||||
await asyncio.sleep(conf.bot.getint("lion_sync_period"))
|
||||
|
||||
|
||||
@module.init_task
|
||||
def setting_initialisation(client):
|
||||
"""
|
||||
Execute all Setting initialisation tasks from GuildSettings and UserSettings.
|
||||
"""
|
||||
for setting in GuildSettings.settings.values():
|
||||
setting.init_task(client)
|
||||
|
||||
for setting in UserSettings.settings.values():
|
||||
setting.init_task(client)
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def preload_guild_configuration(client):
|
||||
"""
|
||||
Loads the plain guild configuration for all guilds the client is part of into data.
|
||||
"""
|
||||
guildids = [guild.id for guild in client.guilds]
|
||||
if guildids:
|
||||
rows = client.data.guild_config.fetch_rows_where(guildid=guildids)
|
||||
client.log(
|
||||
"Preloaded guild configuration for {} guilds.".format(len(rows)),
|
||||
context="CORE_LOADING"
|
||||
)
|
||||
|
||||
|
||||
@module.launch_task
|
||||
async def preload_studying_members(client):
|
||||
"""
|
||||
Loads the member data for all members who are currently in voice channels.
|
||||
"""
|
||||
userids = list(set(member.id for guild in client.guilds for ch in guild.voice_channels for member in ch.members))
|
||||
if userids:
|
||||
users = client.data.user_config.fetch_rows_where(userid=userids)
|
||||
members = client.data.lions.fetch_rows_where(userid=userids)
|
||||
client.log(
|
||||
"Preloaded data for {} user with {} members.".format(len(users), len(members)),
|
||||
context="CORE_LOADING"
|
||||
)
|
||||
|
||||
|
||||
# Removing the sync loop in favour of the studybadge sync.
|
||||
# @module.launch_task
|
||||
# async def launch_lion_sync_loop(client):
|
||||
# asyncio.create_task(_lion_sync_loop())
|
||||
|
||||
|
||||
@module.unload_task
|
||||
async def final_lion_sync(client):
|
||||
Lion.sync()
|
||||
9
bot/pending-rewrite/dev_main.py
Normal file
9
bot/pending-rewrite/dev_main.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import logging
|
||||
import meta
|
||||
|
||||
meta.logger.setLevel(logging.DEBUG)
|
||||
logging.getLogger("discord").setLevel(logging.INFO)
|
||||
|
||||
from utils import interactive # noqa
|
||||
|
||||
import main # noqa
|
||||
28
bot/pending-rewrite/main.py
Normal file
28
bot/pending-rewrite/main.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from meta import client, conf, log, sharding
|
||||
|
||||
from data import tables
|
||||
|
||||
import core # noqa
|
||||
|
||||
# Note: This MUST be imported after core, due to table definition orders
|
||||
from settings import AppSettings
|
||||
|
||||
import modules # noqa
|
||||
|
||||
# Load and attach app specific data
|
||||
if sharding.sharded:
|
||||
appname = f"{conf.bot['data_appid']}_{sharding.shard_count}_{sharding.shard_number}"
|
||||
else:
|
||||
appname = conf.bot['data_appid']
|
||||
client.appdata = core.data.meta.fetch_or_create(appname)
|
||||
|
||||
client.data = tables
|
||||
|
||||
client.settings = AppSettings(conf.bot['data_appid'])
|
||||
|
||||
# Initialise all modules
|
||||
client.initialise_modules()
|
||||
|
||||
# Log readyness and execute
|
||||
log("Initial setup complete, logging in", context='SETUP')
|
||||
client.run(conf.bot['TOKEN'])
|
||||
6
bot/pending-rewrite/settings/__init__.py
Normal file
6
bot/pending-rewrite/settings/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .base import * # noqa
|
||||
from .setting_types import * # noqa
|
||||
|
||||
from .user_settings import UserSettings, UserSetting # noqa
|
||||
from .guild_settings import GuildSettings, GuildSetting # noqa
|
||||
from .app_settings import AppSettings
|
||||
5
bot/pending-rewrite/settings/app_settings.py
Normal file
5
bot/pending-rewrite/settings/app_settings.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import settings
|
||||
from utils.lib import DotDict
|
||||
|
||||
class AppSettings(settings.ObjectSettings):
|
||||
settings = DotDict()
|
||||
514
bot/pending-rewrite/settings/base.py
Normal file
514
bot/pending-rewrite/settings/base.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import json
|
||||
import discord
|
||||
from cmdClient.cmdClient import cmdClient
|
||||
from cmdClient.lib import SafeCancellation
|
||||
from cmdClient.Check import Check
|
||||
|
||||
from utils.lib import prop_tabulate, DotDict
|
||||
|
||||
from LionContext import LionContext as Context
|
||||
|
||||
from meta import client
|
||||
from data import Table, RowTable
|
||||
|
||||
|
||||
class Setting:
|
||||
"""
|
||||
Abstract base class describing a stored configuration setting.
|
||||
A setting consists of logic to load the setting from storage,
|
||||
present it in a readable form, understand user entered values,
|
||||
and write it again in storage.
|
||||
Additionally, the setting has attributes attached describing
|
||||
the setting in a user-friendly manner for display purposes.
|
||||
"""
|
||||
attr_name: str = None # Internal attribute name for the setting
|
||||
_default: ... = None # Default data value for the setting.. this may be None if the setting overrides 'default'.
|
||||
|
||||
write_ward: Check = None # Check that must be passed to write the setting. Not implemented internally.
|
||||
|
||||
# Configuration interface descriptions
|
||||
display_name: str = None # User readable name of the setting
|
||||
desc: str = None # User readable brief description of the setting
|
||||
long_desc: str = None # User readable long description of the setting
|
||||
accepts: str = None # User readable description of the acceptable values
|
||||
|
||||
def __init__(self, id, data: ..., **kwargs):
|
||||
self.client: cmdClient = client
|
||||
self.id = id
|
||||
self._data = data
|
||||
|
||||
# Configuration embeds
|
||||
@property
|
||||
def embed(self):
|
||||
"""
|
||||
Discord Embed showing an information summary about the setting.
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
title="Configuration options for `{}`".format(self.display_name),
|
||||
)
|
||||
fields = ("Current value", "Default value", "Accepted input")
|
||||
values = (self.formatted or "Not Set",
|
||||
self._format_data(self.id, self.default) or "None",
|
||||
self.accepts)
|
||||
table = prop_tabulate(fields, values)
|
||||
embed.description = "{}\n{}".format(self.long_desc.format(self=self, client=self.client), table)
|
||||
return embed
|
||||
|
||||
async def widget(self, ctx: Context, **kwargs):
|
||||
"""
|
||||
Show the setting widget for this setting.
|
||||
By default this displays the setting embed.
|
||||
Settings may override this if they need more complex widget context or logic.
|
||||
"""
|
||||
return await ctx.reply(embed=self.embed)
|
||||
|
||||
@property
|
||||
def summary(self):
|
||||
"""
|
||||
Formatted summary of the data.
|
||||
May be implemented in `_format_data(..., summary=True, ...)` or overidden.
|
||||
"""
|
||||
return self._format_data(self.id, self.data, summary=True)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
"""
|
||||
Response message sent when the setting has successfully been updated.
|
||||
"""
|
||||
return "Setting Updated!"
|
||||
|
||||
# Instance generation
|
||||
@classmethod
|
||||
def get(cls, id: int, **kwargs):
|
||||
"""
|
||||
Return a setting instance initialised from the stored value.
|
||||
"""
|
||||
data = cls._reader(id, **kwargs)
|
||||
return cls(id, data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, id: int, ctx: Context, userstr: str, **kwargs):
|
||||
"""
|
||||
Return a setting instance initialised from a parsed user string.
|
||||
"""
|
||||
data = await cls._parse_userstr(ctx, id, userstr, **kwargs)
|
||||
return cls(id, data, **kwargs)
|
||||
|
||||
# Main interface
|
||||
@property
|
||||
def data(self):
|
||||
"""
|
||||
Retrieves the current internal setting data if it is set, otherwise the default data
|
||||
"""
|
||||
return self._data if self._data is not None else self.default
|
||||
|
||||
@data.setter
|
||||
def data(self, new_data):
|
||||
"""
|
||||
Sets the internal setting data and writes the changes.
|
||||
"""
|
||||
self._data = new_data
|
||||
self.write()
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
"""
|
||||
Retrieves the default value for this setting.
|
||||
Settings should override this if the default depends on the object id.
|
||||
"""
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""
|
||||
Discord-aware object or objects associated with the setting.
|
||||
"""
|
||||
return self._data_to_value(self.id, self.data)
|
||||
|
||||
@value.setter
|
||||
def value(self, new_value):
|
||||
"""
|
||||
Setter which reads the discord-aware object, converts it to data, and writes it.
|
||||
"""
|
||||
self._data = self._data_from_value(self.id, new_value)
|
||||
self.write()
|
||||
|
||||
@property
|
||||
def formatted(self):
|
||||
"""
|
||||
User-readable form of the setting.
|
||||
"""
|
||||
return self._format_data(self.id, self.data)
|
||||
|
||||
def write(self, **kwargs):
|
||||
"""
|
||||
Write value to the database.
|
||||
For settings which override this,
|
||||
ensure you handle deletion of values when internal data is None.
|
||||
"""
|
||||
self._writer(self.id, self._data, **kwargs)
|
||||
|
||||
# Raw converters
|
||||
@classmethod
|
||||
def _data_from_value(cls, id: int, value, **kwargs):
|
||||
"""
|
||||
Convert a high-level setting value to internal data.
|
||||
Must be overriden by the setting.
|
||||
Be aware of None values, these should always pass through as None
|
||||
to provide an unsetting interface.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _data_to_value(cls, id: int, data: ..., **kwargs):
|
||||
"""
|
||||
Convert internal data to high-level setting value.
|
||||
Must be overriden by the setting.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs):
|
||||
"""
|
||||
Parse user provided input into internal data.
|
||||
Must be overriden by the setting if the setting is user-configurable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, id: int, data: ..., **kwargs):
|
||||
"""
|
||||
Convert internal data into a formatted user-readable string.
|
||||
Must be overriden by the setting if the setting is user-viewable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Database access classmethods
|
||||
@classmethod
|
||||
def _reader(cls, id: int, **kwargs):
|
||||
"""
|
||||
Read a setting from storage and return setting data or None.
|
||||
Must be overriden by the setting.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _writer(cls, id: int, data: ..., **kwargs):
|
||||
"""
|
||||
Write provided setting data to storage.
|
||||
Must be overriden by the setting unless the `write` method is overidden.
|
||||
If the data is None, the setting is empty and should be unset.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def command(cls, ctx, id, flags=()):
|
||||
"""
|
||||
Standardised command viewing/setting interface for the setting.
|
||||
"""
|
||||
if not ctx.args and not ctx.msg.attachments:
|
||||
# View config embed for provided cls
|
||||
await cls.get(id).widget(ctx, flags=flags)
|
||||
else:
|
||||
# Check the write ward
|
||||
if cls.write_ward and not await cls.write_ward.run(ctx):
|
||||
await ctx.error_reply(cls.write_ward.msg)
|
||||
else:
|
||||
# Attempt to set config cls
|
||||
try:
|
||||
cls = await cls.parse(id, ctx, ctx.args)
|
||||
except UserInputError as e:
|
||||
await ctx.reply(embed=discord.Embed(
|
||||
description="{} {}".format('❌', e.msg),
|
||||
Colour=discord.Colour.red()
|
||||
))
|
||||
else:
|
||||
cls.write()
|
||||
await ctx.reply(embed=discord.Embed(
|
||||
description="{} {}".format('✅', cls.success_response),
|
||||
Colour=discord.Colour.green()
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def init_task(self, client):
|
||||
"""
|
||||
Initialisation task to be excuted during client initialisation.
|
||||
May be used for e.g. populating a cache or required client setup.
|
||||
|
||||
Main application must execute the initialisation task before the setting is used.
|
||||
Further, the task must always be executable, if the setting is loaded.
|
||||
Conditional initalisation should go in the relevant module's init tasks.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class ObjectSettings:
|
||||
"""
|
||||
Abstract class representing a linked collection of settings for a single object.
|
||||
Initialised settings are provided as instance attributes in the form of properties.
|
||||
"""
|
||||
__slots__ = ('id', 'params')
|
||||
|
||||
settings: DotDict = None
|
||||
|
||||
def __init__(self, id, **kwargs):
|
||||
self.id = id
|
||||
self.params = tuple(kwargs.items())
|
||||
|
||||
@classmethod
|
||||
def _setting_property(cls, setting):
|
||||
def wrapped_setting(self):
|
||||
return setting.get(self.id, **dict(self.params))
|
||||
return wrapped_setting
|
||||
|
||||
@classmethod
|
||||
def attach_setting(cls, setting: Setting):
|
||||
name = setting.attr_name or setting.__name__
|
||||
setattr(cls, name, property(cls._setting_property(setting)))
|
||||
cls.settings[name] = setting
|
||||
return setting
|
||||
|
||||
def tabulated(self):
|
||||
"""
|
||||
Convenience method to provide a complete setting property-table.
|
||||
"""
|
||||
formatted = {
|
||||
setting.display_name: setting.get(self.id, **dict(self.params)).formatted
|
||||
for name, setting in self.settings.items()
|
||||
}
|
||||
return prop_tabulate(*zip(*formatted.items()))
|
||||
|
||||
|
||||
class ColumnData:
|
||||
"""
|
||||
Mixin for settings stored in a single row and column of a Table.
|
||||
Intended to be used with tables where the only primary key is the object id.
|
||||
"""
|
||||
# Table storing the desired data
|
||||
_table_interface: Table = None
|
||||
|
||||
# Name of the column storing the setting object id
|
||||
_id_column: str = None
|
||||
|
||||
# Name of the column with the desired data
|
||||
_data_column: str = None
|
||||
|
||||
# Whether to use create a row if not found (only applies to TableRow)
|
||||
_create_row = False
|
||||
|
||||
# Whether to upsert or update for updates
|
||||
_upsert: bool = True
|
||||
|
||||
# High level data cache to use, set to None to disable cache.
|
||||
_cache = None # Map[id -> value]
|
||||
|
||||
@classmethod
|
||||
def _reader(cls, id: int, use_cache=True, **kwargs):
|
||||
"""
|
||||
Read in the requested entry associated to the id.
|
||||
Supports reading cached values from a `RowTable`.
|
||||
"""
|
||||
if cls._cache is not None and id in cls._cache and use_cache:
|
||||
return cls._cache[id]
|
||||
|
||||
table = cls._table_interface
|
||||
if isinstance(table, RowTable) and cls._id_column == table.id_col:
|
||||
if cls._create_row:
|
||||
row = table.fetch_or_create(id)
|
||||
else:
|
||||
row = table.fetch(id)
|
||||
data = row.data[cls._data_column] if row else None
|
||||
else:
|
||||
params = {
|
||||
"select_columns": (cls._data_column,),
|
||||
cls._id_column: id
|
||||
}
|
||||
row = table.select_one_where(**params)
|
||||
data = row[cls._data_column] if row else None
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[id] = data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _writer(cls, id: int, data: ..., **kwargs):
|
||||
"""
|
||||
Write the provided entry to the table, allowing replacements.
|
||||
"""
|
||||
table = cls._table_interface
|
||||
params = {
|
||||
cls._id_column: id
|
||||
}
|
||||
values = {
|
||||
cls._data_column: data
|
||||
}
|
||||
|
||||
# Update data
|
||||
if cls._upsert:
|
||||
# Upsert data
|
||||
table.upsert(
|
||||
constraint=cls._id_column,
|
||||
**params,
|
||||
**values
|
||||
)
|
||||
else:
|
||||
# Update data
|
||||
table.update_where(values, **params)
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[id] = data
|
||||
|
||||
|
||||
class ListData:
|
||||
"""
|
||||
Mixin for list types implemented on a Table.
|
||||
Implements a reader and writer.
|
||||
This assumes the list is the only data stored in the table,
|
||||
and removes list entries by deleting rows.
|
||||
"""
|
||||
# Table storing the setting data
|
||||
_table_interface: Table = None
|
||||
|
||||
# Name of the column storing the id
|
||||
_id_column: str = None
|
||||
|
||||
# Name of the column storing the data to read
|
||||
_data_column: str = None
|
||||
|
||||
# Name of column storing the order index to use, if any. Assumed to be Serial on writing.
|
||||
_order_column: str = None
|
||||
_order_type: str = "ASC"
|
||||
|
||||
# High level data cache to use, set to None to disable cache.
|
||||
_cache = None # Map[id -> value]
|
||||
|
||||
@classmethod
|
||||
def _reader(cls, id: int, use_cache=True, **kwargs):
|
||||
"""
|
||||
Read in all entries associated to the given id.
|
||||
"""
|
||||
if cls._cache is not None and id in cls._cache and use_cache:
|
||||
return cls._cache[id]
|
||||
|
||||
table = cls._table_interface # type: Table
|
||||
params = {
|
||||
"select_columns": [cls._data_column],
|
||||
cls._id_column: id
|
||||
}
|
||||
if cls._order_column:
|
||||
params['_extra'] = "ORDER BY {} {}".format(cls._order_column, cls._order_type)
|
||||
|
||||
rows = table.select_where(**params)
|
||||
data = [row[cls._data_column] for row in rows]
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[id] = data
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _writer(cls, id: int, data: ..., add_only=False, remove_only=False, **kwargs):
|
||||
"""
|
||||
Write the provided list to storage.
|
||||
"""
|
||||
# TODO: Transaction lock on the table so this is atomic
|
||||
# Or just use the connection context manager
|
||||
|
||||
table = cls._table_interface # type: Table
|
||||
|
||||
# Handle None input as an empty list
|
||||
if data is None:
|
||||
data = []
|
||||
|
||||
current = cls._reader(id, **kwargs)
|
||||
if not cls._order_column and (add_only or remove_only):
|
||||
to_insert = [item for item in data if item not in current] if not remove_only else []
|
||||
to_remove = data if remove_only else (
|
||||
[item for item in current if item not in data] if not add_only else []
|
||||
)
|
||||
|
||||
# Handle required deletions
|
||||
if to_remove:
|
||||
params = {
|
||||
cls._id_column: id,
|
||||
cls._data_column: to_remove
|
||||
}
|
||||
table.delete_where(**params)
|
||||
|
||||
# Handle required insertions
|
||||
if to_insert:
|
||||
columns = (cls._id_column, cls._data_column)
|
||||
values = [(id, value) for value in to_insert]
|
||||
table.insert_many(*values, insert_keys=columns)
|
||||
|
||||
if cls._cache is not None:
|
||||
new_current = [item for item in current + to_insert if item not in to_remove]
|
||||
cls._cache[id] = new_current
|
||||
else:
|
||||
# Remove all and add all to preserve order
|
||||
# TODO: This really really should be atomic if anything else reads this
|
||||
delete_params = {cls._id_column: id}
|
||||
table.delete_where(**delete_params)
|
||||
|
||||
if data:
|
||||
columns = (cls._id_column, cls._data_column)
|
||||
values = [(id, value) for value in data]
|
||||
table.insert_many(*values, insert_keys=columns)
|
||||
|
||||
if cls._cache is not None:
|
||||
cls._cache[id] = data
|
||||
|
||||
|
||||
class KeyValueData:
|
||||
"""
|
||||
Mixin for settings implemented in a Key-Value table.
|
||||
The underlying table should have a Unique constraint on the `(_id_column, _key_column)` pair.
|
||||
"""
|
||||
_table_interface: Table = None
|
||||
|
||||
_id_column: str = None
|
||||
|
||||
_key_column: str = None
|
||||
|
||||
_value_column: str = None
|
||||
|
||||
_key: str = None
|
||||
|
||||
@classmethod
|
||||
def _reader(cls, id: ..., **kwargs):
|
||||
params = {
|
||||
"select_columns": (cls._value_column, ),
|
||||
cls._id_column: id,
|
||||
cls._key_column: cls._key
|
||||
}
|
||||
|
||||
row = cls._table_interface.select_one_where(**params)
|
||||
data = row[cls._value_column] if row else None
|
||||
|
||||
if data is not None:
|
||||
data = json.loads(data)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _writer(cls, id: ..., data: ..., **kwargs):
|
||||
params = {
|
||||
cls._id_column: id,
|
||||
cls._key_column: cls._key
|
||||
}
|
||||
if data is not None:
|
||||
values = {
|
||||
cls._value_column: json.dumps(data)
|
||||
}
|
||||
cls._table_interface.upsert(
|
||||
constraint=f"{cls._id_column}, {cls._key_column}",
|
||||
**params,
|
||||
**values
|
||||
)
|
||||
else:
|
||||
cls._table_interface.delete_where(**params)
|
||||
|
||||
|
||||
class UserInputError(SafeCancellation):
|
||||
pass
|
||||
197
bot/pending-rewrite/settings/guild_settings.py
Normal file
197
bot/pending-rewrite/settings/guild_settings.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import datetime
|
||||
import asyncio
|
||||
import discord
|
||||
|
||||
import settings
|
||||
from utils.lib import DotDict
|
||||
from utils import seekers # noqa
|
||||
|
||||
from wards import guild_admin
|
||||
from data import tables as tb
|
||||
|
||||
|
||||
class GuildSettings(settings.ObjectSettings):
|
||||
settings = DotDict()
|
||||
|
||||
|
||||
class GuildSetting(settings.ColumnData, settings.Setting):
|
||||
_table_interface = tb.guild_config
|
||||
_id_column = 'guildid'
|
||||
_create_row = True
|
||||
|
||||
category = None
|
||||
|
||||
write_ward = guild_admin
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class event_log(settings.Channel, GuildSetting):
|
||||
category = "Meta"
|
||||
|
||||
attr_name = 'event_log'
|
||||
_data_column = 'event_log_channel'
|
||||
|
||||
display_name = "event_log"
|
||||
desc = "Bot event logging channel."
|
||||
|
||||
long_desc = (
|
||||
"Channel to post 'events', such as workouts completing or members renting a room."
|
||||
)
|
||||
|
||||
_chan_type = discord.ChannelType.text
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The event log is now {}.".format(self.formatted)
|
||||
else:
|
||||
return "The event log has been unset."
|
||||
|
||||
def log(self, description="", colour=discord.Color.orange(), **kwargs):
|
||||
channel = self.value
|
||||
if channel:
|
||||
embed = discord.Embed(
|
||||
description=description,
|
||||
colour=colour,
|
||||
timestamp=datetime.datetime.utcnow(),
|
||||
**kwargs
|
||||
)
|
||||
asyncio.create_task(channel.send(embed=embed))
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class admin_role(settings.Role, GuildSetting):
|
||||
category = "Guild Roles"
|
||||
|
||||
attr_name = 'admin_role'
|
||||
_data_column = 'admin_role'
|
||||
|
||||
display_name = "admin_role"
|
||||
desc = "Server administrator role."
|
||||
|
||||
long_desc = (
|
||||
"Server administrator role.\n"
|
||||
"Allows usage of the administrative commands, such as `config`.\n"
|
||||
"These commands may also be used by anyone with the discord adminitrator permission."
|
||||
)
|
||||
# TODO Expand on what these are.
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The administrator role is now {}.".format(self.formatted)
|
||||
else:
|
||||
return "The administrator role has been unset."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class mod_role(settings.Role, GuildSetting):
|
||||
category = "Guild Roles"
|
||||
|
||||
attr_name = 'mod_role'
|
||||
_data_column = 'mod_role'
|
||||
|
||||
display_name = "mod_role"
|
||||
desc = "Server moderator role."
|
||||
|
||||
long_desc = (
|
||||
"Server moderator role.\n"
|
||||
"Allows usage of the modistrative commands."
|
||||
)
|
||||
# TODO Expand on what these are.
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The moderator role is now {}.".format(self.formatted)
|
||||
else:
|
||||
return "The moderator role has been unset."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class unranked_roles(settings.RoleList, settings.ListData, settings.Setting):
|
||||
category = "Guild Roles"
|
||||
|
||||
attr_name = 'unranked_roles'
|
||||
|
||||
_table_interface = tb.unranked_roles
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'roleid'
|
||||
|
||||
write_ward = guild_admin
|
||||
display_name = "unranked_roles"
|
||||
desc = "Roles to exclude from the leaderboards."
|
||||
|
||||
_force_unique = True
|
||||
|
||||
long_desc = (
|
||||
"Roles to be excluded from the `top` and `topcoins` leaderboards."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire objects
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The following roles will be excluded from the leaderboard:\n{}".format(self.formatted)
|
||||
else:
|
||||
return "The excluded roles have been removed."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class donator_roles(settings.RoleList, settings.ListData, settings.Setting):
|
||||
category = "Hidden"
|
||||
|
||||
attr_name = 'donator_roles'
|
||||
|
||||
_table_interface = tb.donator_roles
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'roleid'
|
||||
|
||||
write_ward = guild_admin
|
||||
display_name = "donator_roles"
|
||||
desc = "Donator badge roles."
|
||||
|
||||
_force_unique = True
|
||||
|
||||
long_desc = (
|
||||
"Members with these roles will be considered donators and have access to premium features."
|
||||
)
|
||||
|
||||
# Flat cache, no need to expire objects
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The donator badges are now:\n{}".format(self.formatted)
|
||||
else:
|
||||
return "The donator badges have been removed."
|
||||
|
||||
|
||||
@GuildSettings.attach_setting
|
||||
class alert_channel(settings.Channel, GuildSetting):
|
||||
category = "Meta"
|
||||
|
||||
attr_name = 'alert_channel'
|
||||
_data_column = 'alert_channel'
|
||||
|
||||
display_name = "alert_channel"
|
||||
desc = "Channel to display global user alerts."
|
||||
|
||||
long_desc = (
|
||||
"This channel will be used for group notifications, "
|
||||
"for example group timers and anti-cheat messages, "
|
||||
"as well as for critical alerts to users that have their direct messages disapbled.\n"
|
||||
"It should be visible to all members."
|
||||
)
|
||||
|
||||
_chan_type = discord.ChannelType.text
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return "The alert channel is now {}.".format(self.formatted)
|
||||
else:
|
||||
return "The alert channel has been unset."
|
||||
1119
bot/pending-rewrite/settings/setting_types.py
Normal file
1119
bot/pending-rewrite/settings/setting_types.py
Normal file
File diff suppressed because it is too large
Load Diff
42
bot/pending-rewrite/settings/user_settings.py
Normal file
42
bot/pending-rewrite/settings/user_settings.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import datetime
|
||||
|
||||
import settings
|
||||
from utils.lib import DotDict
|
||||
|
||||
from data import tables as tb
|
||||
|
||||
|
||||
class UserSettings(settings.ObjectSettings):
|
||||
settings = DotDict()
|
||||
|
||||
|
||||
class UserSetting(settings.ColumnData, settings.Setting):
|
||||
_table_interface = tb.user_config
|
||||
_id_column = 'userid'
|
||||
_create_row = True
|
||||
|
||||
write_ward = None
|
||||
|
||||
|
||||
@UserSettings.attach_setting
|
||||
class timezone(settings.Timezone, UserSetting):
|
||||
attr_name = 'timezone'
|
||||
_data_column = 'timezone'
|
||||
|
||||
_default = 'UTC'
|
||||
|
||||
display_name = 'timezone'
|
||||
desc = "Timezone to display prompts in."
|
||||
long_desc = (
|
||||
"Timezone used for displaying certain prompts (e.g. selecting an accountability room)."
|
||||
)
|
||||
|
||||
@property
|
||||
def success_response(self):
|
||||
if self.value:
|
||||
return (
|
||||
"Your personal timezone is now {}.\n"
|
||||
"Your current time is **{}**."
|
||||
).format(self.formatted, datetime.datetime.now(tz=self.value).strftime("%H:%M"))
|
||||
else:
|
||||
return "Your personal timezone has been unset."
|
||||
157
bot/pending-rewrite/utils/ctx_addons.py
Normal file
157
bot/pending-rewrite/utils/ctx_addons.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from LionContext import LionContext as Context
|
||||
from cmdClient.lib import SafeCancellation
|
||||
|
||||
from data import tables
|
||||
from core import Lion
|
||||
from . import lib
|
||||
from settings import GuildSettings, UserSettings
|
||||
|
||||
|
||||
@Context.util
|
||||
async def embed_reply(ctx, desc, colour=discord.Colour.orange(), **kwargs):
|
||||
"""
|
||||
Simple helper to embed replies.
|
||||
All arguments are passed to the embed constructor.
|
||||
`desc` is passed as the `description` kwarg.
|
||||
"""
|
||||
embed = discord.Embed(description=desc, colour=colour, **kwargs)
|
||||
try:
|
||||
return await ctx.reply(embed=embed, reference=ctx.msg.to_reference(fail_if_not_exists=False))
|
||||
except discord.Forbidden:
|
||||
if not ctx.guild or ctx.ch.permissions_for(ctx.guild.me).send_messages:
|
||||
await ctx.reply("Command failed, I don't have permission to send embeds in this channel!")
|
||||
raise SafeCancellation
|
||||
|
||||
|
||||
@Context.util
|
||||
async def error_reply(ctx, error_str, send_args={}, **kwargs):
|
||||
"""
|
||||
Notify the user of a user level error.
|
||||
Typically, this will occur in a red embed, posted in the command channel.
|
||||
"""
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.red(),
|
||||
description=error_str,
|
||||
**kwargs
|
||||
)
|
||||
message = None
|
||||
try:
|
||||
message = await ctx.ch.send(
|
||||
embed=embed,
|
||||
reference=ctx.msg.to_reference(fail_if_not_exists=False),
|
||||
**send_args
|
||||
)
|
||||
ctx.sent_messages.append(message)
|
||||
return message
|
||||
except discord.Forbidden:
|
||||
if not ctx.guild or ctx.ch.permissions_for(ctx.guild.me).send_messages:
|
||||
await ctx.reply("Command failed, I don't have permission to send embeds in this channel!")
|
||||
raise SafeCancellation
|
||||
|
||||
|
||||
@Context.util
|
||||
async def offer_delete(ctx: Context, *to_delete, timeout=300):
|
||||
"""
|
||||
Offers to delete the provided messages via a reaction on the last message.
|
||||
Removes the reaction if the offer times out.
|
||||
|
||||
If any exceptions occur, handles them silently and returns.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
to_delete: List[Message]
|
||||
The messages to delete.
|
||||
|
||||
timeout: int
|
||||
Time in seconds after which to remove the delete offer reaction.
|
||||
"""
|
||||
# Get the delete emoji from the config
|
||||
emoji = lib.cross
|
||||
|
||||
# Return if there are no messages to delete
|
||||
if not to_delete:
|
||||
return
|
||||
|
||||
# The message to add the reaction to
|
||||
react_msg = to_delete[-1]
|
||||
|
||||
# Build the reaction check function
|
||||
if ctx.guild:
|
||||
modrole = ctx.guild_settings.mod_role.value if ctx.guild else None
|
||||
|
||||
def check(reaction, user):
|
||||
if not (reaction.message.id == react_msg.id and reaction.emoji == emoji):
|
||||
return False
|
||||
if user == ctx.guild.me:
|
||||
return False
|
||||
return ((user == ctx.author)
|
||||
or (user.permissions_in(ctx.ch).manage_messages)
|
||||
or (modrole and modrole in user.roles))
|
||||
else:
|
||||
def check(reaction, user):
|
||||
return user == ctx.author and reaction.message.id == react_msg.id and reaction.emoji == emoji
|
||||
|
||||
try:
|
||||
# Add the reaction to the message
|
||||
await react_msg.add_reaction(emoji)
|
||||
|
||||
# Wait for the user to press the reaction
|
||||
reaction, user = await ctx.client.wait_for("reaction_add", check=check, timeout=timeout)
|
||||
|
||||
# Since the check was satisfied, the reaction is correct. Delete the messages, ignoring any exceptions
|
||||
deleted = False
|
||||
# First try to bulk delete if we have the permissions
|
||||
if ctx.guild and ctx.ch.permissions_for(ctx.guild.me).manage_messages:
|
||||
try:
|
||||
await ctx.ch.delete_messages(to_delete)
|
||||
deleted = True
|
||||
except Exception:
|
||||
deleted = False
|
||||
|
||||
# If we couldn't bulk delete, delete them one by one
|
||||
if not deleted:
|
||||
try:
|
||||
asyncio.gather(*[message.delete() for message in to_delete], return_exceptions=True)
|
||||
except Exception:
|
||||
pass
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
# Timed out waiting for the reaction, attempt to remove the delete reaction
|
||||
try:
|
||||
await react_msg.remove_reaction(emoji, ctx.client.user)
|
||||
except Exception:
|
||||
pass
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
except discord.NotFound:
|
||||
pass
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
|
||||
def context_property(func):
|
||||
setattr(Context, func.__name__, property(func))
|
||||
return func
|
||||
|
||||
|
||||
@context_property
|
||||
def best_prefix(ctx):
|
||||
return ctx.client.prefix
|
||||
|
||||
|
||||
@context_property
|
||||
def guild_settings(ctx):
|
||||
if ctx.guild:
|
||||
tables.guild_config.fetch_or_create(ctx.guild.id)
|
||||
return GuildSettings(ctx.guild.id if ctx.guild else 0)
|
||||
|
||||
|
||||
@context_property
|
||||
def author_settings(ctx):
|
||||
return UserSettings(ctx.author.id)
|
||||
|
||||
|
||||
@context_property
|
||||
def alion(ctx):
|
||||
return Lion.fetch(ctx.guild.id if ctx.guild else 0, ctx.author.id)
|
||||
461
bot/pending-rewrite/utils/interactive.py
Normal file
461
bot/pending-rewrite/utils/interactive.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from LionContext import LionContext as Context
|
||||
from cmdClient.lib import UserCancelled, ResponseTimedOut
|
||||
|
||||
from .lib import paginate_list
|
||||
|
||||
# TODO: Interactive locks
|
||||
cancel_emoji = '❌'
|
||||
number_emojis = (
|
||||
'1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', '8️⃣', '9️⃣'
|
||||
)
|
||||
|
||||
|
||||
async def discord_shield(coro):
|
||||
try:
|
||||
await coro
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
|
||||
@Context.util
|
||||
async def cancellable(ctx, msg, add_reaction=True, cancel_message=None, timeout=300):
|
||||
"""
|
||||
Add a cancellation reaction to the given message.
|
||||
Pressing the reaction triggers cancellation of the original context, and a UserCancelled-style error response.
|
||||
"""
|
||||
# TODO: Not consistent with the exception driven flow, make a decision here?
|
||||
# Add reaction
|
||||
if add_reaction and cancel_emoji not in (str(r.emoji) for r in msg.reactions):
|
||||
try:
|
||||
await msg.add_reaction(cancel_emoji)
|
||||
except discord.HTTPException:
|
||||
return
|
||||
|
||||
# Define cancellation function
|
||||
async def _cancel():
|
||||
try:
|
||||
await ctx.client.wait_for(
|
||||
'reaction_add',
|
||||
timeout=timeout,
|
||||
check=lambda r, u: (u == ctx.author
|
||||
and r.message == msg
|
||||
and str(r.emoji) == cancel_emoji)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await ctx.client.active_command_response_cleaner(ctx)
|
||||
if cancel_message:
|
||||
await ctx.error_reply(cancel_message)
|
||||
else:
|
||||
try:
|
||||
await ctx.msg.add_reaction(cancel_emoji)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
[task.cancel() for task in ctx.tasks]
|
||||
|
||||
# Launch cancellation task
|
||||
task = asyncio.create_task(_cancel())
|
||||
ctx.tasks.append(task)
|
||||
return task
|
||||
|
||||
|
||||
@Context.util
|
||||
async def listen_for(ctx, allowed_input=None, timeout=120, lower=True, check=None):
|
||||
"""
|
||||
Listen for a one of a particular set of input strings,
|
||||
sent in the current channel by `ctx.author`.
|
||||
When found, return the message containing them.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
allowed_input: Union(List(str), None)
|
||||
List of strings to listen for.
|
||||
Allowed to be `None` precisely when a `check` function is also supplied.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
lower: bool
|
||||
Whether to shift the allowed and message strings to lowercase before checking.
|
||||
check: Function(message) -> bool
|
||||
Alternative custom check function.
|
||||
|
||||
Returns: discord.Message
|
||||
The message that was matched.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised when no messages matching the given criteria are detected in `timeout` seconds.
|
||||
"""
|
||||
# Generate the check if it hasn't been provided
|
||||
if not check:
|
||||
# Quick check the arguments are sane
|
||||
if not allowed_input:
|
||||
raise ValueError("allowed_input and check cannot both be None")
|
||||
|
||||
# Force a lower on the allowed inputs
|
||||
allowed_input = [s.lower() for s in allowed_input]
|
||||
|
||||
# Create the check function
|
||||
def check(message):
|
||||
result = (message.author == ctx.author)
|
||||
result = result and (message.channel == ctx.ch)
|
||||
result = result and ((message.content.lower() if lower else message.content) in allowed_input)
|
||||
return result
|
||||
|
||||
# Wait for a matching message, catch and transform the timeout
|
||||
try:
|
||||
message = await ctx.client.wait_for('message', check=check, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResponseTimedOut("Session timed out waiting for user response.") from None
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@Context.util
|
||||
async def selector(ctx, header, select_from, timeout=120, max_len=20):
|
||||
"""
|
||||
Interactive routine to prompt the `ctx.author` to select an item from a list.
|
||||
Returns the list index that was selected.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
header: str
|
||||
String to put at the top of each page of selection options.
|
||||
Intended to be information about the list the user is selecting from.
|
||||
select_from: List(str)
|
||||
The list of strings to select from.
|
||||
timeout: int
|
||||
The number of seconds to wait before throwing `ResponseTimedOut`.
|
||||
max_len: int
|
||||
The maximum number of items to display on each page.
|
||||
Decrease this if the items are long, to avoid going over the char limit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
The index of the list entry selected by the user.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.UserCancelled:
|
||||
Raised if the user manually cancels the selection.
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised if the user fails to respond to the selector within `timeout` seconds.
|
||||
"""
|
||||
# Handle improper arguments
|
||||
if len(select_from) == 0:
|
||||
raise ValueError("Selection list passed to `selector` cannot be empty.")
|
||||
|
||||
# Generate the selector pages
|
||||
footer = "Please reply with the number of your selection, or press {} to cancel.".format(cancel_emoji)
|
||||
list_pages = paginate_list(select_from, block_length=max_len)
|
||||
pages = ["\n".join([header, page, footer]) for page in list_pages]
|
||||
|
||||
# Post the pages in a paged message
|
||||
out_msg = await ctx.pager(pages, add_cancel=True)
|
||||
cancel_task = await ctx.cancellable(out_msg, add_reaction=False, timeout=None)
|
||||
|
||||
if len(select_from) <= 5:
|
||||
for i, _ in enumerate(select_from):
|
||||
asyncio.create_task(discord_shield(out_msg.add_reaction(number_emojis[i])))
|
||||
|
||||
# Build response tasks
|
||||
valid_input = [str(i+1) for i in range(0, len(select_from))] + ['c', 'C']
|
||||
listen_task = asyncio.create_task(ctx.listen_for(valid_input, timeout=None))
|
||||
emoji_task = asyncio.create_task(ctx.client.wait_for(
|
||||
'reaction_add',
|
||||
check=lambda r, u: (u == ctx.author
|
||||
and r.message == out_msg
|
||||
and str(r.emoji) in number_emojis)
|
||||
))
|
||||
# Wait for the response tasks
|
||||
done, pending = await asyncio.wait(
|
||||
(listen_task, emoji_task),
|
||||
timeout=timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await out_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Handle different return cases
|
||||
if listen_task in done:
|
||||
emoji_task.cancel()
|
||||
|
||||
result_msg = listen_task.result()
|
||||
try:
|
||||
await result_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
if result_msg.content.lower() == 'c':
|
||||
raise UserCancelled("Selection cancelled!")
|
||||
result = int(result_msg.content) - 1
|
||||
elif emoji_task in done:
|
||||
listen_task.cancel()
|
||||
|
||||
reaction, _ = emoji_task.result()
|
||||
result = number_emojis.index(str(reaction.emoji))
|
||||
elif cancel_task in done:
|
||||
# Manually cancelled case.. the current task should have been cancelled
|
||||
# Raise UserCancelled in case the task wasn't cancelled for some reason
|
||||
raise UserCancelled("Selection cancelled!")
|
||||
elif not done:
|
||||
# Timeout case
|
||||
raise ResponseTimedOut("Selector timed out waiting for a response.")
|
||||
|
||||
# Finally cancel the canceller and return the provided index
|
||||
cancel_task.cancel()
|
||||
return result
|
||||
|
||||
|
||||
@Context.util
|
||||
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
|
||||
"""
|
||||
Shows the user each page from the provided list `pages` one at a time,
|
||||
providing reactions to page back and forth between pages.
|
||||
This is done asynchronously, and returns after displaying the first page.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pages: List(Union(str, discord.Embed))
|
||||
A list of either strings or embeds to display as the pages.
|
||||
locked: bool
|
||||
Whether only the `ctx.author` should be able to use the paging reactions.
|
||||
kwargs: ...
|
||||
Remaining keyword arguments are transparently passed to the reply context method.
|
||||
|
||||
Returns: discord.Message
|
||||
This is the output message, returned for easy deletion.
|
||||
"""
|
||||
# Handle broken input
|
||||
if len(pages) == 0:
|
||||
raise ValueError("Pager cannot page with no pages!")
|
||||
|
||||
# Post first page. Method depends on whether the page is an embed or not.
|
||||
if isinstance(pages[start_at], discord.Embed):
|
||||
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
|
||||
else:
|
||||
out_msg = await ctx.reply(pages[start_at], **kwargs)
|
||||
|
||||
# Run the paging loop if required
|
||||
if len(pages) > 1:
|
||||
task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs))
|
||||
ctx.tasks.append(task)
|
||||
elif add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
|
||||
# Return the output message
|
||||
return out_msg
|
||||
|
||||
|
||||
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
|
||||
"""
|
||||
Asynchronous initialiser and loop for the `pager` utility above.
|
||||
"""
|
||||
# Page number
|
||||
page = start_at
|
||||
|
||||
# Add reactions to the output message
|
||||
next_emoji = "▶"
|
||||
prev_emoji = "◀"
|
||||
|
||||
try:
|
||||
await out_msg.add_reaction(prev_emoji)
|
||||
if add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
await out_msg.add_reaction(next_emoji)
|
||||
except discord.Forbidden:
|
||||
# We don't have permission to add paging emojis
|
||||
# Die as gracefully as we can
|
||||
if ctx.guild:
|
||||
perms = ctx.ch.permissions_for(ctx.guild.me)
|
||||
if not perms.add_reactions:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `add_reactions` permission!"
|
||||
)
|
||||
elif not perms.read_message_history:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `read_message_history` permission!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results due to insufficient permissions!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results!"
|
||||
)
|
||||
return
|
||||
|
||||
# Check function to determine whether a reaction is valid
|
||||
def reaction_check(reaction, user):
|
||||
result = reaction.message.id == out_msg.id
|
||||
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
|
||||
result = result and not (user.id == ctx.client.user.id)
|
||||
result = result and not (locked and user != ctx.author)
|
||||
return result
|
||||
|
||||
# Check function to determine if message has a page number
|
||||
def message_check(message):
|
||||
result = message.channel.id == ctx.ch.id
|
||||
result = result and not (locked and message.author != ctx.author)
|
||||
result = result and message.content.lower().startswith('p')
|
||||
result = result and message.content[1:].isdigit()
|
||||
result = result and 1 <= int(message.content[1:]) <= len(pages)
|
||||
return result
|
||||
|
||||
# Begin loop
|
||||
while True:
|
||||
# Wait for a valid reaction or message, break if we time out
|
||||
reaction_task = asyncio.create_task(
|
||||
ctx.client.wait_for('reaction_add', check=reaction_check)
|
||||
)
|
||||
message_task = asyncio.create_task(
|
||||
ctx.client.wait_for('message', check=message_check)
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
(reaction_task, message_task),
|
||||
timeout=300,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if done:
|
||||
if reaction_task in done:
|
||||
# Cancel the message task and collect the reaction result
|
||||
message_task.cancel()
|
||||
reaction, user = reaction_task.result()
|
||||
|
||||
# Attempt to remove the user's reaction, silently ignore errors
|
||||
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
|
||||
|
||||
# Change the page number
|
||||
page += 1 if reaction.emoji == next_emoji else -1
|
||||
page %= len(pages)
|
||||
elif message_task in done:
|
||||
# Cancel the reaction task and collect the message result
|
||||
reaction_task.cancel()
|
||||
message = message_task.result()
|
||||
|
||||
# Attempt to delete the user's message, silently ignore errors
|
||||
asyncio.ensure_future(message.delete())
|
||||
|
||||
# Move to the correct page
|
||||
page = int(message.content[1:]) - 1
|
||||
|
||||
# Edit the message with the new page
|
||||
active_page = pages[page]
|
||||
if isinstance(active_page, discord.Embed):
|
||||
await out_msg.edit(embed=active_page, **kwargs)
|
||||
else:
|
||||
await out_msg.edit(content=active_page, **kwargs)
|
||||
else:
|
||||
# No tasks finished, so we must have timed out, or had an exception.
|
||||
# Break the loop and clean up
|
||||
break
|
||||
|
||||
# Clean up by removing the reactions
|
||||
try:
|
||||
await out_msg.clear_reactions()
|
||||
except discord.Forbidden:
|
||||
try:
|
||||
await out_msg.remove_reaction(next_emoji, ctx.client.user)
|
||||
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
|
||||
except discord.NotFound:
|
||||
pass
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
|
||||
@Context.util
|
||||
async def input(ctx, msg="", timeout=120):
|
||||
"""
|
||||
Listen for a response in the current channel, from ctx.author.
|
||||
Returns the response from ctx.author, if it is provided.
|
||||
Parameters
|
||||
----------
|
||||
msg: string
|
||||
Allows a custom input message to be provided.
|
||||
Will use default message if not provided.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised when ctx.author does not provide a response before the function times out.
|
||||
"""
|
||||
# Deliver prompt
|
||||
offer_msg = await ctx.reply(msg or "Please enter your input.")
|
||||
|
||||
# Criteria for the input message
|
||||
def checks(m):
|
||||
return m.author == ctx.author and m.channel == ctx.ch
|
||||
|
||||
# Listen for the reply
|
||||
try:
|
||||
result_msg = await ctx.client.wait_for("message", check=checks, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResponseTimedOut("Session timed out waiting for user response.") from None
|
||||
|
||||
result = result_msg.content
|
||||
|
||||
# Attempt to delete the prompt and reply messages
|
||||
try:
|
||||
await offer_msg.delete()
|
||||
await result_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@Context.util
|
||||
async def ask(ctx, msg, timeout=30, use_msg=None, del_on_timeout=False):
|
||||
"""
|
||||
Ask ctx.author a yes/no question.
|
||||
Returns 0 if ctx.author answers no
|
||||
Returns 1 if ctx.author answers yes
|
||||
Parameters
|
||||
----------
|
||||
msg: string
|
||||
Adds the question to the message string.
|
||||
Requires an input.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
use_msg: string
|
||||
A completely custom string to use instead of the default string.
|
||||
del_on_timeout: bool
|
||||
Whether to delete the question if it times out.
|
||||
Raises
|
||||
------
|
||||
Nothing
|
||||
"""
|
||||
out = "{} {}".format(msg, "`y(es)`/`n(o)`")
|
||||
|
||||
offer_msg = use_msg or await ctx.reply(out)
|
||||
if use_msg and msg:
|
||||
await use_msg.edit(content=msg)
|
||||
|
||||
result_msg = await ctx.listen_for(["y", "yes", "n", "no"], timeout=timeout)
|
||||
|
||||
if result_msg is None:
|
||||
if del_on_timeout:
|
||||
try:
|
||||
await offer_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
result = result_msg.content.lower()
|
||||
try:
|
||||
if not use_msg:
|
||||
await offer_msg.delete()
|
||||
await result_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
if result in ["n", "no"]:
|
||||
return 0
|
||||
return 1
|
||||
553
bot/pending-rewrite/utils/lib.py
Normal file
553
bot/pending-rewrite/utils/lib.py
Normal file
@@ -0,0 +1,553 @@
|
||||
import datetime
|
||||
import iso8601
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
import discord
|
||||
from psycopg2.extensions import QuotedString
|
||||
|
||||
from cmdClient.lib import SafeCancellation
|
||||
|
||||
|
||||
multiselect_regex = re.compile(
|
||||
r"^([0-9, -]+)$",
|
||||
re.DOTALL | re.IGNORECASE | re.VERBOSE
|
||||
)
|
||||
tick = '✅'
|
||||
cross = '❌'
|
||||
|
||||
|
||||
def prop_tabulate(prop_list, value_list, indent=True, colon=True):
|
||||
"""
|
||||
Turns a list of properties and corresponding list of values into
|
||||
a pretty string with one `prop: value` pair each line,
|
||||
padded so that the colons in each line are lined up.
|
||||
Handles empty props by using an extra couple of spaces instead of a `:`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prop_list: List[str]
|
||||
List of short names to put on the right side of the list.
|
||||
Empty props are considered to be "newlines" for the corresponding value.
|
||||
value_list: List[str]
|
||||
List of values corresponding to the properties above.
|
||||
indent: bool
|
||||
Whether to add padding so the properties are right-adjusted.
|
||||
|
||||
Returns: str
|
||||
"""
|
||||
max_len = max(len(prop) for prop in prop_list)
|
||||
return "".join(["`{}{}{}`\t{}{}".format(" " * (max_len - len(prop)) if indent else "",
|
||||
prop,
|
||||
(":" if len(prop) else " " * 2) if colon else '',
|
||||
value_list[i],
|
||||
'' if str(value_list[i]).endswith("```") else '\n')
|
||||
for i, prop in enumerate(prop_list)])
|
||||
|
||||
|
||||
def paginate_list(item_list, block_length=20, style="markdown", title=None):
|
||||
"""
|
||||
Create pretty codeblock pages from a list of strings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item_list: List[str]
|
||||
List of strings to paginate.
|
||||
block_length: int
|
||||
Maximum number of strings per page.
|
||||
style: str
|
||||
Codeblock style to use.
|
||||
Title formatting assumes the `markdown` style, and numbered lists work well with this.
|
||||
However, `markdown` sometimes messes up formatting in the list.
|
||||
title: str
|
||||
Optional title to add to the top of each page.
|
||||
|
||||
Returns: List[str]
|
||||
List of pages, each formatted into a codeblock,
|
||||
and containing at most `block_length` of the provided strings.
|
||||
"""
|
||||
lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)]
|
||||
page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)]
|
||||
pages = []
|
||||
for i, block in enumerate(page_blocks):
|
||||
pagenum = "Page {}/{}".format(i + 1, len(page_blocks))
|
||||
if title:
|
||||
header = "{} ({})".format(title, pagenum) if len(page_blocks) > 1 else title
|
||||
else:
|
||||
header = pagenum
|
||||
header_line = "=" * len(header)
|
||||
full_header = "{}\n{}\n".format(header, header_line) if len(page_blocks) > 1 or title else ""
|
||||
pages.append("```{}\n{}{}```".format(style, full_header, "\n".join(block)))
|
||||
return pages
|
||||
|
||||
|
||||
def timestamp_utcnow():
|
||||
"""
|
||||
Return the current integer UTC timestamp.
|
||||
"""
|
||||
return int(datetime.datetime.timestamp(datetime.datetime.utcnow()))
|
||||
|
||||
|
||||
def split_text(text, blocksize=2000, code=True, syntax="", maxheight=50):
|
||||
"""
|
||||
Break the text into blocks of maximum length blocksize
|
||||
If possible, break across nearby newlines. Otherwise just break at blocksize chars
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
Text to break into blocks.
|
||||
blocksize: int
|
||||
Maximum character length for each block.
|
||||
code: bool
|
||||
Whether to wrap each block in codeblocks (these are counted in the blocksize).
|
||||
syntax: str
|
||||
The markdown formatting language to use for the codeblocks, if applicable.
|
||||
maxheight: int
|
||||
The maximum number of lines in each block
|
||||
|
||||
Returns: List[str]
|
||||
List of blocks,
|
||||
each containing at most `block_size` characters,
|
||||
of height at most `maxheight`.
|
||||
"""
|
||||
# Adjust blocksize to account for the codeblocks if required
|
||||
blocksize = blocksize - 8 - len(syntax) if code else blocksize
|
||||
|
||||
# Build the blocks
|
||||
blocks = []
|
||||
while True:
|
||||
# If the remaining text is already small enough, append it
|
||||
if len(text) <= blocksize:
|
||||
blocks.append(text)
|
||||
break
|
||||
text = text.strip('\n')
|
||||
|
||||
# Find the last newline in the prototype block
|
||||
split_on = text[0:blocksize].rfind('\n')
|
||||
split_on = blocksize if split_on < blocksize // 5 else split_on
|
||||
|
||||
# Add the block and truncate the text
|
||||
blocks.append(text[0:split_on])
|
||||
text = text[split_on:]
|
||||
|
||||
# Add the codeblock ticks and the code syntax header, if required
|
||||
if code:
|
||||
blocks = ["```{}\n{}\n```".format(syntax, block) for block in blocks]
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def strfdelta(delta, sec=False, minutes=True, short=False):
|
||||
"""
|
||||
Convert a datetime.timedelta object into an easily readable duration string.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delta: datetime.timedelta
|
||||
The timedelta object to convert into a readable string.
|
||||
sec: bool
|
||||
Whether to include the seconds from the timedelta object in the string.
|
||||
minutes: bool
|
||||
Whether to include the minutes from the timedelta object in the string.
|
||||
short: bool
|
||||
Whether to abbreviate the units of time ("hour" to "h", "minute" to "m", "second" to "s").
|
||||
|
||||
Returns: str
|
||||
A string containing a time from the datetime.timedelta object, in a readable format.
|
||||
Time units will be abbreviated if short was set to True.
|
||||
"""
|
||||
|
||||
output = [[delta.days, 'd' if short else ' day'],
|
||||
[delta.seconds // 3600, 'h' if short else ' hour']]
|
||||
if minutes:
|
||||
output.append([delta.seconds // 60 % 60, 'm' if short else ' minute'])
|
||||
if sec:
|
||||
output.append([delta.seconds % 60, 's' if short else ' second'])
|
||||
for i in range(len(output)):
|
||||
if output[i][0] != 1 and not short:
|
||||
output[i][1] += 's'
|
||||
reply_msg = []
|
||||
if output[0][0] != 0:
|
||||
reply_msg.append("{}{} ".format(output[0][0], output[0][1]))
|
||||
if output[0][0] != 0 or output[1][0] != 0 or len(output) == 2:
|
||||
reply_msg.append("{}{} ".format(output[1][0], output[1][1]))
|
||||
for i in range(2, len(output) - 1):
|
||||
reply_msg.append("{}{} ".format(output[i][0], output[i][1]))
|
||||
if not short and reply_msg:
|
||||
reply_msg.append("and ")
|
||||
reply_msg.append("{}{}".format(output[-1][0], output[-1][1]))
|
||||
return "".join(reply_msg)
|
||||
|
||||
|
||||
def parse_dur(time_str):
|
||||
"""
|
||||
Parses a user provided time duration string into a timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_str: str
|
||||
The time string to parse. String can include days, hours, minutes, and seconds.
|
||||
|
||||
Returns: int
|
||||
The number of seconds the duration represents.
|
||||
"""
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
time_str = time_str.strip(" ,")
|
||||
found = re.findall(r'(\d+)\s?(\w+?)', time_str)
|
||||
seconds = 0
|
||||
for bit in found:
|
||||
if bit[1] in funcs:
|
||||
seconds += funcs[bit[1]](int(bit[0]))
|
||||
return seconds
|
||||
|
||||
|
||||
def strfdur(duration, short=True, show_days=False):
|
||||
"""
|
||||
Convert a duration given in seconds to a number of hours, minutes, and seconds.
|
||||
"""
|
||||
days = duration // (3600 * 24) if show_days else 0
|
||||
hours = duration // 3600
|
||||
if days:
|
||||
hours %= 24
|
||||
minutes = duration // 60 % 60
|
||||
seconds = duration % 60
|
||||
|
||||
parts = []
|
||||
if days:
|
||||
unit = 'd' if short else (' days' if days != 1 else ' day')
|
||||
parts.append('{}{}'.format(days, unit))
|
||||
if hours:
|
||||
unit = 'h' if short else (' hours' if hours != 1 else ' hour')
|
||||
parts.append('{}{}'.format(hours, unit))
|
||||
if minutes:
|
||||
unit = 'm' if short else (' minutes' if minutes != 1 else ' minute')
|
||||
parts.append('{}{}'.format(minutes, unit))
|
||||
if seconds or duration == 0:
|
||||
unit = 's' if short else (' seconds' if seconds != 1 else ' second')
|
||||
parts.append('{}{}'.format(seconds, unit))
|
||||
|
||||
if short:
|
||||
return ' '.join(parts)
|
||||
else:
|
||||
return ', '.join(parts)
|
||||
|
||||
|
||||
def substitute_ranges(ranges_str, max_match=20, max_range=1000, separator=','):
|
||||
"""
|
||||
Substitutes a user provided list of numbers and ranges,
|
||||
and replaces the ranges by the corresponding list of numbers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ranges_str: str
|
||||
The string to ranges in.
|
||||
max_match: int
|
||||
The maximum number of ranges to replace.
|
||||
Any ranges exceeding this will be ignored.
|
||||
max_range: int
|
||||
The maximum length of range to replace.
|
||||
Attempting to replace a range longer than this will raise a `ValueError`.
|
||||
"""
|
||||
def _repl(match):
|
||||
n1 = int(match.group(1))
|
||||
n2 = int(match.group(2))
|
||||
if n2 - n1 > max_range:
|
||||
raise SafeCancellation("Provided range is too large!")
|
||||
return separator.join(str(i) for i in range(n1, n2 + 1))
|
||||
|
||||
return re.sub(r'(\d+)\s*-\s*(\d+)', _repl, ranges_str, max_match)
|
||||
|
||||
|
||||
def parse_ranges(ranges_str, ignore_errors=False, separator=',', **kwargs):
|
||||
"""
|
||||
Parses a user provided range string into a list of numbers.
|
||||
Extra keyword arguments are transparently passed to the underlying parser `substitute_ranges`.
|
||||
"""
|
||||
substituted = substitute_ranges(ranges_str, separator=separator, **kwargs)
|
||||
numbers = (item.strip() for item in substituted.split(','))
|
||||
numbers = [item for item in numbers if item]
|
||||
integers = [int(item) for item in numbers if item.isdigit()]
|
||||
|
||||
if not ignore_errors and len(integers) != len(numbers):
|
||||
raise SafeCancellation(
|
||||
"Couldn't parse the provided selection!\n"
|
||||
"Please provide comma separated numbers and ranges, e.g. `1, 5, 6-9`."
|
||||
)
|
||||
|
||||
return integers
|
||||
|
||||
|
||||
def msg_string(msg, mask_link=False, line_break=False, tz=None, clean=True):
|
||||
"""
|
||||
Format a message into a string with various information, such as:
|
||||
the timestamp of the message, author, message content, and attachments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg: Message
|
||||
The message to format.
|
||||
mask_link: bool
|
||||
Whether to mask the URLs of any attachments.
|
||||
line_break: bool
|
||||
Whether a line break should be used in the string.
|
||||
tz: Timezone
|
||||
The timezone to use in the formatted message.
|
||||
clean: bool
|
||||
Whether to use the clean content of the original message.
|
||||
|
||||
Returns: str
|
||||
A formatted string containing various information:
|
||||
User timezone, message author, message content, attachments
|
||||
"""
|
||||
timestr = "%I:%M %p, %d/%m/%Y"
|
||||
if tz:
|
||||
time = iso8601.parse_date(msg.timestamp.isoformat()).astimezone(tz).strftime(timestr)
|
||||
else:
|
||||
time = msg.timestamp.strftime(timestr)
|
||||
user = str(msg.author)
|
||||
attach_list = [attach["url"] for attach in msg.attachments if "url" in attach]
|
||||
if mask_link:
|
||||
attach_list = ["[Link]({})".format(url) for url in attach_list]
|
||||
attachments = "\nAttachments: {}".format(", ".join(attach_list)) if attach_list else ""
|
||||
return "`[{time}]` **{user}:** {line_break}{message} {attachments}".format(
|
||||
time=time,
|
||||
user=user,
|
||||
line_break="\n" if line_break else "",
|
||||
message=msg.clean_content if clean else msg.content,
|
||||
attachments=attachments
|
||||
)
|
||||
|
||||
|
||||
def convdatestring(datestring):
|
||||
"""
|
||||
Convert a date string into a datetime.timedelta object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datestring: str
|
||||
The string to convert to a datetime.timedelta object.
|
||||
|
||||
Returns: datetime.timedelta
|
||||
A datetime.timedelta object formed from the string provided.
|
||||
"""
|
||||
datestring = datestring.strip(' ,')
|
||||
datearray = []
|
||||
funcs = {'d': lambda x: x * 24 * 60 * 60,
|
||||
'h': lambda x: x * 60 * 60,
|
||||
'm': lambda x: x * 60,
|
||||
's': lambda x: x}
|
||||
currentnumber = ''
|
||||
for char in datestring:
|
||||
if char.isdigit():
|
||||
currentnumber += char
|
||||
else:
|
||||
if currentnumber == '':
|
||||
continue
|
||||
datearray.append((int(currentnumber), char))
|
||||
currentnumber = ''
|
||||
seconds = 0
|
||||
if currentnumber:
|
||||
seconds += int(currentnumber)
|
||||
for i in datearray:
|
||||
if i[1] in funcs:
|
||||
seconds += funcs[i[1]](i[0])
|
||||
return datetime.timedelta(seconds=seconds)
|
||||
|
||||
|
||||
class _rawChannel(discord.abc.Messageable):
|
||||
"""
|
||||
Raw messageable class representing an arbitrary channel,
|
||||
not necessarially seen by the gateway.
|
||||
"""
|
||||
def __init__(self, state, id):
|
||||
self._state = state
|
||||
self.id = id
|
||||
|
||||
async def _get_channel(self):
|
||||
return discord.Object(self.id)
|
||||
|
||||
|
||||
async def mail(client: discord.Client, channelid: int, **msg_args):
|
||||
"""
|
||||
Mails a message to a channelid which may be invisible to the gateway.
|
||||
|
||||
Parameters:
|
||||
client: discord.Client
|
||||
The client to use for mailing.
|
||||
Must at least have static authentication and have a valid `_connection`.
|
||||
channelid: int
|
||||
The channel id to mail to.
|
||||
msg_args: Any
|
||||
Message keyword arguments which are passed transparently to `_rawChannel.send(...)`.
|
||||
"""
|
||||
# Create the raw channel
|
||||
channel = _rawChannel(client._connection, channelid)
|
||||
return await channel.send(**msg_args)
|
||||
|
||||
|
||||
def emb_add_fields(embed, emb_fields):
|
||||
"""
|
||||
Append embed fields to an embed.
|
||||
Parameters
|
||||
----------
|
||||
embed: discord.Embed
|
||||
The embed to add the field to.
|
||||
emb_fields: tuple
|
||||
The values to add to a field.
|
||||
name: str
|
||||
The name of the field.
|
||||
value: str
|
||||
The value of the field.
|
||||
inline: bool
|
||||
Whether the embed field should be inline or not.
|
||||
"""
|
||||
for field in emb_fields:
|
||||
embed.add_field(name=str(field[0]), value=str(field[1]), inline=bool(field[2]))
|
||||
|
||||
|
||||
def join_list(string, nfs=False):
|
||||
"""
|
||||
Join a list together, separated with commas, plus add "and" to the beginning of the last value.
|
||||
Parameters
|
||||
----------
|
||||
string: list
|
||||
The list to join together.
|
||||
nfs: bool
|
||||
(no fullstops)
|
||||
Whether to exclude fullstops/periods from the output messages.
|
||||
If not provided, fullstops will be appended to the output.
|
||||
"""
|
||||
if len(string) > 1:
|
||||
return "{}{} and {}{}".format((", ").join(string[:-1]),
|
||||
"," if len(string) > 2 else "", string[-1], "" if nfs else ".")
|
||||
else:
|
||||
return "{}{}".format("".join(string), "" if nfs else ".")
|
||||
|
||||
|
||||
def format_activity(user):
|
||||
"""
|
||||
Format a user's activity string, depending on the type of activity.
|
||||
Currently supported types are:
|
||||
- Nothing
|
||||
- Custom status
|
||||
- Playing (with rich presence support)
|
||||
- Streaming
|
||||
- Listening (with rich presence support)
|
||||
- Watching
|
||||
- Unknown
|
||||
Parameters
|
||||
----------
|
||||
user: discord.Member
|
||||
The user to format the status of.
|
||||
If the user has no activity, "Nothing" will be returned.
|
||||
|
||||
Returns: str
|
||||
A formatted string with various information about the user's current activity like the name,
|
||||
and any extra information about the activity (such as current song artists for Spotify)
|
||||
"""
|
||||
if not user.activity:
|
||||
return "Nothing"
|
||||
|
||||
AT = user.activity.type
|
||||
a = user.activity
|
||||
if str(AT) == "ActivityType.custom":
|
||||
return "Status: {}".format(a)
|
||||
|
||||
if str(AT) == "ActivityType.playing":
|
||||
string = "Playing {}".format(a.name)
|
||||
try:
|
||||
string += " ({})".format(a.details)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return string
|
||||
|
||||
if str(AT) == "ActivityType.streaming":
|
||||
return "Streaming {}".format(a.name)
|
||||
|
||||
if str(AT) == "ActivityType.listening":
|
||||
try:
|
||||
string = "Listening to `{}`".format(a.title)
|
||||
if len(a.artists) > 1:
|
||||
string += " by {}".format(join_list(string=a.artists))
|
||||
else:
|
||||
string += " by **{}**".format(a.artist)
|
||||
except Exception:
|
||||
string = "Listening to `{}`".format(a.name)
|
||||
return string
|
||||
|
||||
if str(AT) == "ActivityType.watching":
|
||||
return "Watching `{}`".format(a.name)
|
||||
|
||||
if str(AT) == "ActivityType.unknown":
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def shard_of(shard_count: int, guildid: int):
|
||||
"""
|
||||
Calculate the shard number of a given guild.
|
||||
"""
|
||||
return (guildid >> 22) % shard_count if shard_count and shard_count > 0 else 0
|
||||
|
||||
|
||||
def jumpto(guildid: int, channeldid: int, messageid: int):
|
||||
"""
|
||||
Build a jump link for a message given its location.
|
||||
"""
|
||||
return 'https://discord.com/channels/{}/{}/{}'.format(
|
||||
guildid,
|
||||
channeldid,
|
||||
messageid
|
||||
)
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""
|
||||
Dict-type allowing dot access to keys.
|
||||
"""
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
|
||||
class FieldEnum(str, Enum):
|
||||
"""
|
||||
String enum with description conforming to the ISQLQuote protocol.
|
||||
Allows processing by psycog
|
||||
"""
|
||||
def __new__(cls, value, desc):
|
||||
obj = str.__new__(cls, value)
|
||||
obj._value_ = value
|
||||
obj.desc = desc
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s.%s>' % (self.__class__.__name__, self.name)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __conform__(self, proto):
|
||||
return QuotedString(self.value)
|
||||
|
||||
|
||||
def utc_now():
|
||||
"""
|
||||
Return the current timezone-aware utc timestamp.
|
||||
"""
|
||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def multiple_replace(string, rep_dict):
|
||||
if rep_dict:
|
||||
pattern = re.compile(
|
||||
"|".join([re.escape(k) for k in sorted(rep_dict, key=len, reverse=True)]),
|
||||
flags=re.DOTALL
|
||||
)
|
||||
return pattern.sub(lambda x: str(rep_dict[x.group(0)]), string)
|
||||
else:
|
||||
return string
|
||||
92
bot/pending-rewrite/utils/ratelimits.py
Normal file
92
bot/pending-rewrite/utils/ratelimits.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import time
|
||||
from cmdClient.lib import SafeCancellation
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
|
||||
class BucketFull(Exception):
|
||||
"""
|
||||
Throw when a requested Bucket is already full
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BucketOverFull(BucketFull):
|
||||
"""
|
||||
Throw when a requested Bucket is overfull
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Bucket:
|
||||
__slots__ = ('max_level', 'empty_time', 'leak_rate', '_level', '_last_checked', '_last_full')
|
||||
|
||||
def __init__(self, max_level, empty_time):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
self.leak_rate = max_level / empty_time
|
||||
|
||||
self._level = 0
|
||||
self._last_checked = time.time()
|
||||
|
||||
self._last_full = False
|
||||
|
||||
@property
|
||||
def overfull(self):
|
||||
self._leak()
|
||||
return self._level > self.max_level
|
||||
|
||||
def _leak(self):
|
||||
if self._level:
|
||||
elapsed = time.time() - self._last_checked
|
||||
self._level = max(0, self._level - (elapsed * self.leak_rate))
|
||||
|
||||
self._last_checked = time.time()
|
||||
|
||||
def request(self):
|
||||
self._leak()
|
||||
if self._level + 1 > self.max_level + 1:
|
||||
raise BucketOverFull
|
||||
elif self._level + 1 > self.max_level:
|
||||
self._level += 1
|
||||
if self._last_full:
|
||||
raise BucketOverFull
|
||||
else:
|
||||
self._last_full = True
|
||||
raise BucketFull
|
||||
else:
|
||||
self._last_full = False
|
||||
self._level += 1
|
||||
|
||||
|
||||
class RateLimit:
|
||||
def __init__(self, max_level, empty_time, error=None, cache=TTLCache(1000, 60 * 60)):
|
||||
self.max_level = max_level
|
||||
self.empty_time = empty_time
|
||||
|
||||
self.error = error or "Too many requests, please slow down!"
|
||||
self.buckets = cache
|
||||
|
||||
def request_for(self, key):
|
||||
if not (bucket := self.buckets.get(key, None)):
|
||||
bucket = self.buckets[key] = Bucket(self.max_level, self.empty_time)
|
||||
|
||||
try:
|
||||
bucket.request()
|
||||
except BucketOverFull:
|
||||
raise SafeCancellation(details="Bucket overflow")
|
||||
except BucketFull:
|
||||
raise SafeCancellation(self.error, details="Bucket full")
|
||||
|
||||
def ward(self, member=True, key=None):
|
||||
"""
|
||||
Command ratelimit decorator.
|
||||
"""
|
||||
key = key or ((lambda ctx: (ctx.guild.id, ctx.author.id)) if member else (lambda ctx: ctx.author.id))
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(ctx, *args, **kwargs):
|
||||
self.request_for(key(ctx))
|
||||
return await func(ctx, *args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
427
bot/pending-rewrite/utils/seekers.py
Normal file
427
bot/pending-rewrite/utils/seekers.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import asyncio
|
||||
import discord
|
||||
|
||||
from LionContext import LionContext as Context
|
||||
from cmdClient.lib import InvalidContext, UserCancelled, ResponseTimedOut, SafeCancellation
|
||||
from . import interactive as _interactive
|
||||
|
||||
|
||||
@Context.util
|
||||
async def find_role(ctx, userstr, create=False, interactive=False, collection=None, allow_notfound=True):
|
||||
"""
|
||||
Find a guild role given a partial matching string,
|
||||
allowing custom role collections and several behavioural switches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
userstr: str
|
||||
String obtained from a user, expected to partially match a role in the collection.
|
||||
The string will be tested against both the id and the name of the role.
|
||||
create: bool
|
||||
Whether to offer to create the role if it does not exist.
|
||||
The bot will only offer to create the role if it has the `manage_channels` permission.
|
||||
interactive: bool
|
||||
Whether to offer the user a list of roles to choose from,
|
||||
or pick the first matching role.
|
||||
collection: List[Union[discord.Role, discord.Object]]
|
||||
Collection of roles to search amongst.
|
||||
If none, uses the guild role list.
|
||||
allow_notfound: bool
|
||||
Whether to return `None` when there are no matches, instead of raising `SafeCancellation`.
|
||||
Overriden by `create`, if it is set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
discord.Role:
|
||||
If a valid role is found.
|
||||
None:
|
||||
If no valid role has been found.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.UserCancelled:
|
||||
If the user cancels interactive role selection.
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
If the user fails to respond to interactive role selection within `60` seconds`
|
||||
cmdClient.lib.SafeCancellation:
|
||||
If `allow_notfound` is `False`, and the search returned no matches.
|
||||
"""
|
||||
# Handle invalid situations and input
|
||||
if not ctx.guild:
|
||||
raise InvalidContext("Attempt to use find_role outside of a guild.")
|
||||
|
||||
if userstr == "":
|
||||
raise ValueError("User string passed to find_role was empty.")
|
||||
|
||||
# Create the collection to search from args or guild roles
|
||||
collection = collection if collection is not None else ctx.guild.roles
|
||||
|
||||
# If the unser input was a number or possible role mention, get it out
|
||||
userstr = userstr.strip()
|
||||
roleid = userstr.strip('<#@&!> ')
|
||||
roleid = int(roleid) if roleid.isdigit() else None
|
||||
searchstr = userstr.lower()
|
||||
|
||||
# Find the role
|
||||
role = None
|
||||
|
||||
# Check method to determine whether a role matches
|
||||
def check(role):
|
||||
return (role.id == roleid) or (searchstr in role.name.lower())
|
||||
|
||||
# Get list of matching roles
|
||||
roles = list(filter(check, collection))
|
||||
|
||||
if len(roles) == 0:
|
||||
# Nope
|
||||
role = None
|
||||
elif len(roles) == 1:
|
||||
# Select our lucky winner
|
||||
role = roles[0]
|
||||
else:
|
||||
# We have multiple matching roles!
|
||||
if interactive:
|
||||
# Interactive prompt with the list of roles, handle `Object`s
|
||||
role_names = [
|
||||
role.name if isinstance(role, discord.Role) else str(role.id) for role in roles
|
||||
]
|
||||
|
||||
try:
|
||||
selected = await ctx.selector(
|
||||
"`{}` roles found matching `{}`!".format(len(roles), userstr),
|
||||
role_names,
|
||||
timeout=60
|
||||
)
|
||||
except UserCancelled:
|
||||
raise UserCancelled("User cancelled role selection.") from None
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Role selection timed out.") from None
|
||||
|
||||
role = roles[selected]
|
||||
else:
|
||||
# Just select the first one
|
||||
role = roles[0]
|
||||
|
||||
# Handle non-existence of the role
|
||||
if role is None:
|
||||
msgstr = "Couldn't find a role matching `{}`!".format(userstr)
|
||||
if create:
|
||||
# Inform the user
|
||||
msg = await ctx.error_reply(msgstr)
|
||||
if ctx.guild.me.guild_permissions.manage_roles:
|
||||
# Offer to create it
|
||||
resp = await ctx.ask("Would you like to create this role?", timeout=30)
|
||||
if resp:
|
||||
# They accepted, create the role
|
||||
# Before creation, check if the role name is too long
|
||||
if len(userstr) > 100:
|
||||
await ctx.error_reply("Could not create a role with a name over 100 characters long!")
|
||||
else:
|
||||
role = await ctx.guild.create_role(
|
||||
name=userstr,
|
||||
reason="Interactive role creation for {} (uid:{})".format(ctx.author, ctx.author.id)
|
||||
)
|
||||
await msg.delete()
|
||||
await ctx.reply("You have created the role `{}`!".format(userstr))
|
||||
|
||||
# If we still don't have a role, cancel unless allow_notfound is set
|
||||
if role is None and not allow_notfound:
|
||||
raise SafeCancellation
|
||||
elif not allow_notfound:
|
||||
raise SafeCancellation(msgstr)
|
||||
else:
|
||||
await ctx.error_reply(msgstr)
|
||||
|
||||
return role
|
||||
|
||||
|
||||
@Context.util
|
||||
async def find_channel(ctx, userstr, interactive=False, collection=None, chan_type=None, type_name=None):
|
||||
"""
|
||||
Find a guild channel given a partial matching string,
|
||||
allowing custom channel collections and several behavioural switches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
userstr: str
|
||||
String obtained from a user, expected to partially match a channel in the collection.
|
||||
The string will be tested against both the id and the name of the channel.
|
||||
interactive: bool
|
||||
Whether to offer the user a list of channels to choose from,
|
||||
or pick the first matching channel.
|
||||
collection: List(discord.Channel)
|
||||
Collection of channels to search amongst.
|
||||
If none, uses the full guild channel list.
|
||||
chan_type: discord.ChannelType
|
||||
Type of channel to restrict the collection to.
|
||||
type_name: str
|
||||
Optional name to use for the channel type if it is not found.
|
||||
Used particularly with custom collections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
discord.Channel:
|
||||
If a valid channel is found.
|
||||
None:
|
||||
If no valid channel has been found.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.UserCancelled:
|
||||
If the user cancels interactive channel selection.
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
If the user fails to respond to interactive channel selection within `60` seconds`
|
||||
"""
|
||||
# Handle invalid situations and input
|
||||
if not ctx.guild:
|
||||
raise InvalidContext("Attempt to use find_channel outside of a guild.")
|
||||
|
||||
if userstr == "":
|
||||
raise ValueError("User string passed to find_channel was empty.")
|
||||
|
||||
# Create the collection to search from args or guild channels
|
||||
collection = collection if collection else ctx.guild.channels
|
||||
if chan_type is not None:
|
||||
if chan_type == discord.ChannelType.text:
|
||||
# Hack to support news channels as text channels
|
||||
collection = [chan for chan in collection if isinstance(chan, discord.TextChannel)]
|
||||
else:
|
||||
collection = [chan for chan in collection if chan.type == chan_type]
|
||||
|
||||
# If the user input was a number or possible channel mention, extract it
|
||||
chanid = userstr.strip('<#@&!>')
|
||||
chanid = int(chanid) if chanid.isdigit() else None
|
||||
searchstr = userstr.lower()
|
||||
|
||||
# Find the channel
|
||||
chan = None
|
||||
|
||||
# Check method to determine whether a channel matches
|
||||
def check(chan):
|
||||
return (chan.id == chanid) or (searchstr in chan.name.lower())
|
||||
|
||||
# Get list of matching roles
|
||||
channels = list(filter(check, collection))
|
||||
|
||||
if len(channels) == 0:
|
||||
# Nope
|
||||
chan = None
|
||||
elif len(channels) == 1:
|
||||
# Select our lucky winner
|
||||
chan = channels[0]
|
||||
else:
|
||||
# We have multiple matching channels!
|
||||
if interactive:
|
||||
# Interactive prompt with the list of channels
|
||||
chan_names = [chan.name for chan in channels]
|
||||
|
||||
try:
|
||||
selected = await ctx.selector(
|
||||
"`{}` channels found matching `{}`!".format(len(channels), userstr),
|
||||
chan_names,
|
||||
timeout=60
|
||||
)
|
||||
except UserCancelled:
|
||||
raise UserCancelled("User cancelled channel selection.") from None
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Channel selection timed out.") from None
|
||||
|
||||
chan = channels[selected]
|
||||
else:
|
||||
# Just select the first one
|
||||
chan = channels[0]
|
||||
|
||||
if chan is None:
|
||||
typestr = type_name
|
||||
addendum = ""
|
||||
if chan_type and not type_name:
|
||||
chan_type_strings = {
|
||||
discord.ChannelType.category: "category",
|
||||
discord.ChannelType.text: "text channel",
|
||||
discord.ChannelType.voice: "voice channel",
|
||||
discord.ChannelType.stage_voice: "stage channel",
|
||||
}
|
||||
typestr = chan_type_strings.get(chan_type, None)
|
||||
if typestr and chanid:
|
||||
actual = ctx.guild.get_channel(chanid)
|
||||
if actual and actual.type in chan_type_strings:
|
||||
addendum = "\n{} appears to be a {} instead.".format(
|
||||
actual.mention,
|
||||
chan_type_strings[actual.type]
|
||||
)
|
||||
typestr = typestr or "channel"
|
||||
|
||||
await ctx.error_reply("Couldn't find a {} matching `{}`!{}".format(typestr, userstr, addendum))
|
||||
|
||||
return chan
|
||||
|
||||
|
||||
@Context.util
|
||||
async def find_member(ctx, userstr, interactive=False, collection=None, silent=False):
|
||||
"""
|
||||
Find a guild member given a partial matching string,
|
||||
allowing custom member collections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
userstr: str
|
||||
String obtained from a user, expected to partially match a member in the collection.
|
||||
The string will be tested against both the userid, full user name and user nickname.
|
||||
interactive: bool
|
||||
Whether to offer the user a list of members to choose from,
|
||||
or pick the first matching channel.
|
||||
collection: List(discord.Member)
|
||||
Collection of members to search amongst.
|
||||
If none, uses the full guild member list.
|
||||
silent: bool
|
||||
Whether to reply with an error when there are no matches.
|
||||
|
||||
Returns
|
||||
-------
|
||||
discord.Member:
|
||||
If a valid member is found.
|
||||
None:
|
||||
If no valid member has been found.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.UserCancelled:
|
||||
If the user cancels interactive member selection.
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
If the user fails to respond to interactive member selection within `60` seconds`
|
||||
"""
|
||||
# Handle invalid situations and input
|
||||
if not ctx.guild:
|
||||
raise InvalidContext("Attempt to use find_member outside of a guild.")
|
||||
|
||||
if userstr == "":
|
||||
raise ValueError("User string passed to find_member was empty.")
|
||||
|
||||
# Create the collection to search from args or guild members
|
||||
collection = collection if collection else ctx.guild.members
|
||||
|
||||
# If the user input was a number or possible member mention, extract it
|
||||
userid = userstr.strip('<#@&!>')
|
||||
userid = int(userid) if userid.isdigit() else None
|
||||
searchstr = userstr.lower()
|
||||
|
||||
# Find the member
|
||||
member = None
|
||||
|
||||
# Check method to determine whether a member matches
|
||||
def check(member):
|
||||
return (
|
||||
member.id == userid
|
||||
or searchstr in member.display_name.lower()
|
||||
or searchstr in str(member).lower()
|
||||
)
|
||||
|
||||
# Get list of matching roles
|
||||
members = list(filter(check, collection))
|
||||
|
||||
if len(members) == 0:
|
||||
# Nope
|
||||
member = None
|
||||
elif len(members) == 1:
|
||||
# Select our lucky winner
|
||||
member = members[0]
|
||||
else:
|
||||
# We have multiple matching members!
|
||||
if interactive:
|
||||
# Interactive prompt with the list of members
|
||||
member_names = [
|
||||
"{} {}".format(
|
||||
member.nick if member.nick else (member if members.count(member) > 1
|
||||
else member.name),
|
||||
("<{}>".format(member)) if member.nick else ""
|
||||
) for member in members
|
||||
]
|
||||
|
||||
try:
|
||||
selected = await ctx.selector(
|
||||
"`{}` members found matching `{}`!".format(len(members), userstr),
|
||||
member_names,
|
||||
timeout=60
|
||||
)
|
||||
except UserCancelled:
|
||||
raise UserCancelled("User cancelled member selection.") from None
|
||||
except ResponseTimedOut:
|
||||
raise ResponseTimedOut("Member selection timed out.") from None
|
||||
|
||||
member = members[selected]
|
||||
else:
|
||||
# Just select the first one
|
||||
member = members[0]
|
||||
|
||||
if member is None and not silent:
|
||||
await ctx.error_reply("Couldn't find a member matching `{}`!".format(userstr))
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@Context.util
|
||||
async def find_message(ctx, msgid, chlist=None, ignore=[]):
|
||||
"""
|
||||
Searches for the given message id in the guild channels.
|
||||
|
||||
Parameters
|
||||
-------
|
||||
msgid: int
|
||||
The `id` of the message to search for.
|
||||
chlist: Optional[List[discord.TextChannel]]
|
||||
List of channels to search in.
|
||||
If `None`, searches all the text channels that the `ctx.author` can read.
|
||||
ignore: list
|
||||
A list of channelids to explicitly ignore in the search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[discord.Message]:
|
||||
If a message is found, returns the message.
|
||||
Otherwise, returns `None`.
|
||||
"""
|
||||
if not ctx.guild:
|
||||
raise InvalidContext("Cannot use this seeker outside of a guild!")
|
||||
|
||||
msgid = int(msgid)
|
||||
|
||||
# Build the channel list to search
|
||||
if chlist is None:
|
||||
chlist = [ch for ch in ctx.guild.text_channels if ch.permissions_for(ctx.author).read_messages]
|
||||
|
||||
# Remove any channels we are ignoring
|
||||
chlist = [ch for ch in chlist if ch.id not in ignore]
|
||||
|
||||
tasks = set()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
done = set((task for task in tasks if task.done()))
|
||||
tasks = tasks.difference(done)
|
||||
|
||||
results = [task.result() for task in done]
|
||||
|
||||
result = next((result for result in results if result is not None), None)
|
||||
if result:
|
||||
[task.cancel() for task in tasks]
|
||||
return result
|
||||
|
||||
if i < len(chlist):
|
||||
task = asyncio.create_task(_search_in_channel(chlist[i], msgid))
|
||||
tasks.add(task)
|
||||
i += 1
|
||||
elif len(tasks) == 0:
|
||||
return None
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def _search_in_channel(channel: discord.TextChannel, msgid: int):
|
||||
if not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
try:
|
||||
message = await channel.fetch_message(msgid)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return message
|
||||
40
bot/pending-rewrite/wards.py
Normal file
40
bot/pending-rewrite/wards.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from cmdClient import check
|
||||
from cmdClient.checks import in_guild
|
||||
|
||||
from meta import client
|
||||
|
||||
from data import tables
|
||||
|
||||
|
||||
def is_guild_admin(member):
|
||||
if member.id in client.owners:
|
||||
return True
|
||||
|
||||
# First check guild admin permissions
|
||||
admin = member.guild_permissions.administrator
|
||||
|
||||
# Then check the admin role, if it is set
|
||||
if not admin:
|
||||
admin_role_id = tables.guild_config.fetch_or_create(member.guild.id).admin_role
|
||||
admin = admin_role_id and (admin_role_id in (r.id for r in member.roles))
|
||||
return admin
|
||||
|
||||
|
||||
@check(
|
||||
name="ADMIN",
|
||||
msg=("You need to be a server admin to do this!"),
|
||||
requires=[in_guild]
|
||||
)
|
||||
async def guild_admin(ctx, *args, **kwargs):
|
||||
return is_guild_admin(ctx.author)
|
||||
|
||||
|
||||
@check(
|
||||
name="MODERATOR",
|
||||
msg=("You need to be a server moderator to do this!"),
|
||||
requires=[in_guild],
|
||||
parents=(guild_admin,)
|
||||
)
|
||||
async def guild_moderator(ctx, *args, **kwargs):
|
||||
modrole = ctx.guild_settings.mod_role.value
|
||||
return (modrole and (modrole in ctx.author.roles))
|
||||
Reference in New Issue
Block a user