import logging from typing import TYPE_CHECKING, Any, Literal, Optional, overload from twitchio.authentication import UserTokenPayload from twitchio.ext import commands from twitchio import Scopes, eventsub from data import Database, ORDER from botdata import BotData, UserAuth, BotChannel, VersionHistory from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS from .config import Conf from .context import Context if TYPE_CHECKING: from modules.profiles.profiles.twitch.component import ProfilesComponent logger = logging.getLogger(__name__) class Bot(commands.Bot): def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs): kwargs.setdefault('client_id', config.bot['client_id']) kwargs.setdefault('client_secret', config.bot['client_secret']) kwargs.setdefault('bot_id', config.bot['bot_id']) kwargs.setdefault('prefix', config.bot['prefix']) super().__init__(*args, **kwargs) # Whether we should do eventsub via webhooks or websockets self.using_webhooks = kwargs.get('using_webhooks', False) self.config = config self.dbconn = dbconn self.data: BotData = dbconn.load_registry(BotData()) self._setup_hook = setup self.joined: dict[str, BotChannel] = {} # Make the type checker happy about fetching components by name # TODO: Move to stubs @property def profiles(self): return self.get_component('ProfilesComponent') @overload def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent': ... @overload def get_component(self, name: str) -> Optional[commands.Component]: ... def get_component(self, name: str) -> Optional[commands.Component]: return super().get_component(name) def get_context(self, payload, *, cls: Any = None) -> Context: cls = cls or Context return cls(payload, bot=self) async def event_ready(self): # logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") logger.info("Logged in as %s", self.bot_id) async def version_check(self, component: str, req_version: int): # Query the database to confirm that the given component is listed with the given version. # Typically done upon loading a component rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1) version = rows[0].to_version if rows else 0 if version != req_version: raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'") else: logger.debug( "Component %s passed version check with version %s", component, version ) return True async def setup_hook(self): await self.data.init() for component, req in SCHEMA_VERSIONS.items(): await self.version_check(component, req) if self._setup_hook is not None: await self._setup_hook(self) # Get all current bot channels channels = await BotChannel.fetch_where(autojoin=True) # Join the channels await self.join_channels(*channels) # Build bot account's own url scopes = BOTUSER_SCOPES url = self.get_auth_url(scopes) logger.info("Bot account authorisation url: %s", url) # Build everyone else's url scopes = CHANNEL_SCOPES url = self.get_auth_url(scopes) logger.info("User account authorisation url: %s", url) logger.info("Finished setup") def get_auth_url(self, scopes: Optional[Scopes] = None): if scopes is None: scopes = Scopes((Scopes.channel_bot,)) url = self._adapter.get_authorization_url(scopes=scopes) return url async def join_channels(self, *channels: BotChannel): """ Register webhook subscriptions to the given channel(s). """ # TODO: If channels are already joined, unsubscribe for channel in channels: sub = None try: sub = eventsub.ChatMessageSubscription( broadcaster_user_id=channel.userid, user_id=self.bot_id, ) if self.using_webhooks: resp = await self.subscribe_webhook(sub) else: resp = await self.subscribe_websocket(sub) logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp) self.joined[channel.userid] = channel self.safe_dispatch('channel_joined', payload=channel) except Exception: logger.exception("Failed to subscribe to %s with %s", channel.userid, sub) async def event_oauth_authorized(self, payload: UserTokenPayload): logger.debug("Oauth flow authorization with payload %s", repr(payload)) # Save the token and scopes and update internal authorisations resp = await self.add_token(payload.access_token, payload.refresh_token) if resp.user_id is None: logger.warning( "Oauth flow recieved with no user_id. Payload was: %s", repr(payload) ) return # If the scopes authorised included channel:bot, ensure a BotChannel exists # And join it if needed if Scopes.channel_bot.value in resp.scopes: bot_channel = await BotChannel.fetch_or_create( resp.user_id, autojoin=True, ) if bot_channel.autojoin: await self.join_channels(bot_channel) logger.info("Oauth flow authorization complete for payload %s", repr(payload)) async def add_token(self, token: str, refresh: str): # Update the tokens in internal cache # This also validates the token # And hopefully gets the userid and scopes resp = await super().add_token(token, refresh) if resp.user_id is None: logger.warning( "Added a token with no user_id. Response was: %s", repr(resp) ) return resp userid = resp.user_id new_scopes = resp.scopes # Save the token and scopes to data # Wrap this in a transaction so if it fails halfway we rollback correctly # TODO row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh) if row.token != token or row.refresh_token != refresh: await row.update(token=token, refresh_token=refresh) await self.data.user_auth_scopes.delete_where(userid=userid) await self.data.user_auth_scopes.insert_many( ('userid', 'scope'), *((userid, scope) for scope in new_scopes) ) logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes)) return resp async def load_tokens(self, path: str | None = None): for row in await UserAuth.fetch_where(): try: await self.add_token(row.token, row.refresh_token) except Exception: logger.exception(f"Failed to add token for {row}")