import logging import contextlib from itertools import chain from enum import Enum import psycopg2 as psy from cachetools import LRUCache from meta import log, conf from constants import DATA_VERSION from .custom_cursor import DictLoggingCursor # Set up database connection log("Establishing connection.", "DB_INIT", level=logging.DEBUG) conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor) # conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG)) # sq.register_adapter(datetime, lambda dt: dt.timestamp()) # Check the version matches the required version with conn: log("Checking db version.", "DB_INIT") cursor = conn.cursor() # Get last entry in version table, compare against desired version cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") current_version, _, _ = cursor.fetchone() if current_version != DATA_VERSION: # Complain raise Exception( ("Database version is {}, required version is {}. " "Please migrate database.").format(current_version, DATA_VERSION) ) cursor.close() log("Established connection.", "DB_INIT") # --------------- Data Interface Classes --------------- class Table: """ Transparent interface to a single table structure in the database. Contains standard methods to access the table. Intended to be subclassed to provide more derivative access for specific tables. """ conn = conn def __init__(self, name): self.name = name def select_where(self, *args, **kwargs): with self.conn: return select_where(self.name, *args, **kwargs) def select_one_where(self, *args, **kwargs): with self.conn: rows = self.select_where(*args, **kwargs) return rows[0] if rows else None def update_where(self, *args, **kwargs): with self.conn: return update_where(self.name, *args, **kwargs) def delete_where(self, *args, **kwargs): with self.conn: return delete_where(self.name, *args, **kwargs) def insert(self, *args, **kwargs): with self.conn: return insert(self.name, *args, **kwargs) def insert_many(self, *args, **kwargs): with self.conn: return insert_many(self.name, *args, **kwargs) def upsert(self, *args, **kwargs): with self.conn: return upsert(self.name, *args, **kwargs) class Row: __slots__ = ('table', 'data', '_pending') conn = conn def __init__(self, table, data, *args, **kwargs): super().__setattr__('table', table) self.data = data self._pending = None @property def rowid(self): return self.data[self.table.id_col] def __repr__(self): return "Row[{}]({})".format( self.table.name, ', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns) ) def __getattr__(self, key): if key in self.table.columns: if self._pending and key in self._pending: return self._pending[key] else: return self.data[key] else: raise AttributeError(key) def __setattr__(self, key, value): if key in self.table.columns: if self._pending is None: self.update(**{key: value}) else: self._pending[key] = value else: super().__setattr__(key, value) @contextlib.contextmanager def batch_update(self): if self._pending: raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__)) self._pending = {} try: yield self._pending finally: self.update(**self._pending) self._pending = None def _refresh(self): row = self.table.select_one_where(**{self.table.id_col: self.rowid}) if not row: raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__)) self.data = row def update(self, **values): rows = self.table.update_where(values, **{self.table.id_col: self.rowid}) self.data = rows[0] @classmethod def _select_where(cls, _extra=None, **conditions): return select_where(cls._table, **conditions) @classmethod def _insert(cls, **values): return insert(cls._table, **values) @classmethod def _update_where(cls, values, **conditions): return update_where(cls._table, values, **conditions) class RowTable(Table): __slots__ = ( 'name', 'columns', 'id_col', 'row_cache' ) conn = conn def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000): self.name = name self.columns = columns self.id_col = id_col self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None # Extend original Table update methods to modify the cached rows def update_where(self, *args, **kwargs): data = super().update_where(*args, **kwargs) if self.row_cache is not None: for data_row in data: cached_row = self.row_cache.get(data_row[self.id_col], None) if cached_row is not None: cached_row.data = data_row return data def delete_where(self, *args, **kwargs): data = super().delete_where(*args, **kwargs) if self.row_cache is not None: for data_row in data: self.row_cache.pop(data_row[self.id_col], None) return data def upsert(self, *args, **kwargs): data = super().upsert(*args, **kwargs) if self.row_cache is not None: cached_row = self.row_cache.get(data[self.id_col], None) if cached_row is not None: cached_row.data = data return data # New methods to fetch and create rows def _make_rows(self, *data_rows): """ Create or retrieve Row objects for each provided data row. If the rows already exist in cache, updates the cached row. """ if self.row_cache is not None: rows = [] for data_row in data_rows: rowid = data_row[self.id_col] cached_row = self.row_cache.get(rowid, None) if cached_row is not None: cached_row.data = data_row row = cached_row else: row = Row(self, data_row) self.row_cache[rowid] = row rows.append(row) else: rows = [Row(self, data_row) for data_row in data_rows] return rows def create_row(self, *args, **kwargs): data = self.insert(*args, **kwargs) return self._make_rows(data)[0] def fetch_rows_where(self, *args, **kwargs): # TODO: Handle list of rowids here? data = self.select_where(*args, **kwargs) return self._make_rows(*data) def fetch(self, rowid): """ Fetch the row with the given id, retrieving from cache where possible. """ row = self.row_cache.get(rowid, None) if self.row_cache is not None else None if row is None: rows = self.fetch_rows_where(**{self.id_col: rowid}) row = rows[0] if rows else None return row def fetch_or_create(self, rowid=None, **kwargs): """ Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. """ if rowid is not None: row = self.fetch(rowid) else: data = self.select_where(**kwargs) row = self._make_rows(data[0])[0] if data else None if row is None: creation_kwargs = kwargs if rowid is not None: creation_kwargs[self.id_col] = rowid row = self.create_row(**creation_kwargs) return row # --------------- Query Builders --------------- def select_where(table, select_columns=None, cursor=None, _extra='', **conditions): """ Select rows from the given table matching the conditions """ criteria, criteria_values = _format_conditions(conditions) col_str = _format_selectkeys(select_columns) if conditions: where_str = "WHERE {}".format(criteria) else: where_str = "" cursor = cursor or conn.cursor() cursor.execute( 'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra), criteria_values ) return cursor.fetchall() def update_where(table, valuedict, cursor=None, **conditions): """ Update rows in the given table matching the conditions """ key_str, key_values = _format_updatestr(valuedict) criteria, criteria_values = _format_conditions(conditions) if conditions: where_str = "WHERE {}".format(criteria) else: where_str = "" cursor = cursor or conn.cursor() cursor.execute( 'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str), tuple((*key_values, *criteria_values)) ) return cursor.fetchall() def delete_where(table, cursor=None, **conditions): """ Delete rows in the given table matching the conditions """ criteria, criteria_values = _format_conditions(conditions) cursor = cursor or conn.cursor() cursor.execute( 'DELETE FROM {} WHERE {}'.format(table, criteria), criteria_values ) return cursor.fetchall() def insert(table, cursor=None, allow_replace=False, **values): """ Insert the given values into the table """ keys, values = zip(*values.items()) key_str = _format_insertkeys(keys) value_str, values = _format_insertvalues(values) action = 'REPLACE' if allow_replace else 'INSERT' cursor = cursor or conn.cursor() cursor.execute( '{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str), values ) return cursor.fetchone() def insert_many(table, *value_tuples, insert_keys=None, cursor=None): """ Insert all the given values into the table """ key_str = _format_insertkeys(insert_keys) value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples)) value_str = ", ".join(value_strs) values = tuple(chain(*value_tuples)) cursor = cursor or conn.cursor() cursor.execute( 'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str), values ) return cursor.fetchall() def upsert(table, constraint, cursor=None, **values): """ Insert or on conflict update. """ valuedict = values keys, values = zip(*values.items()) key_str = _format_insertkeys(keys) value_str, values = _format_insertvalues(values) update_key_str, update_key_values = _format_updatestr(valuedict) if not isinstance(constraint, str): constraint = ", ".join(constraint) cursor = cursor or conn.cursor() cursor.execute( 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( table, key_str, value_str, constraint, update_key_str ), tuple((*values, *update_key_values)) ) return cursor.fetchone() # --------------- Query Formatting Tools --------------- # Replace char used by the connection for query formatting _replace_char: str = '%s' class fieldConstants(Enum): """ A collection of database field constants to use for selection conditions. """ NULL = "IS NULL" NOTNULL = "IS NOT NULL" def _format_conditions(conditions): """ Formats a dictionary of conditions into a string suitable for 'WHERE' clauses. Supports `IN` type conditionals. """ if not conditions: return ("", tuple()) values = [] conditional_strings = [] for key, item in conditions.items(): if isinstance(item, (list, tuple)): conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item)))) values.extend(item) elif isinstance(item, fieldConstants): conditional_strings.append("{} {}".format(key, item.value)) else: conditional_strings.append("{}={}".format(key, _replace_char)) values.append(item) return (' AND '.join(conditional_strings), values) def _format_selectkeys(keys): """ Formats a list of keys into a string suitable for `SELECT`. """ if not keys: return "*" else: return ", ".join(keys) def _format_insertkeys(keys): """ Formats a list of keys into a string suitable for `INSERT` """ if not keys: return "" else: return "({})".format(", ".join(keys)) def _format_insertvalues(values): """ Formats a list of values into a string suitable for `INSERT` """ value_str = "({})".format(", ".join(_replace_char for value in values)) return (value_str, values) def _format_updatestr(valuedict): """ Formats a dictionary of keys and values into a string suitable for 'SET' clauses. """ if not valuedict: return ("", tuple()) keys, values = zip(*valuedict.items()) set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys) return (set_str, values)