645 lines
20 KiB
Python
645 lines
20 KiB
Python
from typing import Optional, TypeVar, Any, Callable, Generic, List, Union
|
|
from enum import Enum
|
|
from itertools import chain
|
|
from psycopg import AsyncConnection, AsyncCursor
|
|
from psycopg import sql
|
|
from psycopg.rows import DictRow
|
|
|
|
import logging
|
|
|
|
from .conditions import Condition
|
|
from .base import Expression, RawExpr
|
|
from .connector import Connector
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
TQueryT = TypeVar('TQueryT', bound='TableQuery')
|
|
SQueryT = TypeVar('SQueryT', bound='Select')
|
|
|
|
QueryResult = TypeVar('QueryResult')
|
|
|
|
|
|
class Query(Generic[QueryResult]):
|
|
"""
|
|
ABC for an executable query statement.
|
|
"""
|
|
__slots__ = ('conn', 'cursor', '_adapter', 'connector', 'result')
|
|
|
|
_adapter: Callable[..., QueryResult]
|
|
|
|
def __init__(self, *args, row_adapter=None, connector=None, conn=None, cursor=None, **kwargs):
|
|
self.connector: Optional[Connector] = connector
|
|
self.conn: Optional[AsyncConnection] = conn
|
|
self.cursor: Optional[AsyncCursor] = cursor
|
|
|
|
if row_adapter is not None:
|
|
self._adapter = row_adapter
|
|
else:
|
|
self._adapter = self._no_adapter
|
|
|
|
self.result: Optional[QueryResult] = None
|
|
|
|
def bind(self, connector: Connector):
|
|
self.connector = connector
|
|
return self
|
|
|
|
def with_cursor(self, cursor: AsyncCursor):
|
|
self.cursor = cursor
|
|
return self
|
|
|
|
def with_connection(self, conn: AsyncConnection):
|
|
self.conn = conn
|
|
return self
|
|
|
|
def _no_adapter(self, *data: DictRow) -> tuple[DictRow, ...]:
|
|
return data
|
|
|
|
def with_adapter(self, callable: Callable[..., QueryResult]):
|
|
# NOTE: Postcomposition functor, Query[QR2] = (QR1 -> QR2) o Query[QR1]
|
|
# For this to work cleanly, callable should have arg type of QR1, not any
|
|
self._adapter = callable
|
|
return self
|
|
|
|
def with_no_adapter(self):
|
|
"""
|
|
Sets the adapater to the identity.
|
|
"""
|
|
self._adapter = self._no_adapter
|
|
return self
|
|
|
|
def one(self):
|
|
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
|
return self
|
|
|
|
def build(self) -> Expression:
|
|
raise NotImplementedError
|
|
|
|
async def _execute(self, cursor: AsyncCursor) -> QueryResult:
|
|
query, values = self.build().as_tuple()
|
|
# TODO: Move logging out to a custom cursor
|
|
# logger.debug(
|
|
# f"Executing query ({query.as_string(cursor)}) with values {values}",
|
|
# extra={'action': "Query"}
|
|
# )
|
|
await cursor.execute(sql.Composed((query,)), values)
|
|
data = await cursor.fetchall()
|
|
self.result = self._adapter(*data)
|
|
return self.result
|
|
|
|
async def execute(self, cursor=None) -> QueryResult:
|
|
"""
|
|
Execute the query, optionally with the provided cursor, and return the result rows.
|
|
If no cursor is provided, and no cursor has been set with `with_cursor`,
|
|
the execution will create a new cursor from the connection and close it automatically.
|
|
"""
|
|
# Create a cursor if possible
|
|
cursor = cursor if cursor is not None else self.cursor
|
|
if self.cursor is None:
|
|
if self.conn is None:
|
|
if self.connector is None:
|
|
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
|
else:
|
|
async with self.connector.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
data = await self._execute(cursor)
|
|
else:
|
|
async with self.conn.cursor() as cursor:
|
|
data = await self._execute(cursor)
|
|
else:
|
|
data = await self._execute(cursor)
|
|
return data
|
|
|
|
def __await__(self):
|
|
return self.execute().__await__()
|
|
|
|
|
|
class TableQuery(Query[QueryResult]):
|
|
"""
|
|
ABC for an executable query statement expected to be run on a single table.
|
|
"""
|
|
__slots__ = (
|
|
'tableid',
|
|
'condition', '_extra', '_limit', '_order', '_joins', '_from', '_group'
|
|
)
|
|
|
|
def __init__(self, tableid, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.tableid: sql.Identifier = tableid
|
|
|
|
def options(self, **kwargs):
|
|
"""
|
|
Set some query options.
|
|
Default implementation does nothing.
|
|
Should be overridden to provide specific options.
|
|
"""
|
|
return self
|
|
|
|
|
|
class WhereMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.condition: Optional[Condition] = None
|
|
|
|
def where(self, *args: Condition, **kwargs):
|
|
"""
|
|
Add a Condition to the query.
|
|
Position arguments should be Conditions,
|
|
and keyword arguments should be of the form `column=Value`,
|
|
where Value may be a Value-type or a literal value.
|
|
All provided Conditions will be and-ed together to create a new Condition.
|
|
TODO: Maybe just pass this verbatim to a condition.
|
|
"""
|
|
if args or kwargs:
|
|
condition = Condition.construct(*args, **kwargs)
|
|
if self.condition is not None:
|
|
condition = self.condition & condition
|
|
|
|
self.condition = condition
|
|
|
|
return self
|
|
|
|
@property
|
|
def _where_section(self) -> Optional[Expression]:
|
|
if self.condition is not None:
|
|
return RawExpr.join_tuples((sql.SQL('WHERE'), ()), self.condition.as_tuple())
|
|
else:
|
|
return None
|
|
|
|
|
|
class JOINTYPE(Enum):
|
|
LEFT = sql.SQL('LEFT JOIN')
|
|
RIGHT = sql.SQL('RIGHT JOIN')
|
|
INNER = sql.SQL('INNER JOIN')
|
|
OUTER = sql.SQL('OUTER JOIN')
|
|
FULLOUTER = sql.SQL('FULL OUTER JOIN')
|
|
|
|
|
|
class JoinMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
# TODO: Remember to add join slots to TableQuery
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._joins: list[Expression] = []
|
|
|
|
def join(self,
|
|
target: Union[str, Expression],
|
|
on: Optional[Condition] = None, using: Optional[Union[Expression, tuple[str, ...]]] = None,
|
|
join_type: JOINTYPE = JOINTYPE.INNER,
|
|
natural=False):
|
|
available = (on is not None) + (using is not None) + natural
|
|
if available == 0:
|
|
raise ValueError("No conditions given for Query Join")
|
|
if available > 1:
|
|
raise ValueError("Exactly one join format must be given for Query Join")
|
|
|
|
sections: list[tuple[sql.Composable, tuple[Any, ...]]] = [(join_type.value, ())]
|
|
if isinstance(target, str):
|
|
sections.append((sql.Identifier(target), ()))
|
|
else:
|
|
sections.append(target.as_tuple())
|
|
|
|
if on is not None:
|
|
sections.append((sql.SQL('ON'), ()))
|
|
sections.append(on.as_tuple())
|
|
elif using is not None:
|
|
sections.append((sql.SQL('USING'), ()))
|
|
if isinstance(using, Expression):
|
|
sections.append(using.as_tuple())
|
|
elif isinstance(using, tuple) and len(using) > 0 and isinstance(using[0], str):
|
|
cols = sql.SQL("({})").format(sql.SQL(',').join(sql.Identifier(col) for col in using))
|
|
sections.append((cols, ()))
|
|
else:
|
|
raise ValueError("Unrecognised 'using' type.")
|
|
elif natural:
|
|
sections.insert(0, (sql.SQL('NATURAL'), ()))
|
|
|
|
expr = RawExpr.join_tuples(*sections)
|
|
self._joins.append(expr)
|
|
return self
|
|
|
|
def leftjoin(self, *args, **kwargs):
|
|
return self.join(*args, join_type=JOINTYPE.LEFT, **kwargs)
|
|
|
|
@property
|
|
def _join_section(self) -> Optional[Expression]:
|
|
if self._joins:
|
|
return RawExpr.join(*self._joins)
|
|
else:
|
|
return None
|
|
|
|
|
|
class ExtraMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._extra: Optional[Expression] = None
|
|
|
|
def extra(self, extra: sql.Composable, values: tuple[Any, ...] = ()):
|
|
"""
|
|
Add an extra string, and optionally values, to this query.
|
|
The extra string is inserted after any condition, and before the limit.
|
|
"""
|
|
extra_expr = RawExpr(extra, values)
|
|
if self._extra is not None:
|
|
extra_expr = RawExpr.join(self._extra, extra_expr)
|
|
self._extra = extra_expr
|
|
return self
|
|
|
|
@property
|
|
def _extra_section(self) -> Optional[Expression]:
|
|
if self._extra is None:
|
|
return None
|
|
else:
|
|
return self._extra
|
|
|
|
|
|
class LimitMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._limit: Optional[int] = None
|
|
|
|
def limit(self, limit: int):
|
|
"""
|
|
Add a limit to this query.
|
|
"""
|
|
self._limit = limit
|
|
return self
|
|
|
|
@property
|
|
def _limit_section(self) -> Optional[Expression]:
|
|
if self._limit is not None:
|
|
return RawExpr(sql.SQL("LIMIT {}").format(sql.Placeholder()), (self._limit,))
|
|
else:
|
|
return None
|
|
|
|
|
|
class FromMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._from: Optional[Expression] = None
|
|
|
|
def from_expr(self, _from: Expression):
|
|
self._from = _from
|
|
return self
|
|
|
|
@property
|
|
def _from_section(self) -> Optional[Expression]:
|
|
if self._from is not None:
|
|
expr, values = self._from.as_tuple()
|
|
return RawExpr(sql.SQL("FROM {}").format(expr), values)
|
|
else:
|
|
return None
|
|
|
|
|
|
class ORDER(Enum):
|
|
ASC = sql.SQL('ASC')
|
|
DESC = sql.SQL('DESC')
|
|
|
|
|
|
class NULLS(Enum):
|
|
FIRST = sql.SQL('NULLS FIRST')
|
|
LAST = sql.SQL('NULLS LAST')
|
|
|
|
|
|
class OrderMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._order: list[Expression] = []
|
|
|
|
def order_by(self, expr: Union[Expression, str], direction: Optional[ORDER] = None, nulls: Optional[NULLS] = None):
|
|
"""
|
|
Add a single sort expression to the query.
|
|
This method stacks.
|
|
"""
|
|
if isinstance(expr, Expression):
|
|
string, values = expr.as_tuple()
|
|
else:
|
|
string = sql.Identifier(expr)
|
|
values = ()
|
|
|
|
parts = [string]
|
|
if direction is not None:
|
|
parts.append(direction.value)
|
|
if nulls is not None:
|
|
parts.append(nulls.value)
|
|
|
|
order_string = sql.SQL(' ').join(parts)
|
|
self._order.append(RawExpr(order_string, values))
|
|
return self
|
|
|
|
@property
|
|
def _order_section(self) -> Optional[Expression]:
|
|
if self._order:
|
|
expr = RawExpr.join(*self._order, joiner=sql.SQL(', '))
|
|
expr.expr = sql.SQL("ORDER BY {}").format(expr.expr)
|
|
return expr
|
|
else:
|
|
return None
|
|
|
|
|
|
class GroupMixin(TableQuery[QueryResult]):
|
|
__slots__ = ()
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._group: list[Expression] = []
|
|
|
|
def group_by(self, *exprs: Union[Expression, str]):
|
|
"""
|
|
Add a group expression(s) to the query.
|
|
This method stacks.
|
|
"""
|
|
for expr in exprs:
|
|
if isinstance(expr, Expression):
|
|
self._group.append(expr)
|
|
else:
|
|
self._group.append(RawExpr(sql.Identifier(expr)))
|
|
return self
|
|
|
|
@property
|
|
def _group_section(self) -> Optional[Expression]:
|
|
if self._group:
|
|
expr = RawExpr.join(*self._group, joiner=sql.SQL(', '))
|
|
expr.expr = sql.SQL("GROUP BY {}").format(expr.expr)
|
|
return expr
|
|
else:
|
|
return None
|
|
|
|
|
|
class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|
"""
|
|
Query type representing a table insert query.
|
|
"""
|
|
# TODO: Support ON CONFLICT for upserts
|
|
__slots__ = ('_columns', '_values', '_conflict')
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._columns: tuple[str, ...] = ()
|
|
self._values: tuple[tuple[Any, ...], ...] = ()
|
|
self._conflict: Optional[Expression] = None
|
|
|
|
def insert(self, columns, *values):
|
|
"""
|
|
Insert the given data.
|
|
|
|
Parameters
|
|
----------
|
|
columns: tuple[str]
|
|
Tuple of column names to insert.
|
|
|
|
values: tuple[tuple[Any, ...], ...]
|
|
Tuple of values to insert, corresponding to the columns.
|
|
"""
|
|
if not values:
|
|
raise ValueError("Cannot insert zero rows.")
|
|
if len(values[0]) != len(columns):
|
|
raise ValueError("Number of columns does not match length of values.")
|
|
|
|
self._columns = columns
|
|
self._values = values
|
|
return self
|
|
|
|
def on_conflict(self, ignore=False):
|
|
# TODO lots more we can do here
|
|
# Maybe return a Conflict object that can chain itself (not the query)
|
|
if ignore:
|
|
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
|
return self
|
|
|
|
@property
|
|
def _conflict_section(self) -> Optional[Expression]:
|
|
if self._conflict is not None:
|
|
e, v = self._conflict.as_tuple()
|
|
expr = RawExpr(
|
|
sql.SQL("ON CONFLICT {}").format(
|
|
e
|
|
),
|
|
v
|
|
)
|
|
return expr
|
|
return None
|
|
|
|
def build(self):
|
|
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
|
single_value_str = sql.SQL('({})').format(
|
|
sql.SQL(',').join(sql.Placeholder() * len(self._columns))
|
|
)
|
|
values_str = sql.SQL(',').join(single_value_str * len(self._values))
|
|
|
|
# TODO: Check efficiency of inserting multiple values like this
|
|
# Also implement a Copy query
|
|
base = sql.SQL("INSERT INTO {table} ({columns}) VALUES {values_str}").format(
|
|
table=self.tableid,
|
|
columns=columns,
|
|
values_str=values_str
|
|
)
|
|
|
|
sections = [
|
|
RawExpr(base, tuple(chain(*self._values))),
|
|
self._conflict_section,
|
|
self._extra_section,
|
|
RawExpr(sql.SQL('RETURNING *'))
|
|
]
|
|
|
|
sections = (section for section in sections if section is not None)
|
|
return RawExpr.join(*sections)
|
|
|
|
|
|
class Select(WhereMixin, ExtraMixin, OrderMixin, LimitMixin, JoinMixin, GroupMixin, TableQuery[QueryResult]):
|
|
"""
|
|
Select rows from a table matching provided conditions.
|
|
"""
|
|
__slots__ = ('_columns',)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._columns: tuple[Expression, ...] = ()
|
|
|
|
def select(self, *columns: str, **exprs: Union[str, sql.Composable, Expression]):
|
|
"""
|
|
Set the columns and expressions to select.
|
|
If none are given, selects all columns.
|
|
"""
|
|
cols: List[Expression] = []
|
|
if columns:
|
|
cols.extend(map(RawExpr, map(sql.Identifier, columns)))
|
|
if exprs:
|
|
for name, expr in exprs.items():
|
|
if isinstance(expr, str):
|
|
cols.append(
|
|
RawExpr(sql.SQL(expr) + sql.SQL(' AS ') + sql.Identifier(name))
|
|
)
|
|
elif isinstance(expr, sql.Composable):
|
|
cols.append(
|
|
RawExpr(expr + sql.SQL(' AS ') + sql.Identifier(name))
|
|
)
|
|
elif isinstance(expr, Expression):
|
|
value_expr, value_values = expr.as_tuple()
|
|
cols.append(RawExpr(
|
|
value_expr + sql.SQL(' AS ') + sql.Identifier(name),
|
|
value_values
|
|
))
|
|
if cols:
|
|
self._columns = (*self._columns, *cols)
|
|
return self
|
|
|
|
def build(self):
|
|
if not self._columns:
|
|
columns, columns_values = sql.SQL('*'), ()
|
|
else:
|
|
columns, columns_values = RawExpr.join(*self._columns, joiner=sql.SQL(',')).as_tuple()
|
|
|
|
base = sql.SQL("SELECT {columns} FROM {table}").format(
|
|
columns=columns,
|
|
table=self.tableid
|
|
)
|
|
|
|
sections = [
|
|
RawExpr(base, columns_values),
|
|
self._join_section,
|
|
self._where_section,
|
|
self._group_section,
|
|
self._extra_section,
|
|
self._order_section,
|
|
self._limit_section,
|
|
]
|
|
|
|
sections = (section for section in sections if section is not None)
|
|
return RawExpr.join(*sections)
|
|
|
|
|
|
class Delete(WhereMixin, ExtraMixin, TableQuery[QueryResult]):
|
|
"""
|
|
Query type representing a table delete query.
|
|
"""
|
|
# TODO: Cascade option for delete, maybe other options
|
|
# TODO: Require a where unless specifically disabled, for safety
|
|
|
|
def build(self):
|
|
base = sql.SQL("DELETE FROM {table}").format(
|
|
table=self.tableid,
|
|
)
|
|
sections = [
|
|
RawExpr(base),
|
|
self._where_section,
|
|
self._extra_section,
|
|
RawExpr(sql.SQL('RETURNING *'))
|
|
]
|
|
|
|
sections = (section for section in sections if section is not None)
|
|
return RawExpr.join(*sections)
|
|
|
|
|
|
class Update(LimitMixin, WhereMixin, ExtraMixin, FromMixin, TableQuery[QueryResult]):
|
|
__slots__ = (
|
|
'_set',
|
|
)
|
|
# TODO: Again, require a where unless specifically disabled
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._set: List[Expression] = []
|
|
|
|
def set(self, **column_values: Union[Any, Expression]):
|
|
exprs: List[Expression] = []
|
|
for name, value in column_values.items():
|
|
if isinstance(value, Expression):
|
|
value_tup = value.as_tuple()
|
|
else:
|
|
value_tup = (sql.Placeholder(), (value,))
|
|
|
|
exprs.append(
|
|
RawExpr.join_tuples(
|
|
(sql.Identifier(name), ()),
|
|
value_tup,
|
|
joiner=sql.SQL(' = ')
|
|
)
|
|
)
|
|
self._set.extend(exprs)
|
|
return self
|
|
|
|
def build(self):
|
|
if not self._set:
|
|
raise ValueError("No columns provided to update.")
|
|
set_expr, set_values = RawExpr.join(*self._set, joiner=sql.SQL(', ')).as_tuple()
|
|
|
|
base = sql.SQL("UPDATE {table} SET {set}").format(
|
|
table=self.tableid,
|
|
set=set_expr
|
|
)
|
|
sections = [
|
|
RawExpr(base, set_values),
|
|
self._from_section,
|
|
self._where_section,
|
|
self._extra_section,
|
|
self._limit_section,
|
|
RawExpr(sql.SQL('RETURNING *'))
|
|
]
|
|
|
|
sections = (section for section in sections if section is not None)
|
|
return RawExpr.join(*sections)
|
|
|
|
|
|
# async def upsert(cursor, table, constraint, **values):
|
|
# """
|
|
# Insert or on conflict update.
|
|
# """
|
|
# valuedict = values
|
|
# keys, values = zip(*values.items())
|
|
#
|
|
# key_str = _format_insertkeys(keys)
|
|
# value_str, values = _format_insertvalues(values)
|
|
# update_key_str, update_key_values = _format_updatestr(valuedict)
|
|
#
|
|
# if not isinstance(constraint, str):
|
|
# constraint = ", ".join(constraint)
|
|
#
|
|
# await cursor.execute(
|
|
# 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format(
|
|
# table, key_str, value_str, constraint, update_key_str
|
|
# ),
|
|
# tuple((*values, *update_key_values))
|
|
# )
|
|
# return await cursor.fetchone()
|
|
|
|
|
|
# def update_many(table, *values, set_keys=None, where_keys=None, cast_row=None, cursor=None):
|
|
# cursor = cursor or conn.cursor()
|
|
#
|
|
# # TODO: executemany or copy syntax now
|
|
# return execute_values(
|
|
# cursor,
|
|
# """
|
|
# UPDATE {table}
|
|
# SET {set_clause}
|
|
# FROM (VALUES {cast_row}%s)
|
|
# AS {temp_table}
|
|
# WHERE {where_clause}
|
|
# RETURNING *
|
|
# """.format(
|
|
# table=table,
|
|
# set_clause=', '.join("{0} = _t.{0}".format(key) for key in set_keys),
|
|
# cast_row=cast_row + ',' if cast_row else '',
|
|
# where_clause=' AND '.join("{1}.{0} = _t.{0}".format(key, table) for key in where_keys),
|
|
# temp_table="_t ({})".format(', '.join(set_keys + where_keys))
|
|
# ),
|
|
# values,
|
|
# fetch=True
|
|
# )
|