rewrite: Add on_conflict insert section.
This commit is contained in:
@@ -62,6 +62,13 @@ class Query(Generic[QueryResult]):
|
|||||||
self._adapter = callable
|
self._adapter = callable
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_no_adapter(self):
|
||||||
|
"""
|
||||||
|
Sets the adapater to the identity.
|
||||||
|
"""
|
||||||
|
self._adapter = self._no_adapter
|
||||||
|
return self
|
||||||
|
|
||||||
def one(self):
|
def one(self):
|
||||||
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
# TODO: Postcomposition with item functor, Query[List[QR1]] -> Query[QR1]
|
||||||
return self
|
return self
|
||||||
@@ -329,12 +336,13 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|||||||
Query type representing a table insert query.
|
Query type representing a table insert query.
|
||||||
"""
|
"""
|
||||||
# TODO: Support ON CONFLICT for upserts
|
# TODO: Support ON CONFLICT for upserts
|
||||||
__slots__ = ('_columns', '_values')
|
__slots__ = ('_columns', '_values', '_conflict')
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._columns: tuple[str, ...] = ()
|
self._columns: tuple[str, ...] = ()
|
||||||
self._values: tuple[tuple[Any, ...], ...] = ()
|
self._values: tuple[tuple[Any, ...], ...] = ()
|
||||||
|
self._conflict: Optional[Expression] = None
|
||||||
|
|
||||||
def insert(self, columns, *values):
|
def insert(self, columns, *values):
|
||||||
"""
|
"""
|
||||||
@@ -357,6 +365,26 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|||||||
self._values = values
|
self._values = values
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def on_conflict(self, ignore=False):
|
||||||
|
# TODO lots more we can do here
|
||||||
|
# Maybe return a Conflict object that can chain itself (not the query)
|
||||||
|
if ignore:
|
||||||
|
self._conflict = RawExpr(sql.SQL('DO NOTHING'))
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _conflict_section(self) -> Optional[Expression]:
|
||||||
|
if self._conflict is not None:
|
||||||
|
e, v = self._conflict.as_tuple()
|
||||||
|
expr = RawExpr(
|
||||||
|
sql.SQL("ON CONFLICT {}").format(
|
||||||
|
e
|
||||||
|
),
|
||||||
|
v
|
||||||
|
)
|
||||||
|
return expr
|
||||||
|
return None
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
columns = sql.SQL(',').join(map(sql.Identifier, self._columns))
|
||||||
single_value_str = sql.SQL('({})').format(
|
single_value_str = sql.SQL('({})').format(
|
||||||
@@ -374,6 +402,7 @@ class Insert(ExtraMixin, TableQuery[QueryResult]):
|
|||||||
|
|
||||||
sections = [
|
sections = [
|
||||||
RawExpr(base, tuple(chain(*self._values))),
|
RawExpr(base, tuple(chain(*self._values))),
|
||||||
|
self._conflict_section,
|
||||||
self._extra_section,
|
self._extra_section,
|
||||||
RawExpr(sql.SQL('RETURNING *'))
|
RawExpr(sql.SQL('RETURNING *'))
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user