diff --git a/requirements.txt b/requirements.txt index 48f9b57..77b57d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ twitchio psycopg[pool] cachetools discord.py +websockets diff --git a/src/bot.py b/src/bot.py index ca898e5..6d46439 100644 --- a/src/bot.py +++ b/src/bot.py @@ -4,7 +4,7 @@ import websockets from twitchio.web import AiohttpAdapter -from meta import Bot, conf, setup_main_logger, args +from meta import Bot, conf, setup_main_logger, args, sockets from data import Database from modules import twitch_setup @@ -30,10 +30,11 @@ async def main(): setup=twitch_setup, ) - try: - await bot.start() - finally: - await bot.close() + async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')): + try: + await bot.start() + finally: + await bot.close() def _main(): diff --git a/src/meta/sockets.py b/src/meta/sockets.py new file mode 100644 index 0000000..c89c538 --- /dev/null +++ b/src/meta/sockets.py @@ -0,0 +1,68 @@ +from abc import ABC +from collections import defaultdict +import json +from typing import Any +import logging + +import websockets + + +logger = logging.getLogger(__name__) + + + +class Channel(ABC): + """ + A channel is a stateful connection handler for a group of connected websockets. + """ + name = "Root Channel" + + def __init__(self, **kwargs): + self.connections = set() + + @property + def empty(self): + return not self.connections + + async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]): + logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}") + self.connections.add(websocket) + + async def del_connection(self, websocket: websockets.WebSocketServerProtocol): + logger.info(f"Channel '{self.name}' dropped connection {websocket=}") + self.connections.discard(websocket) + + async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message): + raise NotImplementedError + + async def send_event(self, event, websocket=None): + message = json.dumps(event) + if not websocket: + for ws in self.connections: + await ws.send(message) + else: + await websocket.send(message) + +channels = {} + +def register_channel(name, channel: Channel): + channels[name] = channel + + +async def root_handler(websocket: websockets.WebSocketServerProtocol): + message = await websocket.recv() + event = json.loads(message) + + if event.get('type', None) != 'init': + raise ValueError("Received Websocket connection with no init.") + + if (channel_name := event.get('channel', None)) not in channels: + raise ValueError(f"Received Init for unhandled channel {channel_name=}") + channel = channels[channel_name] + + try: + await channel.on_connection(websocket, event) + async for message in websocket: + await channel.handle_message(websocket, message) + finally: + await channel.del_connection(websocket)