197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
import logging
|
|
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
|
|
|
from twitchio.authentication import UserTokenPayload
|
|
from twitchio.ext import commands
|
|
from twitchio import Scopes, eventsub
|
|
|
|
from data import Database, ORDER
|
|
from botdata import BotData, UserAuth, BotChannel, VersionHistory
|
|
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
|
|
|
from .config import Conf
|
|
from .context import Context
|
|
|
|
if TYPE_CHECKING:
|
|
from modules.profiles.profiles.twitch.component import ProfilesComponent
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Bot(commands.Bot):
|
|
def __init__(self, *args, config: Conf, dbconn: Database, setup=None, **kwargs):
|
|
kwargs.setdefault('client_id', config.bot['client_id'])
|
|
kwargs.setdefault('client_secret', config.bot['client_secret'])
|
|
kwargs.setdefault('bot_id', config.bot['bot_id'])
|
|
kwargs.setdefault('prefix', config.bot['prefix'])
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Whether we should do eventsub via webhooks or websockets
|
|
self.using_webhooks = kwargs.get('using_webhooks', False)
|
|
|
|
self.config = config
|
|
self.dbconn = dbconn
|
|
self.data: BotData = dbconn.load_registry(BotData())
|
|
self._setup_hook = setup
|
|
|
|
self.joined: dict[str, BotChannel] = {}
|
|
|
|
# Make the type checker happy about fetching components by name
|
|
# TODO: Move to stubs
|
|
|
|
@property
|
|
def profiles(self):
|
|
return self.get_component('ProfilesComponent')
|
|
|
|
@overload
|
|
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
|
|
...
|
|
|
|
@overload
|
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
|
...
|
|
|
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
|
return super().get_component(name)
|
|
|
|
def get_context(self, payload, *, cls: Any = None) -> Context:
|
|
cls = cls or Context
|
|
return cls(payload, bot=self)
|
|
|
|
async def event_ready(self):
|
|
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
|
logger.info("Logged in as %s", self.bot_id)
|
|
|
|
async def version_check(self, component: str, req_version: int):
|
|
# Query the database to confirm that the given component is listed with the given version.
|
|
# Typically done upon loading a component
|
|
rows = await VersionHistory.fetch_where(component=component).order_by('_timestamp', ORDER.DESC).limit(1)
|
|
|
|
version = rows[0].to_version if rows else 0
|
|
|
|
if version != req_version:
|
|
raise ValueError(f"Component {component} failed version check. Has version '{version}', required version '{req_version}'")
|
|
else:
|
|
logger.debug(
|
|
"Component %s passed version check with version %s",
|
|
component,
|
|
version
|
|
)
|
|
return True
|
|
|
|
async def setup_hook(self):
|
|
await self.data.init()
|
|
for component, req in SCHEMA_VERSIONS.items():
|
|
await self.version_check(component, req)
|
|
|
|
if self._setup_hook is not None:
|
|
await self._setup_hook(self)
|
|
|
|
# Get all current bot channels
|
|
channels = await BotChannel.fetch_where(autojoin=True)
|
|
|
|
# Join the channels
|
|
await self.join_channels(*channels)
|
|
|
|
# Build bot account's own url
|
|
scopes = BOTUSER_SCOPES
|
|
url = self.get_auth_url(scopes)
|
|
logger.info("Bot account authorisation url: %s", url)
|
|
|
|
# Build everyone else's url
|
|
scopes = CHANNEL_SCOPES
|
|
url = self.get_auth_url(scopes)
|
|
logger.info("User account authorisation url: %s", url)
|
|
|
|
logger.info("Finished setup")
|
|
|
|
def get_auth_url(self, scopes: Optional[Scopes] = None):
|
|
if scopes is None:
|
|
scopes = Scopes((Scopes.channel_bot,))
|
|
|
|
url = self._adapter.get_authorization_url(scopes=scopes)
|
|
return url
|
|
|
|
async def join_channels(self, *channels: BotChannel):
|
|
"""
|
|
Register webhook subscriptions to the given channel(s).
|
|
"""
|
|
# TODO: If channels are already joined, unsubscribe
|
|
for channel in channels:
|
|
sub = None
|
|
try:
|
|
sub = eventsub.ChatMessageSubscription(
|
|
broadcaster_user_id=channel.userid,
|
|
user_id=self.bot_id,
|
|
)
|
|
if self.using_webhooks:
|
|
resp = await self.subscribe_webhook(sub)
|
|
else:
|
|
resp = await self.subscribe_websocket(sub)
|
|
logger.info("Subscribed to %s with %s response %s", channel.userid, sub, resp)
|
|
self.joined[channel.userid] = channel
|
|
self.safe_dispatch('channel_joined', payload=channel)
|
|
except Exception:
|
|
logger.exception("Failed to subscribe to %s with %s", channel.userid, sub)
|
|
|
|
async def event_oauth_authorized(self, payload: UserTokenPayload):
|
|
logger.debug("Oauth flow authorization with payload %s", repr(payload))
|
|
# Save the token and scopes and update internal authorisations
|
|
resp = await self.add_token(payload.access_token, payload.refresh_token)
|
|
if resp.user_id is None:
|
|
logger.warning(
|
|
"Oauth flow recieved with no user_id. Payload was: %s",
|
|
repr(payload)
|
|
)
|
|
return
|
|
|
|
# If the scopes authorised included channel:bot, ensure a BotChannel exists
|
|
# And join it if needed
|
|
if Scopes.channel_bot.value in resp.scopes:
|
|
bot_channel = await BotChannel.fetch_or_create(
|
|
resp.user_id,
|
|
autojoin=True,
|
|
)
|
|
if bot_channel.autojoin:
|
|
await self.join_channels(bot_channel)
|
|
|
|
logger.info("Oauth flow authorization complete for payload %s", repr(payload))
|
|
|
|
async def add_token(self, token: str, refresh: str):
|
|
# Update the tokens in internal cache
|
|
# This also validates the token
|
|
# And hopefully gets the userid and scopes
|
|
resp = await super().add_token(token, refresh)
|
|
if resp.user_id is None:
|
|
logger.warning(
|
|
"Added a token with no user_id. Response was: %s",
|
|
repr(resp)
|
|
)
|
|
return resp
|
|
userid = resp.user_id
|
|
new_scopes = resp.scopes
|
|
|
|
# Save the token and scopes to data
|
|
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
|
# TODO
|
|
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
|
if row.token != token or row.refresh_token != refresh:
|
|
await row.update(token=token, refresh_token=refresh)
|
|
await self.data.user_auth_scopes.delete_where(userid=userid)
|
|
await self.data.user_auth_scopes.insert_many(
|
|
('userid', 'scope'),
|
|
*((userid, scope) for scope in new_scopes)
|
|
)
|
|
|
|
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
|
return resp
|
|
|
|
async def load_tokens(self, path: str | None = None):
|
|
for row in await UserAuth.fetch_where():
|
|
try:
|
|
await self.add_token(row.token, row.refresh_token)
|
|
except Exception:
|
|
logger.exception(f"Failed to add token for {row}")
|