@@ -5,12 +5,16 @@ from datetime import timedelta
import discord
from discord . ext import commands as cmds
from discord import app_commands as appcmds
import twitchio
from twitchio . ext import commands
from data . queries import ORDER
from meta import LionCog , LionBot , CrocBot
from meta import LionCog , LionBot , CrocBot , LionContext
from modules . profiles . community import Community
from modules . profiles . profile import UserProfile
from utils . lib import utc_now
from . import logger
from . data import CounterData
@@ -25,6 +29,11 @@ class PERIOD(Enum):
YEAR = ( ' this year ' , ' y ' , ' year ' , ' yearly ' )
class ORIGIN ( Enum ) :
DISCORD = ' discord '
TWITCH = ' twitch '
def counter_cmd_factory (
counter : str ,
response : str ,
@@ -32,10 +41,16 @@ def counter_cmd_factory(
context : Optional [ str ] = None
) :
context = context or f " cmd: { counter } "
async def counter_cmd ( cog , ctx : commands . Context , * , args : Optional [ str ] = None ) :
userid = int ( ctx . author . id )
channelid = int ( ( await ctx . channel . user ( ) ) . id )
period , start_time = await cog . parse_period ( channelid , ' ' , default = default_period )
async def counter_cmd (
cog ,
ctx : commands . Context | LionContext ,
origin : ORIGIN ,
author : UserProfile ,
community : Community ,
args : Optional [ str ]
) :
userid = author . profileid
period , start_time = await cog . parse_period ( community , ' ' , default = default_period )
args = ( args or ' ' ) . strip ( " " )
splits = args . split ( maxsplit = 1 )
@@ -69,13 +84,25 @@ def counter_cmd_factory(
)
)
async def lb_cmd ( cog , ctx : commands . Context , * , args : str = ' ' ) :
user = await ctx . channel . user ( )
await ctx . reply ( await cog . formatted_lb ( counter , args , int ( user . id ) ) )
async def lb_cmd (
cog ,
ctx : commands . Context | LionContext ,
origin : ORIGIN ,
author : UserProfile ,
community : Community ,
args : Optional [ str ]
) :
await ctx . reply ( await cog . formatted_lb ( counter , args , community , origin ) )
async def undo_cmd ( cog , ctx : commands . Context ) :
userid = int ( ctx . author . id )
channelid = int ( ( await ctx . channel . user ( ) ) . id )
async def undo_cmd (
cog ,
ctx : commands . Context | LionContext ,
origin : ORIGIN ,
author : UserProfile ,
community : Community ,
args : Optional [ str ]
) :
userid = author . profileid
_counter = await cog . fetch_counter ( counter )
query = cog . data . CounterEntry . fetch_where (
counterid = _counter . counterid ,
@@ -113,6 +140,9 @@ class CounterCog(LionCog):
await self . load_counters ( )
self . loaded . set ( )
profiles = self . bot . get_cog ( ' ProfileCog ' )
profiles . add_profile_migrator ( self . migrate_profiles , name = ' counters ' )
async def cog_unload ( self ) :
self . _unload_twitch_methods ( self . crocbot )
@@ -124,18 +154,49 @@ class CounterCog(LionCog):
counter . name ,
row . response
)
cmds = [ ]
main _cmd = commands . command ( name = row . name ) ( counter_cb )
cmds . append ( main_cmd )
if row . lbname :
lb_cmd = commands . command ( name = row . lb name) ( lb_cb )
cmds . append ( lb_cmd )
if row . undoname :
undo_cmd = commands . command ( name = row . undoname ) ( undo_cb )
cmds . append ( undo_cm d)
twitch_ cmds = [ ]
disc _cmds = [ ]
twitch_ cmds. append (
commands . command (
name = row . name
) ( self . twitch_callback ( counter_cb ) )
)
disc_cmds . append (
cmds . hybrid_comman d(
name = row . name
) ( self . discord_callback ( counter_cb ) )
)
for cmd in cmds :
if row . lbname :
twitch_cmds . append (
commands . command (
name = row . lbname
) ( self . twitch_callback ( lb_cb ) )
)
disc_cmds . append (
cmds . hybrid_command (
name = row . lbname
) ( self . discord_callback ( lb_cb ) )
)
if row . undoname :
twitch_cmds . append (
commands . command (
name = row . undoname
) ( self . twitch_callback ( undo_cb ) )
)
disc_cmds . append (
cmds . hybrid_command (
name = row . undoname
) ( self . discord_callback ( undo_cb ) )
)
for cmd in twitch_cmds :
self . add_twitch_command ( self . crocbot , cmd )
for cmd in disc_cmds :
# cmd.cog = self
self . bot . remove_command ( cmd . name )
self . bot . add_command ( cmd )
print ( f " Adding command: { cmd } " )
logger . info ( f " (Re)Loaded { len ( rows ) } counter commands! " )
@@ -152,6 +213,87 @@ class CounterCog(LionCog):
f " Loaded { len ( self . counters ) } counters. "
)
async def migrate_profiles ( self , source_profile : UserProfile , target_profile : UserProfile ) :
"""
Move source profile entries to target profile entries
"""
results = [ " (Counters) " ]
rows = await self . data . CounterEntry . table . update_where ( userid = source_profile . profileid ) . set ( userid = target_profile . profileid )
if rows :
results . append (
f " Migrated { len ( rows ) } counter entries from source profile. "
)
else :
results . append (
" No counter entries to migrate in source profile. "
)
return ' ' . join ( results )
async def user_profile_migration ( self ) :
"""
Manual single-use migration method from the old userid format to the new profileid format.
"""
async with self . bot . db . connection ( ) as conn :
self . bot . db . conn = conn
async with conn . transaction ( ) :
entries = await self . data . CounterEntry . fetch_where ( )
for entry in entries :
if entry . userid > 1000 :
# Assume userid is a twitch userid
profile = await UserProfile . fetch_from_twitchid ( self . bot , entry . userid )
if not profile :
# Need to create
users = await self . crocbot . fetch_users ( ids = [ entry . userid ] )
if not users :
continue
user = users [ 0 ]
profile = await UserProfile . create_from_twitch ( self . bot , user )
await entry . update ( userid = profile . profileid )
logger . info ( " Completed single-shot user profile migration " )
# General API
def twitch_callback ( self , callback ) :
"""
Generate a Twitch command callback from the given general callback.
General callback must be of the form
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
Return will be a command callback of the form
callback(cog, ctx: Context, *, args: Optional[str] = None)
"""
async def command_callback ( cog : CounterCog , ctx : commands . Context , * , args : Optional [ str ] = None ) :
profiles = cog . bot . get_cog ( ' ProfileCog ' )
# Compute author profile
author = await profiles . fetch_profile_twitch ( ctx . author )
# Compute community profile
community = await profiles . fetch_community_twitch ( await ctx . channel . user ( ) )
return await callback ( cog , ctx , ORIGIN . TWITCH , author , community , args )
return command_callback
def discord_callback ( self , callback ) :
"""
Generate a Discord command callback from the given general callback.
General callback must be of the form
callback(cog, ctx: GeneralContext, origin: ORIGIN, author: Profile, comm: Community, args: Optional[str])
Return will be a command callback of the form
callback(cog, ctx: LionContext, *, args: Optional[str] = None)
"""
cog = self
async def command_callback ( ctx : LionContext , * , args : Optional [ str ] = None ) :
profiles = cog . bot . get_cog ( ' ProfileCog ' )
# Compute author profile
author = await profiles . fetch_profile_discord ( ctx . author )
# Compute community profile
community = await profiles . fetch_community_discord ( ctx . guild )
return await callback ( cog , ctx , ORIGIN . DISCORD , author , community , args )
return command_callback
# Counters API
async def fetch_counter ( self , counter : str ) - > CounterData . Counter :
@@ -218,6 +360,14 @@ class CounterCog(LionCog):
results = await query
return results [ 0 ] [ ' counter_total ' ] if results else 0
# Manage commands
@commands.command ( )
async def countermigration ( self , ctx : commands . Context , * , args : Optional [ str ] = None ) :
if not ( ctx . author . is_mod or ctx . author . is_broadcaster ) :
return
await self . user_profile_migration ( )
await ctx . reply ( " Counter userid->profileid migration done. " )
# Counters commands
@commands.command ( )
async def counter ( self , ctx : commands . Context , name : str , subcmd : Optional [ str ] , * , args : Optional [ str ] = None ) :
@@ -225,6 +375,10 @@ class CounterCog(LionCog):
return
name = name . lower ( )
profiles = self . bot . get_cog ( ' ProfileCog ' )
author = await profiles . fetch_profile_twitch ( ctx . author )
userid = author . profileid
community = await profiles . fetch_community_twitch ( await ctx . channel . user ( ) )
if subcmd is None or subcmd == ' show ' :
# Show
@@ -241,15 +395,14 @@ class CounterCog(LionCog):
return
await self . add_to_counter (
name ,
int ( ctx . author . id ) ,
userid ,
value ,
context = ' cmd: counter add '
)
total = await self . totals ( name )
await ctx . reply ( f " ' { name } ' counter is now: { total } " )
elif subcmd == ' lb ' :
use r = await ctx . channel . user ( )
lbstr = await self . formatted_lb ( name , args or ' ' , int ( user . id ) )
lbst r = await self . formatted_lb ( name , args or ' ' , community )
await ctx . reply ( lbstr )
elif subcmd == ' clear ' :
await self . reset_counter ( name )
@@ -292,7 +445,7 @@ class CounterCog(LionCog):
else :
await ctx . reply ( f " Unrecognised subcommand { subcmd } . Supported subcommands: ' show ' , ' add ' , ' lb ' , ' clear ' , ' alias ' . " )
async def parse_period ( self , userid : int , periodstr : str , default = PERIOD . STREAM ) :
async def parse_period ( self , community : Community , periodstr : str , default = PERIOD . STREAM ) :
if periodstr :
period = next ( ( period for period in PERIOD if periodstr . lower ( ) in period . value ) , None )
if period is None :
@@ -306,9 +459,13 @@ class CounterCog(LionCog):
if period is PERIOD . ALL :
start_time = None
elif period is PERIOD . STREAM :
stream s = await self . crocbot . fe tch_streams ( user_ids = [ userid ] )
if streams :
stream = streams [ 0 ]
twitche s = await community . twi tch_channels ( )
stream = None
if twitches :
twitch = twitches [ 0 ]
streams = await self . crocbot . fetch_streams ( user_ids = [ int ( twitch . channelid ) ] )
stream = streams [ 0 ] if streams else None
if stream :
start_time = stream . started_at
else :
period = PERIOD . ALL
@@ -327,21 +484,33 @@ class CounterCog(LionCog):
return ( period , start_time )
async def formatted_lb ( self , counter : str , periodstr : str , channelid : int ) :
async def formatted_lb (
self ,
counter : str ,
periodstr : str ,
community : Community ,
origin : ORIGIN = ORIGIN . TWITCH
) :
period , start_time = await self . parse_period ( channelid , periodstr )
period , start_time = await self . parse_period ( community , periodstr )
lb = await self . leaderboard ( counter , start_time = start_time )
if lb :
userids = list ( lb . keys ( ) )
users = await self . crocbot . fetch_users ( ids = userids )
name_map = { user . id : user . display_name for user in users }
name_map = { }
for userid in lb . keys ( ) :
profile = await UserProfile . fetch ( self . bot , userid )
name = await profile . get_name ( )
name_map [ userid ] = name
parts = [ ]
for userid , total in lb . items ( ) :
items = list ( lb . items ( ) )
prefix = ' top 10 ' if len ( items ) > 10 else ' '
items = items [ : 10 ]
for userid , total in items :
name = name_map . get ( userid , str ( userid ) )
part = f " { name } : { total } "
parts . append ( part )
lbstr = ' ; ' . join ( parts )
return f " { counter } { period . value [ - 1 ] } leaderboard --- { lbstr } "
return f " { counter } { period . value [ - 1 ] } { prefix } leaderboard --- { lbstr } "
else :
return f " { counter } { period . value [ - 1 ] } leaderboard is empty! "