diff --git a/bot/data/connection.py b/bot/data/connection.py index 1f35eda2..20baf0ec 100644 --- a/bot/data/connection.py +++ b/bot/data/connection.py @@ -38,3 +38,10 @@ with conn: log("Established connection.", "DB_INIT") + + +def reset_connection(): + log("Re-establishing connection.", "DB_INIT", level=logging.DEBUG) + global conn + conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor) + log("Re-established connection.", "DB_INIT") diff --git a/bot/data/interfaces.py b/bot/data/interfaces.py index 266b5a03..42810e72 100644 --- a/bot/data/interfaces.py +++ b/bot/data/interfaces.py @@ -1,9 +1,14 @@ from __future__ import annotations +import logging +import traceback import contextlib from cachetools import LRUCache from typing import Mapping +import psycopg2 +import asyncio +from meta import log, client from utils.lib import DotDict from .connection import conn @@ -14,6 +19,25 @@ from .queries import insert, insert_many, select_where, update_where, upsert, de tables: Mapping[str, Table] = DotDict() +def _connection_guard(func): + """ + Query decorator that performs a client shutdown when the database isn't responding. + """ + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except (psycopg2.OperationalError, psycopg2.InterfaceError): + log("Critical error performing database query. Shutting down. " + "Exception traceback follows.\n{}".format( + traceback.format_exc() + ), + context="DATABASE_QUERY", + level=logging.ERROR) + asyncio.create_task(client.close()) + raise Exception("Critical error, database connection closed. Restarting client.") + return wrapper + + class Table: """ Transparent interface to a single table structure in the database. @@ -27,6 +51,7 @@ class Table: self.name = name tables[attach_as or name] = self + @_connection_guard def select_where(self, *args, **kwargs): with self.conn: return select_where(self.name, *args, **kwargs) @@ -35,26 +60,32 @@ class Table: rows = self.select_where(*args, **kwargs) return rows[0] if rows else None + @_connection_guard def update_where(self, *args, **kwargs): with self.conn: return update_where(self.name, *args, **kwargs) + @_connection_guard def delete_where(self, *args, **kwargs): with self.conn: return delete_where(self.name, *args, **kwargs) + @_connection_guard def insert(self, *args, **kwargs): with self.conn: return insert(self.name, *args, **kwargs) + @_connection_guard def insert_many(self, *args, **kwargs): with self.conn: return insert_many(self.name, *args, **kwargs) + @_connection_guard def update_many(self, *args, **kwargs): with self.conn: return update_many(self.name, *args, **kwargs) + @_connection_guard def upsert(self, *args, **kwargs): with self.conn: return upsert(self.name, *args, **kwargs)