6 Commits

17 changed files with 733 additions and 272 deletions

View File

@@ -287,14 +287,13 @@ CREATE TABLE tasklist(
deleted_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ, completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ, created_at TIMESTAMPTZ,
last_updated_at TIMESTAMPTZ, last_updated_at TIMESTAMPTZ
duration INTEGER
); );
CREATE INDEX tasklist_users ON tasklist (userid); CREATE INDEX tasklist_users ON tasklist (userid);
ALTER TABLE tasklist ALTER TABLE tasklist
ADD CONSTRAINT fk_tasklist_users ADD CONSTRAINT fk_tasklist_users
FOREIGN KEY (userid) FOREIGN KEY (userid)
REFERENCES user_profiles (profileid) REFERENCES user_config (userid)
ON DELETE CASCADE ON DELETE CASCADE
NOT VALID; NOT VALID;
ALTER TABLE tasklist ALTER TABLE tasklist
@@ -318,20 +317,6 @@ CREATE TABLE tasklist_reward_history(
reward_count INTEGER reward_count INTEGER
); );
CREATE INDEX tasklist_reward_history_users ON tasklist_reward_history (userid, reward_time); CREATE INDEX tasklist_reward_history_users ON tasklist_reward_history (userid, reward_time);
CREATE TABLE tasklist_current(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
started_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE tasklist_planner(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
sortkey INTEGER
);
-- }}} -- }}}
-- Reminder data {{{ -- Reminder data {{{

View File

@@ -92,10 +92,6 @@ class LionBot(Bot):
def core(self): def core(self):
return self.get_cog('CoreCog') return self.get_cog('CoreCog')
@property
def profiles(self):
return self.get_cog('ProfileCog')
async def _handle_global_dispatch(self, event_name: str, *args, **kwargs): async def _handle_global_dispatch(self, event_name: str, *args, **kwargs):
self.dispatch(event_name, *args, **kwargs) self.dispatch(event_name, *args, **kwargs)

View File

@@ -13,8 +13,6 @@ if TYPE_CHECKING:
from core.lion_member import LionMember from core.lion_member import LionMember
from core.lion_user import LionUser from core.lion_user import LionUser
from core.lion_guild import LionGuild from core.lion_guild import LionGuild
from modules.profiles.profile import UserProfile
from modules.profiles.community import Community
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -56,8 +54,6 @@ class LionContext(Context['LionBot']):
lguild: 'LionGuild' lguild: 'LionGuild'
lmember: 'LionMember' lmember: 'LionMember'
alion: 'LionUser | LionMember' alion: 'LionUser | LionMember'
profile: 'UserProfile'
community: 'Community'
def __repr__(self): def __repr__(self):
parts = {} parts = {}

View File

@@ -5,15 +5,20 @@ from datetime import timedelta
import discord import discord
from discord.ext import commands as cmds from discord.ext import commands as cmds
from discord import app_commands as appcmds
import twitchio import twitchio
from twitchio.ext import commands from twitchio.ext import commands
from data.queries import ORDER from data.queries import ORDER
from meta import LionCog, LionBot, CrocBot from meta import LionCog, LionBot, CrocBot, LionContext
from utils.lib import utc_now from modules.profiles.community import Community
from modules.profiles.profile import UserProfile
from utils.lib import utc_now, paginate_list, pager
from . import logger from . import logger
from .data import CounterData from .data import CounterData
from .graphics.weekly import counter_weekly_card, counter_monthly_card
class PERIOD(Enum): class PERIOD(Enum):
@@ -25,6 +30,11 @@ class PERIOD(Enum):
YEAR = ('this year', 'y', 'year', 'yearly') YEAR = ('this year', 'y', 'year', 'yearly')
class ORIGIN(Enum):
DISCORD = 'discord'
TWITCH = 'twitch'
def counter_cmd_factory( def counter_cmd_factory(
counter: str, counter: str,
response: str, response: str,
@@ -32,10 +42,16 @@ def counter_cmd_factory(
context: Optional[str] = None context: Optional[str] = None
): ):
context = context or f"cmd: {counter}" context = context or f"cmd: {counter}"
async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None): async def counter_cmd(
userid = int(ctx.author.id) cog,
channelid = int((await ctx.channel.user()).id) ctx: commands.Context | LionContext,
period, start_time = await cog.parse_period(channelid, '', default=default_period) origin: ORIGIN,
author: UserProfile,
community: Community,
args: Optional[str]
):
userid = author.profileid
period, start_time = await cog.parse_period(community, '', default=default_period)
args = (args or '').strip(" 󠀀 ") args = (args or '').strip(" 󠀀 ")
splits = args.split(maxsplit=1) splits = args.split(maxsplit=1)
@@ -69,13 +85,25 @@ def counter_cmd_factory(
) )
) )
async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''): async def lb_cmd(
user = await ctx.channel.user() cog,
await ctx.reply(await cog.formatted_lb(counter, args, int(user.id))) ctx: commands.Context | LionContext,
origin: ORIGIN,
author: UserProfile,
community: Community,
args: Optional[str]
):
await cog.show_lb(ctx, counter, args, author, community, origin)
async def undo_cmd(cog, ctx: commands.Context): async def undo_cmd(
userid = int(ctx.author.id) cog,
channelid = int((await ctx.channel.user()).id) ctx: commands.Context | LionContext,
origin: ORIGIN,
author: UserProfile,
community: Community,
args: Optional[str]
):
userid = author.profileid
_counter = await cog.fetch_counter(counter) _counter = await cog.fetch_counter(counter)
query = cog.data.CounterEntry.fetch_where( query = cog.data.CounterEntry.fetch_where(
counterid=_counter.counterid, counterid=_counter.counterid,
@@ -107,12 +135,16 @@ class CounterCog(LionCog):
async def cog_load(self): async def cog_load(self):
self._load_twitch_methods(self.crocbot) self._load_twitch_methods(self.crocbot)
await self.load_counter_commands()
await self.data.init() await self.data.init()
await self.load_counter_commands()
await self.load_counters() await self.load_counters()
self.loaded.set() self.loaded.set()
profiles = self.bot.get_cog('ProfileCog')
profiles.add_profile_migrator(self.migrate_profiles, name='counters')
async def cog_unload(self): async def cog_unload(self):
self._unload_twitch_methods(self.crocbot) self._unload_twitch_methods(self.crocbot)
@@ -124,18 +156,48 @@ class CounterCog(LionCog):
counter.name, counter.name,
row.response row.response
) )
cmds = [] twitch_cmds = []
main_cmd = commands.command(name=row.name)(counter_cb) disc_cmds = []
cmds.append(main_cmd) twitch_cmds.append(
if row.lbname: commands.command(
lb_cmd = commands.command(name=row.lbname)(lb_cb) name=row.name
cmds.append(lb_cmd) )(self.twitch_callback(counter_cb))
if row.undoname: )
undo_cmd = commands.command(name=row.undoname)(undo_cb) disc_cmds.append(
cmds.append(undo_cmd) cmds.hybrid_command(
name=row.name
)(self.discord_callback(counter_cb))
)
for cmd in cmds: if row.lbname:
twitch_cmds.append(
commands.command(
name=row.lbname
)(self.twitch_callback(lb_cb))
)
disc_cmds.append(
cmds.hybrid_command(
name=row.lbname
)(self.discord_callback(lb_cb))
)
if row.undoname:
twitch_cmds.append(
commands.command(
name=row.undoname
)(self.twitch_callback(undo_cb))
)
disc_cmds.append(
cmds.hybrid_command(
name=row.undoname
)(self.discord_callback(undo_cb))
)
for cmd in twitch_cmds:
self.add_twitch_command(self.crocbot, cmd) self.add_twitch_command(self.crocbot, cmd)
for cmd in disc_cmds:
# cmd.cog = self
self.bot.add_command(cmd)
print(f"Adding command: {cmd}")
logger.info(f"(Re)Loaded {len(rows)} counter commands!") logger.info(f"(Re)Loaded {len(rows)} counter commands!")
@@ -152,6 +214,87 @@ class CounterCog(LionCog):
f"Loaded {len(self.counters)} counters." f"Loaded {len(self.counters)} counters."
) )
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
"""
Move source profile entries to target profile entries
"""
results = ["(Counters)"]
rows = await self.data.CounterEntry.table.update_where(userid=source_profile.profileid).set(userid=target_profile.profileid)
if rows:
results.append(
f"Migrated {len(rows)} counter entries from source profile."
)
else:
results.append(
"No counter entries to migrate in source profile."
)
return ' '.join(results)
async def user_profile_migration(self):
"""
Manual single-use migration method from the old userid format to the new profileid format.
"""
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
entries = await self.data.CounterEntry.fetch_where()
for entry in entries:
if entry.userid > 1000:
# Assume userid is a twitch userid
profile = await UserProfile.fetch_from_twitchid(self.bot, entry.userid)
if not profile:
# Need to create
users = await self.crocbot.fetch_users(ids=[entry.userid])
if not users:
continue
user = users[0]
profile = await UserProfile.create_from_twitch(self.bot, user)
await entry.update(userid=profile.profileid)
logger.info("Completed single-shot user profile migration")
# General API
def twitch_callback(self, callback):
"""
Generate a Twitch command callback from the given general callback.
General callback must be of the form
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
Return will be a command callback of the form
callback(cog, ctx: Context, *, args: Optional[str] = None)
"""
async def command_callback(cog: CounterCog, ctx: commands.Context, *, args: Optional[str] = None):
profiles = cog.bot.get_cog('ProfileCog')
# Compute author profile
author = await profiles.fetch_profile_twitch(ctx.author)
# Compute community profile
community = await profiles.fetch_community_twitch(await ctx.channel.user())
return await callback(cog, ctx, ORIGIN.TWITCH, author, community, args)
return command_callback
def discord_callback(self, callback):
"""
Generate a Discord command callback from the given general callback.
General callback must be of the form
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
Return will be a command callback of the form
callback(cog, ctx: LionContext, *, args: Optional[str] = None)
"""
cog = self
async def command_callback(ctx: LionContext, *, args: Optional[str] = None):
profiles = cog.bot.get_cog('ProfileCog')
# Compute author profile
author = await profiles.fetch_profile_discord(ctx.author)
# Compute community profile
community = await profiles.fetch_community_discord(ctx.guild)
return await callback(cog, ctx, ORIGIN.DISCORD, author, community, args)
return command_callback
# Counters API # Counters API
async def fetch_counter(self, counter: str) -> CounterData.Counter: async def fetch_counter(self, counter: str) -> CounterData.Counter:
@@ -218,6 +361,14 @@ class CounterCog(LionCog):
results = await query results = await query
return results[0]['counter_total'] if results else 0 return results[0]['counter_total'] if results else 0
# Manage commands
@commands.command()
async def countermigration(self, ctx: commands.Context, *, args: Optional[str]=None):
if not (ctx.author.is_mod or ctx.author.is_broadcaster):
return
await self.user_profile_migration()
await ctx.reply("Counter userid->profileid migration done.")
# Counters commands # Counters commands
@commands.command() @commands.command()
async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None): async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None):
@@ -225,6 +376,10 @@ class CounterCog(LionCog):
return return
name = name.lower() name = name.lower()
profiles = self.bot.get_cog('ProfileCog')
author = await profiles.fetch_profile_twitch(ctx.author)
userid = author.profileid
community = await profiles.fetch_community_twitch(await ctx.channel.user())
if subcmd is None or subcmd == 'show': if subcmd is None or subcmd == 'show':
# Show # Show
@@ -241,16 +396,14 @@ class CounterCog(LionCog):
return return
await self.add_to_counter( await self.add_to_counter(
name, name,
int(ctx.author.id), userid,
value, value,
context='cmd: counter add' context='cmd: counter add'
) )
total = await self.totals(name) total = await self.totals(name)
await ctx.reply(f"'{name}' counter is now: {total}") await ctx.reply(f"'{name}' counter is now: {total}")
elif subcmd == 'lb': elif subcmd == 'lb':
user = await ctx.channel.user() await self.show_lb(ctx, name, args or '', author, community, origin=ORIGIN.TWITCH)
lbstr = await self.formatted_lb(name, args or '', int(user.id))
await ctx.reply(lbstr)
elif subcmd == 'clear': elif subcmd == 'clear':
await self.reset_counter(name) await self.reset_counter(name)
await ctx.reply(f"'{name}' counter reset.") await ctx.reply(f"'{name}' counter reset.")
@@ -292,7 +445,7 @@ class CounterCog(LionCog):
else: else:
await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.") await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.")
async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM): async def parse_period(self, community: Community, periodstr: str, default=PERIOD.STREAM):
if periodstr: if periodstr:
period = next((period for period in PERIOD if periodstr.lower() in period.value), None) period = next((period for period in PERIOD if periodstr.lower() in period.value), None)
if period is None: if period is None:
@@ -306,9 +459,13 @@ class CounterCog(LionCog):
if period is PERIOD.ALL: if period is PERIOD.ALL:
start_time = None start_time = None
elif period is PERIOD.STREAM: elif period is PERIOD.STREAM:
streams = await self.crocbot.fetch_streams(user_ids=[userid]) twitches = await community.twitch_channels()
if streams: stream = None
stream = streams[0] if twitches:
twitch = twitches[0]
streams = await self.crocbot.fetch_streams(user_ids=[int(twitch.channelid)])
stream = streams[0] if streams else None
if stream:
start_time = stream.started_at start_time = stream.started_at
else: else:
period = PERIOD.ALL period = PERIOD.ALL
@@ -327,21 +484,104 @@ class CounterCog(LionCog):
return (period, start_time) return (period, start_time)
async def formatted_lb(self, counter: str, periodstr: str, channelid: int): @cmds.hybrid_command(
name='counterlb',
description="Show the leaderboard for the given counter."
)
async def counterlb_dcmd(self, ctx: LionContext, counter: str, period: Optional[str] = None):
profiles = self.bot.get_cog('ProfileCog')
author = await profiles.fetch_profile_discord(ctx.author)
community = await profiles.fetch_community_discord(ctx.guild)
await self.show_lb(ctx, counter, period, author, community, ORIGIN.DISCORD)
period, start_time = await self.parse_period(channelid, periodstr) @cmds.hybrid_command(
name='counterstats',
description="Show your stats for the given counter."
)
async def counterstats_dcmd(self, ctx: LionContext, counter: str, period: Optional[str]=None):
profiles = self.bot.get_cog('ProfileCog')
author = await profiles.fetch_profile_discord(ctx.author)
community = await profiles.fetch_community_discord(ctx.guild)
if period and period.lower() in ('monthly', 'month'):
card = await counter_monthly_card(
self.bot,
userid=ctx.author.id,
profile=author,
counter=await self.fetch_counter(counter),
guildid=ctx.guild.id,
offset=0,
)
await card.render()
await ctx.reply(file=card.as_file('stats.png'))
else:
card = await counter_weekly_card(
self.bot,
userid=ctx.author.id,
profile=author,
counter=await self.fetch_counter(counter),
guildid=ctx.guild.id,
offset=0,
)
await card.render()
await ctx.reply(file=card.as_file('stats.png'))
async def show_lb(
self,
ctx: commands.Context | LionContext,
counter: str,
periodstr: str,
caller: UserProfile,
community: Community,
origin: ORIGIN = ORIGIN.TWITCH
):
period, start_time = await self.parse_period(community, periodstr)
lb = await self.leaderboard(counter, start_time=start_time) lb = await self.leaderboard(counter, start_time=start_time)
if lb: name_map = {}
userids = list(lb.keys()) for userid in lb.keys():
users = await self.crocbot.fetch_users(ids=userids) profile = await UserProfile.fetch(self.bot, userid)
name_map = {user.id: user.display_name for user in users} name = await profile.get_name()
name_map[userid] = name
if not lb:
await ctx.reply(
f"{counter} {period.value[-1]} leaderboard is empty!"
)
elif origin is ORIGIN.TWITCH:
parts = [] parts = []
for userid, total in lb.items(): items = list(lb.items())
prefix = 'top 10 ' if len(items) > 10 else ''
items = items[:10]
for userid, total in items:
name = name_map.get(userid, str(userid)) name = name_map.get(userid, str(userid))
part = f"{name}: {total}" part = f"{name}: {total}"
parts.append(part) parts.append(part)
lbstr = '; '.join(parts) lbstr = '; '.join(parts)
return f"{counter} {period.value[-1]} leaderboard --- {lbstr}" await ctx.reply(f"{counter} {period.value[-1]} {prefix}leaderboard --- {lbstr}")
else: elif origin is ORIGIN.DISCORD:
return f"{counter} {period.value[-1]} leaderboard is empty!" title = f"'{counter}' {period.value[-1]} leaderboard"
lb_strings = []
author_index = None
max_name_len = min((30, max(len(name) for name in name_map.values())))
for i, (uid, total) in enumerate(lb.items()):
if author_index is None and uid == caller.profileid:
author_index = i
lb_strings.append(
"{:<{}}\t{:<9}".format(
name_map[uid],
max_name_len,
total,
)
)
page_len = 20
pages = paginate_list(lb_strings, block_length=page_len, title=title)
start_page = author_index // page_len if author_index is not None else 0
await pager(
ctx,
pages,
start_at=start_page
)

