rewrite: Add support for database schemas.

This commit is contained in:
2022-11-18 11:01:56 +02:00
parent 916de8dd4c
commit b471e78a75
4 changed files with 32 additions and 22 deletions

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from .base import RawExpr, Expression from .base import RawExpr, Expression
from .conditions import Condition, Joiner from .conditions import Condition, Joiner
from .table import Table
class ColumnExpr(RawExpr): class ColumnExpr(RawExpr):
@@ -111,7 +112,6 @@ class Column(ColumnExpr, Generic[T]):
self.references = references self.references = references
self.name: str = name # type: ignore self.name: str = name # type: ignore
self.owner: Optional['RowModel'] = None self.owner: Optional['RowModel'] = None
self.tablename: Optional[str] = None
self._type = type self._type = type
self.expr = sql.Identifier(name) if name else sql.SQL('') self.expr = sql.Identifier(name) if name else sql.SQL('')
@@ -122,8 +122,7 @@ class Column(ColumnExpr, Generic[T]):
if self.owner is None: if self.owner is None:
self.name = self.name or name self.name = self.name or name
self.owner = owner self.owner = owner
self.tablename = owner._tablename_ self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
self.expr = sql.Identifier(self.tablename, self.name)
@overload @overload
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]': def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':

View File

@@ -63,7 +63,7 @@ class RowTable(Table, Generic[RowT]):
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]: def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
# TODO: Handle list of rowids here? # TODO: Handle list of rowids here?
return q.Select( return q.Select(
self.name, self.identifier,
row_adapter=self.model._make_rows, row_adapter=self.model._make_rows,
connector=self.connector connector=self.connector
).where(*args, **kwargs) ).where(*args, **kwargs)
@@ -118,6 +118,7 @@ class WeakCache(MutableMapping):
class RowModel: class RowModel:
__slots__ = ('data',) __slots__ = ('data',)
_schema_: str = 'public'
_tablename_: Optional[str] = None _tablename_: Optional[str] = None
_columns_: dict[str, Column] = {} _columns_: dict[str, Column] = {}
@@ -147,7 +148,7 @@ class RowModel:
cls._columns_ = columns cls._columns_ = columns
if not cls._key_: if not cls._key_:
cls._key_ = tuple(column.name for column in columns.values() if column.primary) cls._key_ = tuple(column.name for column in columns.values() if column.primary)
cls.table = RowTable(cls._tablename_, cls) cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
if cls._cache_ is None: if cls._cache_ is None:
cls._cache_ = WeakValueDictionary() cls._cache_ = WeakValueDictionary()
@@ -199,7 +200,8 @@ class RowModel:
return tuple(self.data[key] for key in self._key_) return tuple(self.data[key] for key in self._key_)
def __repr__(self): def __repr__(self):
return "{}({})".format( return "{}.{}({})".format(
self.table.schema,
self.table.name, self.table.name,
', '.join(repr(column.__get__(self)) for column in self._columns_.values()) ', '.join(repr(column.__get__(self)) for column in self._columns_.values())
) )

View File

