Core user data and leaderboard commands.
Added flexibility to data `update_where`. Added interactive utils, with improved pager. Added user data table, with caching and transactional interface. Added `topcoins` command to `Economy` Added `top` command to `Study`
This commit is contained in:
2
bot/core/__init__.py
Normal file
2
bot/core/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from . import tables
|
||||
from .user import User
|
||||
31
bot/core/tables.py
Normal file
31
bot/core/tables.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from cachetools import TTLCache
|
||||
from data import RowTable, Table
|
||||
|
||||
|
||||
users = RowTable(
|
||||
'lions',
|
||||
('userid', 'tracked_time', 'coins'),
|
||||
'userid',
|
||||
cache=TTLCache(5000, ttl=60*5)
|
||||
)
|
||||
|
||||
|
||||
@users.save_query
|
||||
def add_coins(userid_coins):
|
||||
with users.conn:
|
||||
cursor = users.conn.cursor()
|
||||
data = execute_values(
|
||||
cursor,
|
||||
"""
|
||||
UPDATE lions
|
||||
SET coins = coins + t.diff
|
||||
FROM (VALUES %s) AS t (userid, diff)
|
||||
WHERE lions.userid = t.userid
|
||||
RETURNING *
|
||||
""",
|
||||
userid_coins,
|
||||
fetch=True
|
||||
)
|
||||
return users._make_rows(*data)
|
||||
105
bot/core/user.py
Normal file
105
bot/core/user.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from . import tables as tb
|
||||
from meta import conf, client
|
||||
|
||||
|
||||
class User:
|
||||
"""
|
||||
Class representing a "Lion", i.e. a member of the managed guild.
|
||||
Mostly acts as a transparent interface to the corresponding Row,
|
||||
but also adds some transaction caching logic to `coins`.
|
||||
"""
|
||||
__slots__ = ('userid', '_pending_coins', '_member')
|
||||
|
||||
# Users with pending transactions
|
||||
_pending = {} # userid -> User
|
||||
|
||||
# User cache. Currently users don't expire
|
||||
_users = {} # userid -> User
|
||||
|
||||
def __init__(self, userid):
|
||||
self.userid = userid
|
||||
self._pending_coins = 0
|
||||
|
||||
self._users[self.userid] = self
|
||||
|
||||
@classmethod
|
||||
def fetch(cls, userid):
|
||||
"""
|
||||
Fetch a User with the given userid.
|
||||
If they don't exist, creates them.
|
||||
If possible, retrieves the user from the user cache.
|
||||
"""
|
||||
if userid in cls._users:
|
||||
return cls._users[userid]
|
||||
else:
|
||||
tb.users.fetch_or_create(userid)
|
||||
return cls(userid)
|
||||
|
||||
@property
|
||||
def member(self):
|
||||
"""
|
||||
The discord `Member` corresponding to this user.
|
||||
May be `None` if the member is no longer in the guild or the caches aren't populated.
|
||||
Not guaranteed to be `None` if the member is not in the guild.
|
||||
"""
|
||||
if self._member is None:
|
||||
self._member = client.get_guild(conf.meta.getint('managed_guild_id')).get_member(self.userid)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""
|
||||
The Row corresponding to this user.
|
||||
"""
|
||||
return tb.users.fetch(self.userid)
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
"""
|
||||
Amount of time the user has spent.. studying?
|
||||
"""
|
||||
return self.data.tracked_time
|
||||
|
||||
@property
|
||||
def coins(self):
|
||||
"""
|
||||
Number of coins the user has, accounting for the pending value.
|
||||
"""
|
||||
return self.data.coins + self._pending_coins
|
||||
|
||||
def addCoins(self, amount, flush=True):
|
||||
"""
|
||||
Add coins to the user, optionally store the transaction in pending.
|
||||
"""
|
||||
self._pending_coins += amount
|
||||
if self._pending_coins != 0:
|
||||
self._pending[self.userid] = self
|
||||
else:
|
||||
self._pending.pop(self.userid, None)
|
||||
if flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush any pending transactions to the database.
|
||||
"""
|
||||
self.sync(self)
|
||||
|
||||
@classmethod
|
||||
def sync(cls, *users):
|
||||
"""
|
||||
Flush pending transactions to the database.
|
||||
Also refreshes the Row cache for updated users.
|
||||
"""
|
||||
users = users or list(cls._pending.values())
|
||||
|
||||
if users:
|
||||
# Build userid to pending coin map
|
||||
userid_coins = [(user.userid, user._pending_coins) for user in users]
|
||||
|
||||
# Write to database
|
||||
tb.users.queries.add_coins(userid_coins)
|
||||
|
||||
# Cleanup pending users
|
||||
for user in users:
|
||||
user._pending_coins = 0
|
||||
cls._pending.pop(user.userid, None)
|
||||
@@ -1,3 +1,3 @@
|
||||
from .data import *
|
||||
from . import tables
|
||||
# from . import tables
|
||||
# from . import queries
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
import psycopg2 as psy
|
||||
from cachetools import LRUCache
|
||||
|
||||
from utils.lib import DotDict
|
||||
from meta import log, conf
|
||||
from constants import DATA_VERSION
|
||||
from .custom_cursor import DictLoggingCursor
|
||||
@@ -49,6 +50,7 @@ class Table:
|
||||
Intended to be subclassed to provide more derivative access for specific tables.
|
||||
"""
|
||||
conn = conn
|
||||
queries = DotDict()
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
@@ -82,6 +84,12 @@ class Table:
|
||||
with self.conn:
|
||||
return upsert(self.name, *args, **kwargs)
|
||||
|
||||
def save_query(self, func):
|
||||
"""
|
||||
Decorator to add a saved query to the table.
|
||||
"""
|
||||
self.queries[func.__name__] = func
|
||||
|
||||
|
||||
class Row:
|
||||
__slots__ = ('table', 'data', '_pending')
|
||||
@@ -386,6 +394,44 @@ class fieldConstants(Enum):
|
||||
NOTNULL = "IS NOT NULL"
|
||||
|
||||
|
||||
class _updateField:
|
||||
__slots__ = ()
|
||||
_EMPTY = object() # Return value for `value` indicating no value should be added
|
||||
|
||||
def key_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def value_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UpdateValue(_updateField):
|
||||
__slots__ = ('key_str', 'value')
|
||||
|
||||
def __init__(self, key_str, value=_updateField._EMPTY):
|
||||
self.key_str = key_str
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
class UpdateValueAdd(_updateField):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
def _format_conditions(conditions):
|
||||
"""
|
||||
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
||||
@@ -443,8 +489,17 @@ def _format_updatestr(valuedict):
|
||||
"""
|
||||
if not valuedict:
|
||||
return ("", tuple())
|
||||
keys, values = zip(*valuedict.items())
|
||||
|
||||
set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys)
|
||||
key_fields = []
|
||||
values = []
|
||||
for key, value in valuedict.items():
|
||||
if isinstance(value, _updateField):
|
||||
key_fields.append(value.key_field(key))
|
||||
v = value.value_field(key)
|
||||
if v is not _updateField._EMPTY:
|
||||
values.append(value.value_field(key))
|
||||
else:
|
||||
key_fields.append("{} = {}".format(key, _replace_char))
|
||||
values.append(value)
|
||||
|
||||
return (set_str, values)
|
||||
return (', '.join(key_fields), values)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from .data import RowTable, Table
|
||||
|
||||
raw_users = Table('Users')
|
||||
users = RowTable(
|
||||
'users',
|
||||
('userid', 'tracked_time', 'coins'),
|
||||
'userid',
|
||||
)
|
||||
@@ -4,4 +4,6 @@ import meta
|
||||
meta.logger.setLevel(logging.DEBUG)
|
||||
logging.getLogger("discord").setLevel(logging.INFO)
|
||||
|
||||
from utils import interactive # noqa
|
||||
|
||||
import main # noqa
|
||||
|
||||
@@ -2,6 +2,8 @@ from meta import client, conf, log
|
||||
|
||||
import data # noqa
|
||||
|
||||
import core # noqa
|
||||
|
||||
import modules # noqa
|
||||
|
||||
# Initialise all modules
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from discord import Intents
|
||||
from cmdClient.cmdClient import cmdClient
|
||||
|
||||
from .config import Conf
|
||||
@@ -9,5 +10,5 @@ conf = Conf(CONFIG_FILE)
|
||||
|
||||
# Initialise client
|
||||
owners = [int(owner) for owner in conf.bot.getlist('owners')]
|
||||
client = cmdClient(prefix=conf.bot['prefix'], owners=owners)
|
||||
client = cmdClient(prefix=conf.bot['prefix'], owners=owners, intents=Intents.all())
|
||||
client.conf = conf
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
from .sysadmin import *
|
||||
from .economy import *
|
||||
from .study import *
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import module
|
||||
from . import commands
|
||||
|
||||
101
bot/modules/economy/commands.py
Normal file
101
bot/modules/economy/commands.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from core import User
|
||||
from core.tables import users
|
||||
|
||||
from utils import interactive # noqa
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
first_emoji = "🥇"
|
||||
second_emoji = "🥈"
|
||||
third_emoji = "🥉"
|
||||
|
||||
|
||||
# TODO: in_guild ward
|
||||
@module.cmd(
|
||||
"topcoin",
|
||||
short_help="View the LionCoin leaderboard.",
|
||||
aliases=('topc', 'ctop')
|
||||
)
|
||||
async def cmd_topcoin(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}topcoin
|
||||
{prefix}topcoin 100
|
||||
Description:
|
||||
Display the LionCoin leaderboard, or top 100.
|
||||
|
||||
Use the paging reactions or send `p<n>` to switch pages (e.g. `p11` to switch to page 11).
|
||||
"""
|
||||
# Handle args
|
||||
if ctx.args and not ctx.args == "100":
|
||||
return await ctx.error_reply(
|
||||
"**Usage:**`{prefix}topcoin` or `{prefix}topcoin100`.".format(prefix=ctx.client.prefix)
|
||||
)
|
||||
top100 = ctx.args == "100"
|
||||
|
||||
# Flush any pending coin transactions
|
||||
User.sync()
|
||||
|
||||
# Fetch the leaderboard
|
||||
user_data = users.select_where(
|
||||
select_columns=('userid', 'coins'),
|
||||
_extra="ORDER BY coins DESC " + ("LIMIT 100" if top100 else "")
|
||||
)
|
||||
|
||||
# Quit early if the leaderboard is empty
|
||||
if not user_data:
|
||||
return await ctx.reply("No leaderboard entries yet!")
|
||||
|
||||
# Extract entries
|
||||
author_index = None
|
||||
entries = []
|
||||
for i, (userid, coins) in enumerate(user_data):
|
||||
member = ctx.guild.get_member(userid)
|
||||
name = member.display_name if member else str(userid)
|
||||
name = name.replace('*', ' ').replace('_', ' ')
|
||||
|
||||
num_str = "{}.".format(i+1)
|
||||
|
||||
coin_str = "{} LC".format(coins)
|
||||
|
||||
if ctx.author.id == userid:
|
||||
author_index = i
|
||||
|
||||
entries.append((num_str, name, coin_str))
|
||||
|
||||
# Extract blocks
|
||||
blocks = [entries[i:i+20] for i in range(0, len(entries), 20)]
|
||||
block_count = len(blocks)
|
||||
|
||||
# Build strings
|
||||
header = "LionCoin Top 100" if top100 else "LionCoin Leaderboard"
|
||||
if block_count > 1:
|
||||
header += " (Page {{page}}/{})".format(block_count)
|
||||
|
||||
# Build pages
|
||||
pages = []
|
||||
for i, block in enumerate(blocks):
|
||||
max_num_l, max_name_l, max_coin_l = [max(len(e[i]) for e in block) for i in (0, 1, 2)]
|
||||
body = '\n'.join(
|
||||
"{:>{}} {:<{}} \t {:>{}} {} {}".format(
|
||||
entry[0], max_num_l,
|
||||
entry[1], max_name_l + 2,
|
||||
entry[2], max_coin_l + 1,
|
||||
first_emoji if i == 0 and j == 0 else (
|
||||
second_emoji if i == 0 and j == 1 else (
|
||||
third_emoji if i == 0 and j == 2 else ''
|
||||
)
|
||||
),
|
||||
"⮜" if author_index is not None and author_index == i * 20 + j else ""
|
||||
)
|
||||
for j, entry in enumerate(block)
|
||||
)
|
||||
title = header.format(page=i+1)
|
||||
line = '='*len(title)
|
||||
pages.append(
|
||||
"```md\n{}\n{}\n{}```".format(title, line, body)
|
||||
)
|
||||
|
||||
# Finally, page the results
|
||||
await ctx.pager(pages, start_at=(author_index or 0)//20 if not top100 else 0)
|
||||
4
bot/modules/economy/module.py
Normal file
4
bot/modules/economy/module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from cmdClient import Module
|
||||
|
||||
|
||||
module = Module("Economy")
|
||||
@@ -0,0 +1,2 @@
|
||||
from .module import module
|
||||
from . import commands
|
||||
|
||||
111
bot/modules/study/commands.py
Normal file
111
bot/modules/study/commands.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import datetime as dt
|
||||
|
||||
from core import User
|
||||
from core.tables import users
|
||||
|
||||
from utils import interactive # noqa
|
||||
|
||||
from .module import module
|
||||
|
||||
|
||||
first_emoji = "🥇"
|
||||
second_emoji = "🥈"
|
||||
third_emoji = "🥉"
|
||||
|
||||
|
||||
# TODO: in_guild ward
|
||||
@module.cmd(
|
||||
"top",
|
||||
short_help="View the Study Time leaderboard.",
|
||||
aliases=('ttop', 'toptime')
|
||||
)
|
||||
async def cmd_top(ctx):
|
||||
"""
|
||||
Usage``:
|
||||
{prefix}top
|
||||
{prefix}top 100
|
||||
Description:
|
||||
Display the study time leaderboard, or the top 100.
|
||||
|
||||
Use the paging reactions or send `p<n>` to switch pages (e.g. `p11` to switch to page 11).
|
||||
"""
|
||||
# Handle args
|
||||
if ctx.args and not ctx.args == "100":
|
||||
return await ctx.error_reply(
|
||||
"**Usage:**`{prefix}top` or `{prefix}top100`.".format(prefix=ctx.client.prefix)
|
||||
)
|
||||
top100 = ctx.args == "100"
|
||||
|
||||
# Flush any pending coin transactions
|
||||
User.sync()
|
||||
|
||||
# Fetch the leaderboard
|
||||
user_data = users.select_where(
|
||||
select_columns=('userid', 'tracked_time'),
|
||||
_extra="ORDER BY tracked_time DESC " + ("LIMIT 100" if top100 else "")
|
||||
)
|
||||
|
||||
# Quit early if the leaderboard is empty
|
||||
if not user_data:
|
||||
return await ctx.reply("No leaderboard entries yet!")
|
||||
|
||||
# Extract entries
|
||||
author_index = None
|
||||
entries = []
|
||||
for i, (userid, time) in enumerate(user_data):
|
||||
member = ctx.guild.get_member(userid)
|
||||
name = member.display_name if member else str(userid)
|
||||
name = name.replace('*', ' ').replace('_', ' ')
|
||||
|
||||
num_str = "{}.".format(i+1)
|
||||
|
||||
hours = time // 3600
|
||||
minutes = time // 60 % 60
|
||||
seconds = time % 60
|
||||
|
||||
time_str = "{}:{:02}:{:02}".format(
|
||||
hours,
|
||||
minutes,
|
||||
seconds
|
||||
)
|
||||
|
||||
if ctx.author.id == userid:
|
||||
author_index = i
|
||||
|
||||
entries.append((num_str, name, time_str))
|
||||
|
||||
# Extract blocks
|
||||
blocks = [entries[i:i+20] for i in range(0, len(entries), 20)]
|
||||
block_count = len(blocks)
|
||||
|
||||
# Build strings
|
||||
header = "Study Time Top 100" if top100 else "Study Time Leaderboard"
|
||||
if block_count > 1:
|
||||
header += " (Page {{page}}/{})".format(block_count)
|
||||
|
||||
# Build pages
|
||||
pages = []
|
||||
for i, block in enumerate(blocks):
|
||||
max_num_l, max_name_l, max_time_l = [max(len(e[i]) for e in block) for i in (0, 1, 2)]
|
||||
body = '\n'.join(
|
||||
"{:>{}} {:<{}} \t {:>{}} {} {}".format(
|
||||
entry[0], max_num_l,
|
||||
entry[1], max_name_l + 2,
|
||||
entry[2], max_time_l + 1,
|
||||
first_emoji if i == 0 and j == 0 else (
|
||||
second_emoji if i == 0 and j == 1 else (
|
||||
third_emoji if i == 0 and j == 2 else ''
|
||||
)
|
||||
),
|
||||
"⮜" if author_index is not None and author_index == i * 20 + j else ""
|
||||
)
|
||||
for j, entry in enumerate(block)
|
||||
)
|
||||
title = header.format(page=i+1)
|
||||
line = '='*len(title)
|
||||
pages.append(
|
||||
"```md\n{}\n{}\n{}```".format(title, line, body)
|
||||
)
|
||||
|
||||
# Finally, page the results
|
||||
await ctx.pager(pages, start_at=(author_index or 0)//20 if not top100 else 0)
|
||||
4
bot/modules/study/module.py
Normal file
4
bot/modules/study/module.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from cmdClient import Module
|
||||
|
||||
|
||||
module = Module("Study_Stats")
|
||||
461
bot/utils/interactive.py
Normal file
461
bot/utils/interactive.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import asyncio
|
||||
import discord
|
||||
from cmdClient import Context
|
||||
from cmdClient.lib import UserCancelled, ResponseTimedOut
|
||||
|
||||
from .lib import paginate_list
|
||||
|
||||
# TODO: Interactive locks
|
||||
cancel_emoji = '❌'
|
||||
number_emojis = (
|
||||
'1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', '8️⃣', '9️⃣'
|
||||
)
|
||||
|
||||
|
||||
async def discord_shield(coro):
|
||||
try:
|
||||
await coro
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
|
||||
@Context.util
|
||||
async def cancellable(ctx, msg, add_reaction=True, cancel_message=None, timeout=300):
|
||||
"""
|
||||
Add a cancellation reaction to the given message.
|
||||
Pressing the reaction triggers cancellation of the original context, and a UserCancelled-style error response.
|
||||
"""
|
||||
# TODO: Not consistent with the exception driven flow, make a decision here?
|
||||
# Add reaction
|
||||
if add_reaction and cancel_emoji not in (str(r.emoji) for r in msg.reactions):
|
||||
try:
|
||||
await msg.add_reaction(cancel_emoji)
|
||||
except discord.HTTPException:
|
||||
return
|
||||
|
||||
# Define cancellation function
|
||||
async def _cancel():
|
||||
try:
|
||||
await ctx.client.wait_for(
|
||||
'reaction_add',
|
||||
timeout=timeout,
|
||||
check=lambda r, u: (u == ctx.author
|
||||
and r.message == msg
|
||||
and str(r.emoji) == cancel_emoji)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await ctx.client.active_command_response_cleaner(ctx)
|
||||
if cancel_message:
|
||||
await ctx.error_reply(cancel_message)
|
||||
else:
|
||||
try:
|
||||
await ctx.msg.add_reaction(cancel_emoji)
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
[task.cancel() for task in ctx.tasks]
|
||||
|
||||
# Launch cancellation task
|
||||
task = asyncio.create_task(_cancel())
|
||||
ctx.tasks.append(task)
|
||||
return task
|
||||
|
||||
|
||||
@Context.util
|
||||
async def listen_for(ctx, allowed_input=None, timeout=120, lower=True, check=None):
|
||||
"""
|
||||
Listen for a one of a particular set of input strings,
|
||||
sent in the current channel by `ctx.author`.
|
||||
When found, return the message containing them.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
allowed_input: Union(List(str), None)
|
||||
List of strings to listen for.
|
||||
Allowed to be `None` precisely when a `check` function is also supplied.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
lower: bool
|
||||
Whether to shift the allowed and message strings to lowercase before checking.
|
||||
check: Function(message) -> bool
|
||||
Alternative custom check function.
|
||||
|
||||
Returns: discord.Message
|
||||
The message that was matched.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised when no messages matching the given criteria are detected in `timeout` seconds.
|
||||
"""
|
||||
# Generate the check if it hasn't been provided
|
||||
if not check:
|
||||
# Quick check the arguments are sane
|
||||
if not allowed_input:
|
||||
raise ValueError("allowed_input and check cannot both be None")
|
||||
|
||||
# Force a lower on the allowed inputs
|
||||
allowed_input = [s.lower() for s in allowed_input]
|
||||
|
||||
# Create the check function
|
||||
def check(message):
|
||||
result = (message.author == ctx.author)
|
||||
result = result and (message.channel == ctx.ch)
|
||||
result = result and ((message.content.lower() if lower else message.content) in allowed_input)
|
||||
return result
|
||||
|
||||
# Wait for a matching message, catch and transform the timeout
|
||||
try:
|
||||
message = await ctx.client.wait_for('message', check=check, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResponseTimedOut("Session timed out waiting for user response.") from None
|
||||
|
||||
return message
|
||||
|
||||
|
||||
@Context.util
|
||||
async def selector(ctx, header, select_from, timeout=120, max_len=20):
|
||||
"""
|
||||
Interactive routine to prompt the `ctx.author` to select an item from a list.
|
||||
Returns the list index that was selected.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
header: str
|
||||
String to put at the top of each page of selection options.
|
||||
Intended to be information about the list the user is selecting from.
|
||||
select_from: List(str)
|
||||
The list of strings to select from.
|
||||
timeout: int
|
||||
The number of seconds to wait before throwing `ResponseTimedOut`.
|
||||
max_len: int
|
||||
The maximum number of items to display on each page.
|
||||
Decrease this if the items are long, to avoid going over the char limit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
The index of the list entry selected by the user.
|
||||
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.UserCancelled:
|
||||
Raised if the user manually cancels the selection.
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised if the user fails to respond to the selector within `timeout` seconds.
|
||||
"""
|
||||
# Handle improper arguments
|
||||
if len(select_from) == 0:
|
||||
raise ValueError("Selection list passed to `selector` cannot be empty.")
|
||||
|
||||
# Generate the selector pages
|
||||
footer = "Please reply with the number of your selection, or press {} to cancel.".format(cancel_emoji)
|
||||
list_pages = paginate_list(select_from, block_length=max_len)
|
||||
pages = ["\n".join([header, page, footer]) for page in list_pages]
|
||||
|
||||
# Post the pages in a paged message
|
||||
out_msg = await ctx.pager(pages, add_cancel=True)
|
||||
cancel_task = await ctx.cancellable(out_msg, add_reaction=False, timeout=None)
|
||||
|
||||
if len(select_from) <= 5:
|
||||
for i, _ in enumerate(select_from):
|
||||
asyncio.create_task(discord_shield(out_msg.add_reaction(number_emojis[i])))
|
||||
|
||||
# Build response tasks
|
||||
valid_input = [str(i+1) for i in range(0, len(select_from))] + ['c', 'C']
|
||||
listen_task = asyncio.create_task(ctx.listen_for(valid_input, timeout=None))
|
||||
emoji_task = asyncio.create_task(ctx.client.wait_for(
|
||||
'reaction_add',
|
||||
check=lambda r, u: (u == ctx.author
|
||||
and r.message == out_msg
|
||||
and str(r.emoji) in number_emojis)
|
||||
))
|
||||
# Wait for the response tasks
|
||||
done, pending = await asyncio.wait(
|
||||
(listen_task, emoji_task),
|
||||
timeout=timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await out_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
|
||||
# Handle different return cases
|
||||
if listen_task in done:
|
||||
emoji_task.cancel()
|
||||
|
||||
result_msg = listen_task.result()
|
||||
try:
|
||||
await result_msg.delete()
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
if result_msg.content.lower() == 'c':
|
||||
raise UserCancelled("Selection cancelled!")
|
||||
result = int(result_msg.content) - 1
|
||||
elif emoji_task in done:
|
||||
listen_task.cancel()
|
||||
|
||||
reaction, _ = emoji_task.result()
|
||||
result = number_emojis.index(str(reaction.emoji))
|
||||
elif cancel_task in done:
|
||||
# Manually cancelled case.. the current task should have been cancelled
|
||||
# Raise UserCancelled in case the task wasn't cancelled for some reason
|
||||
raise UserCancelled("Selection cancelled!")
|
||||
elif not done:
|
||||
# Timeout case
|
||||
raise ResponseTimedOut("Selector timed out waiting for a response.")
|
||||
|
||||
# Finally cancel the canceller and return the provided index
|
||||
cancel_task.cancel()
|
||||
return result
|
||||
|
||||
|
||||
@Context.util
|
||||
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
|
||||
"""
|
||||
Shows the user each page from the provided list `pages` one at a time,
|
||||
providing reactions to page back and forth between pages.
|
||||
This is done asynchronously, and returns after displaying the first page.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pages: List(Union(str, discord.Embed))
|
||||
A list of either strings or embeds to display as the pages.
|
||||
locked: bool
|
||||
Whether only the `ctx.author` should be able to use the paging reactions.
|
||||
kwargs: ...
|
||||
Remaining keyword arguments are transparently passed to the reply context method.
|
||||
|
||||
Returns: discord.Message
|
||||
This is the output message, returned for easy deletion.
|
||||
"""
|
||||
# Handle broken input
|
||||
if len(pages) == 0:
|
||||
raise ValueError("Pager cannot page with no pages!")
|
||||
|
||||
# Post first page. Method depends on whether the page is an embed or not.
|
||||
if isinstance(pages[start_at], discord.Embed):
|
||||
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
|
||||
else:
|
||||
out_msg = await ctx.reply(pages[start_at], **kwargs)
|
||||
|
||||
# Run the paging loop if required
|
||||
if len(pages) > 1:
|
||||
task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs))
|
||||
ctx.tasks.append(task)
|
||||
elif add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
|
||||
# Return the output message
|
||||
return out_msg
|
||||
|
||||
|
||||
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
|
||||
"""
|
||||
Asynchronous initialiser and loop for the `pager` utility above.
|
||||
"""
|
||||
# Page number
|
||||
page = start_at
|
||||
|
||||
# Add reactions to the output message
|
||||
next_emoji = "▶"
|
||||
prev_emoji = "◀"
|
||||
|
||||
try:
|
||||
await out_msg.add_reaction(prev_emoji)
|
||||
if add_cancel:
|
||||
await out_msg.add_reaction(cancel_emoji)
|
||||
await out_msg.add_reaction(next_emoji)
|
||||
except discord.Forbidden:
|
||||
# We don't have permission to add paging emojis
|
||||
# Die as gracefully as we can
|
||||
if ctx.guild:
|
||||
perms = ctx.ch.permissions_for(ctx.guild.me)
|
||||
if not perms.add_reactions:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `add_reactions` permission!"
|
||||
)
|
||||
elif not perms.read_message_history:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results because I do not have the `read_message_history` permission!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results due to insufficient permissions!"
|
||||
)
|
||||
else:
|
||||
await ctx.error_reply(
|
||||
"Cannot page results!"
|
||||
)
|
||||
return
|
||||
|
||||
# Check function to determine whether a reaction is valid
|
||||
def reaction_check(reaction, user):
|
||||
result = reaction.message.id == out_msg.id
|
||||
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
|
||||
result = result and not (user.id == ctx.client.user.id)
|
||||
result = result and not (locked and user != ctx.author)
|
||||
return result
|
||||
|
||||
# Check function to determine if message has a page number
|
||||
def message_check(message):
|
||||
result = message.channel.id == ctx.ch.id
|
||||
result = result and not (locked and message.author != ctx.author)
|
||||
result = result and message.content.lower().startswith('p')
|
||||
result = result and message.content[1:].isdigit()
|
||||
result = result and 1 <= int(message.content[1:]) <= len(pages)
|
||||
return result
|
||||
|
||||
# Begin loop
|
||||
while True:
|
||||
# Wait for a valid reaction or message, break if we time out
|
||||
reaction_task = asyncio.create_task(
|
||||
ctx.client.wait_for('reaction_add', check=reaction_check)
|
||||
)
|
||||
message_task = asyncio.create_task(
|
||||
ctx.client.wait_for('message', check=message_check)
|
||||
)
|
||||
done, pending = await asyncio.wait(
|
||||
(reaction_task, message_task),
|
||||
timeout=300,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
if done:
|
||||
if reaction_task in done:
|
||||
# Cancel the message task and collect the reaction result
|
||||
message_task.cancel()
|
||||
reaction, user = reaction_task.result()
|
||||
|
||||
# Attempt to remove the user's reaction, silently ignore errors
|
||||
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
|
||||
|
||||
# Change the page number
|
||||
page += 1 if reaction.emoji == next_emoji else -1
|
||||
page %= len(pages)
|
||||
elif message_task in done:
|
||||
# Cancel the reaction task and collect the message result
|
||||
reaction_task.cancel()
|
||||
message = message_task.result()
|
||||
|
||||
# Attempt to delete the user's message, silently ignore errors
|
||||
asyncio.ensure_future(message.delete())
|
||||
|
||||
# Move to the correct page
|
||||
page = int(message.content[1:]) - 1
|
||||
|
||||
# Edit the message with the new page
|
||||
active_page = pages[page]
|
||||
if isinstance(active_page, discord.Embed):
|
||||
await out_msg.edit(embed=active_page, **kwargs)
|
||||
else:
|
||||
await out_msg.edit(content=active_page, **kwargs)
|
||||
else:
|
||||
# No tasks finished, so we must have timed out, or had an exception.
|
||||
# Break the loop and clean up
|
||||
break
|
||||
|
||||
# Clean up by removing the reactions
|
||||
try:
|
||||
await out_msg.clear_reactions()
|
||||
except discord.Forbidden:
|
||||
try:
|
||||
await out_msg.remove_reaction(next_emoji, ctx.client.user)
|
||||
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
|
||||
except discord.NotFound:
|
||||
pass
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
|
||||
@Context.util
|
||||
async def input(ctx, msg="", timeout=120):
|
||||
"""
|
||||
Listen for a response in the current channel, from ctx.author.
|
||||
Returns the response from ctx.author, if it is provided.
|
||||
Parameters
|
||||
----------
|
||||
msg: string
|
||||
Allows a custom input message to be provided.
|
||||
Will use default message if not provided.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
Raises
|
||||
------
|
||||
cmdClient.lib.ResponseTimedOut:
|
||||
Raised when ctx.author does not provide a response before the function times out.
|
||||
"""
|
||||
# Deliver prompt
|
||||
offer_msg = await ctx.reply(msg or "Please enter your input.")
|
||||
|
||||
# Criteria for the input message
|
||||
def checks(m):
|
||||
return m.author == ctx.author and m.channel == ctx.ch
|
||||
|
||||
# Listen for the reply
|
||||
try:
|
||||
result_msg = await ctx.client.wait_for("message", check=checks, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise ResponseTimedOut("Session timed out waiting for user response.") from None
|
||||
|
||||
result = result_msg.content
|
||||
|
||||
# Attempt to delete the prompt and reply messages
|
||||
try:
|
||||
await offer_msg.delete()
|
||||
await result_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@Context.util
|
||||
async def ask(ctx, msg, timeout=30, use_msg=None, del_on_timeout=False):
|
||||
"""
|
||||
Ask ctx.author a yes/no question.
|
||||
Returns 0 if ctx.author answers no
|
||||
Returns 1 if ctx.author answers yes
|
||||
Parameters
|
||||
----------
|
||||
msg: string
|
||||
Adds the question to the message string.
|
||||
Requires an input.
|
||||
timeout: int
|
||||
Number of seconds to wait before timing out.
|
||||
use_msg: string
|
||||
A completely custom string to use instead of the default string.
|
||||
del_on_timeout: bool
|
||||
Whether to delete the question if it times out.
|
||||
Raises
|
||||
------
|
||||
Nothing
|
||||
"""
|
||||
out = "{} {}".format(msg, "`y(es)`/`n(o)`")
|
||||
|
||||
offer_msg = use_msg or await ctx.reply(out)
|
||||
if use_msg:
|
||||
await use_msg.edit(content=msg)
|
||||
|
||||
result_msg = await ctx.listen_for(["y", "yes", "n", "no"], timeout=timeout)
|
||||
|
||||
if result_msg is None:
|
||||
if del_on_timeout:
|
||||
try:
|
||||
await offer_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
result = result_msg.content.lower()
|
||||
try:
|
||||
if not use_msg:
|
||||
await offer_msg.delete()
|
||||
await result_msg.delete()
|
||||
except Exception:
|
||||
pass
|
||||
if result in ["n", "no"]:
|
||||
return 0
|
||||
return 1
|
||||
@@ -441,3 +441,12 @@ def jumpto(guildid: int, channeldid: int, messageid: int):
|
||||
channeldid,
|
||||
messageid
|
||||
)
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""
|
||||
Dict-type allowing dot access to keys.
|
||||
"""
|
||||
__getattr__ = dict.get
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
Reference in New Issue
Block a user