View File

View File

@@ -0,0 +1,222 @@
import itertools
from typing import Optional
from datetime import timedelta, datetime
import calendar
from meta import LionBot
from gui.cards import WeeklyStatsCard, MonthlyStatsCard
from gui.base import CardMode
from modules.profiles.profile import UserProfile
from babel import LocalBabel
from modules.statistics.lib import apply_month_offset
from ..data import CounterData
babel = LocalBabel('counters')
_ = babel._
async def counter_monthly_card(
bot: LionBot,
userid: int,
profile: UserProfile,
counter: CounterData.Counter,
guildid: int,
offset: int,
):
cog = bot.get_cog('CounterCog')
data: CounterData = cog.data
if guildid:
lion = await bot.core.lions.fetch_member(guildid, userid)
user = await lion.fetch_member()
else:
lion = await bot.core.lions.fetch_user(userid)
user = await bot.fetch_user(userid)
today = lion.today
month_start = today.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
target = apply_month_offset(month_start, offset)
target_end = (target + timedelta(days=40)).replace(day=1, hour=0, minute=0) - timedelta(days=1)
months = [target]
for i in range(0, 3):
months.append((months[-1] - timedelta(days=1)).replace(day=1))
months.reverse()
rows = await data.CounterEntry.fetch_where(
data.CounterEntry.counterid == counter.counterid,
data.CounterEntry.userid == profile.profileid,
data.CounterEntry.created_at <= target_end,
data.CounterEntry.created_at >= months[0],
)
events = [(row.created_at, row.value) for row in rows]
month_lengths = [
(calendar.monthrange(month.year, month.month)[1]) for month in months
]
month_dates = []
for month, length in zip(months, month_lengths):
for day in range(1, length + 1):
month_dates.append(datetime(month.year, month.month, day, tzinfo=month.tzinfo))
monthly_flat = events_to_dayfreq(events, month_dates)
print(monthly_flat)
monthly = []
i = 0
for length in month_lengths:
this_month = monthly_flat[i : i+length]
i += length
monthly.append(this_month)
skin = await bot.get_cog('CustomSkinCog').get_skinargs_for(
guildid, userid, MonthlyStatsCard.card_id
)
skin |= {
'title_text': f"{counter.name.upper()}",
'this_month_text': f"THIS MONTH: {{amount}} {counter.name.upper()}",
'last_month_text': f"LAST MONTH: {{amount}} {counter.name.upper()}"
}
if user:
username = (user.display_name, '')
else:
username = (await profile.get_name(), '')
card = MonthlyStatsCard(
user=username,
timezone=str(lion.timezone),
now=lion.now.timestamp(),
month=int(target.timestamp()),
monthly=monthly,
current_streak=-1,
longest_streak=-1,
skin=skin | {'mode': CardMode.TEXT}
)
return card
async def counter_weekly_card(
bot: LionBot,
userid: int,
profile: UserProfile,
counter: CounterData.Counter,
guildid: int,
offset: int,
):
cog = bot.get_cog('CounterCog')
data: CounterData = cog.data
if guildid:
lion = await bot.core.lions.fetch_member(guildid, userid)
user = await lion.fetch_member()
else:
lion = await bot.core.lions.fetch_user(userid)
user = await bot.fetch_user(userid)
today = lion.today
week_start = today - timedelta(days=today.weekday()) - timedelta(weeks=offset)
days = [week_start + timedelta(i) for i in range(-7, 8 if offset else (today.weekday() + 2))]
rows = await data.CounterEntry.fetch_where(
data.CounterEntry.counterid == counter.counterid,
data.CounterEntry.userid == profile.profileid,
data.CounterEntry.created_at <= days[-1],
data.CounterEntry.created_at >= days[0],
)
events = [(row.created_at, row.value) for row in rows]
daily = events_to_dayfreq(events, days)
sessions = events_to_sessions(next(zip(*events), []))
skin = await bot.get_cog('CustomSkinCog').get_skinargs_for(
guildid, userid, WeeklyStatsCard.card_id
)
skin |= {
'title_text': f"{counter.name.upper()}",
'this_week_text': f"THIS WEEK: {{amount}} {counter.name.upper()}",
'last_week_text': f"LAST WEEK: {{amount}} {counter.name.upper()}"
}
if user:
username = (user.display_name, '')
else:
username = (await profile.get_name(), '')
card = WeeklyStatsCard(
user=username,
timezone=str(lion.timezone),
now=lion.now.timestamp(),
week=week_start.timestamp(),
daily=tuple(map(int, daily)),
sessions=sessions,
skin=skin | {'mode': CardMode.TEXT}
)
return card
def events_to_dayfreq(events: list[tuple[datetime, int]], days: list[datetime]) -> list[int]:
if not days:
return []
last_day = 0
dayts = 0
daymap = {}
for day in sorted(days, reverse=True):
dayts = day.timestamp()
last_day = last_day or (day + timedelta(days=1)).timestamp()
daymap[dayts] = 0
first_day = dayts
for tim, count in events:
timts = tim.timestamp()
if not first_day < timts < last_day:
continue
for day_start in daymap:
if timts > day_start:
daymap[day_start] += count
break
return list(reversed(daymap.values()))
def events_to_sessions(event_times: list[datetime]) -> list[tuple[int, int]]:
"""
Convert a provided list of event times to a session list.
"""
sessions = []
session_start = None
session_end = None
SESSION_GAP = 60 * 30
SESSION_RADIUS = 60 * 30
for time in sorted(event_times):
if session_start and session_end and (time - session_end).total_seconds() - SESSION_RADIUS > SESSION_GAP:
session = (int(session_start.timestamp()), int(session_end.timestamp()))
sessions.append(session)
session_start = None
session_end = None
if session_start is None:
session_start = time - timedelta(seconds=SESSION_RADIUS)
session_end = time + timedelta(seconds=SESSION_RADIUS)
if session_start and session_end:
session = (int(session_start.timestamp()), int(session_end.timestamp()))
sessions.append(session)
return sessions

