rewrite: Add support for database schemas.
This commit is contained in:
@@ -4,6 +4,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from .base import RawExpr, Expression
|
from .base import RawExpr, Expression
|
||||||
from .conditions import Condition, Joiner
|
from .conditions import Condition, Joiner
|
||||||
|
from .table import Table
|
||||||
|
|
||||||
|
|
||||||
class ColumnExpr(RawExpr):
|
class ColumnExpr(RawExpr):
|
||||||
@@ -111,7 +112,6 @@ class Column(ColumnExpr, Generic[T]):
|
|||||||
self.references = references
|
self.references = references
|
||||||
self.name: str = name # type: ignore
|
self.name: str = name # type: ignore
|
||||||
self.owner: Optional['RowModel'] = None
|
self.owner: Optional['RowModel'] = None
|
||||||
self.tablename: Optional[str] = None
|
|
||||||
self._type = type
|
self._type = type
|
||||||
|
|
||||||
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
self.expr = sql.Identifier(name) if name else sql.SQL('')
|
||||||
@@ -122,8 +122,7 @@ class Column(ColumnExpr, Generic[T]):
|
|||||||
if self.owner is None:
|
if self.owner is None:
|
||||||
self.name = self.name or name
|
self.name = self.name or name
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.tablename = owner._tablename_
|
self.expr = sql.Identifier(self.owner._schema_, self.owner._tablename_, self.name)
|
||||||
self.expr = sql.Identifier(self.tablename, self.name)
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
def __get__(self: 'Column[T]', obj: None, objtype: "None | Type['RowModel']") -> 'Column[T]':
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class RowTable(Table, Generic[RowT]):
|
|||||||
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
def fetch_rows_where(self, *args, **kwargs) -> q.Select[list[RowT]]:
|
||||||
# TODO: Handle list of rowids here?
|
# TODO: Handle list of rowids here?
|
||||||
return q.Select(
|
return q.Select(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self.model._make_rows,
|
row_adapter=self.model._make_rows,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).where(*args, **kwargs)
|
).where(*args, **kwargs)
|
||||||
@@ -118,6 +118,7 @@ class WeakCache(MutableMapping):
|
|||||||
class RowModel:
|
class RowModel:
|
||||||
__slots__ = ('data',)
|
__slots__ = ('data',)
|
||||||
|
|
||||||
|
_schema_: str = 'public'
|
||||||
_tablename_: Optional[str] = None
|
_tablename_: Optional[str] = None
|
||||||
_columns_: dict[str, Column] = {}
|
_columns_: dict[str, Column] = {}
|
||||||
|
|
||||||
@@ -147,7 +148,7 @@ class RowModel:
|
|||||||
cls._columns_ = columns
|
cls._columns_ = columns
|
||||||
if not cls._key_:
|
if not cls._key_:
|
||||||
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
cls._key_ = tuple(column.name for column in columns.values() if column.primary)
|
||||||
cls.table = RowTable(cls._tablename_, cls)
|
cls.table = RowTable(cls._tablename_, cls, schema=cls._schema_)
|
||||||
if cls._cache_ is None:
|
if cls._cache_ is None:
|
||||||
cls._cache_ = WeakValueDictionary()
|
cls._cache_ = WeakValueDictionary()
|
||||||
|
|
||||||
@@ -199,7 +200,8 @@ class RowModel:
|
|||||||
return tuple(self.data[key] for key in self._key_)
|
return tuple(self.data[key] for key in self._key_)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}({})".format(
|
return "{}.{}({})".format(
|
||||||
|
self.table.schema,
|
||||||
self.table.name,
|
self.table.name,
|
||||||
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
', '.join(repr(column.__get__(self)) for column in self._columns_.values())
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -113,13 +113,13 @@ class TableQuery(Query[QueryResult]):
|
|||||||
ABC for an executable query statement expected to be run on a single table.
|
ABC for an executable query statement expected to be run on a single table.
|
||||||
"""
|
"""
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
'table',
|
'tableid',
|
||||||
'condition', '_extra', '_limit', '_order', '_joins'
|
'condition', '_extra', '_limit', '_order', '_joins'
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, table, *args, **kwargs):
|
def __init__(self, tableid, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.table: str = table
|
self.tableid: sql.Identifier = tableid
|
||||||
|
|
||||||
def options(self, **kwargs):
|
def options(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -367,7 +367,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|||||||
# TODO: Check efficiency of inserting multiple values like this
|
# TODO: Check efficiency of inserting multiple values like this
|
||||||
# Also implement a Copy query
|
# Also implement a Copy query
|
||||||
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
||||||
table=sql.Identifier(self.table),
|
table=self.tableid,
|
||||||
columns=columns,
|
columns=columns,
|
||||||
values_str=values_str
|
values_str=values_str
|
||||||
)
|
)
|
||||||
@@ -428,7 +428,7 @@ class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, TableQue
|
|||||||
|
|
||||||
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
||||||
columns=columns,
|
columns=columns,
|
||||||
table=sql.Identifier(self.table)
|
table=self.tableid
|
||||||
)
|
)
|
||||||
|
|
||||||
sections = [
|
sections = [
|
||||||
@@ -453,7 +453,7 @@ class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
|||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
base = sql.SQL("DELETE FROM {table}").format(
|
base = sql.SQL("DELETE FROM {table}").format(
|
||||||
table=sql.Identifier(self.table),
|
table=self.tableid,
|
||||||
)
|
)
|
||||||
sections = [
|
sections = [
|
||||||
RawExpr(base),
|
RawExpr(base),
|
||||||
@@ -500,7 +500,7 @@ class Update(LimitMixin, WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
|||||||
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
||||||
|
|
||||||
base = sql.SQL("UPDATE {table} SET {set}").format(
|
base = sql.SQL("UPDATE {table} SET {set}").format(
|
||||||
table=sql.Identifier(self.table),
|
table=self.tableid,
|
||||||
set=set_expr
|
set=set_expr
|
||||||
)
|
)
|
||||||
sections = [
|
sections = [
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from psycopg.rows import DictRow
|
from psycopg.rows import DictRow
|
||||||
|
from psycopg import sql
|
||||||
|
|
||||||
from . import queries as q
|
from . import queries as q
|
||||||
from .connector import Connector
|
from .connector import Connector
|
||||||
@@ -12,10 +13,18 @@ class Table:
|
|||||||
Contains standard methods to access the table.
|
Contains standard methods to access the table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, *args, **kwargs):
|
def __init__(self, name, *args, schema='public', **kwargs):
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
|
self.schema: str = schema
|
||||||
self.connector: Connector = None
|
self.connector: Connector = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifier(self):
|
||||||
|
if self.schema == 'public':
|
||||||
|
return sql.Identifier(self.name)
|
||||||
|
else:
|
||||||
|
return sql.Identifier(self.schema, self.name)
|
||||||
|
|
||||||
def bind(self, connector: Connector):
|
def bind(self, connector: Connector):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
return self
|
return self
|
||||||
@@ -38,49 +47,49 @@ class Table:
|
|||||||
|
|
||||||
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
def select_where(self, *args, **kwargs) -> q.Select[tuple[DictRow, ...]]:
|
||||||
return q.Select(
|
return q.Select(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._many_query_adapter,
|
row_adapter=self._many_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).where(*args, **kwargs)
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
def select_one_where(self, *args, **kwargs) -> q.Select[DictRow]:
|
||||||
return q.Select(
|
return q.Select(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._single_query_adapter,
|
row_adapter=self._single_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).where(*args, **kwargs)
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
def update_where(self, *args, **kwargs) -> q.Update[tuple[DictRow, ...]]:
|
||||||
return q.Update(
|
return q.Update(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._many_query_adapter,
|
row_adapter=self._many_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).where(*args, **kwargs)
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
def delete_where(self, *args, **kwargs) -> q.Delete[tuple[DictRow, ...]]:
|
||||||
return q.Delete(
|
return q.Delete(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._many_query_adapter,
|
row_adapter=self._many_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).where(*args, **kwargs)
|
).where(*args, **kwargs)
|
||||||
|
|
||||||
def insert(self, **column_values) -> q.Insert[DictRow]:
|
def insert(self, **column_values) -> q.Insert[DictRow]:
|
||||||
return q.Insert(
|
return q.Insert(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._single_query_adapter,
|
row_adapter=self._single_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).insert(column_values.keys(), column_values.values())
|
).insert(column_values.keys(), column_values.values())
|
||||||
|
|
||||||
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
def insert_many(self, *args, **kwargs) -> q.Insert[tuple[DictRow, ...]]:
|
||||||
return q.Insert(
|
return q.Insert(
|
||||||
self.name,
|
self.identifier,
|
||||||
row_adapter=self._many_query_adapter,
|
row_adapter=self._many_query_adapter,
|
||||||
connector=self.connector
|
connector=self.connector
|
||||||
).insert(*args, **kwargs)
|
).insert(*args, **kwargs)
|
||||||
|
|
||||||
# def update_many(self, *args, **kwargs):
|
# def update_many(self, *args, **kwargs):
|
||||||
# with self.conn:
|
# with self.conn:
|
||||||
# return update_many(self.name, *args, **kwargs)
|
# return update_many(self.identifier, *args, **kwargs)
|
||||||
|
|
||||||
# def upsert(self, *args, **kwargs):
|
# def upsert(self, *args, **kwargs):
|
||||||
# return upsert(self.name, *args, **kwargs)
|
# return upsert(self.identifier, *args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user