rewrite (voice sessions): Voice session tracker.
This commit is contained in:
11
src/tracking/voice/__init__.py
Normal file
11
src/tracking/voice/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import logging
|
||||
from babel.translator import LocalBabel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
babel = LocalBabel('voice-tracker')
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
from .cog import VoiceTrackerCog
|
||||
|
||||
await bot.add_cog(VoiceTrackerCog(bot))
|
||||
700
src/tracking/voice/cog.py
Normal file
700
src/tracking/voice/cog.py
Normal file
@@ -0,0 +1,700 @@
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from collections import defaultdict
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
from discord import app_commands as appcmds
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.errors import UserInputError
|
||||
from meta.logger import log_wrap, logging_context
|
||||
from meta.sharding import THIS_SHARD
|
||||
from utils.lib import utc_now, error_embed
|
||||
from core.lion_guild import VoiceMode
|
||||
|
||||
from wards import low_management
|
||||
|
||||
from . import babel, logger
|
||||
from .data import VoiceTrackerData
|
||||
from .settings import VoiceTrackerSettings, VoiceTrackerConfigUI
|
||||
|
||||
from .session import VoiceSession, TrackedVoiceState
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
class VoiceTrackerCog(LionCog):
|
||||
"""
|
||||
LionCog module controlling and configuring the voice tracking subsystem.
|
||||
"""
|
||||
|
||||
def __init__(self, bot: LionBot):
|
||||
self.bot = bot
|
||||
self.data = bot.db.load_registry(VoiceTrackerData())
|
||||
self.settings = VoiceTrackerSettings()
|
||||
self.babel = babel
|
||||
|
||||
# State
|
||||
self.handle_events = False
|
||||
self.tracking_lock = asyncio.Lock()
|
||||
|
||||
self.untracked_channels = self.settings.UntrackedChannels._cache
|
||||
|
||||
async def cog_load(self):
|
||||
await self.data.init()
|
||||
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.HourlyReward)
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.HourlyLiveBonus)
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.UntrackedChannels)
|
||||
self.bot.core.guild_config.register_model_setting(self.settings.DailyVoiceCap)
|
||||
|
||||
# Update the tracked voice channel cache
|
||||
await self.settings.UntrackedChannels.setup(self.bot)
|
||||
|
||||
configcog = self.bot.get_cog('ConfigCog')
|
||||
if configcog is None:
|
||||
logger.critical(
|
||||
"Attempting to load VoiceTrackerCog before ConfigCog! Cannot crossload configuration group."
|
||||
)
|
||||
else:
|
||||
self.crossload_group(self.configure_group, configcog.configure_group)
|
||||
|
||||
if self.bot.is_ready():
|
||||
await self.initialise()
|
||||
|
||||
async def cog_unload(self):
|
||||
# TODO: Shutdown task to trigger updates on all ongoing sessions
|
||||
# Simultaneously!
|
||||
...
|
||||
|
||||
def get_session(self, guildid, userid) -> VoiceSession:
|
||||
"""
|
||||
Get the VoiceSession for the given member.
|
||||
|
||||
Creates it if it does not exist.
|
||||
"""
|
||||
return VoiceSession.get(self.bot, guildid, userid)
|
||||
|
||||
@LionCog.listener('on_ready')
|
||||
@log_wrap(action='Init Voice Sessions')
|
||||
async def initialise(self):
|
||||
"""
|
||||
(Re)-initialise voice tracking using current voice channel members as source of truth.
|
||||
|
||||
Ends ongoing sessions for members who are not in the given voice channel.
|
||||
"""
|
||||
# First take the tracking lock
|
||||
# Ensures current event handling completes before re-initialisation
|
||||
async with self.tracking_lock:
|
||||
logger.info("Reloading ongoing voice sessions")
|
||||
|
||||
logger.debug("Disabling voice state event handling.")
|
||||
self.handle_events = False
|
||||
# Read and save the tracked voice states of all visible voice channels
|
||||
voice_members = {} # (guildid, userid) -> TrackedVoiceState
|
||||
voice_guilds = set()
|
||||
for guild in self.bot.guilds:
|
||||
for channel in guild.voice_channels:
|
||||
for member in channel.members:
|
||||
voice_members[(guild.id, member.id)] = TrackedVoiceState.from_voice_state(member.voice)
|
||||
voice_guilds.add(guild.id)
|
||||
|
||||
logger.debug(f"Cached {len(voice_members)} members from voice channels.")
|
||||
self.handle_events = True
|
||||
logger.debug("Re-enabled voice state event handling.")
|
||||
|
||||
# Iterate through members with current ongoing sessions
|
||||
# End or update sessions as needed, based on saved tracked state
|
||||
ongoing_rows = await self.data.VoiceSessionsOngoing.fetch_where(
|
||||
guildid=[guild.id for guild in self.bot.guilds]
|
||||
)
|
||||
logger.debug(
|
||||
f"Loaded {len(ongoing_rows)} ongoing sessions from data. Splitting into complete and incomplete."
|
||||
)
|
||||
complete = []
|
||||
incomplete = []
|
||||
incomplete_guildids = set()
|
||||
|
||||
# Compute time to end complete sessions
|
||||
now = utc_now()
|
||||
last_update = max((row.last_update for row in ongoing_rows), default=now)
|
||||
end_at = min(last_update + dt.timedelta(seconds=3600), now)
|
||||
|
||||
for row in ongoing_rows:
|
||||
key = (row.guildid, row.userid)
|
||||
state = voice_members.get(key, None)
|
||||
untracked = self.untracked_channels.get(row.guildid, [])
|
||||
if (
|
||||
state
|
||||
and state.channelid == row.channelid
|
||||
and state.channelid not in untracked
|
||||
and (ch := self.bot.get_channel(state.channelid)) is not None
|
||||
and (not ch.category_id or ch.category_id not in untracked)
|
||||
):
|
||||
# Mark session as ongoing
|
||||
incomplete.append((row, state))
|
||||
incomplete_guildids.add(row.guildid)
|
||||
voice_members.pop(key)
|
||||
else:
|
||||
# Mark session as complete
|
||||
complete.append((row.guildid, row.userid, end_at))
|
||||
|
||||
# Load required guild data into cache
|
||||
active_guildids = incomplete_guildids.union(voice_guilds)
|
||||
if active_guildids:
|
||||
await self.bot.core.data.Guild.fetch_where(guildid=tuple(active_guildids))
|
||||
lguilds = {guildid: await self.bot.core.lions.fetch_guild(guildid) for guildid in active_guildids}
|
||||
|
||||
# Calculate tracked_today for members with ongoing sessions
|
||||
active_members = set((row.guildid, row.userid) for row, _ in incomplete)
|
||||
active_members.update(voice_members.keys())
|
||||
if active_members:
|
||||
tracked_today_data = await self.data.VoiceSessions.multiple_voice_tracked_since(
|
||||
*((guildid, userid, lguilds[guildid].today) for guildid, userid in active_members)
|
||||
)
|
||||
else:
|
||||
tracked_today_data = []
|
||||
tracked_today = {(row['guildid'], row['userid']): row['tracked'] for row in tracked_today_data}
|
||||
|
||||
if incomplete:
|
||||
# Note that study_time_since _includes_ ongoing sessions in its calculation
|
||||
# So expiry times are "time left today until cap" or "tomorrow + cap"
|
||||
to_load = [] # (session_data, expiry_time)
|
||||
to_update = [] # (guildid, userid, update_at, stream, video, hourly_rate)
|
||||
for session_data, state in incomplete:
|
||||
# Calculate expiry times
|
||||
lguild = lguilds[session_data.guildid]
|
||||
cap = lguild.config.get('daily_voice_cap').value
|
||||
tracked = tracked_today[(session_data.guildid, session_data.userid)]
|
||||
if tracked >= cap:
|
||||
# Already over cap
|
||||
complete.append(
|
||||
session_data.guildid,
|
||||
session_data.userid,
|
||||
max(now + dt.timedelta(seconds=tracked - cap), session_data.last_update)
|
||||
)
|
||||
else:
|
||||
tomorrow = lguild.today + dt.timedelta(days=1)
|
||||
expiry = now + dt.timedelta(seconds=(cap - tracked))
|
||||
if expiry > tomorrow:
|
||||
expiry = tomorrow + dt.timedelta(seconds=cap)
|
||||
to_load.append((session_data, expiry))
|
||||
|
||||
# TODO: Probably better to do this by batch
|
||||
# Could force all bonus calculators to accept list of members
|
||||
hourly_rate = await self._calculate_rate(session_data.guildid, session_data.userid, state)
|
||||
to_update.append((
|
||||
session_data.guildid,
|
||||
session_data.userid,
|
||||
now,
|
||||
state.stream,
|
||||
state.video,
|
||||
hourly_rate
|
||||
))
|
||||
# Run the updates, note that session_data uses registry pattern so will also update
|
||||
if to_update:
|
||||
await self.data.VoiceSessionsOngoing.update_voice_sessions_at(*to_update)
|
||||
|
||||
# Load the sessions
|
||||
for data, expiry in to_load:
|
||||
VoiceSession.from_ongoing(self.bot, data, expiry)
|
||||
|
||||
logger.info(f"Resumed {len(to_load)} ongoing voice sessions.")
|
||||
|
||||
if complete:
|
||||
logger.info(f"Ending {len(complete)} out-of-date or expired study sessions.")
|
||||
|
||||
# Complete sessions just need a mass end_voice_session_at()
|
||||
await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*complete)
|
||||
|
||||
# Then iterate through the saved states from tracked voice channels
|
||||
# Start sessions if they don't already exist
|
||||
if voice_members:
|
||||
expiries = {} # (guildid, memberid) -> expiry time
|
||||
to_create = [] # (guildid, userid, channelid, start_time, last_update, live_stream, live_video, rate)
|
||||
for (guildid, userid), state in voice_members.items():
|
||||
untracked = self.untracked_channels.get(guildid, [])
|
||||
channel = self.bot.get_channel(state.channelid)
|
||||
if (
|
||||
channel
|
||||
and channel.id not in untracked
|
||||
and (not channel.category_id or channel.category_id not in untracked)
|
||||
):
|
||||
# State is from member in tracked voice channel
|
||||
# Calculate expiry
|
||||
lguild = lguilds[guildid]
|
||||
cap = lguild.config.get('daily_voice_cap').value
|
||||
tracked = tracked_today[(guildid, userid)]
|
||||
if tracked < cap:
|
||||
tomorrow = lguild.today + dt.timedelta(days=1)
|
||||
expiry = now + dt.timedelta(seconds=(cap - tracked))
|
||||
if expiry > tomorrow:
|
||||
expiry = tomorrow + dt.timedelta(seconds=cap)
|
||||
expiries[(guildid, userid)] = expiry
|
||||
|
||||
hourly_rate = await self._calculate_rate(guildid, userid, state)
|
||||
to_create.append((
|
||||
guildid, userid,
|
||||
state.channelid,
|
||||
now, now,
|
||||
state.stream, state.video,
|
||||
hourly_rate
|
||||
))
|
||||
# Bulk create the ongoing sessions
|
||||
if to_create:
|
||||
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
|
||||
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
|
||||
'live_video', 'hourly_coins'),
|
||||
*to_create
|
||||
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
|
||||
for row in rows:
|
||||
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
|
||||
logger.info(f"Started {len(rows)} new voice sessions from voice channels!")
|
||||
|
||||
@LionCog.listener("on_voice_state_update")
|
||||
@log_wrap(action='Voice Track')
|
||||
@log_wrap(action='Voice Event')
|
||||
async def session_voice_tracker(self, member, before, after):
|
||||
"""
|
||||
Spawns the correct tasks from members joining, leaving, and changing live state.
|
||||
"""
|
||||
# TODO: Logging context
|
||||
if not self.handle_events:
|
||||
# Rely on initialisation to handle current state
|
||||
return
|
||||
|
||||
# Check user blacklist
|
||||
blacklists = self.bot.get_cog('Blacklists')
|
||||
if member.id in blacklists.user_blacklist:
|
||||
# TODO: Make sure we cancel user sessions when they get blacklisted
|
||||
# Should we dispatch an event for the blacklist?
|
||||
return
|
||||
|
||||
# Serialise state before waiting on the lock
|
||||
bstate = TrackedVoiceState.from_voice_state(before)
|
||||
astate = TrackedVoiceState.from_voice_state(after)
|
||||
if bstate == astate:
|
||||
# If tracked state did not change, ignore event
|
||||
return
|
||||
|
||||
# Take tracking lock
|
||||
async with self.tracking_lock:
|
||||
# Fetch tracked member session state
|
||||
session = self.get_session(member.guild.id, member.id)
|
||||
tstate = session.state
|
||||
untracked = self.untracked_channels.get(member.guild.id, [])
|
||||
|
||||
if (bstate.channelid != astate.channelid):
|
||||
# Leaving/Moving/Joining channels
|
||||
if (leaving := bstate.channelid):
|
||||
# Leaving channel
|
||||
if session.activity:
|
||||
# Leaving channel during active session
|
||||
if tstate.channelid != leaving:
|
||||
# Active session channel does not match leaving channel
|
||||
logger.warning(
|
||||
"Voice event does not match session information! "
|
||||
f"Member '{member.name}' <uid:{member.id}> "
|
||||
f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
|
||||
f"left channel '#{before.channel.name}' <cid:{leaving}> "
|
||||
f"during voice session in channel <cid:{tstate.channelid}>!"
|
||||
)
|
||||
# Close (or cancel) active session
|
||||
logger.info(
|
||||
f"Closing session for member `{member.name}' <uid:{member.id}> "
|
||||
f"in guild '{member.guild.name}' <gid: {member.guild.id}> "
|
||||
" because they left the channel."
|
||||
)
|
||||
await session.close()
|
||||
elif (
|
||||
leaving not in untracked and
|
||||
not (before.channel.category_id and before.channel.category_id in untracked)
|
||||
):
|
||||
# Leaving tracked channel without an active session?
|
||||
logger.warning(
|
||||
"Voice event does not match session information! "
|
||||
f"Member '{member.name}' <uid:{member.id}> "
|
||||
f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
|
||||
f"left tracked channel '#{before.channel.name}' <cid:{leaving}> "
|
||||
f"with no matching voice session!"
|
||||
)
|
||||
|
||||
if (joining := astate.channelid):
|
||||
# Joining channel
|
||||
if session.activity:
|
||||
# Member has an active voice session, should be impossible!
|
||||
logger.warning(
|
||||
"Voice event does not match session information! "
|
||||
f"Member '{member.name}' <uid:{member.id}> "
|
||||
f"of guild '{member.guild.name}' <gid:{member.guild.id}> "
|
||||
f"joined channel '#{after.channel.name}' <cid:{joining}> "
|
||||
f"during voice session in channel <cid:{tstate.channelid}>!"
|
||||
)
|
||||
await session.close()
|
||||
if (
|
||||
joining not in untracked and
|
||||
not (after.channel.category_id and after.channel.category_id in untracked)
|
||||
):
|
||||
# If the channel they are joining is tracked, schedule a session start for them
|
||||
delay, start, expiry = await self._session_boundaries_for(member.guild.id, member.id)
|
||||
hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate)
|
||||
|
||||
logger.debug(
|
||||
f"Scheduling voice session for member `{member.name}' <uid:{member.id}> "
|
||||
f"in guild '{member.guild.name}' <gid: member.guild.id> "
|
||||
f"in channel '{after.channel.name}' <cid: {after.channel.id}>. "
|
||||
f"Session will start at {start}, expire at {expiry}, and confirm in {delay}."
|
||||
)
|
||||
await session.schedule_start(delay, start, expiry, astate, hourly_rate)
|
||||
elif session.activity:
|
||||
# If the channelid did not change, the live state must have
|
||||
# Recalculate the economy rate, and update the session
|
||||
# Touch the ongoing session with the new state
|
||||
hourly_rate = await self._calculate_rate(member.guild.id, member.id, astate)
|
||||
await session.update(new_state=astate, new_rate=hourly_rate)
|
||||
|
||||
@LionCog.listener("on_guild_setting_update_untracked_channels")
|
||||
async def update_untracked_channels(self, guildid, setting):
|
||||
"""
|
||||
Close sessions in untracked channels, and recalculate previously untracked sessions
|
||||
"""
|
||||
if not self.handle_events:
|
||||
return
|
||||
|
||||
async with self.tracking_lock:
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
guild = self.bot.get_guild(guildid)
|
||||
if not guild:
|
||||
# Left guild while waiting on lock
|
||||
return
|
||||
cap = lguild.config.get('daily_voice_cap').value
|
||||
untracked = self.untracked_channels.get(guildid, [])
|
||||
now = utc_now()
|
||||
|
||||
# Iterate through active sessions, close any that are in untracked channels
|
||||
active = VoiceSession._active_sessions_.get(guildid, {})
|
||||
for session in list(active.values()):
|
||||
if session.state.channelid in untracked:
|
||||
await session.close()
|
||||
|
||||
# Iterate through voice members, open new sessions if needed
|
||||
expiries = {}
|
||||
to_create = []
|
||||
for channel in guild.voice_channels:
|
||||
if channel.id in untracked:
|
||||
continue
|
||||
for member in channel.members:
|
||||
if self.get_session(guildid, member.id).activity:
|
||||
# Already have an active session for this member
|
||||
continue
|
||||
userid = member.id
|
||||
state = TrackedVoiceState.from_voice_state(member.voice)
|
||||
|
||||
# TODO: Take into account tracked_today time?
|
||||
# TODO: Make a per-guild refresh function to stay DRY
|
||||
tomorrow = lguild.today + dt.timedelta(days=1)
|
||||
expiry = now + dt.timedelta(seconds=cap)
|
||||
if expiry > tomorrow:
|
||||
expiry = tomorrow + dt.timedelta(seconds=cap)
|
||||
expiries[(guildid, userid)] = expiry
|
||||
|
||||
hourly_rate = await self._calculate_rate(guildid, userid, state)
|
||||
to_create.append((
|
||||
guildid, userid,
|
||||
state.channelid,
|
||||
now, now,
|
||||
state.stream, state.video,
|
||||
hourly_rate
|
||||
))
|
||||
|
||||
if to_create:
|
||||
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
|
||||
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
|
||||
'live_video', 'hourly_coins'),
|
||||
*to_create
|
||||
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
|
||||
for row in rows:
|
||||
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
|
||||
logger.info(
|
||||
f"Started {len(rows)} new voice sessions from voice members "
|
||||
f"in previously untracked channels of guild '{guild.name}' <gid:{guildid}>."
|
||||
)
|
||||
|
||||
@LionCog.listener("on_guild_setting_update_hourly_reward")
|
||||
async def update_hourly_reward(self, guildid, setting):
|
||||
if not self.handle_events:
|
||||
return
|
||||
|
||||
async with self.tracking_lock:
|
||||
sessions = VoiceSession._active_sessions_.get(guildid, {})
|
||||
for session in list(sessions.values()):
|
||||
hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state)
|
||||
await session.update(new_rate=hourly_rate)
|
||||
|
||||
@LionCog.listener("on_guild_setting_update_hourly_live_bonus")
|
||||
async def update_hourly_live_bonus(self, guildid, setting):
|
||||
if not self.handle_events:
|
||||
return
|
||||
|
||||
async with self.tracking_lock:
|
||||
sessions = VoiceSession._active_sessions_.get(guildid)
|
||||
for session in sessions:
|
||||
hourly_rate = await self._calculate_rate(session.guildid, session.userid, session.state)
|
||||
await session.update(new_rate=hourly_rate)
|
||||
|
||||
@LionCog.listener("on_guild_setting_update_daily_voice_cap")
|
||||
async def update_daily_voice_cap(self, guildid, setting):
|
||||
# TODO: Guild daily_voice_cap setting triggers session expiry recalculation for all sessions
|
||||
...
|
||||
|
||||
@LionCog.listener("on_guild_setting_update_timezone")
|
||||
@log_wrap(action='Voice Track')
|
||||
@log_wrap(action='Timezone Update')
|
||||
async def update_timezone(self, guildid, setting):
|
||||
# TODO: Guild timezone setting triggers studied_today cache rebuild
|
||||
logger.info("Received dispatch event for timezone change!")
|
||||
|
||||
async def _calculate_rate(self, guildid, userid, state):
|
||||
"""
|
||||
Calculate the economy hourly rate for the given member in the given state.
|
||||
|
||||
Takes into account economy bonuses.
|
||||
"""
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
hourly_rate = lguild.config.get('hourly_reward').value
|
||||
if state.live:
|
||||
hourly_rate += lguild.config.get('hourly_live_bonus').value
|
||||
|
||||
economy = self.bot.get_cog('Economy')
|
||||
if economy is not None:
|
||||
bonus = await economy.fetch_economy_bonus(guildid, userid)
|
||||
hourly_rate *= bonus
|
||||
else:
|
||||
logger.warning("Economy cog not loaded! Voice tracker cannot account for economy bonuses.")
|
||||
|
||||
return hourly_rate
|
||||
|
||||
async def _session_boundaries_for(self, guildid: int, userid: int) -> tuple[int, dt.datetime, dt.datetime]:
|
||||
"""
|
||||
Compute when the next session for this member should start and expire.
|
||||
|
||||
Assumes the member does not have a currently active session!
|
||||
Takes into account the daily voice cap, and the member's study time so far today.
|
||||
Days are based on the guild timezone, not the member timezone.
|
||||
(Otherwise could be abused through timezone-shifting.)
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[int, dt.datetime, dt.datetime]:
|
||||
(start delay, start time, expiry time)
|
||||
|
||||
"""
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
now = lguild.now
|
||||
tomorrow = now + dt.timedelta(days=1)
|
||||
|
||||
studied_today = await self.fetch_tracked_today(guildid, userid)
|
||||
cap = lguild.config.get('daily_voice_cap').value
|
||||
|
||||
if studied_today >= cap - 60:
|
||||
start_time = tomorrow
|
||||
delay = (tomorrow - now).total_seconds()
|
||||
else:
|
||||
start_time = now
|
||||
delay = 60
|
||||
|
||||
expiry = start_time + dt.timedelta(seconds=cap)
|
||||
if expiry >= tomorrow:
|
||||
expiry = tomorrow + dt.timedelta(seconds=cap)
|
||||
|
||||
return (delay, start_time, expiry)
|
||||
|
||||
async def fetch_tracked_today(self, guildid, userid) -> int:
|
||||
"""
|
||||
Fetch how long the given member has tracked on voice today, using the guild timezone.
|
||||
|
||||
Applies cache wherever possible.
|
||||
"""
|
||||
# TODO: Design caching scheme for this.
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
return await self.data.VoiceSessions.study_time_since(guildid, userid, lguild.today)
|
||||
|
||||
@LionCog.listener("on_guild_join")
|
||||
@log_wrap(action='Join Guild Voice Sessions')
|
||||
async def join_guild_sessions(self, guild: discord.Guild):
|
||||
"""
|
||||
Initialise and start required new sessions from voice channel members when we join a guild.
|
||||
"""
|
||||
if not self.handle_events:
|
||||
return
|
||||
|
||||
async with self.tracking_lock:
|
||||
guildid = guild.id
|
||||
lguild = await self.bot.core.lions.fetch_guild(guildid)
|
||||
cap = lguild.config.get('daily_voice_cap').value
|
||||
untracked = self.untracked_channels.get(guildid, [])
|
||||
now = utc_now()
|
||||
|
||||
expiries = {}
|
||||
to_create = []
|
||||
for channel in guild.voice_channels:
|
||||
if channel.id in untracked:
|
||||
continue
|
||||
for member in channel.members:
|
||||
userid = member.id
|
||||
state = TrackedVoiceState.from_voice_state(member.voice)
|
||||
|
||||
tomorrow = lguild.today + dt.timedelta(days=1)
|
||||
expiry = now + dt.timedelta(seconds=cap)
|
||||
if expiry > tomorrow:
|
||||
expiry = tomorrow + dt.timedelta(seconds=cap)
|
||||
expiries[(guildid, userid)] = expiry
|
||||
|
||||
hourly_rate = await self._calculate_rate(guildid, userid, state)
|
||||
to_create.append((
|
||||
guildid, userid,
|
||||
state.channelid,
|
||||
now, now,
|
||||
state.stream, state.video,
|
||||
hourly_rate
|
||||
))
|
||||
|
||||
if to_create:
|
||||
rows = await self.data.VoiceSessionsOngoing.table.insert_many(
|
||||
('guildid', 'userid', 'channelid', 'start_time', 'last_update', 'live_stream',
|
||||
'live_video', 'hourly_coins'),
|
||||
*to_create
|
||||
).with_adapter(self.data.VoiceSessionsOngoing._make_rows)
|
||||
for row in rows:
|
||||
VoiceSession.from_ongoing(self.bot, row, expiries[(row.guildid, row.userid)])
|
||||
logger.info(
|
||||
f"Started {len(rows)} new voice sessions from voice members "
|
||||
f"in new guild '{guild.name}' <gid:{guildid}>."
|
||||
)
|
||||
|
||||
@LionCog.listener("on_guild_remove")
|
||||
@log_wrap(action='Leave Guild Voice Sessions')
|
||||
async def leave_guild_sessions(self, guild):
|
||||
"""
|
||||
Terminate ongoing sessions when we leave a guild.
|
||||
"""
|
||||
if not self.handle_events:
|
||||
return
|
||||
|
||||
async with self.tracking_lock:
|
||||
sessions = VoiceSession._active_sessions_.pop(guild.id, {})
|
||||
VoiceSession._sessions_.pop(guild.id, None)
|
||||
now = utc_now()
|
||||
to_close = [] # (guildid, userid, _at)
|
||||
for session in sessions.vallues():
|
||||
if session.start_task is not None:
|
||||
session.start_task.cancel()
|
||||
if session.expiry_task is not None:
|
||||
session.expiry_task.cancel()
|
||||
to_close.append(session.guildid, session.userid, now)
|
||||
await self.data.VoiceSessionsOngoing.close_voice_sessions_at(*to_close)
|
||||
logger.info(
|
||||
f"Closed {len(to_close)} voice sessions after leaving guild '{guild.name}' <gid:{guild.id}>"
|
||||
)
|
||||
|
||||
# ----- Configuration Commands -----
|
||||
@LionCog.placeholder_group
|
||||
@cmds.hybrid_group('configure', with_app_command=False)
|
||||
async def configure_group(self, ctx: LionContext):
|
||||
# Placeholder group method, not used.
|
||||
pass
|
||||
|
||||
@configure_group.command(
|
||||
name=_p('cmd:configure_voice_tracking', "voice_tracking"),
|
||||
description=_p(
|
||||
'cmd:configure_voice_tracking|desc',
|
||||
"Voice tracking configuration panel"
|
||||
)
|
||||
)
|
||||
@appcmds.rename(
|
||||
hourly_reward=VoiceTrackerSettings.HourlyReward._display_name,
|
||||
hourly_live_bonus=VoiceTrackerSettings.HourlyLiveBonus._display_name,
|
||||
daily_voice_cap=VoiceTrackerSettings.DailyVoiceCap._display_name,
|
||||
)
|
||||
@appcmds.describe(
|
||||
hourly_reward=VoiceTrackerSettings.HourlyReward._desc,
|
||||
hourly_live_bonus=VoiceTrackerSettings.HourlyLiveBonus._desc,
|
||||
daily_voice_cap=VoiceTrackerSettings.DailyVoiceCap._desc,
|
||||
)
|
||||
@cmds.check(low_management)
|
||||
async def configure_voice_tracking_cmd(self, ctx: LionContext,
|
||||
hourly_reward: Optional[int] = None, # TODO: Change these to Ranges
|
||||
hourly_live_bonus: Optional[int] = None,
|
||||
daily_voice_cap: Optional[int] = None):
|
||||
"""
|
||||
Guild configuration command to control the voice tracking configuration.
|
||||
"""
|
||||
# TODO: daily_voice_cap could technically be a string, but simplest to represent it as hours
|
||||
t = self.bot.translator.t
|
||||
|
||||
# Type checking guards
|
||||
if not ctx.guild:
|
||||
return
|
||||
if not ctx.interaction:
|
||||
return
|
||||
|
||||
# Retrieve settings, initialising from cache where possible
|
||||
setting_hourly_reward = ctx.lguild.config.get('hourly_reward')
|
||||
setting_hourly_live_bonus = ctx.lguild.config.get('hourly_live_bonus')
|
||||
setting_daily_voice_cap = ctx.lguild.config.get('daily_voice_cap')
|
||||
|
||||
modified = []
|
||||
if hourly_reward is not None and hourly_reward != setting_hourly_reward._data:
|
||||
setting_hourly_reward.data = hourly_reward
|
||||
await setting_hourly_reward.write()
|
||||
modified.append(setting_hourly_reward)
|
||||
|
||||
if hourly_live_bonus is not None and hourly_live_bonus != setting_hourly_live_bonus._data:
|
||||
setting_hourly_live_bonus.data = hourly_live_bonus
|
||||
await setting_hourly_live_bonus.write()
|
||||
modified.append(setting_hourly_live_bonus)
|
||||
|
||||
if daily_voice_cap is not None and daily_voice_cap * 3600 != setting_daily_voice_cap._data:
|
||||
setting_daily_voice_cap.data = daily_voice_cap * 3600
|
||||
await setting_daily_voice_cap.write()
|
||||
modified.append(setting_daily_voice_cap)
|
||||
|
||||
# Send update ack
|
||||
if modified:
|
||||
if ctx.lguild.guild_mode.voice is VoiceMode.VOICE:
|
||||
description = t(_p(
|
||||
'cmd:configure_voice_tracking|mode:voice|resp:success|desc',
|
||||
"Members will now be rewarded {coin}**{base} (+ {bonus})** per hour they spend (live) "
|
||||
"in a voice channel, up to a total of **{cap}** hours per server day."
|
||||
)).format(
|
||||
coin=self.bot.config.emojis.coin,
|
||||
base=setting_hourly_reward.value,
|
||||
bonus=setting_hourly_live_bonus.value,
|
||||
cap=int(setting_daily_voice_cap.value // 3600)
|
||||
)
|
||||
else:
|
||||
description = t(_p(
|
||||
'cmd:configure_voice_tracking|mode:study|resp:success|desc',
|
||||
"Members will now be rewarded {coin}**{base}** per hour of study "
|
||||
"in this server, with a bonus of {coin}**{bonus}** if they stream of display video, "
|
||||
"up to a total of **{cap}** hours per server day."
|
||||
)).format(
|
||||
coin=self.bot.config.emojis.coin,
|
||||
base=setting_hourly_reward.value,
|
||||
bonus=setting_hourly_live_bonus.value,
|
||||
cap=int(setting_daily_voice_cap.value // 3600)
|
||||
)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
description=description
|
||||
)
|
||||
)
|
||||
|
||||
if ctx.channel.id not in VoiceTrackerConfigUI._listening or not modified:
|
||||
# Launch setting group UI
|
||||
configui = VoiceTrackerConfigUI(self.bot, self.settings, ctx.guild.id, ctx.channel.id)
|
||||
await configui.run(ctx.interaction)
|
||||
await configui.wait()
|
||||
273
src/tracking/voice/data.py
Normal file
273
src/tracking/voice/data.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import datetime as dt
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
from data import RowModel, Registry, Table
|
||||
from data.columns import Integer, String, Timestamp, Bool
|
||||
|
||||
from core.data import CoreData
|
||||
|
||||
|
||||
class VoiceTrackerData(Registry):
|
||||
# Tracked Channels
|
||||
# Current sessions
|
||||
# Session history
|
||||
# Untracked channels table
|
||||
class TrackedChannel(RowModel):
|
||||
"""
|
||||
Reference model describing channels which have been used in tracking.
|
||||
TODO: Refactor into central tracking data?
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE tracked_channels(
|
||||
channelid BIGINT PRIMARY KEY,
|
||||
guildid BIGINT NOT NULL,
|
||||
deleted BOOLEAN DEFAULT FALSE,
|
||||
_timestamp TIMESTAMPTZ NOT NULL DEFAULT (now() AT TIME ZONE 'utc'),
|
||||
FOREIGN KEY (guildid) REFERENCES guild_config (guildid) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX tracked_channels_guilds ON tracked_channels (guildid);
|
||||
"""
|
||||
_tablename_ = "tracked_channels"
|
||||
_cache_ = {}
|
||||
|
||||
channelid = Integer(primary=True)
|
||||
guildid = Integer()
|
||||
deleted = Bool()
|
||||
_timestamp = Timestamp()
|
||||
|
||||
class VoiceSessionsOngoing(RowModel):
|
||||
"""
|
||||
Model describing currently active voice sessions.
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE voice_sessions_ongoing(
|
||||
guildid BIGINT NOT NULL,
|
||||
userid BIGINT NOT NULL,
|
||||
channelid BIGINT REFERENCES tracked_channels (channelid),
|
||||
rating INTEGER,
|
||||
tag TEXT,
|
||||
start_time TIMESTAMPTZ DEFAULT (now() AT TIME ZONE 'UTC'),
|
||||
live_duration INTEGER NOT NULL DEFAULT 0,
|
||||
video_duration INTEGER NOT NULL DEFAULT 0,
|
||||
stream_duration INTEGER NOT NULL DEFAULT 0,
|
||||
coins_earned INTEGER NOT NULL DEFAULT 0,
|
||||
last_update TIMESTAMPTZ DEFAULT (now() AT TIME ZONE 'UTC'),
|
||||
live_stream BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
live_video BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
hourly_coins FLOAT NOT NULL DEFAULT 0,
|
||||
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX voice_sessions_ongoing_members ON voice_sessions_ongoing (guildid, userid);
|
||||
"""
|
||||
_tablename_ = "voice_sessions_ongoing"
|
||||
|
||||
guildid = Integer(primary=True)
|
||||
userid = Integer(primary=True)
|
||||
channelid = Integer()
|
||||
rating = Integer()
|
||||
tag = String()
|
||||
start_time = Timestamp()
|
||||
live_duration = Integer()
|
||||
video_duration = Integer()
|
||||
stream_duration = Integer()
|
||||
coins_earned = Integer()
|
||||
last_update = Integer()
|
||||
live_stream = Bool()
|
||||
live_video = Bool()
|
||||
hourly_coins = Integer()
|
||||
|
||||
@classmethod
|
||||
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT close_study_session_at(%s, %s, %s)",
|
||||
(guildid, userid, _at)
|
||||
)
|
||||
member_data = await cursor.fetchone()
|
||||
|
||||
@classmethod
|
||||
async def close_voice_sessions_at(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
SELECT
|
||||
close_study_session_at(t.guildid, t.userid, t.at)
|
||||
FROM
|
||||
(VALUES {})
|
||||
AS
|
||||
t (guildid, userid, at);
|
||||
""").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {}, {})").format(
|
||||
sql.Placeholder(), sql.Placeholder(), sql.Placeholder(),
|
||||
)
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
tuple(chain(*arg_tuples))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def update_voice_session_at(
|
||||
cls, guildid: int, userid: int, _at: dt.datetime,
|
||||
stream: bool, video: bool, rate: float
|
||||
) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
|
||||
(guildid, userid, _at, stream, video, rate)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return cls._make_rows(*rows)
|
||||
|
||||
@classmethod
|
||||
async def update_voice_sessions_at(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
UPDATE
|
||||
voice_sessions_ongoing
|
||||
SET
|
||||
stream_duration = (
|
||||
CASE WHEN live_stream
|
||||
THEN stream_duration + EXTRACT(EPOCH FROM (t.at - last_update))
|
||||
ELSE stream_duration
|
||||
END
|
||||
),
|
||||
video_duration = (
|
||||
CASE WHEN live_video
|
||||
THEN video_duration + EXTRACT(EPOCH FROM (t.at - last_update))
|
||||
ELSE video_duration
|
||||
END
|
||||
),
|
||||
live_duration = (
|
||||
CASE WHEN live_stream OR live_video
|
||||
THEN live_duration + EXTRACT(EPOCH FROM (t.at - last_update))
|
||||
ELSE live_duration
|
||||
END
|
||||
),
|
||||
coins_earned = (
|
||||
coins_earned + LEAST((EXTRACT(EPOCH FROM (t.at - last_update)) * hourly_coins) / 3600, 2147483647)
|
||||
),
|
||||
last_update = t.at,
|
||||
live_stream = t.stream,
|
||||
live_video = t.video,
|
||||
hourly_coins = t.rate
|
||||
FROM
|
||||
(VALUES {})
|
||||
AS
|
||||
t(_guildid, _userid, at, stream, video, rate)
|
||||
WHERE
|
||||
guildid = t._guildid
|
||||
AND
|
||||
userid = t._userid
|
||||
RETURNING *;
|
||||
""").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {}, {}, {}, {}, {})").format(
|
||||
sql.Placeholder(), sql.Placeholder(), sql.Placeholder(),
|
||||
sql.Placeholder(), sql.Placeholder(), sql.Placeholder(),
|
||||
)
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
tuple(chain(*arg_tuples))
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return cls._make_rows(*rows)
|
||||
|
||||
class VoiceSessions(RowModel):
|
||||
"""
|
||||
Model describing completed voice sessions.
|
||||
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE voice_sessions(
|
||||
sessionid SERIAL PRIMARY KEY,
|
||||
guildid BIGINT NOT NULL,
|
||||
userid BIGINT NOT NULL,
|
||||
channelid BIGINT REFERENCES tracked_channels (channelid),
|
||||
rating INTEGER,
|
||||
tag TEXT,
|
||||
start_time TIMESTAMPTZ NOT NULL,
|
||||
duration INTEGER NOT NULL,
|
||||
live_duration INTEGER DEFAULT 0,
|
||||
stream_duration INTEGER DEFAULT 0,
|
||||
video_duration INTEGER DEFAULT 0,
|
||||
transactionid INTEGER REFERENCES coin_transactions (transactionid) ON UPDATE CASCADE ON DELETE CASCADE,
|
||||
FOREIGN KEY (guildid, userid) REFERENCES members (guildid, userid) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX session_history_members ON session_history (guildid, userid, start_time);
|
||||
"""
|
||||
_tablename_ = "voice_sessions"
|
||||
|
||||
sessionid = Integer(primary=True)
|
||||
guildid = Integer()
|
||||
userid = Integer()
|
||||
channelid = Integer()
|
||||
rating = Integer()
|
||||
tag = String()
|
||||
start_time = Timestamp()
|
||||
duration = Integer()
|
||||
live_duration = Integer()
|
||||
stream_duration = Integer()
|
||||
video_duration = Integer()
|
||||
transactionid = Integer()
|
||||
|
||||
@classmethod
|
||||
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT study_time_since(%s, %s, %s) AS result",
|
||||
(guildid, userid, _start)
|
||||
)
|
||||
result = await cursor.fetchone()
|
||||
return (result['result'] or 0) if result else 0
|
||||
|
||||
@classmethod
|
||||
async def multiple_voice_tracked_since(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
SELECT
|
||||
t.guildid AS guildid,
|
||||
t.userid AS userid,
|
||||
COALESCE(study_time_since(t.guildid, t.userid, t.at), 0) AS tracked
|
||||
FROM
|
||||
(VALUES {})
|
||||
AS
|
||||
t (guildid, userid, at);
|
||||
""").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {}, {})").format(
|
||||
sql.Placeholder(), sql.Placeholder(), sql.Placeholder(),
|
||||
)
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
tuple(chain(*arg_tuples))
|
||||
)
|
||||
return await cursor.fetchall()
|
||||
|
||||
"""
|
||||
Schema
|
||||
------
|
||||
CREATE TABLE untracked_channels(
|
||||
guildid BIGINT NOT NULL,
|
||||
channelid BIGINT NOT NULL
|
||||
);
|
||||
CREATE INDEX untracked_channels_guilds ON untracked_channels (guildid);
|
||||
"""
|
||||
untracked_channels = Table('untracked_channels')
|
||||
248
src/tracking/voice/session.py
Normal file
248
src/tracking/voice/session.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from typing import Optional
|
||||
from enum import IntEnum
|
||||
from collections import defaultdict
|
||||
import datetime as dt
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from cachetools import TTLCache
|
||||
|
||||
from utils.lib import utc_now
|
||||
from meta import LionBot
|
||||
from data import WeakCache
|
||||
from .data import VoiceTrackerData
|
||||
|
||||
from . import logger
|
||||
|
||||
|
||||
class TrackedVoiceState:
|
||||
__slots__ = (
|
||||
'channelid',
|
||||
'video',
|
||||
'stream'
|
||||
)
|
||||
|
||||
def __init__(self, channelid: Optional[int], video: bool, stream: bool):
|
||||
self.channelid = channelid
|
||||
self.video = video
|
||||
self.stream = stream
|
||||
|
||||
def __eq__(self, other: 'TrackedVoiceState'):
|
||||
equal = other.channelid == self.channelid
|
||||
equal = equal and other.video == self.video
|
||||
equal = equal and other.stream == self.stream
|
||||
|
||||
def __bool__(self):
|
||||
"""Whether this is an active state"""
|
||||
return bool(self.channelid)
|
||||
|
||||
@property
|
||||
def live(self):
|
||||
return self.video or self.stream
|
||||
|
||||
@classmethod
|
||||
def from_voice_state(cls, state: discord.VoiceState):
|
||||
if state is not None:
|
||||
return cls(
|
||||
state.channel.id if state.channel else None,
|
||||
state.self_video,
|
||||
state.self_stream
|
||||
)
|
||||
else:
|
||||
return cls(None, False, False)
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
ONGOING = 2
|
||||
PENDING = 1
|
||||
INACTIVE = 0
|
||||
|
||||
|
||||
class VoiceSession:
|
||||
"""
|
||||
High-level tracked voice state in the LionBot paradigm.
|
||||
|
||||
To ensure cache integrity and event safety,
|
||||
this state may lag behind the `member.voice` obtained from Discord API.
|
||||
However, the state must always match the stored state (in data).
|
||||
"""
|
||||
__slots__ = (
|
||||
'bot',
|
||||
'guildid', 'userid',
|
||||
'registry',
|
||||
'start_task', 'expiry_task',
|
||||
'data', 'state', 'hourly_rate',
|
||||
'__weakref__'
|
||||
)
|
||||
|
||||
_sessions_ = defaultdict(lambda: WeakCache(TTLCache(5000, ttl=60*60))) # Registry mapping
|
||||
_active_sessions_ = defaultdict(dict) # Maintains strong references to active sessions
|
||||
|
||||
def __init__(self, bot: LionBot, guildid: int, userid: int, data=None):
|
||||
self.bot = bot
|
||||
self.guildid = guildid
|
||||
self.userid = userid
|
||||
self.registry: VoiceTrackerData = self.bot.get_cog('VoiceTrackerCog').data
|
||||
|
||||
self.start_task = None # Task triggering a delayed session start
|
||||
self.expiry_task = None # Task triggering a session expiry from reaching the daily cap
|
||||
self.data: Optional[VoiceTrackerData.VoiceSessionsOngoing] = data # Ongoing session data
|
||||
|
||||
# TrackedVoiceState set when session is active
|
||||
# Must match data when session in ongoing
|
||||
self.state: Optional[TrackedVoiceState] = None
|
||||
self.hourly_rate: Optional[float] = None
|
||||
|
||||
@property
|
||||
def activity(self):
|
||||
if self.data is not None:
|
||||
return SessionState.ONGOING
|
||||
elif self.start_task is not None:
|
||||
return SessionState.PENDING
|
||||
else:
|
||||
return SessionState.INACTIVE
|
||||
|
||||
@classmethod
|
||||
def get(cls, bot: LionBot, guildid: int, userid: int) -> 'VoiceSession':
|
||||
"""
|
||||
Fetch the VoiceSession for the given member. Respects cache.
|
||||
Creates the session if it doesn't already exist.
|
||||
"""
|
||||
session = cls._sessions_[guildid].get(userid, None)
|
||||
if session is None:
|
||||
session = cls(bot, guildid, userid)
|
||||
cls._sessions_[guildid][userid] = session
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def from_ongoing(cls, bot: LionBot, data: VoiceTrackerData.VoiceSessionsOngoing, expires_at: dt.datetime):
|
||||
"""
|
||||
Create a VoiceSession from ongoing data and expiry time.
|
||||
"""
|
||||
self = cls.get(bot, data.guildid, data.userid)
|
||||
if self.activity:
|
||||
raise ValueError("Initialising a session which is already running!")
|
||||
self.data = data
|
||||
self.state = TrackedVoiceState(data.channelid, data.live_video, data.live_stream)
|
||||
self.hourly_rate = data.hourly_coins
|
||||
self.schedule_expiry(expires_at)
|
||||
self._active_sessions_[self.guildid][self.userid] = self
|
||||
return self
|
||||
|
||||
async def schedule_start(self, delay, start_time, expire_time, state, hourly_rate):
|
||||
"""
|
||||
Schedule the voice session to start at the given target time,
|
||||
with the given state and hourly rate.
|
||||
"""
|
||||
self.state = state
|
||||
self.hourly_rate = hourly_rate
|
||||
|
||||
self.start_task = asyncio.create_task(self._start_after(delay, start_time))
|
||||
self.schedule_expiry(expire_time)
|
||||
|
||||
async def _start_after(self, delay: int, start_time: dt.datetime):
|
||||
"""
|
||||
Start a new voice session with the given state and hourly rate.
|
||||
|
||||
Creates the tracked_channel if required.
|
||||
"""
|
||||
self._active_sessions_[self.guildid][self.userid] = self
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.info(
|
||||
f"Starting voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||
f"and channel <cid:{self.state.channelid}>."
|
||||
)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# Create the tracked channel if required
|
||||
await self.registry.TrackedChannel.fetch_or_create(
|
||||
self.state.channelid, guildid=self.guildid, deleted=False
|
||||
)
|
||||
|
||||
# Insert an ongoing_session with the correct state, set data
|
||||
state = self.state
|
||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||
guildid=self.guildid,
|
||||
userid=self.userid,
|
||||
channelid=state.channelid,
|
||||
start_time=start_time,
|
||||
last_update=start_time,
|
||||
live_stream=state.stream,
|
||||
live_video=state.video,
|
||||
hourly_coins=self.hourly_rate
|
||||
)
|
||||
self.start_task = None
|
||||
|
||||
def schedule_expiry(self, expire_time):
|
||||
"""
|
||||
(Re-)schedule expiry for an ongoing session.
|
||||
"""
|
||||
if not self.activity:
|
||||
raise ValueError("Cannot schedule expiry for an inactive session!")
|
||||
if self.expiry_task is not None and not self.expiry_task.done():
|
||||
self.expiry_task.cancel()
|
||||
|
||||
delay = (expire_time - utc_now()).total_seconds()
|
||||
self.expiry_task = asyncio.create_task(self._expire_after(delay))
|
||||
|
||||
async def _expire_after(self, delay: int):
|
||||
"""
|
||||
Expire a session which has exceeded the daily voice cap.
|
||||
"""
|
||||
# TODO: Logging, and guild logging, and user notification (?)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(
|
||||
f"Expiring voice session for member <uid:{self.userid}> in guild <gid:{self.guildid}> "
|
||||
f"and channel <cid:{self.state.channelid}>."
|
||||
)
|
||||
await self.close()
|
||||
|
||||
async def update(self, new_state: Optional[TrackedVoiceState] = None, new_rate: Optional[int] = None):
|
||||
"""
|
||||
Update the session state with the provided voice state or hourly rate.
|
||||
Also applies to pending states.
|
||||
|
||||
Raises ValueError if the state does not match the saved session (i.e. wrong channel)
|
||||
"""
|
||||
if not self.activity:
|
||||
raise ValueError("Cannot update inactive session!")
|
||||
elif (new_state is not None and new_state != self.state) or (new_rate != self.hourly_rate):
|
||||
if new_state is not None:
|
||||
self.state = new_state
|
||||
if new_rate is not None:
|
||||
self.hourly_rate = new_rate
|
||||
|
||||
if self.data:
|
||||
await self.data.update_voice_session_at(
|
||||
guildid=self.guildid,
|
||||
userid=self.userid,
|
||||
_at=utc_now(),
|
||||
stream=self.state.stream,
|
||||
video=self.state.video,
|
||||
rate=self.hourly_rate
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close the session, or cancel the pending session. Idempotent.
|
||||
"""
|
||||
if self.activity is SessionState.ONGOING:
|
||||
# End the ongoing session
|
||||
await self.data.close_study_session_at(self.guildid, self.userid, utc_now())
|
||||
|
||||
if self.start_task is not None:
|
||||
self.start_task.cancel()
|
||||
self.start_task = None
|
||||
|
||||
if self.expiry_task is not None:
|
||||
self.expiry_task.cancel()
|
||||
self.expiry_task = None
|
||||
|
||||
self.data = None
|
||||
self.state = None
|
||||
self.hourly_rate = None
|
||||
|
||||
# Always release strong reference to session (to allow garbage collection)
|
||||
self._active_sessions_[self.guildid].pop(self.userid)
|
||||
433
src/tracking/voice/settings.py
Normal file
433
src/tracking/voice/settings.py
Normal file
@@ -0,0 +1,433 @@
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
import discord
|
||||
from discord.ui.select import select, Select, ChannelSelect
|
||||
from discord.ui.button import button, Button, ButtonStyle
|
||||
|
||||
from settings.groups import SettingGroup
|
||||
from settings.data import ModelData, ListData
|
||||
from settings.setting_types import ChannelListSetting, IntegerSetting, DurationSetting
|
||||
|
||||
from meta import conf, LionBot
|
||||
from meta.sharding import THIS_SHARD
|
||||
from meta.logger import log_wrap
|
||||
from utils.ui import LeoUI
|
||||
|
||||
from core.data import CoreData
|
||||
from core.lion_guild import VoiceMode
|
||||
from babel.translator import ctx_translator
|
||||
|
||||
from . import babel, logger
|
||||
from .data import VoiceTrackerData
|
||||
|
||||
_p = babel._p
|
||||
|
||||
|
||||
# untracked channels
|
||||
# hourly_reward
|
||||
# hourly_live_bonus
|
||||
# daily_voice_cap
|
||||
|
||||
|
||||
class VoiceTrackerSettings(SettingGroup):
|
||||
class UntrackedChannels(ListData, ChannelListSetting):
|
||||
# TODO: Factor out into combined tracking settings?
|
||||
setting_id = 'untracked_channels'
|
||||
_event = 'guild_setting_update_untracked_channels'
|
||||
|
||||
_display_name = _p('guildset:untracked_channels', "untracked_channels")
|
||||
_desc = _p(
|
||||
'guildset:untracked_channels|desc',
|
||||
"Channels which will be ignored for statistics tracking."
|
||||
)
|
||||
_long_desc = _p(
|
||||
'guildset:untracked_channels|long_desc',
|
||||
"Activity in these channels will not count towards a member's statistics. "
|
||||
"If a category is selected, all channels under the category will be untracked."
|
||||
)
|
||||
|
||||
_default = None
|
||||
|
||||
_table_interface = VoiceTrackerData.untracked_channels
|
||||
_id_column = 'guildid'
|
||||
_data_column = 'channelid'
|
||||
_order_column = 'channelid'
|
||||
|
||||
_cache = {}
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:untracked_channels|set',
|
||||
"Channel selector below."
|
||||
))
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:untracked_channels|response',
|
||||
"Activity in the following channels will now be ignored: {channels}"
|
||||
)).format(
|
||||
channels=self.formatted
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='Cache Untracked Channels')
|
||||
async def setup(cls, bot):
|
||||
"""
|
||||
Pre-load untracked channels for every guild on the current shard.
|
||||
"""
|
||||
data: VoiceTrackerData = bot.db.registries['VoiceTrackerData']
|
||||
# TODO: Filter by joining on guild_config with last_left = NULL
|
||||
# Otherwise we are also caching all the guilds we left
|
||||
rows = await data.untracked_channels.select_where(THIS_SHARD)
|
||||
new_cache = defaultdict(list)
|
||||
count = 0
|
||||
for row in rows:
|
||||
new_cache[row['guildid']].append(row['channelid'])
|
||||
count += 1
|
||||
cls._cache.clear()
|
||||
cls._cache.update(new_cache)
|
||||
logger.info(f"Loaded {count} untracked channels on this shard.")
|
||||
|
||||
class HourlyReward(ModelData, IntegerSetting):
|
||||
setting_id = 'hourly_reward'
|
||||
_event = 'guild_setting_update_hourly_reward'
|
||||
|
||||
_display_name = _p('guildset:hourly_reward', "hourly_reward")
|
||||
_desc = _p(
|
||||
'guildset:hourly_reward|mode:voice|desc',
|
||||
"LionCoins given per hour in a voice channel."
|
||||
)
|
||||
|
||||
_default = 50
|
||||
_min = 0
|
||||
_max = 2**15
|
||||
|
||||
_model = CoreData.Guild
|
||||
_column = CoreData.Guild.study_hourly_reward.name
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, parent_id, data, **kwargs):
|
||||
t = ctx_translator.get().t
|
||||
if data is not None:
|
||||
return t(_p(
|
||||
'guildset:hourly_reward|formatted',
|
||||
"{coin}**{amount}** per hour."
|
||||
)).format(
|
||||
coin=conf.emojis.coin,
|
||||
amount=data
|
||||
)
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
# TODO: Dynamic retrieval of command id
|
||||
return '</configure voice_tracking:1038560947666694144>'
|
||||
|
||||
class HourlyReward_Voice(HourlyReward):
|
||||
"""
|
||||
Voice-mode specialised version of HourlyReward
|
||||
"""
|
||||
_desc = _p(
|
||||
'guildset:hourly_reward|mode:voice|desc',
|
||||
"LionCoins given per hour in a voice channel."
|
||||
)
|
||||
_long_desc = _p(
|
||||
'guildset:hourly_reward|mode:voice|long_desc',
|
||||
"Number of LionCoins to each member per hour that they stay in a tracked voice channel."
|
||||
)
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
# TODO: Dynamic retrieval of command id
|
||||
return '</configure voice_tracking:1038560947666694144>'
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:hourly_reward|mode:voice|response',
|
||||
"Members will be given {coin}**{amount}** per hour in a voice channel!"
|
||||
)).format(
|
||||
coin=conf.emojis.coin,
|
||||
amount=self.data
|
||||
)
|
||||
|
||||
class HourlyReward_Study(HourlyReward):
|
||||
"""
|
||||
Study-mode specialised version of HourlyReward.
|
||||
"""
|
||||
_desc = _p(
|
||||
'guildset:hourly_reward|mode:study|desc',
|
||||
"LionCoins given per hour of study."
|
||||
)
|
||||
_long_desc = _p(
|
||||
'guildset:hourly_reward|mode:study|long_desc',
|
||||
"Number of LionCoins given per hour of study, up to the daily hour cap."
|
||||
)
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:hourly_reward|mode:study|response',
|
||||
"Members will be given {coin}**{amount}** per hour that they study!"
|
||||
)).format(
|
||||
coin=conf.emojis.coin,
|
||||
amount=self.data
|
||||
)
|
||||
|
||||
class HourlyLiveBonus(ModelData, IntegerSetting):
|
||||
"""
|
||||
Guild setting describing the per-hour LionCoin bonus given to "live" members during tracking.
|
||||
"""
|
||||
setting_id = 'hourly_live_bonus'
|
||||
_event = 'guild_setting_update_hourly_live_bonus'
|
||||
|
||||
_display_name = _p('guildset:hourly_live_bonus', "hourly_live_bonus")
|
||||
_desc = _p(
|
||||
'guildset:hourly_live_bonus|desc',
|
||||
"Bonus Lioncoins given per hour when a member streams or video-chats."
|
||||
)
|
||||
|
||||
_long_desc = _p(
|
||||
'guildset:hourly_live_bonus|long_desc',
|
||||
"When a member streams or video-chats in a channel they will be given this bonus *additionally* "
|
||||
"to the `hourly_reward`."
|
||||
)
|
||||
|
||||
_default = 150
|
||||
_min = 0
|
||||
_max = 2**15
|
||||
|
||||
_model = CoreData.Guild
|
||||
_column = CoreData.Guild.study_hourly_live_bonus.name
|
||||
|
||||
@classmethod
|
||||
def _format_data(cls, parent_id, data, **kwargs):
|
||||
t = ctx_translator.get().t
|
||||
if data is not None:
|
||||
return t(_p(
|
||||
'guildset:hourly_live_bonus|formatted',
|
||||
"{coin}**{amount}** bonus per hour when live."
|
||||
)).format(
|
||||
coin=conf.emojis.coin,
|
||||
amount=data
|
||||
)
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
# TODO: Dynamic retrieval of command id
|
||||
return '</configure voice_tracking:1038560947666694144>'
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:hourly_live_bonus|response',
|
||||
"Live members will now *additionally* be given {coin}**{amount}** per hour."
|
||||
)).format(
|
||||
coin=conf.emojis.coin,
|
||||
amount=self.data
|
||||
)
|
||||
|
||||
class DailyVoiceCap(ModelData, DurationSetting):
|
||||
setting_id = 'daily_voice_cap'
|
||||
_event = 'guild_setting_update_daily_voice_cap'
|
||||
|
||||
_display_name = _p('guildset:daily_voice_cap', "daily_voice_cap")
|
||||
_desc = _p(
|
||||
'guildset:daily_voice_cap|desc',
|
||||
"Maximum number of hours per day to count for each member."
|
||||
)
|
||||
_long_desc = _p(
|
||||
'guildset:daily_voice_cap|long_desc',
|
||||
"Time spend in voice channels over this amount will not be tracked towards the member's statistics. "
|
||||
"Tracking will resume at the start of the next day. "
|
||||
"The start of the day is determined by the configured guild timezone."
|
||||
)
|
||||
|
||||
_default = 16 * 60 * 60
|
||||
_default_multiplier = 60 * 60
|
||||
|
||||
_max = 60 * 60 * 25
|
||||
|
||||
_model = CoreData.Guild
|
||||
_column = CoreData.Guild.daily_study_cap.name
|
||||
|
||||
@property
|
||||
def set_str(self):
|
||||
# TODO: Dynamic retrieval of command id
|
||||
return '</configure voice_tracking:1038560947666694144>'
|
||||
|
||||
@property
|
||||
def update_message(self):
|
||||
t = ctx_translator.get().t
|
||||
return t(_p(
|
||||
'guildset:daily_voice_cap|response',
|
||||
"Members will be tracked for at most {duration} per day. "
|
||||
"(**NOTE:** This will not affect members currently in voice channels.)"
|
||||
)).format(
|
||||
duration=self.formatted
|
||||
)
|
||||
|
||||
|
||||
class VoiceTrackerConfigUI(LeoUI):
|
||||
# TODO: Bulk edit
|
||||
# TODO: Cohesive exit
|
||||
# TODO: Back to main configuration panel
|
||||
|
||||
_listening = {}
|
||||
|
||||
def __init__(self, bot: LionBot, settings: VoiceTrackerSettings, guildid: int, channelid: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bot = bot
|
||||
self.settings = settings
|
||||
self.guildid = guildid
|
||||
self.channelid = channelid
|
||||
|
||||
self._original: Optional[discord.Interaction] = None
|
||||
self._message: Optional[discord.Message] = None
|
||||
|
||||
self.hourly_reward: Optional[VoiceTrackerSettings.HourlyReward] = None
|
||||
self.hourly_live_bonus: Optional[VoiceTrackerSettings.HourlyLiveBonus] = None
|
||||
self.daily_voice_cap: Optional[VoiceTrackerSettings.DailyVoiceCap] = None
|
||||
self.untracked_channels: Optional[VoiceTrackerSettings.UntrackedChannels] = None
|
||||
|
||||
self.embed: Optional[discord.Embed] = None
|
||||
|
||||
@property
|
||||
def instances(self):
|
||||
return (self.hourly_reward, self.hourly_live_bonus, self.daily_voice_cap, self.untracked_channels)
|
||||
|
||||
async def cleanup(self):
|
||||
self._listening.pop(self.channelid, None)
|
||||
for instance in self.instances:
|
||||
instance.deregister_callback(self.id)
|
||||
try:
|
||||
if self._original is not None:
|
||||
await self._original.delete_original_response()
|
||||
self._original = None
|
||||
if self._message is not None:
|
||||
await self._message.delete()
|
||||
self._message = None
|
||||
except discord.HTTPException:
|
||||
# Interaction is likely expired or invalid, or some form of comms issue
|
||||
pass
|
||||
|
||||
@button(label='CLOSE')
|
||||
async def close_button(self, interaction: discord.Interaction, pressed):
|
||||
await interaction.response.defer()
|
||||
await self.close()
|
||||
|
||||
async def refresh_close_button(self):
|
||||
t = self.bot.translator.t
|
||||
self.close_button.label = t(_p('ui:voice_tracker_config|button:close|label', "Close"))
|
||||
|
||||
@button(label='RESET')
|
||||
async def reset_button(self, interaction: discord.Interaction, pressed):
|
||||
await interaction.response.defer()
|
||||
|
||||
for instance in self.instances:
|
||||
instance.data = None
|
||||
await instance.write()
|
||||
|
||||
await self.reload()
|
||||
|
||||
async def refresh_reset_button(self):
|
||||
t = self.bot.translator.t
|
||||
self.reset_button.label = t(_p('ui:voice_tracker_config|button:reset|label', "Reset"))
|
||||
|
||||
@select(cls=ChannelSelect, placeholder='UNTRACKED_CHANNEL_MENU', min_values=0, max_values=25)
|
||||
async def untracked_channels_menu(self, interaction: discord.Interaction, selected):
|
||||
await interaction.response.defer()
|
||||
self.untracked_channels.value = selected.values
|
||||
await self.untracked_channels.write()
|
||||
await self.reload()
|
||||
|
||||
async def refresh_untracked_channels_menu(self):
|
||||
t = self.bot.translator.t
|
||||
self.untracked_channels_menu.placeholder = t(_p(
|
||||
'ui:voice_tracker_config|menu:untracked_channels|placeholder',
|
||||
"Set Untracked Channels"
|
||||
))
|
||||
|
||||
async def run(self, interaction: discord.Interaction):
|
||||
if old := self._listening.get(self.channelid, None):
|
||||
await old.close()
|
||||
|
||||
await self.refresh()
|
||||
|
||||
if interaction.response.is_done():
|
||||
# Use followup to respond
|
||||
self._mesage = await interaction.followup.send(embed=self.embed, view=self)
|
||||
else:
|
||||
# Use interaction response to respond
|
||||
self._original = interaction
|
||||
await interaction.response.send_message(embed=self.embed, view=self)
|
||||
|
||||
for instance in self.instances:
|
||||
instance.register_callback(self.id)(self.reload)
|
||||
|
||||
self._listening[self.channelid] = self
|
||||
|
||||
async def refresh(self):
|
||||
# TODO: Check if listening works for subclasses
|
||||
await self.refresh_close_button()
|
||||
await self.refresh_reset_button()
|
||||
await self.refresh_untracked_channels_menu()
|
||||
|
||||
lguild = await self.bot.core.lions.fetch_guild(self.guildid)
|
||||
|
||||
if lguild.guild_mode.voice is VoiceMode.VOICE:
|
||||
self.hourly_reward = await self.settings.HourlyReward_Voice.get(self.guildid)
|
||||
else:
|
||||
self.hourly_reward = await self.settings.HourlyReward_Study.get(self.guildid)
|
||||
|
||||
self.hourly_live_bonus = lguild.config.get('hourly_live_bonus')
|
||||
self.daily_voice_cap = lguild.config.get('daily_voice_cap')
|
||||
self.untracked_channels = await self.settings.UntrackedChannels.get(self.guildid)
|
||||
|
||||
self._layout = [
|
||||
(self.untracked_channels_menu,),
|
||||
(self.reset_button, self.close_button)
|
||||
]
|
||||
|
||||
self.embed = await self.make_embed()
|
||||
|
||||
async def redraw(self):
|
||||
try:
|
||||
if self._message:
|
||||
await self._message.edit(embed=self.embed, view=self)
|
||||
elif self._original:
|
||||
await self._original.edit_original_response(embed=self.embed, view=self)
|
||||
except discord.HTTPException:
|
||||
await self.close()
|
||||
|
||||
async def reload(self, *args, **kwargs):
|
||||
await self.refresh()
|
||||
await self.redraw()
|
||||
|
||||
async def make_embed(self):
|
||||
t = self.bot.translator.t
|
||||
lguild = await self.bot.core.lions.fetch_guild(self.guildid)
|
||||
mode = lguild.guild_mode
|
||||
if mode.voice is VoiceMode.VOICE:
|
||||
title = t(_p(
|
||||
'ui:voice_tracker_config|mode:voice|embed|title',
|
||||
"Voice Tracker Configuration Panel"
|
||||
))
|
||||
else:
|
||||
title = t(_p(
|
||||
'ui:voice_tracker_config|mode:study|embed|title',
|
||||
"Study Tracker Configuration Panel"
|
||||
))
|
||||
embed = discord.Embed(
|
||||
colour=discord.Colour.orange(),
|
||||
title=title
|
||||
)
|
||||
for setting in self.instances:
|
||||
embed.add_field(**setting.embed_field, inline=False)
|
||||
return embed
|
||||
Reference in New Issue
Block a user