View File

View File

View File

@@ -202,7 +202,7 @@ class NowDoingCog(LionCog):
await self.data.Task.table.delete_where(userid=userid) await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create( task = await self.data.Task.create(
userid=userid, userid=userid,
name=ctx.author.display_name, name=await profile.get_name(),
task=args, task=args,
started_at=existing.started_at if (existing and edit) else utc_now(), started_at=existing.started_at if (existing and edit) else utc_now(),
) )
@@ -272,7 +272,7 @@ class NowDoingCog(LionCog):
await self.data.Task.table.delete_where(userid=userid) await self.data.Task.table.delete_where(userid=userid)
task = await self.data.Task.create( task = await self.data.Task.create(
userid=userid, userid=userid,
name=ctx.author.display_name, name=await profile.get_name(),
task=args, task=args,
started_at=utc_now(), started_at=utc_now(),
) )

View File

@@ -40,17 +40,6 @@ class ProfileCog(LionCog):
async def cog_check(self, ctx): async def cog_check(self, ctx):
return True return True
async def bot_check_once(self, ctx: LionContext):
"""
Inject the contextual UserProfile and Community into the LionContext.
Creates the profile and community if they do not exist.
"""
if ctx.guild:
ctx.community = await self.fetch_community_discord(ctx.guild)
ctx.profile = await self.fetch_profile_discord(ctx.author)
return True
# Profile API # Profile API
def add_profile_migrator(self, migrator, name=None): def add_profile_migrator(self, migrator, name=None):
name = name or migrator.__name__ name = name or migrator.__name__

