rewrite: Implement shard IPC server.
This commit is contained in:
200
bot/meta/ipc/client.py
Normal file
200
bot/meta/ipc/client.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppClient:
|
||||
routes = {} # route_name -> Callable[Any, Awaitable[Any]]
|
||||
|
||||
def __init__(self, appid, client_address, server_address):
|
||||
self.appid = appid
|
||||
self.address = client_address
|
||||
self.server_address = server_address
|
||||
|
||||
self.peers = {appid: client_address} # appid -> address
|
||||
|
||||
self._listener: Optional[asyncio.Server] = None # Local client server
|
||||
self._server = None # Connection to the registry server
|
||||
|
||||
self.register_route('new_peer')(self.new_peer)
|
||||
self.register_route('drop_peer')(self.drop_peer)
|
||||
self.register_route('peer_list')(self.peer_list)
|
||||
|
||||
def register_route(self, name=None):
|
||||
def wrapper(coro):
|
||||
route = AppRoute(coro, name)
|
||||
self.routes[route.name] = route
|
||||
return route
|
||||
return wrapper
|
||||
|
||||
async def server_connection(self):
|
||||
"""Establish a connection to the registry server"""
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(**self.server_address)
|
||||
|
||||
payload = ('connect', (), {'appid': self.appid, 'address': self.address})
|
||||
writer.write(pickle.dumps(payload))
|
||||
writer.write(b'\n')
|
||||
await writer.drain()
|
||||
|
||||
data = await reader.readline()
|
||||
peers = pickle.loads(data)
|
||||
self.peers = peers
|
||||
self._server = (reader, writer)
|
||||
except Exception:
|
||||
logger.exception("Could not connect to registry server. Trying again in 30 seconds.")
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
else:
|
||||
logger.info("Connected to the registry server, launching keepalive.")
|
||||
asyncio.create_task(self._server_keepalive())
|
||||
|
||||
async def _server_keepalive(self):
|
||||
if self._server is None:
|
||||
raise ValueError("Cannot keepalive non-existent server!")
|
||||
reader, write = self._server
|
||||
try:
|
||||
await reader.read()
|
||||
except Exception:
|
||||
logger.exception("Lost connection to address server. Reconnecting...")
|
||||
else:
|
||||
# Connection ended or broke
|
||||
logger.info("Lost connection to address server. Reconnecting...")
|
||||
await asyncio.sleep(30)
|
||||
asyncio.create_task(self.server_connection())
|
||||
|
||||
async def new_peer(self, appid, address):
|
||||
self.peers[appid] = address
|
||||
|
||||
async def peer_list(self, peers):
|
||||
self.peers = peers
|
||||
|
||||
async def drop_peer(self, appid):
|
||||
self.peers.pop(appid, None)
|
||||
|
||||
async def close(self):
|
||||
# Close connection to the server
|
||||
# TODO
|
||||
...
|
||||
|
||||
async def request(self, appid, payload: 'AppPayload'):
|
||||
try:
|
||||
if appid not in self.peers:
|
||||
raise ValueError(f"Peer '{appid}' not found.")
|
||||
logger.debug(f"Sending request to app '{appid}' with payload {payload}")
|
||||
|
||||
address = self.peers[appid]
|
||||
reader, writer = await asyncio.open_connection(**address)
|
||||
|
||||
writer.write(payload.encoded())
|
||||
await writer.drain()
|
||||
writer.write_eof()
|
||||
result = await reader.read()
|
||||
writer.close()
|
||||
decoded = payload.route.decode(result)
|
||||
return decoded
|
||||
except Exception:
|
||||
logging.exception(f"Failed to send request to {appid}'")
|
||||
return None
|
||||
|
||||
async def requestall(self, payload):
|
||||
results = await asyncio.gather(*(self.request(appid, payload) for appid in self.peers))
|
||||
return dict(zip(self.peers.keys(), results))
|
||||
|
||||
async def handle_request(self, reader, writer):
|
||||
data = await reader.read()
|
||||
loaded = pickle.loads(data)
|
||||
route, args, kwargs = loaded
|
||||
|
||||
logger.debug(f"AppClient {self.appid} handling request on route '{route}' with args {args} and kwargs {kwargs}")
|
||||
|
||||
if route in self.routes:
|
||||
try:
|
||||
await self.routes[route].run((reader, writer), args, kwargs)
|
||||
except Exception:
|
||||
logger.exception(f"Fatal exception during route '{route}'. This should never happen!")
|
||||
else:
|
||||
logger.warning(f"Appclient '{self.appid}' recieved unknown route {route}. Ignoring.")
|
||||
writer.write_eof()
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
Start the local peer server.
|
||||
Connect to the address server.
|
||||
"""
|
||||
# Start the client server
|
||||
self._listener = await asyncio.start_server(self.handle_request, **self.address, start_serving=True)
|
||||
|
||||
logger.info(f"Serving on {self.address}")
|
||||
await self.server_connection()
|
||||
|
||||
|
||||
class AppPayload:
|
||||
__slots__ = ('route', 'args', 'kwargs')
|
||||
|
||||
def __init__(self, route, *args, **kwargs):
|
||||
self.route = route
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __await__(self):
|
||||
return self.route.execute(*self.args, **self.kwargs).__await__()
|
||||
|
||||
def encoded(self):
|
||||
return pickle.dumps((self.route.name, self.args, self.kwargs))
|
||||
|
||||
|
||||
class AppRoute:
|
||||
__slots__ = ('func', 'name')
|
||||
|
||||
def __init__(self, func, name=None):
|
||||
self.func = func
|
||||
self.name = name or func.__name__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AppPayload(self, *args, **kwargs)
|
||||
|
||||
def encode(self, output):
|
||||
return pickle.dumps(output)
|
||||
|
||||
def decode(self, encoded):
|
||||
# TODO: Handle exceptions here somehow
|
||||
if len(encoded) > 0:
|
||||
return pickle.loads(encoded)
|
||||
else:
|
||||
return ''
|
||||
|
||||
def encoder(self, func):
|
||||
self.encode = func
|
||||
|
||||
def decoder(self, func):
|
||||
self.decode = func
|
||||
|
||||
async def execute(self, *args, **kwargs):
|
||||
"""
|
||||
Execute the underlying function, with the given arguments.
|
||||
"""
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
async def run(self, connection, args, kwargs):
|
||||
"""
|
||||
Run the route, with the given arguments, using the given connection.
|
||||
"""
|
||||
# TODO: ContextVar here for logging? Or in handle_request?
|
||||
# Get encoded result
|
||||
# TODO: handle exceptions in the execution process
|
||||
try:
|
||||
result = await self.execute(*args, **kwargs)
|
||||
payload = self.encode(result)
|
||||
except Exception:
|
||||
logger.exception(f"Exception occured running route '{self.name}' with args: {args} and kwargs: {kwargs}")
|
||||
payload = b''
|
||||
_, writer = connection
|
||||
writer.write(payload)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
Reference in New Issue
Block a user