Data system refactor and core redesign for public.
Redesigned data and core systems to be public-capable.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from .data import *
|
||||
# from . import tables
|
||||
# from . import queries
|
||||
from .connection import conn # noqa
|
||||
from .formatters import UpdateValue, UpdateValueAdd # noqa
|
||||
from .interfaces import Table, RowTable, Row, tables # noqa
|
||||
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where # noqa
|
||||
from .conditions import Condition, NOT, Constant, NULL, NOTNULL # noqa
|
||||
|
||||
59
bot/data/conditions.py
Normal file
59
bot/data/conditions.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from .connection import _replace_char
|
||||
|
||||
|
||||
class Condition:
|
||||
"""
|
||||
ABC representing a selection condition.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
def apply(self, key, values, conditions):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NOT(Condition):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def apply(self, key, values, conditions):
|
||||
item = self.value
|
||||
if isinstance(item, (list, tuple)):
|
||||
if item:
|
||||
conditions.append("{} NOT IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
values.extend(item)
|
||||
else:
|
||||
raise ValueError("Cannot check an empty iterable!")
|
||||
else:
|
||||
conditions.append("{}!={}".format(key, _replace_char))
|
||||
values.append(item)
|
||||
|
||||
|
||||
class GEQ(Condition):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def apply(self, key, values, conditions):
|
||||
item = self.value
|
||||
if isinstance(item, (list, tuple)):
|
||||
raise ValueError("Cannot apply GEQ condition to a list!")
|
||||
else:
|
||||
conditions.append("{} >= {}".format(key, _replace_char))
|
||||
values.append(item)
|
||||
|
||||
|
||||
class Constant(Condition):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def apply(self, key, values, conditions):
|
||||
conditions.append("{} {}".format(key, self.value))
|
||||
|
||||
|
||||
NULL = Constant('IS NULL')
|
||||
NOTNULL = Constant('IS NOT NULL')
|
||||
40
bot/data/connection.py
Normal file
40
bot/data/connection.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
|
||||
import psycopg2 as psy
|
||||
|
||||
from meta import log, conf
|
||||
from constants import DATA_VERSION
|
||||
from .cursor import DictLoggingCursor
|
||||
|
||||
|
||||
# Set up database connection
|
||||
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
|
||||
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
|
||||
|
||||
# Replace char used by the connection for query formatting
|
||||
_replace_char: str = '%s'
|
||||
|
||||
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
|
||||
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
|
||||
|
||||
|
||||
# Check the version matches the required version
|
||||
with conn:
|
||||
log("Checking db version.", "DB_INIT")
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get last entry in version table, compare against desired version
|
||||
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
current_version, _, _ = cursor.fetchone()
|
||||
|
||||
if current_version != DATA_VERSION:
|
||||
# Complain
|
||||
raise Exception(
|
||||
("Database version is {}, required version is {}. "
|
||||
"Please migrate database.").format(current_version, DATA_VERSION)
|
||||
)
|
||||
|
||||
cursor.close()
|
||||
|
||||
|
||||
log("Established connection.", "DB_INIT")
|
||||
505
bot/data/data.py
505
bot/data/data.py
@@ -1,505 +0,0 @@
|
||||
import logging
|
||||
import contextlib
|
||||
from itertools import chain
|
||||
from enum import Enum
|
||||
|
||||
import psycopg2 as psy
|
||||
from cachetools import LRUCache
|
||||
|
||||
from utils.lib import DotDict
|
||||
from meta import log, conf
|
||||
from constants import DATA_VERSION
|
||||
from .custom_cursor import DictLoggingCursor
|
||||
|
||||
|
||||
# Set up database connection
|
||||
log("Establishing connection.", "DB_INIT", level=logging.DEBUG)
|
||||
conn = psy.connect(conf.bot['database'], cursor_factory=DictLoggingCursor)
|
||||
|
||||
# conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG))
|
||||
# sq.register_adapter(datetime, lambda dt: dt.timestamp())
|
||||
|
||||
|
||||
# Check the version matches the required version
|
||||
with conn:
|
||||
log("Checking db version.", "DB_INIT")
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get last entry in version table, compare against desired version
|
||||
cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
current_version, _, _ = cursor.fetchone()
|
||||
|
||||
if current_version != DATA_VERSION:
|
||||
# Complain
|
||||
raise Exception(
|
||||
("Database version is {}, required version is {}. "
|
||||
"Please migrate database.").format(current_version, DATA_VERSION)
|
||||
)
|
||||
|
||||
cursor.close()
|
||||
|
||||
|
||||
log("Established connection.", "DB_INIT")
|
||||
|
||||
|
||||
# --------------- Data Interface Classes ---------------
|
||||
class Table:
|
||||
"""
|
||||
Transparent interface to a single table structure in the database.
|
||||
Contains standard methods to access the table.
|
||||
Intended to be subclassed to provide more derivative access for specific tables.
|
||||
"""
|
||||
conn = conn
|
||||
queries = DotDict()
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def select_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return select_where(self.name, *args, **kwargs)
|
||||
|
||||
def select_one_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
rows = self.select_where(*args, **kwargs)
|
||||
return rows[0] if rows else None
|
||||
|
||||
def update_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return update_where(self.name, *args, **kwargs)
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return delete_where(self.name, *args, **kwargs)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert(self.name, *args, **kwargs)
|
||||
|
||||
def insert_many(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert_many(self.name, *args, **kwargs)
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return upsert(self.name, *args, **kwargs)
|
||||
|
||||
def save_query(self, func):
|
||||
"""
|
||||
Decorator to add a saved query to the table.
|
||||
"""
|
||||
self.queries[func.__name__] = func
|
||||
|
||||
|
||||
class Row:
|
||||
__slots__ = ('table', 'data', '_pending')
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, table, data, *args, **kwargs):
|
||||
super().__setattr__('table', table)
|
||||
self.data = data
|
||||
self._pending = None
|
||||
|
||||
@property
|
||||
def rowid(self):
|
||||
return self.data[self.table.id_col]
|
||||
|
||||
def __repr__(self):
|
||||
return "Row[{}]({})".format(
|
||||
self.table.name,
|
||||
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
|
||||
)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self.table.columns:
|
||||
if self._pending and key in self._pending:
|
||||
return self._pending[key]
|
||||
else:
|
||||
return self.data[key]
|
||||
else:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self.table.columns:
|
||||
if self._pending is None:
|
||||
self.update(**{key: value})
|
||||
else:
|
||||
self._pending[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batch_update(self):
|
||||
if self._pending:
|
||||
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
|
||||
|
||||
self._pending = {}
|
||||
try:
|
||||
yield self._pending
|
||||
finally:
|
||||
self.update(**self._pending)
|
||||
self._pending = None
|
||||
|
||||
def _refresh(self):
|
||||
row = self.table.select_one_where(**{self.table.id_col: self.rowid})
|
||||
if not row:
|
||||
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
|
||||
self.data = row
|
||||
|
||||
def update(self, **values):
|
||||
rows = self.table.update_where(values, **{self.table.id_col: self.rowid})
|
||||
self.data = rows[0]
|
||||
|
||||
@classmethod
|
||||
def _select_where(cls, _extra=None, **conditions):
|
||||
return select_where(cls._table, **conditions)
|
||||
|
||||
@classmethod
|
||||
def _insert(cls, **values):
|
||||
return insert(cls._table, **values)
|
||||
|
||||
@classmethod
|
||||
def _update_where(cls, values, **conditions):
|
||||
return update_where(cls._table, values, **conditions)
|
||||
|
||||
|
||||
class RowTable(Table):
|
||||
__slots__ = (
|
||||
'name',
|
||||
'columns',
|
||||
'id_col',
|
||||
'row_cache'
|
||||
)
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000):
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.id_col = id_col
|
||||
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
|
||||
|
||||
# Extend original Table update methods to modify the cached rows
|
||||
def update_where(self, *args, **kwargs):
|
||||
data = super().update_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
cached_row = self.row_cache.get(data_row[self.id_col], None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
return data
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
data = super().delete_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
self.row_cache.pop(data_row[self.id_col], None)
|
||||
return data
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
data = super().upsert(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
cached_row = self.row_cache.get(data[self.id_col], None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
def _make_rows(self, *data_rows):
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
if self.row_cache is not None:
|
||||
rows = []
|
||||
for data_row in data_rows:
|
||||
rowid = data_row[self.id_col]
|
||||
|
||||
cached_row = self.row_cache.get(rowid, None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
row = cached_row
|
||||
else:
|
||||
row = Row(self, data_row)
|
||||
self.row_cache[rowid] = row
|
||||
rows.append(row)
|
||||
else:
|
||||
rows = [Row(self, data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
def create_row(self, *args, **kwargs):
|
||||
data = self.insert(*args, **kwargs)
|
||||
return self._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs):
|
||||
# TODO: Handle list of rowids here?
|
||||
data = self.select_where(*args, **kwargs)
|
||||
return self._make_rows(*data)
|
||||
|
||||
def fetch(self, rowid):
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
|
||||
if row is None:
|
||||
rows = self.fetch_rows_where(**{self.id_col: rowid})
|
||||
row = rows[0] if rows else None
|
||||
return row
|
||||
|
||||
def fetch_or_create(self, rowid=None, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid is not None:
|
||||
row = self.fetch(rowid)
|
||||
else:
|
||||
data = self.select_where(**kwargs)
|
||||
row = self._make_rows(data[0])[0] if data else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid is not None:
|
||||
creation_kwargs[self.id_col] = rowid
|
||||
row = self.create_row(**creation_kwargs)
|
||||
return row
|
||||
|
||||
|
||||
# --------------- Query Builders ---------------
|
||||
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
|
||||
"""
|
||||
Select rows from the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
col_str = _format_selectkeys(select_columns)
|
||||
|
||||
if conditions:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def update_where(table, valuedict, cursor=None, **conditions):
|
||||
"""
|
||||
Update rows in the given table matching the conditions
|
||||
"""
|
||||
key_str, key_values = _format_updatestr(valuedict)
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
if conditions:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
|
||||
tuple((*key_values, *criteria_values))
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def delete_where(table, cursor=None, **conditions):
|
||||
"""
|
||||
Delete rows in the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'DELETE FROM {} WHERE {}'.format(table, criteria),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def insert(table, cursor=None, allow_replace=False, **values):
|
||||
"""
|
||||
Insert the given values into the table
|
||||
"""
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
|
||||
action = 'REPLACE' if allow_replace else 'INSERT'
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
|
||||
"""
|
||||
Insert all the given values into the table
|
||||
"""
|
||||
key_str = _format_insertkeys(insert_keys)
|
||||
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
|
||||
|
||||
value_str = ", ".join(value_strs)
|
||||
values = tuple(chain(*value_tuples))
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def upsert(table, constraint, cursor=None, **values):
|
||||
"""
|
||||
Insert or on conflict update.
|
||||
"""
|
||||
valuedict = values
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||
|
||||
if not isinstance(constraint, str):
|
||||
constraint = ", ".join(constraint)
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||
table, key_str, value_str, constraint, update_key_str
|
||||
),
|
||||
tuple((*values, *update_key_values))
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
# --------------- Query Formatting Tools ---------------
|
||||
# Replace char used by the connection for query formatting
|
||||
_replace_char: str = '%s'
|
||||
|
||||
|
||||
class fieldConstants(Enum):
|
||||
"""
|
||||
A collection of database field constants to use for selection conditions.
|
||||
"""
|
||||
NULL = "IS NULL"
|
||||
NOTNULL = "IS NOT NULL"
|
||||
|
||||
|
||||
class _updateField:
|
||||
__slots__ = ()
|
||||
_EMPTY = object() # Return value for `value` indicating no value should be added
|
||||
|
||||
def key_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def value_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UpdateValue(_updateField):
|
||||
__slots__ = ('key_str', 'value')
|
||||
|
||||
def __init__(self, key_str, value=_updateField._EMPTY):
|
||||
self.key_str = key_str
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
class UpdateValueAdd(_updateField):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
def _format_conditions(conditions):
|
||||
"""
|
||||
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
||||
Supports `IN` type conditionals.
|
||||
"""
|
||||
if not conditions:
|
||||
return ("", tuple())
|
||||
|
||||
values = []
|
||||
conditional_strings = []
|
||||
for key, item in conditions.items():
|
||||
if isinstance(item, (list, tuple)):
|
||||
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
values.extend(item)
|
||||
elif isinstance(item, fieldConstants):
|
||||
conditional_strings.append("{} {}".format(key, item.value))
|
||||
else:
|
||||
conditional_strings.append("{}={}".format(key, _replace_char))
|
||||
values.append(item)
|
||||
|
||||
return (' AND '.join(conditional_strings), values)
|
||||
|
||||
|
||||
def _format_selectkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `SELECT`.
|
||||
"""
|
||||
if not keys:
|
||||
return "*"
|
||||
else:
|
||||
return ", ".join(keys)
|
||||
|
||||
|
||||
def _format_insertkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `INSERT`
|
||||
"""
|
||||
if not keys:
|
||||
return ""
|
||||
else:
|
||||
return "({})".format(", ".join(keys))
|
||||
|
||||
|
||||
def _format_insertvalues(values):
|
||||
"""
|
||||
Formats a list of values into a string suitable for `INSERT`
|
||||
"""
|
||||
value_str = "({})".format(", ".join(_replace_char for value in values))
|
||||
return (value_str, values)
|
||||
|
||||
|
||||
def _format_updatestr(valuedict):
|
||||
"""
|
||||
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
|
||||
"""
|
||||
if not valuedict:
|
||||
return ("", tuple())
|
||||
|
||||
key_fields = []
|
||||
values = []
|
||||
for key, value in valuedict.items():
|
||||
if isinstance(value, _updateField):
|
||||
key_fields.append(value.key_field(key))
|
||||
v = value.value_field(key)
|
||||
if v is not _updateField._EMPTY:
|
||||
values.append(value.value_field(key))
|
||||
else:
|
||||
key_fields.append("{} = {}".format(key, _replace_char))
|
||||
values.append(value)
|
||||
|
||||
return (', '.join(key_fields), values)
|
||||
113
bot/data/formatters.py
Normal file
113
bot/data/formatters.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from .connection import _replace_char
|
||||
from .conditions import Condition
|
||||
|
||||
|
||||
class _updateField:
|
||||
__slots__ = ()
|
||||
_EMPTY = object() # Return value for `value` indicating no value should be added
|
||||
|
||||
def key_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
def value_field(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UpdateValue(_updateField):
|
||||
__slots__ = ('key_str', 'value')
|
||||
|
||||
def __init__(self, key_str, value=_updateField._EMPTY):
|
||||
self.key_str = key_str
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return self.key_str.format(key=key, value=_replace_char, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
class UpdateValueAdd(_updateField):
|
||||
__slots__ = ('value',)
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def key_field(self, key):
|
||||
return "{key} = {key} + {replace}".format(key=key, replace=_replace_char)
|
||||
|
||||
def value_field(self, key):
|
||||
return self.value
|
||||
|
||||
|
||||
def _format_conditions(conditions):
|
||||
"""
|
||||
Formats a dictionary of conditions into a string suitable for 'WHERE' clauses.
|
||||
Supports `IN` type conditionals.
|
||||
"""
|
||||
if not conditions:
|
||||
return ("", tuple())
|
||||
|
||||
values = []
|
||||
conditional_strings = []
|
||||
for key, item in conditions.items():
|
||||
if isinstance(item, (list, tuple)):
|
||||
conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item))))
|
||||
values.extend(item)
|
||||
elif isinstance(item, Condition):
|
||||
item.apply(key, values, conditional_strings)
|
||||
else:
|
||||
conditional_strings.append("{}={}".format(key, _replace_char))
|
||||
values.append(item)
|
||||
|
||||
return (' AND '.join(conditional_strings), values)
|
||||
|
||||
|
||||
def _format_selectkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `SELECT`.
|
||||
"""
|
||||
if not keys:
|
||||
return "*"
|
||||
else:
|
||||
return ", ".join(keys)
|
||||
|
||||
|
||||
def _format_insertkeys(keys):
|
||||
"""
|
||||
Formats a list of keys into a string suitable for `INSERT`
|
||||
"""
|
||||
if not keys:
|
||||
return ""
|
||||
else:
|
||||
return "({})".format(", ".join(keys))
|
||||
|
||||
|
||||
def _format_insertvalues(values):
|
||||
"""
|
||||
Formats a list of values into a string suitable for `INSERT`
|
||||
"""
|
||||
value_str = "({})".format(", ".join(_replace_char for value in values))
|
||||
return (value_str, values)
|
||||
|
||||
|
||||
def _format_updatestr(valuedict):
|
||||
"""
|
||||
Formats a dictionary of keys and values into a string suitable for 'SET' clauses.
|
||||
"""
|
||||
if not valuedict:
|
||||
return ("", tuple())
|
||||
|
||||
key_fields = []
|
||||
values = []
|
||||
for key, value in valuedict.items():
|
||||
if isinstance(value, _updateField):
|
||||
key_fields.append(value.key_field(key))
|
||||
v = value.value_field(key)
|
||||
if v is not _updateField._EMPTY:
|
||||
values.append(value.value_field(key))
|
||||
else:
|
||||
key_fields.append("{} = {}".format(key, _replace_char))
|
||||
values.append(value)
|
||||
|
||||
return (', '.join(key_fields), values)
|
||||
282
bot/data/interfaces.py
Normal file
282
bot/data/interfaces.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from cachetools import LRUCache
|
||||
from typing import Mapping
|
||||
|
||||
from utils.lib import DotDict
|
||||
|
||||
from .connection import conn
|
||||
from .queries import insert, insert_many, select_where, update_where, upsert, delete_where, update_many
|
||||
|
||||
|
||||
# Global cache of interfaces
|
||||
tables: Mapping[str, Table] = DotDict()
|
||||
|
||||
|
||||
class Table:
|
||||
"""
|
||||
Transparent interface to a single table structure in the database.
|
||||
Contains standard methods to access the table.
|
||||
Intended to be subclassed to provide more derivative access for specific tables.
|
||||
"""
|
||||
conn = conn
|
||||
queries = DotDict()
|
||||
|
||||
def __init__(self, name, attach_as=None):
|
||||
self.name = name
|
||||
tables[attach_as or name] = self
|
||||
|
||||
def select_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return select_where(self.name, *args, **kwargs)
|
||||
|
||||
def select_one_where(self, *args, **kwargs):
|
||||
rows = self.select_where(*args, **kwargs)
|
||||
return rows[0] if rows else None
|
||||
|
||||
def update_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return update_where(self.name, *args, **kwargs)
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return delete_where(self.name, *args, **kwargs)
|
||||
|
||||
def insert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert(self.name, *args, **kwargs)
|
||||
|
||||
def insert_many(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return insert_many(self.name, *args, **kwargs)
|
||||
|
||||
def update_many(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return update_many(self.name, *args, **kwargs)
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
with self.conn:
|
||||
return upsert(self.name, *args, **kwargs)
|
||||
|
||||
def save_query(self, func):
|
||||
"""
|
||||
Decorator to add a saved query to the table.
|
||||
"""
|
||||
self.queries[func.__name__] = func
|
||||
|
||||
|
||||
class Row:
|
||||
__slots__ = ('table', 'data', '_pending')
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, table, data, *args, **kwargs):
|
||||
super().__setattr__('table', table)
|
||||
self.data = data
|
||||
self._pending = None
|
||||
|
||||
@property
|
||||
def rowid(self):
|
||||
return self.table.id_from_row(self.data)
|
||||
|
||||
def __repr__(self):
|
||||
return "Row[{}]({})".format(
|
||||
self.table.name,
|
||||
', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns)
|
||||
)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self.table.columns:
|
||||
if self._pending and key in self._pending:
|
||||
return self._pending[key]
|
||||
else:
|
||||
return self.data[key]
|
||||
else:
|
||||
raise AttributeError(key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self.table.columns:
|
||||
if self._pending is None:
|
||||
self.update(**{key: value})
|
||||
else:
|
||||
self._pending[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def batch_update(self):
|
||||
if self._pending:
|
||||
raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__))
|
||||
|
||||
self._pending = {}
|
||||
try:
|
||||
yield self._pending
|
||||
finally:
|
||||
self.update(**self._pending)
|
||||
self._pending = None
|
||||
|
||||
def _refresh(self):
|
||||
row = self.table.select_one_where(self.table.dict_from_id(self.rowid))
|
||||
if not row:
|
||||
raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__))
|
||||
self.data = row
|
||||
|
||||
def update(self, **values):
|
||||
rows = self.table.update_where(values, **self.table.dict_from_id(self.rowid))
|
||||
self.data = rows[0]
|
||||
|
||||
@classmethod
|
||||
def _select_where(cls, _extra=None, **conditions):
|
||||
return select_where(cls._table, **conditions)
|
||||
|
||||
@classmethod
|
||||
def _insert(cls, **values):
|
||||
return insert(cls._table, **values)
|
||||
|
||||
@classmethod
|
||||
def _update_where(cls, values, **conditions):
|
||||
return update_where(cls._table, values, **conditions)
|
||||
|
||||
|
||||
class RowTable(Table):
|
||||
__slots__ = (
|
||||
'name',
|
||||
'columns',
|
||||
'id_col',
|
||||
'multi_key',
|
||||
'row_cache'
|
||||
)
|
||||
|
||||
conn = conn
|
||||
|
||||
def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000, **kwargs):
|
||||
super().__init__(name, **kwargs)
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.id_col = id_col
|
||||
self.multi_key = isinstance(id_col, tuple)
|
||||
self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None
|
||||
|
||||
def id_from_row(self, row):
|
||||
if self.multi_key:
|
||||
return tuple(row[key] for key in self.id_col)
|
||||
else:
|
||||
return row[self.id_col]
|
||||
|
||||
def dict_from_id(self, rowid):
|
||||
if self.multi_key:
|
||||
return dict(zip(self.id_col, rowid))
|
||||
else:
|
||||
return {self.id_col: rowid}
|
||||
|
||||
# Extend original Table update methods to modify the cached rows
|
||||
def insert(self, *args, **kwargs):
|
||||
data = super().insert(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
self.row_cache[self.id_from_row(data)] = Row(self, data)
|
||||
return data
|
||||
|
||||
def insert_many(self, *args, **kwargs):
|
||||
data = super().insert_many(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
return data
|
||||
|
||||
def update_where(self, *args, **kwargs):
|
||||
data = super().update_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
return data
|
||||
|
||||
def update_many(self, *args, **kwargs):
|
||||
data = super().update_many(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
cached_row = self.row_cache.get(self.id_from_row(data_row), None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
return data
|
||||
|
||||
def delete_where(self, *args, **kwargs):
|
||||
data = super().delete_where(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
for data_row in data:
|
||||
self.row_cache.pop(self.id_from_row(data_row), None)
|
||||
return data
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
data = super().upsert(*args, **kwargs)
|
||||
if self.row_cache is not None:
|
||||
rowid = self.id_from_row(data)
|
||||
cached_row = self.row_cache.get(rowid, None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data
|
||||
else:
|
||||
self.row_cache[rowid] = Row(self, data)
|
||||
return data
|
||||
|
||||
# New methods to fetch and create rows
|
||||
def _make_rows(self, *data_rows):
|
||||
"""
|
||||
Create or retrieve Row objects for each provided data row.
|
||||
If the rows already exist in cache, updates the cached row.
|
||||
"""
|
||||
if self.row_cache is not None:
|
||||
rows = []
|
||||
for data_row in data_rows:
|
||||
rowid = self.id_from_row(data_row)
|
||||
|
||||
cached_row = self.row_cache.get(rowid, None)
|
||||
if cached_row is not None:
|
||||
cached_row.data = data_row
|
||||
row = cached_row
|
||||
else:
|
||||
row = Row(self, data_row)
|
||||
self.row_cache[rowid] = row
|
||||
rows.append(row)
|
||||
else:
|
||||
rows = [Row(self, data_row) for data_row in data_rows]
|
||||
return rows
|
||||
|
||||
def create_row(self, *args, **kwargs):
|
||||
data = self.insert(*args, **kwargs)
|
||||
return self._make_rows(data)[0]
|
||||
|
||||
def fetch_rows_where(self, *args, **kwargs):
|
||||
# TODO: Handle list of rowids here?
|
||||
data = self.select_where(*args, **kwargs)
|
||||
return self._make_rows(*data)
|
||||
|
||||
def fetch(self, rowid):
|
||||
"""
|
||||
Fetch the row with the given id, retrieving from cache where possible.
|
||||
"""
|
||||
row = self.row_cache.get(rowid, None) if self.row_cache is not None else None
|
||||
if row is None:
|
||||
rows = self.fetch_rows_where(**self.dict_from_id(rowid))
|
||||
row = rows[0] if rows else None
|
||||
return row
|
||||
|
||||
def fetch_or_create(self, rowid=None, **kwargs):
|
||||
"""
|
||||
Helper method to fetch a row with the given id or fields, or create it if it doesn't exist.
|
||||
"""
|
||||
if rowid is not None:
|
||||
row = self.fetch(rowid)
|
||||
else:
|
||||
data = self.select_where(**kwargs)
|
||||
row = self._make_rows(data[0])[0] if data else None
|
||||
|
||||
if row is None:
|
||||
creation_kwargs = kwargs
|
||||
if rowid is not None:
|
||||
creation_kwargs.update(self.dict_from_id(rowid))
|
||||
row = self.create_row(**creation_kwargs)
|
||||
return row
|
||||
149
bot/data/queries.py
Normal file
149
bot/data/queries.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from itertools import chain
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from .connection import conn
|
||||
from .formatters import (_format_updatestr, _format_conditions, _format_insertkeys,
|
||||
_format_selectkeys, _format_insertvalues)
|
||||
|
||||
|
||||
def select_where(table, select_columns=None, cursor=None, _extra='', **conditions):
|
||||
"""
|
||||
Select rows from the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
col_str = _format_selectkeys(select_columns)
|
||||
|
||||
if criteria:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def update_where(table, valuedict, cursor=None, **conditions):
|
||||
"""
|
||||
Update rows in the given table matching the conditions
|
||||
"""
|
||||
key_str, key_values = _format_updatestr(valuedict)
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
if criteria:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str),
|
||||
tuple((*key_values, *criteria_values))
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def delete_where(table, cursor=None, **conditions):
|
||||
"""
|
||||
Delete rows in the given table matching the conditions
|
||||
"""
|
||||
criteria, criteria_values = _format_conditions(conditions)
|
||||
|
||||
if criteria:
|
||||
where_str = "WHERE {}".format(criteria)
|
||||
else:
|
||||
where_str = ""
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'DELETE FROM {} {} RETURNING *'.format(table, where_str),
|
||||
criteria_values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def insert(table, cursor=None, allow_replace=False, **values):
|
||||
"""
|
||||
Insert the given values into the table
|
||||
"""
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
|
||||
action = 'REPLACE' if allow_replace else 'INSERT'
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def insert_many(table, *value_tuples, insert_keys=None, cursor=None):
|
||||
"""
|
||||
Insert all the given values into the table
|
||||
"""
|
||||
key_str = _format_insertkeys(insert_keys)
|
||||
value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples))
|
||||
|
||||
value_str = ", ".join(value_strs)
|
||||
values = tuple(chain(*value_tuples))
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str),
|
||||
values
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def upsert(table, constraint, cursor=None, **values):
|
||||
"""
|
||||
Insert or on conflict update.
|
||||
"""
|
||||
valuedict = values
|
||||
keys, values = zip(*values.items())
|
||||
|
||||
key_str = _format_insertkeys(keys)
|
||||
value_str, values = _format_insertvalues(values)
|
||||
update_key_str, update_key_values = _format_updatestr(valuedict)
|
||||
|
||||
if not isinstance(constraint, str):
|
||||
constraint = ", ".join(constraint)
|
||||
|
||||
cursor = cursor or conn.cursor()
|
||||
cursor.execute(
|
||||
'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
||||
table, key_str, value_str, constraint, update_key_str
|
||||
),
|
||||
tuple((*values, *update_key_values))
|
||||
)
|
||||
return cursor.fetchone()
|
||||
|
||||
|
||||
def update_many(table, *values, set_keys=None, where_keys=None, cursor=None):
|
||||
cursor = cursor or conn.cursor()
|
||||
|
||||
return execute_values(
|
||||
cursor,
|
||||
"""
|
||||
UPDATE {table}
|
||||
SET {set_clause}
|
||||
FROM (VALUES %s)
|
||||
AS {temp_table}
|
||||
WHERE {where_clause}
|
||||
RETURNING *
|
||||
""".format(
|
||||
table=table,
|
||||
set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
||||
where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
||||
temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
||||
),
|
||||
values,
|
||||
fetch=True
|
||||
)
|
||||
Reference in New Issue
Block a user