View File

@@ -31,6 +31,30 @@ class UserProfile:
def __repr__(self): def __repr__(self):
return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>" return f"<UserProfile profileid={self.profileid} profile={self.profile_row}>"
async def get_name(self):
# TODO: Store a preferred name in the profile preferences
# TODO Should have a multi-fetch system
name = None
twitches = await self.twitch_accounts()
if twitches:
users = await self.bot.crocbot.fetch_users(
ids=[int(twitches[0].userid)]
)
if users:
user = users[0]
name = user.display_name
if not name:
discords = await self.discord_accounts()
if discords:
user = await self.bot.fetch_user(discords[0].userid)
name = user.display_name
if not name:
name = 'Unknown'
return name
async def attach_discord(self, user: discord.User | discord.Member): async def attach_discord(self, user: discord.User | discord.Member):
""" """
Attach a new discord user to this profile. Attach a new discord user to this profile.
@@ -62,13 +86,21 @@ class UserProfile:
""" """
Fetch the Discord accounts associated to this profile. Fetch the Discord accounts associated to this profile.
""" """
return await self.data.DiscordProfileRow.fetch_where(profileid=self.profileid) return await self.data.DiscordProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]: async def twitch_accounts(self) -> list[ProfileData.TwitchProfileRow]:
""" """
Fetch the Twitch accounts associated to this profile. Fetch the Twitch accounts associated to this profile.
""" """
return await self.data.TwitchProfileRow.fetch_where(profileid=self.profileid) return await self.data.TwitchProfileRow.fetch_where(
profileid=self.profileid
).order_by(
'created_at'
)
@classmethod @classmethod
async def fetch(cls, bot: LionBot, profile_id: int) -> Self: async def fetch(cls, bot: LionBot, profile_id: int) -> Self:

View File

