rewrite: New Scheduled Session System.
This commit is contained in:
162
src/utils/data.py
Normal file
162
src/utils/data.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Some useful pre-built Conditions for data queries.
|
||||
"""
|
||||
from typing import Optional
|
||||
from itertools import chain
|
||||
|
||||
from psycopg import sql
|
||||
from data.conditions import Condition, Joiner
|
||||
from data.columns import ColumnExpr
|
||||
from data.base import Expression
|
||||
from constants import MAX_COINS
|
||||
|
||||
|
||||
def MULTIVALUE_IN(columns: tuple[str, ...], *data: tuple[...]) -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering by multiple column equalities.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(MULTIVALUE_IN(('guildid', 'userid'), (1, 2), (3, 4)))
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("Cannot create empty multivalue condition.")
|
||||
left = sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.Identifier(key)
|
||||
for key in columns
|
||||
)
|
||||
)
|
||||
right_item = sql.SQL('({})').format(
|
||||
sql.SQL(', ').join(
|
||||
sql.Placeholder()
|
||||
for _ in columns
|
||||
)
|
||||
)
|
||||
right = sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
right_item
|
||||
for _ in data
|
||||
)
|
||||
)
|
||||
return Condition(
|
||||
left,
|
||||
Joiner.IN,
|
||||
right,
|
||||
chain(*data)
|
||||
)
|
||||
|
||||
|
||||
def MEMBERS(*memberids: tuple[int, int], guild_column='guildid', user_column='userid') -> Condition:
|
||||
"""
|
||||
Condition constructor for filtering member tables by guild and user id simultaneously.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
Query.where(MEMBERS((1234,12), (5678,34)))
|
||||
"""
|
||||
if not memberids:
|
||||
raise ValueError("Cannot create a condition with no members")
|
||||
return Condition(
|
||||
sql.SQL("({guildid}, {userid})").format(
|
||||
guildid=sql.Identifier(guild_column),
|
||||
userid=sql.Identifier(user_column)
|
||||
),
|
||||
Joiner.IN,
|
||||
sql.SQL("({})").format(
|
||||
sql.SQL(', ').join(
|
||||
sql.SQL("({}, {})").format(
|
||||
sql.Placeholder(),
|
||||
sql.Placeholder()
|
||||
) for _ in memberids
|
||||
)
|
||||
),
|
||||
chain(*memberids)
|
||||
)
|
||||
|
||||
|
||||
def as_duration(expr: Expression) -> ColumnExpr:
|
||||
"""
|
||||
Convert an integer expression into a duration expression.
|
||||
"""
|
||||
expr_expr, expr_values = expr.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("({} * interval '1 second')").format(expr_expr),
|
||||
expr_values
|
||||
)
|
||||
|
||||
|
||||
class TemporaryTable(Expression):
|
||||
"""
|
||||
Create a temporary table expression to be used in From or With clauses.
|
||||
|
||||
Example
|
||||
-------
|
||||
```
|
||||
tmp_table = TemporaryTable('_col1', '_col2', name='data')
|
||||
tmp_table.values((1, 2), (3, 4))
|
||||
|
||||
real_table.update_where(col1=tmp_table['_col1']).set(col2=tmp_table['_col2']).from_(tmp_table)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *columns: str, name: str = '_t', types: Optional[tuple[str]] = None):
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.types = types
|
||||
if types and len(types) != len(columns):
|
||||
raise ValueError("Number of types does not much number of columns!")
|
||||
|
||||
self._table_columns = {
|
||||
col: ColumnExpr(sql.Identifier(name, col))
|
||||
for col in columns
|
||||
}
|
||||
|
||||
self.values = []
|
||||
|
||||
def __getitem__(self, key) -> sql.Identifier:
|
||||
return self._table_columns[key]
|
||||
|
||||
def as_tuple(self):
|
||||
"""
|
||||
(VALUES {})
|
||||
AS
|
||||
name (col1, col2)
|
||||
"""
|
||||
single_value = sql.SQL("({})").format(sql.SQL(", ").join(sql.Placeholder() for _ in self.columns))
|
||||
if self.types:
|
||||
first_value = sql.SQL("({})").format(
|
||||
sql.SQL(", ").join(
|
||||
sql.SQL("{}::{}").format(sql.Placeholder(), sql.SQL(cast))
|
||||
for cast in self.types
|
||||
)
|
||||
)
|
||||
else:
|
||||
first_value = single_value
|
||||
|
||||
value_placeholder = sql.SQL("(VALUES {})").format(
|
||||
sql.SQL(", ").join(
|
||||
(first_value, *(single_value for _ in self.values[1:]))
|
||||
)
|
||||
)
|
||||
expr = sql.SQL("{values} AS {name} ({columns})").format(
|
||||
values=value_placeholder,
|
||||
name=sql.Identifier(self.name),
|
||||
columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.columns)
|
||||
)
|
||||
values = chain(*self.values)
|
||||
return (expr, values)
|
||||
|
||||
def set_values(self, *data):
|
||||
self.values = data
|
||||
|
||||
|
||||
def SAFECOINS(expr: Expression) -> Expression:
|
||||
expr_expr, expr_values = expr.as_tuple()
|
||||
return ColumnExpr(
|
||||
sql.SQL("LEAST({}, {})").format(
|
||||
expr_expr,
|
||||
sql.Literal(MAX_COINS)
|
||||
),
|
||||
expr_values
|
||||
)
|
||||
@@ -85,6 +85,7 @@ class Bucket:
|
||||
# Wrapped in a lock so that waiters are correctly handled in wait-order
|
||||
# Otherwise multiple waiters will have the same delay,
|
||||
# and race for the wakeup after sleep.
|
||||
# Also avoids short-circuiting in the 0 delay case, which would not correctly handle wait-order
|
||||
async with self._wait_lock:
|
||||
# We do this in a loop in case asyncio.sleep throws us out early,
|
||||
# or a synchronous request overflows the bucket while we are waiting.
|
||||
|
||||
@@ -52,6 +52,10 @@ class ConfigUI(LeoUI):
|
||||
# Instances of the settings this UI is managing
|
||||
self.instances = ()
|
||||
|
||||
@property
|
||||
def page_instances(self):
|
||||
return self.instances
|
||||
|
||||
async def interaction_check(self, interaction: discord.Interaction):
|
||||
"""
|
||||
Default requirement for a Config UI is low management (i.e. manage_guild permissions).
|
||||
@@ -95,7 +99,7 @@ class ConfigUI(LeoUI):
|
||||
Errors should raise instances of `UserInputError`, and will be caught for retry.
|
||||
"""
|
||||
t = ctx_translator.get().t
|
||||
instances = self.instances
|
||||
instances = self.page_instances
|
||||
items = [setting.input_field for setting in instances]
|
||||
# Filter out settings which don't have input fields
|
||||
items = [item for item in items if item]
|
||||
@@ -174,7 +178,7 @@ class ConfigUI(LeoUI):
|
||||
"""
|
||||
await press.response.defer()
|
||||
|
||||
for instance in self.instances:
|
||||
for instance in self.page_instances:
|
||||
instance.data = None
|
||||
await instance.write()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user