diff --git a/src/modules/counters/cog.py b/src/modules/counters/cog.py index 2508208f..df01057a 100644 --- a/src/modules/counters/cog.py +++ b/src/modules/counters/cog.py @@ -5,12 +5,16 @@ from datetime import timedelta import discord from discord.ext import commands as cmds +from discord import app_commands as appcmds + import twitchio from twitchio.ext import commands from data.queries import ORDER -from meta import LionCog, LionBot, CrocBot +from meta import LionCog, LionBot, CrocBot, LionContext +from modules.profiles.community import Community +from modules.profiles.profile import UserProfile from utils.lib import utc_now from . import logger from .data import CounterData @@ -25,6 +29,11 @@ class PERIOD(Enum): YEAR = ('this year', 'y', 'year', 'yearly') +class ORIGIN(Enum): + DISCORD = 'discord' + TWITCH = 'twitch' + + def counter_cmd_factory( counter: str, response: str, @@ -32,10 +41,16 @@ def counter_cmd_factory( context: Optional[str] = None ): context = context or f"cmd: {counter}" - async def counter_cmd(cog, ctx: commands.Context, *, args: Optional[str] = None): - userid = int(ctx.author.id) - channelid = int((await ctx.channel.user()).id) - period, start_time = await cog.parse_period(channelid, '', default=default_period) + async def counter_cmd( + cog, + ctx: commands.Context | LionContext, + origin: ORIGIN, + author: UserProfile, + community: Community, + args: Optional[str] + ): + userid = author.profileid + period, start_time = await cog.parse_period(community, '', default=default_period) args = (args or '').strip(" 󠀀 ") splits = args.split(maxsplit=1) @@ -69,13 +84,25 @@ def counter_cmd_factory( ) ) - async def lb_cmd(cog, ctx: commands.Context, *, args: str = ''): - user = await ctx.channel.user() - await ctx.reply(await cog.formatted_lb(counter, args, int(user.id))) + async def lb_cmd( + cog, + ctx: commands.Context | LionContext, + origin: ORIGIN, + author: UserProfile, + community: Community, + args: Optional[str] + ): + await ctx.reply(await cog.formatted_lb(counter, args, community, origin)) - async def undo_cmd(cog, ctx: commands.Context): - userid = int(ctx.author.id) - channelid = int((await ctx.channel.user()).id) + async def undo_cmd( + cog, + ctx: commands.Context | LionContext, + origin: ORIGIN, + author: UserProfile, + community: Community, + args: Optional[str] + ): + userid = author.profileid _counter = await cog.fetch_counter(counter) query = cog.data.CounterEntry.fetch_where( counterid=_counter.counterid, @@ -113,6 +140,9 @@ class CounterCog(LionCog): await self.load_counters() self.loaded.set() + profiles = self.bot.get_cog('ProfileCog') + profiles.add_profile_migrator(self.migrate_profiles, name='counters') + async def cog_unload(self): self._unload_twitch_methods(self.crocbot) @@ -124,18 +154,48 @@ class CounterCog(LionCog): counter.name, row.response ) - cmds = [] - main_cmd = commands.command(name=row.name)(counter_cb) - cmds.append(main_cmd) - if row.lbname: - lb_cmd = commands.command(name=row.lbname)(lb_cb) - cmds.append(lb_cmd) - if row.undoname: - undo_cmd = commands.command(name=row.undoname)(undo_cb) - cmds.append(undo_cmd) + twitch_cmds = [] + disc_cmds = [] + twitch_cmds.append( + commands.command( + name=row.name + )(self.twitch_callback(counter_cb)) + ) + disc_cmds.append( + cmds.hybrid_command( + name=row.name + )(self.discord_callback(counter_cb)) + ) - for cmd in cmds: + if row.lbname: + twitch_cmds.append( + commands.command( + name=row.lbname + )(self.twitch_callback(lb_cb)) + ) + disc_cmds.append( + cmds.hybrid_command( + name=row.lbname + )(self.discord_callback(lb_cb)) + ) + if row.undoname: + twitch_cmds.append( + commands.command( + name=row.undoname + )(self.twitch_callback(undo_cb)) + ) + disc_cmds.append( + cmds.hybrid_command( + name=row.undoname + )(self.discord_callback(undo_cb)) + ) + + for cmd in twitch_cmds: self.add_twitch_command(self.crocbot, cmd) + for cmd in disc_cmds: + # cmd.cog = self + self.bot.add_command(cmd) + print(f"Adding command: {cmd}") logger.info(f"(Re)Loaded {len(rows)} counter commands!") @@ -152,6 +212,87 @@ class CounterCog(LionCog): f"Loaded {len(self.counters)} counters." ) + async def migrate_profiles(self, source_profile: UserProfile, target_profile: UserProfile): + """ + Move source profile entries to target profile entries + """ + results = ["(Counters)"] + + rows = await self.data.CounterEntry.table.update_where(userid=source_profile.profileid).set(userid=target_profile.profileid) + if rows: + results.append( + f"Migrated {len(rows)} counter entries from source profile." + ) + else: + results.append( + "No counter entries to migrate in source profile." + ) + + return ' '.join(results) + + async def user_profile_migration(self): + """ + Manual single-use migration method from the old userid format to the new profileid format. + """ + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + entries = await self.data.CounterEntry.fetch_where() + for entry in entries: + if entry.userid > 1000: + # Assume userid is a twitch userid + profile = await UserProfile.fetch_from_twitchid(self.bot, entry.userid) + if not profile: + # Need to create + users = await self.crocbot.fetch_users(ids=[entry.userid]) + if not users: + continue + user = users[0] + profile = await UserProfile.create_from_twitch(self.bot, user) + await entry.update(userid=profile.profileid) + logger.info("Completed single-shot user profile migration") + + # General API + def twitch_callback(self, callback): + """ + Generate a Twitch command callback from the given general callback. + + General callback must be of the form + callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str]) + + Return will be a command callback of the form + callback(cog, ctx: Context, *, args: Optional[str] = None) + """ + async def command_callback(cog: CounterCog, ctx: commands.Context, *, args: Optional[str] = None): + profiles = cog.bot.get_cog('ProfileCog') + # Compute author profile + author = await profiles.fetch_profile_twitch(ctx.author) + # Compute community profile + community = await profiles.fetch_community_twitch(await ctx.channel.user()) + return await callback(cog, ctx, ORIGIN.TWITCH, author, community, args) + return command_callback + + def discord_callback(self, callback): + """ + Generate a Discord command callback from the given general callback. + + General callback must be of the form + callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str]) + + Return will be a command callback of the form + callback(cog, ctx: LionContext, *, args: Optional[str] = None) + """ + cog = self + async def command_callback(ctx: LionContext, *, args: Optional[str] = None): + profiles = cog.bot.get_cog('ProfileCog') + # Compute author profile + author = await profiles.fetch_profile_discord(ctx.author) + # Compute community profile + community = await profiles.fetch_community_discord(ctx.guild) + return await callback(cog, ctx, ORIGIN.DISCORD, author, community, args) + + return command_callback + # Counters API async def fetch_counter(self, counter: str) -> CounterData.Counter: @@ -218,6 +359,14 @@ class CounterCog(LionCog): results = await query return results[0]['counter_total'] if results else 0 + # Manage commands + @commands.command() + async def countermigration(self, ctx: commands.Context, *, args: Optional[str]=None): + if not (ctx.author.is_mod or ctx.author.is_broadcaster): + return + await self.user_profile_migration() + await ctx.reply("Counter userid->profileid migration done.") + # Counters commands @commands.command() async def counter(self, ctx: commands.Context, name: str, subcmd: Optional[str], *, args: Optional[str]=None): @@ -225,6 +374,10 @@ class CounterCog(LionCog): return name = name.lower() + profiles = self.bot.get_cog('ProfileCog') + author = await profiles.fetch_profile_twitch(ctx.author) + userid = author.profileid + community = await profiles.fetch_community_twitch(await ctx.channel.user()) if subcmd is None or subcmd == 'show': # Show @@ -241,15 +394,14 @@ class CounterCog(LionCog): return await self.add_to_counter( name, - int(ctx.author.id), + userid, value, context='cmd: counter add' ) total = await self.totals(name) await ctx.reply(f"'{name}' counter is now: {total}") elif subcmd == 'lb': - user = await ctx.channel.user() - lbstr = await self.formatted_lb(name, args or '', int(user.id)) + lbstr = await self.formatted_lb(name, args or '', community) await ctx.reply(lbstr) elif subcmd == 'clear': await self.reset_counter(name) @@ -292,7 +444,7 @@ class CounterCog(LionCog): else: await ctx.reply(f"Unrecognised subcommand {subcmd}. Supported subcommands: 'show', 'add', 'lb', 'clear', 'alias'.") - async def parse_period(self, userid: int, periodstr: str, default=PERIOD.STREAM): + async def parse_period(self, community: Community, periodstr: str, default=PERIOD.STREAM): if periodstr: period = next((period for period in PERIOD if periodstr.lower() in period.value), None) if period is None: @@ -306,9 +458,13 @@ class CounterCog(LionCog): if period is PERIOD.ALL: start_time = None elif period is PERIOD.STREAM: - streams = await self.crocbot.fetch_streams(user_ids=[userid]) - if streams: - stream = streams[0] + twitches = await community.twitch_channels() + stream = None + if twitches: + twitch = twitches[0] + streams = await self.crocbot.fetch_streams(user_ids=[int(twitch.channelid)]) + stream = streams[0] if streams else None + if stream: start_time = stream.started_at else: period = PERIOD.ALL @@ -327,21 +483,33 @@ class CounterCog(LionCog): return (period, start_time) - async def formatted_lb(self, counter: str, periodstr: str, channelid: int): + async def formatted_lb( + self, + counter: str, + periodstr: str, + community: Community, + origin: ORIGIN = ORIGIN.TWITCH + ): - period, start_time = await self.parse_period(channelid, periodstr) + period, start_time = await self.parse_period(community, periodstr) lb = await self.leaderboard(counter, start_time=start_time) if lb: - userids = list(lb.keys()) - users = await self.crocbot.fetch_users(ids=userids) - name_map = {user.id: user.display_name for user in users} + name_map = {} + for userid in lb.keys(): + profile = await UserProfile.fetch(self.bot, userid) + name = await profile.get_name() + name_map[userid] = name + parts = [] - for userid, total in lb.items(): + items = list(lb.items()) + prefix = 'top 10 ' if len(items) > 10 else '' + items = items[:10] + for userid, total in items: name = name_map.get(userid, str(userid)) part = f"{name}: {total}" parts.append(part) lbstr = '; '.join(parts) - return f"{counter} {period.value[-1]} leaderboard --- {lbstr}" + return f"{counter} {period.value[-1]} {prefix}leaderboard --- {lbstr}" else: return f"{counter} {period.value[-1]} leaderboard is empty!" diff --git a/src/modules/nowdoing/cog.py b/src/modules/nowdoing/cog.py index bead18b5..550cf395 100644 --- a/src/modules/nowdoing/cog.py +++ b/src/modules/nowdoing/cog.py @@ -202,7 +202,7 @@ class NowDoingCog(LionCog): await self.data.Task.table.delete_where(userid=userid) task = await self.data.Task.create( userid=userid, - name=ctx.author.display_name, + name=await profile.get_name(), task=args, started_at=existing.started_at if (existing and edit) else utc_now(), ) @@ -272,7 +272,7 @@ class NowDoingCog(LionCog): await self.data.Task.table.delete_where(userid=userid) task = await self.data.Task.create( userid=userid, - name=ctx.author.display_name, + name=await profile.get_name(), task=args, started_at=utc_now(), ) diff --git a/src/modules/profiles/profile.py b/src/modules/profiles/profile.py index aaf66a96..d3ac4cd6 100644 --- a/src/modules/profiles/profile.py +++ b/src/modules/profiles/profile.py @@ -31,6 +31,30 @@ class UserProfile: def __repr__(self): return f"" + async def get_name(self): + # TODO: Store a preferred name in the profile preferences + # TODO Should have a multi-fetch system + name = None + twitches = await self.twitch_accounts() + if twitches: + users = await self.bot.crocbot.fetch_users( + ids=[int(twitch.userid) for twitch in twitches] + ) + if users: + user = users[0] + name = user.display_name + + if not name: + discords = await self.discord_accounts() + if discords: + user = await self.bot.fetch_user(discords[0].userid) + name = user.display_name + + if not name: + name = 'Unknown' + + return name + async def attach_discord(self, user: discord.User | discord.Member): """ Attach a new discord user to this profile.