@@ -8,9 +8,16 @@ async def get_leaderboard_card(
bot: LionBot, highlightid: int, guildid: int, bot: LionBot, highlightid: int, guildid: int,
mode: CardMode, mode: CardMode,
entry_data: list[tuple[int, int, int]], # userid, position, time entry_data: list[tuple[int, int, int]], # userid, position, time
name_map: dict[int, str] = {},
extra_skin_args = {},
): ):
""" """
Render a leaderboard card with given parameters. Render a leaderboard card with given parameters.
Parameters
----------
name_map: dict[int, str]
Map of userid -> name, used first before cache or fetch.
""" """
guild = bot.get_guild(guildid) guild = bot.get_guild(guildid)
if guild is None: if guild is None:
@@ -20,8 +27,12 @@ async def get_leaderboard_card(
avatars = {} avatars = {}
names = {} names = {}
missing = [] missing = []
for userid, _, _ in entry_data: for userid, _, _ in entry_data:
if guild and (member := guild.get_member(userid)): if (name := name_map.get(userid, None)):
avatars[userid] = None
names[userid] = name
elif guild and (member := guild.get_member(userid)):
avatars[userid] = member.avatar.key if member.avatar else None avatars[userid] = member.avatar.key if member.avatar else None
names[userid] = member.display_name names[userid] = member.display_name
elif (user := bot.get_user(userid)): elif (user := bot.get_user(userid)):
@@ -65,7 +76,7 @@ async def get_leaderboard_card(
guildid, None, LeaderboardCard.card_id guildid, None, LeaderboardCard.card_id
) )
card = LeaderboardCard( card = LeaderboardCard(
skin=skin | {'mode': mode}, skin=skin | {'mode': mode} | extra_skin_args,
server_name=guild.name, server_name=guild.name,
entries=entries, entries=entries,
highlight=highlight highlight=highlight

View File

@@ -7,12 +7,9 @@ from discord.ext import commands as cmds
from discord import app_commands as appcmds from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from data.queries import JOINTYPE
from meta import LionBot, LionCog, LionContext from meta import LionBot, LionCog, LionContext
from meta.CrocBot import CrocBot
from meta.logger import log_wrap from meta.logger import log_wrap
from meta.errors import UserInputError from meta.errors import UserInputError
from modules.profiles.profile import UserProfile
from utils.lib import utc_now, error_embed from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton from utils.ui import ChoicedEnum, Transformed, AButton
@@ -129,7 +126,6 @@ class TasklistCog(LionCog):
def __init__(self, bot: LionBot): def __init__(self, bot: LionBot):
self.bot = bot self.bot = bot
self.crocbot: CrocBot = bot.crocbot
self.data = bot.db.load_registry(TasklistData()) self.data = bot.db.load_registry(TasklistData())
self.babel = babel self.babel = babel
self.settings = TasklistSettings() self.settings = TasklistSettings()
@@ -142,84 +138,10 @@ class TasklistCog(LionCog):
self.bot.core.guild_config.register_model_setting(self.settings.task_reward_limit) self.bot.core.guild_config.register_model_setting(self.settings.task_reward_limit)
self.bot.add_view(TasklistCaller(self.bot)) self.bot.add_view(TasklistCaller(self.bot))
self.bot.profiles.add_profile_migrator(self.migrate_profiles, name='tasklist-migrator')
configcog = self.bot.get_cog('ConfigCog') configcog = self.bot.get_cog('ConfigCog')
self.crossload_group(self.configure_group, configcog.config_group) self.crossload_group(self.configure_group, configcog.config_group)
self._load_twitch_methods(self.crocbot) @LionCog.listener('on_tasks_completed')
async def cog_unload(self):
self.live_tasklists.clear()
if profiles := self.bot.get_cog('ProfileCog'):
profiles.del_profile_migrator('tasklist-migrator')
self._unload_twitch_methods(self.crocbot)
@log_wrap(action="Tasklist Profile Migration")
async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile):
"""
Re-assign all tasklist tasks from source profile to target profile.
TODO: Probably wants some elegant handling of the cached or running tasklists.
"""
results = ["(Tasklist)"]
sourceid = source_profile.profileid
targetid = target_profile.profileid
updated = await self.data.Task.table.update_where(userid=sourceid).set(userid=targetid)
if updated:
results.append(
f"Migrated {len(updated)} task row(s) from source profile."
)
for channel_lists in self.live_tasklists.get(sourceid, []):
for tasklist in list(channel_lists.values()):
await tasklist.close()
self.bot.dispatch('tasklist_update', profileid=targetid, summon=False)
else:
results.append(
"No tasks found in source profile, nothing to migrate!"
)
return ' '.join(results)
async def user_profile_migration(self):
"""
Manual one-shot migration method from old Discord userids to the new profileids.
"""
# First collect all the distinct userids from the tasklist
# Then create a map of userids to profileids, creating the profiles if required
# Then do updates, we can just inefficiently do updates on each distinct userid
# As long as the userids and profileids never overlap, this is fine. Fine for a one-shot
# Extract all the userids that exist in the table
rows = await self.data.Task.table.select_where().select(
userid="DISTINCT(userid)"
).with_no_adapter()
# Fetch or create discord user profiles for them
profile_map = {}
for row in rows:
userid = row['userid']
if userid > 100000:
# Assume a Discord snowflake
profile = await UserProfile.fetch_from_discordid(self.bot, userid)
if not profile:
try:
user = self.bot.get_user(userid)
if user is None:
user = await self.bot.fetch_user(userid)
except discord.HTTPException:
logger.info(f"Skipping user {userid}")
continue
profile = await UserProfile.create_from_discord(self.bot, user)
profile_map[userid] = profile
# Now iterate through
for userid, profile in profile_map.items():
logger.info(f"Migrating userid {userid} to profile {profile}")
await self.data.Task.table.update_where(userid=userid).set(userid=profile.profileid)
# Temporarily disabling integration with userid driven Economy
# @LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed") @log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int): async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
async with self.bot.db.connection() as conn: async with self.bot.db.connection() as conn:
@@ -248,9 +170,6 @@ class TasklistCog(LionCog):
) )
async def is_tasklist_channel(self, channel) -> bool: async def is_tasklist_channel(self, channel) -> bool:
"""
Check whether a given Discord channel is a tasklist channel
"""
if not channel.guild: if not channel.guild:
return True return True
channels = (await self.settings.tasklist_channels.get(channel.guild.id)).value channels = (await self.settings.tasklist_channels.get(channel.guild.id)).value
@@ -267,16 +186,12 @@ class TasklistCog(LionCog):
return (channel in channels) or (channel.id in private_channels) or (channel.category in channels) return (channel in channels) or (channel.id in private_channels) or (channel.category in channels)
async def call_tasklist(self, interaction: discord.Interaction): async def call_tasklist(self, interaction: discord.Interaction):
"""
Given a Discord channel interaction, summon the interacting user's tasklist.
"""
await interaction.response.defer(thinking=True, ephemeral=True) await interaction.response.defer(thinking=True, ephemeral=True)
channel = interaction.channel channel = interaction.channel
guild = channel.guild guild = channel.guild
profile = await self.bot.profiles.fetch_profile_discord(interaction.user) userid = interaction.user.id
profileid = profile.profileid
tasklist = await Tasklist.fetch(self.bot, self.data, profileid) tasklist = await Tasklist.fetch(self.bot, self.data, userid)
if await self.is_tasklist_channel(channel): if await self.is_tasklist_channel(channel):
# Check we have permissions to send a regular message here # Check we have permissions to send a regular message here
@@ -298,7 +213,7 @@ class TasklistCog(LionCog):
) )
await interaction.edit_original_response(embed=error) await interaction.edit_original_response(embed=error)
else: else:
tasklistui = TasklistUI.fetch(tasklist, channel, guild, caller=interaction.user, timeout=None) tasklistui = TasklistUI.fetch(tasklist, channel, guild, timeout=None)
await tasklistui.summon(force=True) await tasklistui.summon(force=True)
await interaction.delete_original_response() await interaction.delete_original_response()
else: else:
@@ -307,14 +222,14 @@ class TasklistCog(LionCog):
await tasklistui.run(interaction) await tasklistui.run(interaction)
@LionCog.listener('on_tasklist_update') @LionCog.listener('on_tasklist_update')
async def update_listening_tasklists(self, profileid, channel=None, summon=True): async def update_listening_tasklists(self, userid, channel=None, summon=True):
""" """
Propagate a tasklist update to all persistent tasklist UIs for this user. Propagate a tasklist update to all persistent tasklist UIs for this user.
If channel is given, also summons the UI if the channel is a tasklist channel. If channel is given, also summons the UI if the channel is a tasklist channel.
""" """
# Do the given channel first, and summon if requested # Do the given channel first, and summon if requested
if channel and (tui := TasklistUI._live_[profileid].get(channel.id, None)) is not None: if channel and (tui := TasklistUI._live_[userid].get(channel.id, None)) is not None:
try: try:
if summon and await self.is_tasklist_channel(channel): if summon and await self.is_tasklist_channel(channel):
await tui.summon() await tui.summon()
@@ -325,7 +240,7 @@ class TasklistCog(LionCog):
await tui.close() await tui.close()
# Now do the rest of the listening channels # Now do the rest of the listening channels
listening = TasklistUI._live_[profileid] listening = TasklistUI._live_[userid]
for cid, ui in list(listening.items()): for cid, ui in list(listening.items()):
if channel and channel.id == cid: if channel and channel.id == cid:
# We already did this channel # We already did this channel
@@ -360,7 +275,7 @@ class TasklistCog(LionCog):
async def tasklist_group(self, ctx: LionContext): async def tasklist_group(self, ctx: LionContext):
raise NotImplementedError raise NotImplementedError
async def _task_acmpl(self, profileid: int, partial: str, multi=False) -> list[appcmds.Choice]: async def _task_acmpl(self, userid: int, partial: str, multi=False) -> list[appcmds.Choice]:
""" """
Generate a list of task Choices matching a given partial string. Generate a list of task Choices matching a given partial string.
@@ -369,7 +284,7 @@ class TasklistCog(LionCog):
t = self.bot.translator.t t = self.bot.translator.t
# Should usually be cached, so this won't trigger repetitive db access # Should usually be cached, so this won't trigger repetitive db access
tasklist = await Tasklist.fetch(self.bot, self.data, profileid) tasklist = await Tasklist.fetch(self.bot, self.data, userid)
# Special case for an empty tasklist # Special case for an empty tasklist
if not tasklist.tasklist: if not tasklist.tasklist:
@@ -477,17 +392,13 @@ class TasklistCog(LionCog):
""" """
Shared autocomplete for single task parameters. Shared autocomplete for single task parameters.
""" """
profile = await self.bot.profiles.fetch_profile_discord(interaction.user) return await self._task_acmpl(interaction.user.id, partial, multi=False)
profileid = profile.profileid
return await self._task_acmpl(profileid, partial, multi=False)
async def tasks_acmpl(self, interaction: discord.Interaction, partial: str) -> list[appcmds.Choice]: async def tasks_acmpl(self, interaction: discord.Interaction, partial: str) -> list[appcmds.Choice]:
""" """
Shared autocomplete for multiple task parameters. Shared autocomplete for multiple task parameters.
""" """
profile = await self.bot.profiles.fetch_profile_discord(interaction.user) return await self._task_acmpl(interaction.user.id, partial, multi=True)
profileid = profile.profileid
return await self._task_acmpl(profileid, partial, multi=True)
@tasklist_group.command( @tasklist_group.command(
name=_p('cmd:tasks_new', "new"), name=_p('cmd:tasks_new', "new"),
@@ -511,7 +422,7 @@ class TasklistCog(LionCog):
if not ctx.interaction: if not ctx.interaction:
return return
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
# Fetch parent task if required # Fetch parent task if required
@@ -542,9 +453,9 @@ class TasklistCog(LionCog):
) )
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=embed, embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot) view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
tasklist_new_cmd.autocomplete('parent')(task_acmpl) tasklist_new_cmd.autocomplete('parent')(task_acmpl)
@@ -612,7 +523,7 @@ class TasklistCog(LionCog):
raise UserInputError(error) raise UserInputError(error)
# Contents successfully parsed, update the tasklist. # Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
taskinfo = tasklist.parse_tasklist(lines) taskinfo = tasklist.parse_tasklist(lines)
@@ -661,9 +572,9 @@ class TasklistCog(LionCog):
) )
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=embed, embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot) view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
@tasklist_group.command( @tasklist_group.command(
name=_p('cmd:tasks_edit', "edit"), name=_p('cmd:tasks_edit', "edit"),
@@ -689,7 +600,7 @@ class TasklistCog(LionCog):
t = self.bot.translator.t t = self.bot.translator.t
if not ctx.interaction: if not ctx.interaction:
return return
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
# Fetch task to edit # Fetch task to edit
tid = tasklist.parse_label(taskstr) if taskstr else None tid = tasklist.parse_label(taskstr) if taskstr else None
@@ -740,12 +651,12 @@ class TasklistCog(LionCog):
await interaction.response.send_message( await interaction.response.send_message(
embed=embed, embed=embed,
view=( view=(
discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] discord.utils.MISSING if ctx.channel.id in TasklistUI._live_[ctx.author.id]
else TasklistCaller(self.bot) else TasklistCaller(self.bot)
), ),
ephemeral=True ephemeral=True
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
if new_content or new_parent: if new_content or new_parent:
# Manual edit route # Manual edit route
@@ -777,17 +688,17 @@ class TasklistCog(LionCog):
async def tasklist_clear_cmd(self, ctx: LionContext): async def tasklist_clear_cmd(self, ctx: LionContext):
t = ctx.bot.translator.t t = ctx.bot.translator.t
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
await tasklist.update_tasklist(deleted_at=utc_now()) await tasklist.update_tasklist(deleted_at=utc_now())
await ctx.reply( await ctx.reply(
t(_p( t(_p(
'cmd:tasks_clear|resp:success', 'cmd:tasks_clear|resp:success',
"Your tasklist has been cleared." "Your tasklist has been cleared."
)), )),
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot), view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot),
ephemeral=True ephemeral=True
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
@tasklist_group.command( @tasklist_group.command(
name=_p('cmd:tasks_remove', "remove"), name=_p('cmd:tasks_remove', "remove"),
@@ -837,7 +748,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
conditions = [] conditions = []
if taskidstr: if taskidstr:
@@ -873,7 +784,7 @@ class TasklistCog(LionCog):
elif completed is False: elif completed is False:
conditions.append(self.data.Task.completed_at == NULL) conditions.append(self.data.Task.completed_at == NULL)
tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.profile.profileid) tasks = await self.data.Task.fetch_where(*conditions, userid=ctx.author.id)
if not tasks: if not tasks:
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=error_embed(t(_p( embed=error_embed(t(_p(
@@ -902,9 +813,9 @@ class TasklistCog(LionCog):
) )
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=embed, embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot) view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
tasklist_remove_cmd.autocomplete('taskidstr')(tasks_acmpl) tasklist_remove_cmd.autocomplete('taskidstr')(tasks_acmpl)
@@ -933,7 +844,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
try: try:
taskids = tasklist.parse_labels(taskidstr) taskids = tasklist.parse_labels(taskidstr)
@@ -978,9 +889,9 @@ class TasklistCog(LionCog):
) )
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=embed, embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot) view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
tasklist_tick_cmd.autocomplete('taskidstr')(tasks_acmpl) tasklist_tick_cmd.autocomplete('taskidstr')(tasks_acmpl)
@@ -1009,7 +920,7 @@ class TasklistCog(LionCog):
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.profile.profileid) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
try: try:
taskids = tasklist.parse_labels(taskidstr) taskids = tasklist.parse_labels(taskidstr)
@@ -1051,9 +962,9 @@ class TasklistCog(LionCog):
) )
await ctx.interaction.edit_original_response( await ctx.interaction.edit_original_response(
embed=embed, embed=embed,
view=None if ctx.channel.id in TasklistUI._live_[ctx.profile.profileid] else TasklistCaller(self.bot) view=None if ctx.channel.id in TasklistUI._live_[ctx.author.id] else TasklistCaller(self.bot)
) )
self.bot.dispatch('tasklist_update', profileid=ctx.profile.profileid, channel=ctx.channel) self.bot.dispatch('tasklist_update', userid=ctx.author.id, channel=ctx.channel)
tasklist_untick_cmd.autocomplete('taskidstr')(tasks_acmpl) tasklist_untick_cmd.autocomplete('taskidstr')(tasks_acmpl)

View File

@@ -5,7 +5,6 @@ from data.columns import Integer, String, Timestamp, Bool
class TasklistData(Registry): class TasklistData(Registry):
class Task(RowModel): class Task(RowModel):
""" """
Row model describing a single task in a tasklist. Row model describing a single task in a tasklist.
@@ -15,17 +14,21 @@ class TasklistData(Registry):
CREATE TABLE tasklist( CREATE TABLE tasklist(
taskid SERIAL PRIMARY KEY, taskid SERIAL PRIMARY KEY,
userid BIGINT NOT NULL REFERENCES user_config ON DELETE CASCADE, userid BIGINT NOT NULL REFERENCES user_config ON DELETE CASCADE,
profileid INTEGER NOT NULL REFERENCES user_profiles ON DELETE CASCADE ON UPDATE CASCADE,
parentid INTEGER REFERENCES tasklist (taskid) ON DELETE SET NULL, parentid INTEGER REFERENCES tasklist (taskid) ON DELETE SET NULL,
content TEXT NOT NULL, content TEXT NOT NULL,
rewarded BOOL DEFAULT FALSE, rewarded BOOL DEFAULT FALSE,
deleted_at TIMESTAMPTZ, deleted_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ, completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ, created_at TIMESTAMPTZ,
duration INTEGER,
last_updated_at TIMESTAMPTZ last_updated_at TIMESTAMPTZ
); );
CREATE INDEX tasklist_users ON tasklist (userid); CREATE INDEX tasklist_users ON tasklist (userid);
CREATE TABLE tasklist_channels(
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
channelid BIGINT NOT NULL
);
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
""" """
_tablename_ = "tasklist" _tablename_ = "tasklist"
@@ -38,26 +41,5 @@ class TasklistData(Registry):
created_at = Timestamp() created_at = Timestamp()
deleted_at = Timestamp() deleted_at = Timestamp()
last_updated_at = Timestamp() last_updated_at = Timestamp()
duration = Integer()
"""
Schema
------
CREATE TABLE tasklist_channels(
guildid BIGINT NOT NULL REFERENCES guild_config (guildid) ON DELETE CASCADE,
channelid BIGINT NOT NULL
);
CREATE INDEX tasklist_channels_guilds ON tasklist_channels (guildid);
"""
channels = Table('tasklist_channels') channels = Table('tasklist_channels')
"""
Schema
------
CREATE TABLE current_tasks(
taskid PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
last_started_at TIMESTAMPTZ NOT NULL
);
"""
current_tasks = Table('current_tasks')

View File

@@ -1,23 +0,0 @@
ALTER TABLE tasklist
DROP CONSTRAINT fk_tasklist_users;
ALTER TABLE tasklist
ADD CONSTRAINT fk_tasklist_users
FOREIGN KEY (userid)
REFERENCES user_profiles (profileid)
ON DELETE CASCADE
ON UPDATE CASCADE
NOT VALID;
ALTER TABLE tasklist
ADD COLUMN duration INTEGER;
CREATE TABLE tasklist_current(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
started_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE tasklist_planner(
taskid INTEGER PRIMARY KEY REFERENCES tasklist (taskid) ON DELETE CASCADE ON UPDATE CASCADE,
sortkey INTEGER
);

View File

@@ -232,18 +232,13 @@ class TasklistUI(BasePager):
def __init__(self, def __init__(self,
tasklist: Tasklist, tasklist: Tasklist,
channel: discord.abc.Messageable, channel: discord.abc.Messageable, guild: Optional[discord.Guild] = None, **kwargs):
guild: Optional[discord.Guild] = None,
caller: Optional[discord.User | discord.Member] = None,
**kwargs):
kwargs.setdefault('timeout', 600) kwargs.setdefault('timeout', 600)
super().__init__(**kwargs) super().__init__(**kwargs)
self.bot = tasklist.bot self.bot = tasklist.bot
self.tasklist = tasklist self.tasklist = tasklist
self.labelled = tasklist.labelled self.labelled = tasklist.labelled
self.caller = caller
# NOTE: This is now a profiled
self.userid = tasklist.userid self.userid = tasklist.userid
self.channel = channel self.channel = channel
self.guild = guild self.guild = guild
@@ -454,10 +449,9 @@ class TasklistUI(BasePager):
cascade=True, cascade=True,
completed_at=utc_now() completed_at=utc_now()
) )
# TODO: Removed economy integration if self.guild:
# if self.guild: if (member := self.guild.get_member(self.userid)):
# if (member := self.guild.get_member(self.userid)): self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
# self.bot.dispatch('tasks_completed', member, *(t.taskid for t in to_complete))
if to_uncomplete: if to_uncomplete:
await self.tasklist.update_tasks( await self.tasklist.update_tasks(
*(t.taskid for t in to_uncomplete), *(t.taskid for t in to_uncomplete),
@@ -481,7 +475,7 @@ class TasklistUI(BasePager):
if shared_root: if shared_root:
self._subtree_root = labelled[shared_root].taskid self._subtree_root = labelled[shared_root].taskid
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
async def _delete_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool): async def _delete_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
await interaction.response.defer() await interaction.response.defer()
@@ -492,7 +486,7 @@ class TasklistUI(BasePager):
cascade=True, cascade=True,
deleted_at=utc_now() deleted_at=utc_now()
) )
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
async def _edit_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool): async def _edit_menu(self, interaction: discord.Interaction, selected: Select, subtree: bool):
if not selected.values: if not selected.values:
@@ -519,7 +513,7 @@ class TasklistUI(BasePager):
self._last_parentid = new_parentid self._last_parentid = new_parentid
if not subtree: if not subtree:
self._subtree_root = new_parentid self._subtree_root = new_parentid
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
await interaction.response.send_modal(editor) await interaction.response.send_modal(editor)
@@ -612,7 +606,7 @@ class TasklistUI(BasePager):
self._subtree_root = pid self._subtree_root = pid
await interaction.response.defer() await interaction.response.defer()
await self.tasklist.create_task(new_task, parentid=pid) await self.tasklist.create_task(new_task, parentid=pid)
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
await press.response.send_modal(editor) await press.response.send_modal(editor)
@@ -673,7 +667,7 @@ class TasklistUI(BasePager):
@editor.add_callback @editor.add_callback
async def editor_callback(interaction: discord.Interaction): async def editor_callback(interaction: discord.Interaction):
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
if sum(len(line) for line in editor.lines.values()) + len(editor.lines) >= 4000: if sum(len(line) for line in editor.lines.values()) + len(editor.lines) >= 4000:
await press.response.send_message( await press.response.send_message(
@@ -704,7 +698,7 @@ class TasklistUI(BasePager):
await self.tasklist.update_tasklist( await self.tasklist.update_tasklist(
deleted_at=utc_now(), deleted_at=utc_now(),
) )
self.bot.dispatch('tasklist_update', profileid=self.userid, channel=self.channel, summon=False) self.bot.dispatch('tasklist_update', userid=self.userid, channel=self.channel, summon=False)
async def clear_button_refresh(self): async def clear_button_refresh(self):
self.clear_button.label = self.bot.translator.t(_p( self.clear_button.label = self.bot.translator.t(_p(
@@ -777,12 +771,11 @@ class TasklistUI(BasePager):
# ----- UI Flow ----- # ----- UI Flow -----
def access_check(self, userid): def access_check(self, userid):
return userid in (self.userid, self.caller.id if self.caller else None) return userid == self.userid
async def interaction_check(self, interaction: discord.Interaction): async def interaction_check(self, interaction: discord.Interaction):
t = self.bot.translator.t t = self.bot.translator.t
interaction_profile = await self.bot.profiles.fetch_profile_discord(interaction.user) if not self.access_check(interaction.user.id):
if not self.access_check(interaction_profile.profileid):
embed = discord.Embed( embed = discord.Embed(
description=t(_p( description=t(_p(
'ui:tasklist|error:wrong_user', 'ui:tasklist|error:wrong_user',
@@ -819,7 +812,10 @@ class TasklistUI(BasePager):
total = len(tasks) total = len(tasks)
completed = sum(t.completed_at is not None for t in tasks) completed = sum(t.completed_at is not None for t in tasks)
user = self.caller if self.guild:
user = self.guild.get_member(self.userid)
else:
user = self.bot.get_user(self.userid)
user_name = user.name if user else str(self.userid) user_name = user.name if user else str(self.userid)
user_colour = user.colour if user else discord.Color.orange() user_colour = user.colour if user else discord.Color.orange()

View File

@@ -7,6 +7,7 @@ import iso8601 # type: ignore
import pytz import pytz
import re import re
import json import json
import asyncio
from contextvars import Context from contextvars import Context
import discord import discord
@@ -918,3 +919,126 @@ def write_records(records: list[dict[str, Any]], stream: StringIO):
for record in records: for record in records:
stream.write(','.join(map(str, record.values()))) stream.write(','.join(map(str, record.values())))
stream.write('\n') stream.write('\n')
async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs):
"""
Shows the user each page from the provided list `pages` one at a time,
providing reactions to page back and forth between pages.
This is done asynchronously, and returns after displaying the first page.
Parameters
----------
pages: List(Union(str, discord.Embed))
A list of either strings or embeds to display as the pages.
locked: bool
Whether only the `ctx.author` should be able to use the paging reactions.
kwargs: ...
Remaining keyword arguments are transparently passed to the reply context method.
Returns: discord.Message
This is the output message, returned for easy deletion.
"""
cancel_emoji = cross
# Handle broken input
if len(pages) == 0:
raise ValueError("Pager cannot page with no pages!")
# Post first page. Method depends on whether the page is an embed or not.
if isinstance(pages[start_at], discord.Embed):
out_msg = await ctx.reply(embed=pages[start_at], **kwargs)
else:
out_msg = await ctx.reply(pages[start_at], **kwargs)
# Run the paging loop if required
if len(pages) > 1:
task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs))
# ctx.tasks.append(task)
elif add_cancel:
await out_msg.add_reaction(cancel_emoji)
# Return the output message
return out_msg
async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs):
"""
Asynchronous initialiser and loop for the `pager` utility above.
"""
# Page number
page = start_at
# Add reactions to the output message
next_emoji = ""
prev_emoji = ""
cancel_emoji = cross
try:
await out_msg.add_reaction(prev_emoji)
if add_cancel:
await out_msg.add_reaction(cancel_emoji)
await out_msg.add_reaction(next_emoji)
except discord.Forbidden:
# We don't have permission to add paging emojis
# Die as gracefully as we can
if ctx.guild:
perms = ctx.channel.permissions_for(ctx.guild.me)
if not perms.add_reactions:
await ctx.error_reply(
"Cannot page results because I do not have the `add_reactions` permission!"
)
elif not perms.read_message_history:
await ctx.error_reply(
"Cannot page results because I do not have the `read_message_history` permission!"
)
else:
await ctx.error_reply(
"Cannot page results due to insufficient permissions!"
)
else:
await ctx.error_reply(
"Cannot page results!"
)
return
# Check function to determine whether a reaction is valid
def check(reaction, user):
result = reaction.message.id == out_msg.id
result = result and str(reaction.emoji) in [next_emoji, prev_emoji]
result = result and not (user.id == ctx.bot.user.id)
result = result and not (locked and user != ctx.author)
return result
# Begin loop
while True:
# Wait for a valid reaction, break if we time out
try:
reaction, user = await ctx.bot.wait_for('reaction_add', check=check, timeout=300)
except asyncio.TimeoutError:
break
# Attempt to remove the user's reaction, silently ignore errors
asyncio.ensure_future(out_msg.remove_reaction(reaction.emoji, user))
# Change the page number
page += 1 if reaction.emoji == next_emoji else -1
page %= len(pages)
# Edit the message with the new page
active_page = pages[page]
if isinstance(active_page, discord.Embed):
await out_msg.edit(embed=active_page, **kwargs)
else:
await out_msg.edit(content=active_page, **kwargs)
# Clean up by removing the reactions
try:
await out_msg.clear_reactions()
except discord.Forbidden:
try:
await out_msg.remove_reaction(next_emoji, ctx.client.user)
await out_msg.remove_reaction(prev_emoji, ctx.client.user)
except discord.NotFound:
pass
except discord.NotFound:
pass