Compare commits
4 Commits
67fbe1b658
...
2e02f39c29
| Author | SHA1 | Date | |
|---|---|---|---|
| 2e02f39c29 | |||
| a3f1fe1322 | |||
| 3f19ec4a17 | |||
| 85f7465283 |
18
src/bot.py
18
src/bot.py
@@ -4,7 +4,7 @@ import websockets
|
|||||||
|
|
||||||
from twitchio.web import AiohttpAdapter
|
from twitchio.web import AiohttpAdapter
|
||||||
|
|
||||||
from meta import Bot, conf, setup_main_logger, args
|
from meta import Bot, conf, setup_main_logger, args, sockets
|
||||||
from data import Database
|
from data import Database
|
||||||
|
|
||||||
from modules import twitch_setup
|
from modules import twitch_setup
|
||||||
@@ -12,11 +12,16 @@ from modules import twitch_setup
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyAiohttpAdapter(AiohttpAdapter):
|
||||||
|
def _find_redirect(self, request):
|
||||||
|
return self.redirect_url
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
db = Database(conf.data['args'])
|
db = Database(conf.data['args'])
|
||||||
|
|
||||||
async with db.open():
|
async with db.open():
|
||||||
adapter = AiohttpAdapter(
|
adapter = ProxyAiohttpAdapter(
|
||||||
host=conf.bot.get('wshost', None),
|
host=conf.bot.get('wshost', None),
|
||||||
port=conf.bot.getint('wsport', None),
|
port=conf.bot.getint('wsport', None),
|
||||||
domain=conf.bot.get('wsdomain', None),
|
domain=conf.bot.get('wsdomain', None),
|
||||||
@@ -30,10 +35,11 @@ async def main():
|
|||||||
setup=twitch_setup,
|
setup=twitch_setup,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
async with websockets.serve(sockets.root_handler, '', conf.wserver.getint('port')):
|
||||||
await bot.start()
|
try:
|
||||||
finally:
|
await bot.start()
|
||||||
await bot.close()
|
finally:
|
||||||
|
await bot.close()
|
||||||
|
|
||||||
|
|
||||||
def _main():
|
def _main():
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from .args import args
|
from .args import args
|
||||||
from .bot import Bot
|
from .bot import Bot
|
||||||
|
from .context import Context
|
||||||
from .config import Conf, conf
|
from .config import Conf, conf
|
||||||
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
from .logger import setup_main_logger, log_context, log_action_stack, log_app, set_logging_context, logging_context, with_log_ctx, persist_task
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Any, Literal, Optional, overload
|
||||||
|
|
||||||
from twitchio.authentication import UserTokenPayload
|
from twitchio.authentication import UserTokenPayload
|
||||||
from twitchio.ext import commands
|
from twitchio.ext import commands
|
||||||
@@ -10,6 +10,10 @@ from botdata import BotData, UserAuth, BotChannel, VersionHistory
|
|||||||
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
from constants import BOTUSER_SCOPES, CHANNEL_SCOPES, SCHEMA_VERSIONS
|
||||||
|
|
||||||
from .config import Conf
|
from .config import Conf
|
||||||
|
from .context import Context
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from modules.profiles.profiles.twitch.component import ProfilesComponent
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -24,6 +28,7 @@ class Bot(commands.Bot):
|
|||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Whether we should do eventsub via webhooks or websockets
|
||||||
if config.bot.get('eventsub_secret', None):
|
if config.bot.get('eventsub_secret', None):
|
||||||
self.using_webhooks = True
|
self.using_webhooks = True
|
||||||
else:
|
else:
|
||||||
@@ -36,6 +41,28 @@ class Bot(commands.Bot):
|
|||||||
|
|
||||||
self.joined: dict[str, BotChannel] = {}
|
self.joined: dict[str, BotChannel] = {}
|
||||||
|
|
||||||
|
# Make the type checker happy about fetching components by name
|
||||||
|
# TODO: Move to stubs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profiles(self):
|
||||||
|
return self.get_component('ProfilesComponent')
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_component(self, name: Literal['ProfilesComponent']) -> 'ProfilesComponent':
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_component(self, name: str) -> Optional[commands.Component]:
|
||||||
|
return super().get_component(name)
|
||||||
|
|
||||||
|
def get_context(self, payload, *, cls: Any = None) -> Context:
|
||||||
|
cls = cls or Context
|
||||||
|
return cls(payload, bot=self)
|
||||||
|
|
||||||
async def event_ready(self):
|
async def event_ready(self):
|
||||||
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
# logger.info(f"Logged in as {self.nick}. User id is {self.user_id}")
|
||||||
logger.info("Logged in as %s", self.bot_id)
|
logger.info("Logged in as %s", self.bot_id)
|
||||||
@@ -151,17 +178,15 @@ class Bot(commands.Bot):
|
|||||||
|
|
||||||
# Save the token and scopes to data
|
# Save the token and scopes to data
|
||||||
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
# Wrap this in a transaction so if it fails halfway we rollback correctly
|
||||||
async with self.dbconn.connection() as conn:
|
# TODO
|
||||||
self.dbconn.conn = conn
|
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
||||||
async with conn.transaction():
|
if row.token != token or row.refresh_token != refresh:
|
||||||
row = await UserAuth.fetch_or_create(userid, token=token, refresh_token=refresh)
|
await row.update(token=token, refresh_token=refresh)
|
||||||
if row.token != token or row.refresh_token != refresh:
|
await self.data.user_auth_scopes.delete_where(userid=userid)
|
||||||
await row.update(token=token, refresh_token=refresh)
|
await self.data.user_auth_scopes.insert_many(
|
||||||
await self.data.user_auth_scopes.delete_where(userid=userid)
|
('userid', 'scope'),
|
||||||
await self.data.user_auth_scopes.insert_many(
|
*((userid, scope) for scope in new_scopes)
|
||||||
('userid', 'scope'),
|
)
|
||||||
*((userid, scope) for scope in new_scopes)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
logger.info("Updated auth token for user '%s' with scopes: %s", resp.user_id, ', '.join(new_scopes))
|
||||||
return resp
|
return resp
|
||||||
|
|||||||
4
src/meta/context.py
Normal file
4
src/meta/context.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from twitchio.ext import commands as cmds
|
||||||
|
|
||||||
|
class Context(cmds.Context):
|
||||||
|
...
|
||||||
@@ -14,6 +14,7 @@ import aiohttp
|
|||||||
|
|
||||||
from .config import conf
|
from .config import conf
|
||||||
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
from utils.ratelimits import Bucket, BucketOverFull, BucketFull
|
||||||
|
from utils.lib import utc_now
|
||||||
|
|
||||||
|
|
||||||
log_logger = logging.getLogger(__name__)
|
log_logger = logging.getLogger(__name__)
|
||||||
@@ -365,6 +366,8 @@ class WebHookHandler(logging.StreamHandler):
|
|||||||
await self._send(batched)
|
await self._send(batched)
|
||||||
|
|
||||||
async def _send(self, message, as_file=False):
|
async def _send(self, message, as_file=False):
|
||||||
|
import discord
|
||||||
|
from discord import File
|
||||||
try:
|
try:
|
||||||
self.bucket.request()
|
self.bucket.request()
|
||||||
except BucketOverFull:
|
except BucketOverFull:
|
||||||
|
|||||||
68
src/meta/sockets.py
Normal file
68
src/meta/sockets.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Channel(ABC):
|
||||||
|
"""
|
||||||
|
A channel is a stateful connection handler for a group of connected websockets.
|
||||||
|
"""
|
||||||
|
name = "Root Channel"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.connections = set()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def empty(self):
|
||||||
|
return not self.connections
|
||||||
|
|
||||||
|
async def on_connection(self, websocket: websockets.WebSocketServerProtocol, event: dict[str, Any]):
|
||||||
|
logger.info(f"Channel '{self.name}' attached new connection {websocket=} {event=}")
|
||||||
|
self.connections.add(websocket)
|
||||||
|
|
||||||
|
async def del_connection(self, websocket: websockets.WebSocketServerProtocol):
|
||||||
|
logger.info(f"Channel '{self.name}' dropped connection {websocket=}")
|
||||||
|
self.connections.discard(websocket)
|
||||||
|
|
||||||
|
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send_event(self, event, websocket=None):
|
||||||
|
message = json.dumps(event)
|
||||||
|
if not websocket:
|
||||||
|
for ws in self.connections:
|
||||||
|
await ws.send(message)
|
||||||
|
else:
|
||||||
|
await websocket.send(message)
|
||||||
|
|
||||||
|
channels = {}
|
||||||
|
|
||||||
|
def register_channel(name, channel: Channel):
|
||||||
|
channels[name] = channel
|
||||||
|
|
||||||
|
|
||||||
|
async def root_handler(websocket: websockets.WebSocketServerProtocol):
|
||||||
|
message = await websocket.recv()
|
||||||
|
event = json.loads(message)
|
||||||
|
|
||||||
|
if event.get('type', None) != 'init':
|
||||||
|
raise ValueError("Received Websocket connection with no init.")
|
||||||
|
|
||||||
|
if (channel_name := event.get('channel', None)) not in channels:
|
||||||
|
raise ValueError(f"Received Init for unhandled channel {channel_name=}")
|
||||||
|
channel = channels[channel_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
await channel.on_connection(websocket, event)
|
||||||
|
async for message in websocket:
|
||||||
|
await channel.handle_message(websocket, message)
|
||||||
|
finally:
|
||||||
|
await channel.del_connection(websocket)
|
||||||
Reference in New Issue
Block a user