@@ -113,13 +113,13 @@ class TableQuery(Query[QueryResult]):
ABC for an executable query statement expected to be run on a single table. ABC for an executable query statement expected to be run on a single table.
""" """
__slots__ = ( __slots__ = (
'table', 'tableid',
'condition', '_extra', '_limit', '_order', '_joins' 'condition', '_extra', '_limit', '_order', '_joins'
) )
def __init__(self, table, *args, **kwargs): def __init__(self, tableid, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.table: str = table self.tableid: sql.Identifier = tableid
def options(self, **kwargs): def options(self, **kwargs):
""" """
@@ -367,7 +367,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
# TODO: Check efficiency of inserting multiple values like this # TODO: Check efficiency of inserting multiple values like this
# Also implement a Copy query # Also implement a Copy query
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format( base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
table=sql.Identifier(self.table), table=self.tableid,
columns=columns, columns=columns,
values_str=values_str values_str=values_str
) )
@@ -428,7 +428,7 @@ class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQue
base = sql.SQL("SELECT {columns} FROM {table}").format( base = sql.SQL("SELECT {columns} FROM {table}").format(
columns=columns, columns=columns,
table=sql.Identifier(self.table) table=self.tableid
) )
sections = [ sections = [
@@ -453,7 +453,7 @@ class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
def build(self): def build(self):
base = sql.SQL("DELETE FROM {table}").format( base = sql.SQL("DELETE FROM {table}").format(
table=sql.Identifier(self.table), table=self.tableid,
) )
sections = [ sections = [
RawExpr(base), RawExpr(base),
@@ -500,7 +500,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple() set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
base = sql.SQL("UPDATE {table} SET {set}").format( base = sql.SQL("UPDATE {table} SET {set}").format(
table=sql.Identifier(self.table), table=self.tableid,
set=set_expr set=set_expr
) )
sections = [ sections = [

View File

@@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from psycopg.rows import DictRow from psycopg.rows import DictRow
from psycopg import sql
from . import queries as q from . import queries as q
from .connector import Connector from .connector import Connector
@@ -12,10 +13,18 @@ class Table:
Contains standard methods to access the table. Contains standard methods to access the table.
""" """
def __init__(self, name, *args, **kwargs): def __init__(self, name, *args, schema='public', **kwargs):
self.name: str = name self.name: str = name
self.schema: str = schema
self.connector: Connector = None self.connector: Connector = None
@property
def identifier(self):
if self.schema == 'public':
return sql.Identifier(self.name)
else:
return sql.Identifier(self.schema, self.name)
def bind(self, connector: Connector): def bind(self, connector: Connector):
self.connector = connector self.connector = connector
return self return self
@@ -38,49 +47,49 @@ class Table:
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]: def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
return q.Select( return q.Select(
self.name, self.identifier,
row_adapter=self._many_query_adapter, row_adapter=self._many_query_adapter,
connector=self.connector connector=self.connector
).where(*args, **kwargs) ).where(*args, **kwargs)
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]: def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
return q.Select( return q.Select(
self.name, self.identifier,
row_adapter=self._single_query_adapter, row_adapter=self._single_query_adapter,
connector=self.connector connector=self.connector
).where(*args, **kwargs) ).where(*args, **kwargs)
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]: def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
return q.Update( return q.Update(
self.name, self.identifier,
row_adapter=self._many_query_adapter, row_adapter=self._many_query_adapter,
connector=self.connector connector=self.connector
).where(*args, **kwargs) ).where(*args, **kwargs)
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]: def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
return q.Delete( return q.Delete(
self.name, self.identifier,
row_adapter=self._many_query_adapter, row_adapter=self._many_query_adapter,
connector=self.connector connector=self.connector
).where(*args, **kwargs) ).where(*args, **kwargs)
def insert(self, **column_values) -> q.Insert[DictRow]: def insert(self, **column_values) -> q.Insert[DictRow]:
return q.Insert( return q.Insert(
self.name, self.identifier,
row_adapter=self._single_query_adapter, row_adapter=self._single_query_adapter,
connector=self.connector connector=self.connector
).insert(column_values.keys(), column_values.values()) ).insert(column_values.keys(), column_values.values())
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]: def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
return q.Insert( return q.Insert(
self.name, self.identifier,
row_adapter=self._many_query_adapter, row_adapter=self._many_query_adapter,
connector=self.connector connector=self.connector
).insert(*args, **kwargs) ).insert(*args, **kwargs)
# def update_many(self, *args, **kwargs): # def update_many(self, *args, **kwargs):
# with self.conn: # with self.conn:
# return update_many(self.name, *args, **kwargs) # return update_many(self.identifier, *args, **kwargs)
# def upsert(self, *args, **kwargs): # def upsert(self, *args, **kwargs):
# return upsert(self.name, *args, **kwargs) # return upsert(self.identifier, *args, **kwargs)