Files
subathon-tracker-bot/src/meta/bot.py

200 lines
7.3 KiB
Python

import logging
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
from twitchio.authentication import UserTokenPayload
from twitchio.ext import commands
from twitchio import Scopes, eventsub
from data import Database, ORDER
from botdata import BotData, UserAuth, BotChannel, VersionHistory
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
from .config import Conf
from .context import Context
if TYPE_CHECKING:
from modules.profiles.profiles.twitch.component import ProfilesComponent
logger = logging.getLogger(__name__)
class Bot(commands.Bot):
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
kwargs.setdefault('client_id', config.bot['client_id'])
kwargs.setdefault('client_secret', config.bot['client_secret'])
kwargs.setdefault('bot_id', config.bot['bot_id'])
kwargs.setdefault('prefix', config.bot['prefix'])
super().__init__(*args, **kwargs)
# Whether we should do eventsub via webhooks or websockets
if config.bot.get('eventsub_secret', None):
self.using_webhooks = True
else:
self.using_webhooks = False
self.config = config
self.dbconn = dbconn
self.data: BotData = dbconn.load_registry(BotData())
self._setup_hook = setup
self.joined: dict[str, BotChannel] = {}
# Make the type checker happy about fetching components by name
# TODO: Move to stubs
@property
def profiles(self):
return self.get_component('ProfilesComponent')
@overload
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
...
@overload
def get_component(self, name: str) -> Optional[commands.Component]:
...
def get_component(self, name: str) -> Optional[commands.Component]:
return super().get_component(name)
def get_context(self, payload, *, cls: Any = None) -> Context:
cls = cls or Context
return cls(payload, bot=self)
async def event_ready(self):
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
logger.info("Logged in as %s", self.bot_id)
async def version_check(self, component: str, req_version: int):
# Query the database to confirm that the given component is listed with the given version.
# Typically done upon loading a component
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
version = rows[0].to_version if rows else 0
if version != req_version:
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
else:
logger.debug(
"Component %s passed version check with version %s",
component,
version
)
return True
async def setup_hook(self):
await self.data.init()
for component, req in SCHEMA_VERSIONS.items():
await self.version_check(component, req)
if self._setup_hook is not None:
await self._setup_hook(self)
# Get all current bot channels
channels = await BotChannel.fetch_where(autojoin=True)
# Join the channels
await self.join_channels(*channels)
# Build bot account's own url
scopes = BOTUSER_SCOPES
url = self.get_auth_url(scopes)
logger.info("Bot account authorisation url: %s", url)
# Build everyone else's url
scopes = CHANNEL_SCOPES
url = self.get_auth_url(scopes)
logger.info("User account authorisation url: %s", url)
logger.info("Finished setup")
def get_auth_url(self, scopes: Optional[Scopes] = None):
if scopes is None:
scopes = Scopes((Scopes.channel_bot,))
url = self._adapter.get_authorization_url(scopes=scopes)
return url
async def join_channels(self, *channels: BotChannel):
"""
Register webhook subscriptions to the given channel(s).
"""
# TODO: If channels are already joined, unsubscribe
for channel in channels:
sub = None
try:
sub = eventsub.ChatMessageSubscription(
broadcaster_user_id=channel.userid,
user_id=self.bot_id,
)
if self.using_webhooks:
resp = await self.subscribe_webhook(sub)
else:
resp = await self.subscribe_websocket(sub)
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
self.joined[channel.userid] = channel
self.safe_dispatch('channel_joined', payload=channel)
except Exception:
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
async def event_oauth_authorized(self, payload: UserTokenPayload):
logger.debug("Oauth flow authorization with payload %s", repr(payload))
# Save the token and scopes and update internal authorisations
resp = await self.add_token(payload.access_token, payload.refresh_token)
if resp.user_id is None:
logger.warning(
"Oauth flow recieved with no user_id. Payload was: %s",
repr(payload)
)
return
# If the scopes authorised included channel:bot, ensure a BotChannel exists
# And join it if needed
if Scopes.channel_bot.value in resp.scopes:
bot_channel = await BotChannel.fetch_or_create(
resp.user_id,
autojoin=True,
)
if bot_channel.autojoin:
await self.join_channels(bot_channel)
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
async def add_token(self, token: str, refresh: str):
# Update the tokens in internal cache
# This also validates the token
# And hopefully gets the userid and scopes
resp = await super().add_token(token, refresh)
if resp.user_id is None:
logger.warning(
"Added a token with no user_id. Response was: %s",
repr(resp)
)
return resp
userid = resp.user_id
new_scopes = resp.scopes
# Save the token and scopes to data
# Wrap this in a transaction so if it fails halfway we rollback correctly
# TODO
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
if row.token != token or row.refresh_token != refresh:
await row.update(token=token, refresh_token=refresh)
await self.data.user_auth_scopes.delete_where(userid=userid)
await self.data.user_auth_scopes.insert_many(
('userid', 'scope'),
*((userid, scope) for scope in new_scopes)
)
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
return resp
async def load_tokens(self, path: str | None = None):
for row in await UserAuth.fetch_where():
try:
await self.add_token(row.token, row.refresh_token)
except Exception:
logger.exception(f"Failed to add token for {row}")