Compare commits

..

4 Commits

Author SHA1 Message Date
2e02f39c29 feat: Extend Context and add Profile module. 2025-09-03 20:19:21 +10:00
a3f1fe1322 feat: Add simplistic websocket channels. 2025-09-03 20:18:48 +10:00
3f19ec4a17 fix: Logger imports.
Add missing discord and utils imports.
2025-09-03 20:17:59 +10:00
85f7465283 fix: Incorrect redirect_url from behind proxy.
Override the redirect URI calculator for AiohttpAdapter.
Required due to changes in Twitchio v3.1
2025-09-03 20:16:56 +10:00
6 changed files with 125 additions and 18 deletions

View File

@@ -4,7 +4,7 @@ import websockets
from twitchio.web import AiohttpAdapter from twitchio.web import AiohttpAdapter
from meta import Bot, conf, setup_main_logger, args from meta import Bot, conf, setup_main_logger, args, sockets
from data import Database from data import Database
from modules import twitch_setup from modules import twitch_setup
@@ -12,11 +12,16 @@ from modules import twitch_setup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ProxyAiohttpAdapter(AiohttpAdapter):
def _find_redirect(self, request):
return self.redirect_url
async def main(): async def main():
db = Database(conf.data['args']) db = Database(conf.data['args'])
async with db.open(): async with db.open():
adapter = AiohttpAdapter( adapter = ProxyAiohttpAdapter(
host=conf.bot.get('wshost', None), host=conf.bot.get('wshost', None),
port=conf.bot.getint('wsport', None), port=conf.bot.getint('wsport', None),
domain=conf.bot.get('wsdomain', None), domain=conf.bot.get('wsdomain', None),
@@ -30,10 +35,11 @@ async def main():
setup=twitch_setup, setup=twitch_setup,
) )
try: async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')):
await bot.start() try:
finally: await bot.start()
await bot.close() finally:
await bot.close()
def _main(): def _main():

View File

@@ -1,4 +1,5 @@
from .args import args from .args import args
from .bot import Bot from .bot import Bot
from .context import Context
from .config import Conf, conf from .config import Conf, conf
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Any, Literal, Optional, overload
from twitchio.authentication import UserTokenPayload from twitchio.authentication import UserTokenPayload
from twitchio.ext import commands from twitchio.ext import commands
@@ -10,6 +10,10 @@ from botdata import BotData, UserAuth, BotChannel, VersionHistory
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
from .config import Conf from .config import Conf
from .context import Context
if TYPE_CHECKING:
from modules.profiles.profiles.twitch.component import ProfilesComponent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,6 +28,7 @@ class Bot(commands.Bot):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Whether we should do eventsub via webhooks or websockets
if config.bot.get('eventsub_secret', None): if config.bot.get('eventsub_secret', None):
self.using_webhooks = True self.using_webhooks = True
else: else:
@@ -36,6 +41,28 @@ class Bot(commands.Bot):
self.joined: dict[str, BotChannel] = {} self.joined: dict[str, BotChannel] = {}
# Make the type checker happy about fetching components by name
# TODO: Move to stubs
@property
def profiles(self):
return self.get_component('ProfilesComponent')
@overload
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
...
@overload
def get_component(self, name: str) -> Optional[commands.Component]:
...
def get_component(self, name: str) -> Optional[commands.Component]:
return super().get_component(name)
def get_context(self, payload, *, cls: Any = None) -> Context:
cls = cls or Context
return cls(payload, bot=self)
async def event_ready(self): async def event_ready(self):
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}") # logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
logger.info("Logged in as %s", self.bot_id) logger.info("Logged in as %s", self.bot_id)
@@ -151,17 +178,15 @@ class Bot(commands.Bot):
# Save the token and scopes to data # Save the token and scopes to data
# Wrap this in a transaction so if it fails halfway we rollback correctly # Wrap this in a transaction so if it fails halfway we rollback correctly
async with self.dbconn.connection() as conn: # TODO
self.dbconn.conn = conn row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
async with conn.transaction(): if row.token != token or row.refresh_token != refresh:
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh) await row.update(token=token, refresh_token=refresh)
if row.token != token or row.refresh_token != refresh: await self.data.user_auth_scopes.delete_where(userid=userid)
await row.update(token=token, refresh_token=refresh) await self.data.user_auth_scopes.insert_many(
await self.data.user_auth_scopes.delete_where(userid=userid) ('userid', 'scope'),
await self.data.user_auth_scopes.insert_many( *((userid, scope) for scope in new_scopes)
('userid', 'scope'), )
*((userid, scope) for scope in new_scopes)
)
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes)) logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
return resp return resp

4
src/meta/context.py Normal file
View File

@@ -0,0 +1,4 @@
from twitchio.ext import commands as cmds
class Context(cmds.Context):
...

View File

@@ -14,6 +14,7 @@ import aiohttp
from .config import conf from .config import conf
from utils.ratelimits import Bucket, BucketOverFull, BucketFull from utils.ratelimits import Bucket, BucketOverFull, BucketFull
from utils.lib import utc_now
log_logger = logging.getLogger(__name__) log_logger = logging.getLogger(__name__)
@@ -365,6 +366,8 @@ class WebHookHandler(logging.StreamHandler):
await self._send(batched) await self._send(batched)
async def _send(self, message, as_file=False): async def _send(self, message, as_file=False):
import discord
from discord import File
try: try:
self.bucket.request() self.bucket.request()
except BucketOverFull: except BucketOverFull:

68
src/meta/sockets.py Normal file
View File

@@ -0,0 +1,68 @@
from abc import ABC
from collections import defaultdict
import json
from typing import Any
import logging
import websockets
logger = logging.getLogger(__name__)
class Channel(ABC):
"""
A channel is a stateful connection handler for a group of connected websockets.
"""
name = "Root Channel"
def __init__(self, **kwargs):
self.connections = set()
@property
def empty(self):
return not self.connections
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
self.connections.add(websocket)
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
self.connections.discard(websocket)
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
raise NotImplementedError
async def send_event(self, event, websocket=None):
message = json.dumps(event)
if not websocket:
for ws in self.connections:
await ws.send(message)
else:
await websocket.send(message)
channels = {}
def register_channel(name, channel: Channel):
channels[name] = channel
async def root_handler(websocket: websockets.WebSocketServerProtocol):
message = await websocket.recv()
event = json.loads(message)
if event.get('type', None) != 'init':
raise ValueError("Received Websocket connection with no init.")
if (channel_name := event.get('channel', None)) not in channels:
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
channel = channels[channel_name]
try:
await channel.on_connection(websocket, event)
async for message in websocket:
await channel.handle_message(websocket, message)
finally:
await channel.del_connection(websocket)