From b471e78a7576499edd49059d077e9008f81a68db Mon Sep 17 00:00:00 2001 From: Conatum Date: Fri, 18 Nov 2022 11:01:56 +0200 Subject: [PATCH] rewrite: Add support for database schemas. --- bot/data/columns.py | 5 ++--- bot/data/models.py | 8 +++++--- bot/data/queries.py | 14 +++++++------- bot/data/table.py | 27 ++++++++++++++++++--------- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/bot/data/columns.py b/bot/data/columns.py index 60626f17..10298a82 100644 --- a/bot/data/columns.py +++ b/bot/data/columns.py @@ -4,6 +4,7 @@ from datetime import datetime from .base import RawExpr, Expression from .conditions import Condition, Joiner +from .table import Table class ColumnExpr(RawExpr): @@ -111,7 +112,6 @@ class Column(ColumnExpr, Generic[T]): self.references = references self.name: str = name # type: ignore self.owner: Optional['RowModel'] = None - self.tablename: Optional[str] = None self._type = type self.expr = sql.Identifier(name) if name else sql.SQL('') @@ -122,8 +122,7 @@ class Column(ColumnExpr, Generic[T]): if self.owner is None: self.name = self.name or name self.owner = owner - self.tablename = owner._tablename_ - self.expr = sql.Identifier(self.tablename, self.name) + self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name) @overload def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': diff --git a/bot/data/models.py b/bot/data/models.py index 2ebb470b..8fef0f10 100644 --- a/bot/data/models.py +++ b/bot/data/models.py @@ -63,7 +63,7 @@ class RowTable(Table, Generic[RowT]): def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: # TODO: Handle list of rowids here? return q.Select( - self.name, + self.identifier, row_adapter=self.model._make_rows, connector=self.connector ).where(*args, **kwargs) @@ -118,6 +118,7 @@ class WeakCache(MutableMapping): class RowModel: __slots__ = ('data',) + _schema_: str = 'public' _tablename_: Optional[str] = None _columns_: dict[str, Column] = {} @@ -147,7 +148,7 @@ class RowModel: cls._columns_ = columns if not cls._key_: cls._key_ = tuple(column.name for column in columns.values() if column.primary) - cls.table = RowTable(cls._tablename_, cls) + cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_) if cls._cache_ is None: cls._cache_ = WeakValueDictionary() @@ -199,7 +200,8 @@ class RowModel: return tuple(self.data[key] for key in self._key_) def __repr__(self): - return "{}({})".format( + return "{}.{}({})".format( + self.table.schema, self.table.name, ', '.join(repr(column.__get__(self)) for column in self._columns_.values()) ) diff --git a/bot/data/queries.py b/bot/data/queries.py index 5d7bc0a0..35db24be 100644 --- a/bot/data/queries.py +++ b/bot/data/queries.py @@ -113,13 +113,13 @@ class TableQuery(Query[QueryResult]): ABC for an executable query statement expected to be run on a single table. """ __slots__ = ( - 'table', + 'tableid', 'condition', '_extra', '_limit', '_order', '_joins' ) - def __init__(self, table, *args, **kwargs): + def __init__(self, tableid, *args, **kwargs): super().__init__(*args, **kwargs) - self.table: str = table + self.tableid: sql.Identifier = tableid def options(self, **kwargs): """ @@ -367,7 +367,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): # TODO: Check efficiency of inserting multiple values like this # Also implement a Copy query base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( - table=sql.Identifier(self.table), + table=self.tableid, columns=columns, values_str=values_str ) @@ -428,7 +428,7 @@ class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQue base = sql.SQL("SELECT {columns} FROM {table}").format( columns=columns, - table=sql.Identifier(self.table) + table=self.tableid ) sections = [ @@ -453,7 +453,7 @@ class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]): def build(self): base = sql.SQL("DELETE FROM {table}").format( - table=sql.Identifier(self.table), + table=self.tableid, ) sections = [ RawExpr(base), @@ -500,7 +500,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]): set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() base = sql.SQL("UPDATE {table} SET {set}").format( - table=sql.Identifier(self.table), + table=self.tableid, set=set_expr ) sections = [ diff --git a/bot/data/table.py b/bot/data/table.py index 567f49ce..e20647e7 100644 --- a/bot/data/table.py +++ b/bot/data/table.py @@ -1,5 +1,6 @@ from typing import Optional from psycopg.rows import DictRow +from psycopg import sql from . import queries as q from .connector import Connector @@ -12,10 +13,18 @@ class Table: Contains standard methods to access the table. """ - def __init__(self, name, *args, **kwargs): + def __init__(self, name, *args, schema='public', **kwargs): self.name: str = name + self.schema: str = schema self.connector: Connector = None + @property + def identifier(self): + if self.schema == 'public': + return sql.Identifier(self.name) + else: + return sql.Identifier(self.schema, self.name) + def bind(self, connector: Connector): self.connector = connector return self @@ -38,49 +47,49 @@ class Table: def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: return q.Select( - self.name, + self.identifier, row_adapter=self._many_query_adapter, connector=self.connector ).where(*args, **kwargs) def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: return q.Select( - self.name, + self.identifier, row_adapter=self._single_query_adapter, connector=self.connector ).where(*args, **kwargs) def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: return q.Update( - self.name, + self.identifier, row_adapter=self._many_query_adapter, connector=self.connector ).where(*args, **kwargs) def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: return q.Delete( - self.name, + self.identifier, row_adapter=self._many_query_adapter, connector=self.connector ).where(*args, **kwargs) def insert(self, **column_values) -> q.Insert[DictRow]: return q.Insert( - self.name, + self.identifier, row_adapter=self._single_query_adapter, connector=self.connector ).insert(column_values.keys(), column_values.values()) def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: return q.Insert( - self.name, + self.identifier, row_adapter=self._many_query_adapter, connector=self.connector ).insert(*args, **kwargs) # def update_many(self, *args, **kwargs): # with self.conn: -# return update_many(self.name, *args, **kwargs) +# return update_many(self.identifier, *args, **kwargs) # def upsert(self, *args, **kwargs): -# return upsert(self.name, *args, **kwargs) +# return upsert(self.identifier, *args, **kwargs)