From 01f967b0e605d3b3dcb926d8415a5b01dba72a16 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 30 Nov 2022 16:58:19 +0200 Subject: [PATCH] rewrite: Add `on_conflict` insert section. --- bot/data/queries.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/bot/data/queries.py b/bot/data/queries.py index 35db24be..aafd7c2e 100644 --- a/bot/data/queries.py +++ b/bot/data/queries.py @@ -62,6 +62,13 @@ class Query(Generic[QueryResult]): self._adapter = callable return self + def with_no_adapter(self): + """ + Sets the adapater to the identity. + """ + self._adapter = self._no_adapter + return self + def one(self): # TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1] return self @@ -329,12 +336,13 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): Query type representing a table insert query. """ # TODO: Support ON CONFLICT for upserts - __slots__ = ('_columns', '_values') + __slots__ = ('_columns', '_values', '_conflict') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._columns: tuple[str, ...] = () self._values: tuple[tuple[Any, ...], ...] = () + self._conflict: Optional[Expression] = None def insert(self, columns, *values): """ @@ -357,6 +365,26 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): self._values = values return self + def on_conflict(self, ignore=False): + # TODO lots more we can do here + # Maybe return a Conflict object that can chain itself (not the query) + if ignore: + self._conflict = RawExpr(sql.SQL('DO NOTHING')) + return self + + @property + def _conflict_section(self) -> Optional[Expression]: + if self._conflict is not None: + e, v = self._conflict.as_tuple() + expr = RawExpr( + sql.SQL("ON CONFLICT {}").format( + e + ), + v + ) + return expr + return None + def build(self): columns = sql.SQL(',').join(map(sql.Identifier, self._columns)) single_value_str = sql.SQL('({})').format( @@ -374,6 +402,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]): sections = [ RawExpr(base, tuple(chain(*self._values))), + self._conflict_section, self._extra_section, RawExpr(sql.SQL('RETURNING *')) ]