fix (data): Parallel connection pool.
This commit is contained in:
@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
|
|||||||
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
|
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_wrap(action='consumer', isolate=False)
|
@log_wrap(action='consumer')
|
||||||
async def consumer(self):
|
async def consumer(self):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
|
|||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@log_wrap(action='batch', isolate=False)
|
@log_wrap(action='batch')
|
||||||
async def process_batch(self):
|
async def process_batch(self):
|
||||||
logger.debug("Processing Batch")
|
logger.debug("Processing Batch")
|
||||||
# TODO: copy syntax might be more efficient here
|
# TODO: copy syntax might be more efficient here
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class AnalyticsServer:
|
|||||||
log_action_stack.set(['Analytics'])
|
log_action_stack.set(['Analytics'])
|
||||||
log_app.set(conf.analytics['appname'])
|
log_app.set(conf.analytics['appname'])
|
||||||
|
|
||||||
async with await self.db.connect():
|
async with self.db.open():
|
||||||
await self.talk.connect()
|
await self.talk.connect()
|
||||||
await self.attach_event_handlers()
|
await self.attach_event_handlers()
|
||||||
self._snap_task = asyncio.create_task(self.snapshot_loop())
|
self._snap_task = asyncio.create_task(self.snapshot_loop())
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ async def main():
|
|||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
intents.presences = False
|
intents.presences = False
|
||||||
|
|
||||||
async with await db.connect():
|
async with db.open():
|
||||||
version = await db.version()
|
version = await db.version()
|
||||||
if version.version != DATA_VERSION:
|
if version.version != DATA_VERSION:
|
||||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||||
|
|||||||
@@ -57,16 +57,14 @@ class CoreCog(LionCog):
|
|||||||
|
|
||||||
async def cog_load(self):
|
async def cog_load(self):
|
||||||
# Fetch (and possibly create) core data rows.
|
# Fetch (and possibly create) core data rows.
|
||||||
conn = await self.bot.db.get_connection()
|
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||||
async with conn.transaction():
|
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||||
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
self.shard_data = await self.data.Shard.fetch_or_create(
|
||||||
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
shardname,
|
||||||
self.shard_data = await self.data.Shard.fetch_or_create(
|
appname=appname,
|
||||||
shardname,
|
shard_id=self.bot.shard_id,
|
||||||
appname=appname,
|
shard_count=self.bot.shard_count
|
||||||
shard_id=self.bot.shard_id,
|
)
|
||||||
shard_count=self.bot.shard_count
|
|
||||||
)
|
|
||||||
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
|
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
|
||||||
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')
|
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from cachetools import TTLCache
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from meta import conf
|
from meta import conf
|
||||||
|
from meta.logger import log_wrap
|
||||||
from data import Table, Registry, Column, RowModel, RegisterEnum
|
from data import Table, Registry, Column, RowModel, RegisterEnum
|
||||||
from data.models import WeakCache
|
from data.models import WeakCache
|
||||||
from data.columns import Integer, String, Bool, Timestamp
|
from data.columns import Integer, String, Bool, Timestamp
|
||||||
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
|
|||||||
_timestamp = Timestamp()
|
_timestamp = Timestamp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action="Add Pending Coins")
|
||||||
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
|
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
|
||||||
"""
|
"""
|
||||||
Safely add pending coins to a list of members.
|
Safely add pending coins to a list of members.
|
||||||
@@ -316,39 +318,40 @@ class CoreData(Registry, name="core"):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
# TODO: Replace with copy syntax/query?
|
# TODO: Replace with copy syntax/query?
|
||||||
conn = await cls.table.connector.get_connection()
|
async with cls.table.connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain(*pending))
|
tuple(chain(*pending))
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return cls._make_rows(*rows)
|
return cls._make_rows(*rows)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='get_member_rank')
|
||||||
async def get_member_rank(cls, guildid, userid, untracked):
|
async def get_member_rank(cls, guildid, userid, untracked):
|
||||||
"""
|
"""
|
||||||
Get the time and coin ranking for the given member, ignoring the provided untracked members.
|
Get the time and coin ranking for the given member, ignoring the provided untracked members.
|
||||||
"""
|
"""
|
||||||
conn = await cls.table.connector.get_connection()
|
async with cls.table.connector.connection() as conn:
|
||||||
async with conn.cursor() as curs:
|
async with conn.cursor() as curs:
|
||||||
await curs.execute(
|
await curs.execute(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
time_rank, coin_rank
|
time_rank, coin_rank
|
||||||
FROM (
|
FROM (
|
||||||
SELECT
|
SELECT
|
||||||
userid,
|
userid,
|
||||||
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
|
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
|
||||||
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
|
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
|
||||||
FROM members_totals
|
FROM members_totals
|
||||||
WHERE
|
WHERE
|
||||||
guildid=%s AND userid NOT IN %s
|
guildid=%s AND userid NOT IN %s
|
||||||
) AS guild_ranks WHERE userid=%s
|
) AS guild_ranks WHERE userid=%s
|
||||||
""",
|
""",
|
||||||
(guildid, tuple(untracked), userid)
|
(guildid, tuple(untracked), userid)
|
||||||
)
|
)
|
||||||
return (await curs.fetchone()) or (None, None)
|
return (await curs.fetchone()) or (None, None)
|
||||||
|
|
||||||
class LionHook(RowModel):
|
class LionHook(RowModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# from enum import Enum
|
# from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from psycopg.types.enum import register_enum, EnumInfo
|
from psycopg.types.enum import register_enum, EnumInfo
|
||||||
|
from psycopg import AsyncConnection
|
||||||
from .registry import Attachable, Registry
|
from .registry import Attachable, Registry
|
||||||
|
|
||||||
|
|
||||||
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
|
|||||||
connector = registry._conn
|
connector = registry._conn
|
||||||
if connector is None:
|
if connector is None:
|
||||||
raise ValueError("Cannot initialise without connector!")
|
raise ValueError("Cannot initialise without connector!")
|
||||||
connection = await connector.get_connection()
|
connector.connect_hook(self.connection_hook)
|
||||||
if connection is None:
|
# await connector.refresh_pool()
|
||||||
raise ValueError("Cannot Init without connection.")
|
# The below may be somewhat dangerous
|
||||||
info = await EnumInfo.fetch(connection, self.name)
|
# But adaption should never write to the database
|
||||||
|
await connector.map_over_pool(self.connection_hook)
|
||||||
|
# if conn := connector.conn:
|
||||||
|
# # Ensure the adaption is run in the current context as well
|
||||||
|
# await self.connection_hook(conn)
|
||||||
|
|
||||||
|
async def connection_hook(self, conn: AsyncConnection):
|
||||||
|
info = await EnumInfo.fetch(conn, self.name)
|
||||||
if info is None:
|
if info is None:
|
||||||
raise ValueError(f"Enum {self.name} not found in database.")
|
raise ValueError(f"Enum {self.name} not found in database.")
|
||||||
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
|
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
import psycopg as psq
|
import psycopg as psq
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
from psycopg.pq import TransactionStatus
|
from psycopg.pq import TransactionStatus
|
||||||
|
|
||||||
from .cursor import AsyncLoggingCursor
|
from .cursor import AsyncLoggingCursor
|
||||||
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
row_factory = psq.rows.dict_row
|
row_factory = psq.rows.dict_row
|
||||||
|
|
||||||
|
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||||
|
|
||||||
|
|
||||||
class Connector:
|
class Connector:
|
||||||
cursor_factory = AsyncLoggingCursor
|
cursor_factory = AsyncLoggingCursor
|
||||||
|
|
||||||
def __init__(self, conn_args):
|
def __init__(self, conn_args):
|
||||||
self._conn_args = conn_args
|
self._conn_args = conn_args
|
||||||
self.conn: psq.AsyncConnection = None
|
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||||
|
|
||||||
|
self.pool = self.make_pool()
|
||||||
|
|
||||||
self.conn_hooks = []
|
self.conn_hooks = []
|
||||||
|
|
||||||
async def get_connection(self) -> psq.AsyncConnection:
|
@property
|
||||||
|
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||||
"""
|
"""
|
||||||
Get the current active connection.
|
Convenience property for the current context connection.
|
||||||
This should never be cached outside of a transaction.
|
|
||||||
"""
|
"""
|
||||||
# TODO: Reconnection logic?
|
return ctx_connection.get()
|
||||||
if not self.conn:
|
|
||||||
raise ValueError("Attempting to get connection before initialisation!")
|
|
||||||
if self.conn.info.transaction_status is TransactionStatus.INERROR:
|
|
||||||
await self.connect()
|
|
||||||
logger.error(
|
|
||||||
"Database connection transaction failed!! This should not happen. Reconnecting."
|
|
||||||
)
|
|
||||||
return self.conn
|
|
||||||
|
|
||||||
async def connect(self) -> psq.AsyncConnection:
|
@conn.setter
|
||||||
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
def conn(self, conn: psq.AsyncConnection):
|
||||||
self.conn = await psq.AsyncConnection.connect(
|
"""
|
||||||
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
Set the contextual connection in the current context.
|
||||||
|
Always do this in an isolated context!
|
||||||
|
"""
|
||||||
|
ctx_connection.set(conn)
|
||||||
|
|
||||||
|
def make_pool(self) -> AsyncConnectionPool:
|
||||||
|
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||||
|
return AsyncConnectionPool(
|
||||||
|
self._conn_args,
|
||||||
|
open=False,
|
||||||
|
min_size=4,
|
||||||
|
max_size=8,
|
||||||
|
configure=self._setup_connection,
|
||||||
|
kwargs=self._conn_kwargs
|
||||||
)
|
)
|
||||||
for hook in self.conn_hooks:
|
|
||||||
await hook(self.conn)
|
|
||||||
return self.conn
|
|
||||||
|
|
||||||
async def reconnect(self) -> psq.AsyncConnection:
|
async def refresh_pool(self):
|
||||||
return await self.connect()
|
"""
|
||||||
|
Refresh the pool.
|
||||||
|
|
||||||
|
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||||
|
Better ways should be sought (a way to
|
||||||
|
"""
|
||||||
|
logger.info("Pool refresh requested, closing and reopening.")
|
||||||
|
old_pool = self.pool
|
||||||
|
self.pool = self.make_pool()
|
||||||
|
await self.pool.open()
|
||||||
|
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||||
|
await old_pool.close()
|
||||||
|
logger.info("Pool refresh complete.")
|
||||||
|
|
||||||
|
async def map_over_pool(self, callable):
|
||||||
|
"""
|
||||||
|
Dangerous method to call a method on each connection in the pool.
|
||||||
|
|
||||||
|
Utilises private methods of the AsyncConnectionPool.
|
||||||
|
"""
|
||||||
|
async with self.pool._lock:
|
||||||
|
conns = list(self.pool._pool)
|
||||||
|
while conns:
|
||||||
|
conn = conns.pop()
|
||||||
|
try:
|
||||||
|
await callable(conn)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def open(self):
|
||||||
|
try:
|
||||||
|
logger.info("Opening database pool.")
|
||||||
|
await self.pool.open()
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# May be a different pool!
|
||||||
|
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||||
|
await self.pool.close()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connection(self) -> psq.AsyncConnection:
|
||||||
|
"""
|
||||||
|
Asynchronous context manager to get and manage a connection.
|
||||||
|
|
||||||
|
If the context connection is set, uses this and does not manage the lifetime.
|
||||||
|
Otherwise, requests a new connection from the pool and returns it when done.
|
||||||
|
"""
|
||||||
|
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||||
|
if (conn := self.conn):
|
||||||
|
yield conn
|
||||||
|
else:
|
||||||
|
async with self.pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||||
|
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||||
|
for hook in self.conn_hooks:
|
||||||
|
try:
|
||||||
|
await hook(conn)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Exception encountered setting up new connection")
|
||||||
|
return conn
|
||||||
|
|
||||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -35,12 +35,13 @@ class Database(Connector):
|
|||||||
"""
|
"""
|
||||||
Return the current schema version as a Version namedtuple.
|
Return the current schema version as a Version namedtuple.
|
||||||
"""
|
"""
|
||||||
async with self.conn.cursor() as cursor:
|
async with self.connection() as conn:
|
||||||
# Get last entry in version table, compare against desired version
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
# Get last entry in version table, compare against desired version
|
||||||
row = await cursor.fetchone()
|
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||||
if row:
|
row = await cursor.fetchone()
|
||||||
return Version(row['version'], row['time'], row['author'])
|
if row:
|
||||||
else:
|
return Version(row['version'], row['time'], row['author'])
|
||||||
# No versions in the database
|
else:
|
||||||
return Version(-1, None, None)
|
# No versions in the database
|
||||||
|
return Version(-1, None, None)
|
||||||
|
|||||||
@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
|
|||||||
if self.connector is None:
|
if self.connector is None:
|
||||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||||
else:
|
else:
|
||||||
conn = await self.connector.get_connection()
|
async with self.connector.connection() as conn:
|
||||||
|
async with conn.cursor() as cursor:
|
||||||
|
data = await self._execute(cursor)
|
||||||
else:
|
else:
|
||||||
conn = self.conn
|
async with self.conn.cursor() as cursor:
|
||||||
|
data = await self._execute(cursor)
|
||||||
async with conn.cursor() as cursor:
|
|
||||||
data = await self._execute(cursor)
|
|
||||||
else:
|
else:
|
||||||
data = await self._execute(cursor)
|
data = await self._execute(cursor)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands as cmds
|
from discord.ext import commands as cmds
|
||||||
@@ -182,30 +183,34 @@ class Economy(LionCog):
|
|||||||
# We may need to do a mass row create operation.
|
# We may need to do a mass row create operation.
|
||||||
targetids = set(target.id for target in targets)
|
targetids = set(target.id for target in targets)
|
||||||
if len(targets) > 1:
|
if len(targets) > 1:
|
||||||
conn = await ctx.bot.db.get_connection()
|
async def wrapper():
|
||||||
async with conn.transaction():
|
async with self.bot.db.connection() as conn:
|
||||||
# First fetch the members which currently exist
|
self.bot.db.conn = conn
|
||||||
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
|
async with conn.transaction():
|
||||||
query.select('userid').with_no_adapter()
|
# First fetch the members which currently exist
|
||||||
if 2 * len(targets) < len(ctx.guild.members):
|
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
|
||||||
# More efficient to fetch the targets explicitly
|
query.select('userid').with_no_adapter()
|
||||||
query.where(userid=list(targetids))
|
if 2 * len(targets) < len(ctx.guild.members):
|
||||||
existent_rows = await query
|
# More efficient to fetch the targets explicitly
|
||||||
existentids = set(r['userid'] for r in existent_rows)
|
query.where(userid=list(targetids))
|
||||||
|
existent_rows = await query
|
||||||
|
existentids = set(r['userid'] for r in existent_rows)
|
||||||
|
|
||||||
# Then check if any new userids need adding, and if so create them
|
# Then check if any new userids need adding, and if so create them
|
||||||
new_ids = targetids.difference(existentids)
|
new_ids = targetids.difference(existentids)
|
||||||
if new_ids:
|
if new_ids:
|
||||||
# We use ON CONFLICT IGNORE here in case the users already exist.
|
# We use ON CONFLICT IGNORE here in case the users already exist.
|
||||||
await self.bot.core.data.User.table.insert_many(
|
await self.bot.core.data.User.table.insert_many(
|
||||||
('userid',),
|
('userid',),
|
||||||
*((id,) for id in new_ids)
|
*((id,) for id in new_ids)
|
||||||
).on_conflict(ignore=True)
|
).on_conflict(ignore=True)
|
||||||
# TODO: Replace 0 here with the starting_coin value
|
# TODO: Replace 0 here with the starting_coin value
|
||||||
await self.bot.core.data.Member.table.insert_many(
|
await self.bot.core.data.Member.table.insert_many(
|
||||||
('guildid', 'userid', 'coins'),
|
('guildid', 'userid', 'coins'),
|
||||||
*((ctx.guild.id, id, 0) for id in new_ids)
|
*((ctx.guild.id, id, 0) for id in new_ids)
|
||||||
).on_conflict(ignore=True)
|
).on_conflict(ignore=True)
|
||||||
|
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
|
||||||
|
await task
|
||||||
else:
|
else:
|
||||||
# With only one target, we can take a simpler path, and make better use of local caches.
|
# With only one target, we can take a simpler path, and make better use of local caches.
|
||||||
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
|
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
|
||||||
@@ -703,31 +708,34 @@ class Economy(LionCog):
|
|||||||
# Alternative flow could be waiting until the target user presses accept
|
# Alternative flow could be waiting until the target user presses accept
|
||||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
async def wrapped():
|
||||||
async with conn.transaction():
|
async with self.bot.db.connection() as conn:
|
||||||
# We do this in a transaction so that if something goes wrong,
|
self.bot.db.conn = conn
|
||||||
# the coins deduction is rolled back atomicly
|
async with conn.transaction():
|
||||||
balance = ctx.alion.data.coins
|
# We do this in a transaction so that if something goes wrong,
|
||||||
if amount > balance:
|
# the coins deduction is rolled back atomicly
|
||||||
await ctx.interaction.edit_original_response(
|
balance = ctx.alion.data.coins
|
||||||
embed=error_embed(
|
if amount > balance:
|
||||||
t(_p(
|
await ctx.interaction.edit_original_response(
|
||||||
'cmd:send|error:insufficient',
|
embed=error_embed(
|
||||||
"You do not have enough lioncoins to do this!\n"
|
t(_p(
|
||||||
"`Current Balance:` {coin_emoji}{balance}"
|
'cmd:send|error:insufficient',
|
||||||
)).format(
|
"You do not have enough lioncoins to do this!\n"
|
||||||
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
"`Current Balance:` {coin_emoji}{balance}"
|
||||||
balance=balance
|
)).format(
|
||||||
|
coin_emoji=self.bot.config.emojis.getemoji('coin'),
|
||||||
|
balance=balance
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
),
|
return
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Transfer the coins
|
# Transfer the coins
|
||||||
await ctx.alion.data.update(coins=(Member.coins - amount))
|
await ctx.alion.data.update(coins=(Member.coins - amount))
|
||||||
await target_lion.data.update(coins=(Member.coins + amount))
|
await target_lion.data.update(coins=(Member.coins + amount))
|
||||||
|
|
||||||
# TODO: Audit trail
|
# TODO: Audit trail
|
||||||
|
await asyncio.create_task(wrapped(), name="wrapped-send")
|
||||||
|
|
||||||
# Message target
|
# Message target
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
from meta.logger import log_wrap
|
||||||
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
|
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
|
||||||
from data.columns import Integer, Bool, Column, Timestamp
|
from data.columns import Integer, Bool, Column, Timestamp
|
||||||
from core.data import CoreData
|
from core.data import CoreData
|
||||||
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
|
|||||||
created_at = Timestamp()
|
created_at = Timestamp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='execute_transaction')
|
||||||
async def execute_transaction(
|
async def execute_transaction(
|
||||||
cls,
|
cls,
|
||||||
transaction_type: TransactionType,
|
transaction_type: TransactionType,
|
||||||
@@ -108,25 +110,27 @@ class EconomyData(Registry, name='economy'):
|
|||||||
from_account: int, to_account: int, amount: int, bonus: int = 0,
|
from_account: int, to_account: int, amount: int, bonus: int = 0,
|
||||||
refunds: int = None
|
refunds: int = None
|
||||||
):
|
):
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
transaction = await cls.create(
|
async with conn.transaction():
|
||||||
transactiontype=transaction_type,
|
transaction = await cls.create(
|
||||||
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
|
transactiontype=transaction_type,
|
||||||
from_account=from_account, to_account=to_account,
|
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
|
||||||
refunds=refunds
|
from_account=from_account, to_account=to_account,
|
||||||
)
|
refunds=refunds
|
||||||
if from_account is not None:
|
)
|
||||||
await CoreData.Member.table.update_where(
|
if from_account is not None:
|
||||||
guildid=guildid, userid=from_account
|
await CoreData.Member.table.update_where(
|
||||||
).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus)))
|
guildid=guildid, userid=from_account
|
||||||
if to_account is not None:
|
).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus)))
|
||||||
await CoreData.Member.table.update_where(
|
if to_account is not None:
|
||||||
guildid=guildid, userid=to_account
|
await CoreData.Member.table.update_where(
|
||||||
).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus)))
|
guildid=guildid, userid=to_account
|
||||||
return transaction
|
).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus)))
|
||||||
|
return transaction
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='execute_transactions')
|
||||||
async def execute_transactions(cls, *transactions):
|
async def execute_transactions(cls, *transactions):
|
||||||
"""
|
"""
|
||||||
Execute multiple transactions in one data transaction.
|
Execute multiple transactions in one data transaction.
|
||||||
@@ -142,65 +146,68 @@ class EconomyData(Registry, name='economy'):
|
|||||||
if not transactions:
|
if not transactions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
# Create the transactions
|
async with conn.transaction():
|
||||||
rows = await cls.table.insert_many(
|
# Create the transactions
|
||||||
(
|
rows = await cls.table.insert_many(
|
||||||
'transactiontype',
|
(
|
||||||
'guildid', 'actorid',
|
'transactiontype',
|
||||||
'from_account', 'to_account',
|
'guildid', 'actorid',
|
||||||
'amount', 'bonus',
|
'from_account', 'to_account',
|
||||||
'refunds'
|
'amount', 'bonus',
|
||||||
),
|
'refunds'
|
||||||
*transactions
|
),
|
||||||
).with_adapter(cls._make_rows)
|
*transactions
|
||||||
|
).with_adapter(cls._make_rows)
|
||||||
|
|
||||||
# Update the members
|
# Update the members
|
||||||
transtable = TemporaryTable(
|
transtable = TemporaryTable(
|
||||||
'_guildid', '_userid', '_amount',
|
'_guildid', '_userid', '_amount',
|
||||||
types=('BIGINT', 'BIGINT', 'INTEGER')
|
types=('BIGINT', 'BIGINT', 'INTEGER')
|
||||||
)
|
)
|
||||||
values = transtable.values
|
values = transtable.values
|
||||||
for transaction in transactions:
|
for transaction in transactions:
|
||||||
_, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction
|
_, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction
|
||||||
coins = amount + bonus
|
coins = amount + bonus
|
||||||
if coins:
|
if coins:
|
||||||
if from_acc:
|
if from_acc:
|
||||||
values.append((guildid, from_acc, -1 * coins))
|
values.append((guildid, from_acc, -1 * coins))
|
||||||
if to_acc:
|
if to_acc:
|
||||||
values.append((guildid, to_acc, coins))
|
values.append((guildid, to_acc, coins))
|
||||||
if values:
|
if values:
|
||||||
Member = CoreData.Member
|
Member = CoreData.Member
|
||||||
await Member.table.update_where(
|
await Member.table.update_where(
|
||||||
guildid=transtable['_guildid'], userid=transtable['_userid']
|
guildid=transtable['_guildid'], userid=transtable['_userid']
|
||||||
).set(
|
).set(
|
||||||
coins=SAFECOINS(Member.coins + transtable['_amount'])
|
coins=SAFECOINS(Member.coins + transtable['_amount'])
|
||||||
).from_expr(transtable)
|
).from_expr(transtable)
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='refund_transactions')
|
||||||
async def refund_transactions(cls, *transactionids, actorid=0):
|
async def refund_transactions(cls, *transactionids, actorid=0):
|
||||||
if not transactionids:
|
if not transactionids:
|
||||||
return []
|
return []
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
# First fetch the transaction rows to refund
|
async with conn.transaction():
|
||||||
data = await cls.table.select_where(transactionid=transactionids)
|
# First fetch the transaction rows to refund
|
||||||
if data:
|
data = await cls.table.select_where(transactionid=transactionids)
|
||||||
# Build the transaction refund data
|
if data:
|
||||||
records = [
|
# Build the transaction refund data
|
||||||
(
|
records = [
|
||||||
TransactionType.REFUND,
|
(
|
||||||
tr['guildid'], actorid,
|
TransactionType.REFUND,
|
||||||
tr['to_account'], tr['from_account'],
|
tr['guildid'], actorid,
|
||||||
tr['amount'] + tr['bonus'], 0,
|
tr['to_account'], tr['from_account'],
|
||||||
tr['transactionid']
|
tr['amount'] + tr['bonus'], 0,
|
||||||
)
|
tr['transactionid']
|
||||||
for tr in data
|
)
|
||||||
]
|
for tr in data
|
||||||
# Execute refund transactions
|
]
|
||||||
return await cls.execute_transactions(*records)
|
# Execute refund transactions
|
||||||
|
return await cls.execute_transactions(*records)
|
||||||
|
|
||||||
class ShopTransaction(RowModel):
|
class ShopTransaction(RowModel):
|
||||||
"""
|
"""
|
||||||
@@ -217,19 +224,21 @@ class EconomyData(Registry, name='economy'):
|
|||||||
itemid = Integer()
|
itemid = Integer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='purchase_transaction')
|
||||||
async def purchase_transaction(
|
async def purchase_transaction(
|
||||||
cls,
|
cls,
|
||||||
guildid: int, actorid: int,
|
guildid: int, actorid: int,
|
||||||
userid: int, itemid: int, amount: int
|
userid: int, itemid: int, amount: int
|
||||||
):
|
):
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
row = await EconomyData.Transaction.execute_transaction(
|
async with conn.transaction():
|
||||||
TransactionType.SHOP_PURCHASE,
|
row = await EconomyData.Transaction.execute_transaction(
|
||||||
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
|
TransactionType.SHOP_PURCHASE,
|
||||||
amount=amount
|
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
|
||||||
)
|
amount=amount
|
||||||
return await cls.create(transactionid=row.transactionid, itemid=itemid)
|
)
|
||||||
|
return await cls.create(transactionid=row.transactionid, itemid=itemid)
|
||||||
|
|
||||||
class TaskTransaction(RowModel):
|
class TaskTransaction(RowModel):
|
||||||
"""
|
"""
|
||||||
@@ -263,19 +272,21 @@ class EconomyData(Registry, name='economy'):
|
|||||||
return result[0]['recent'] or 0
|
return result[0]['recent'] or 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='reward_completed_tasks')
|
||||||
async def reward_completed(cls, userid, guildid, count, amount):
|
async def reward_completed(cls, userid, guildid, count, amount):
|
||||||
"""
|
"""
|
||||||
Reward the specified member `amount` coins for completing `count` tasks.
|
Reward the specified member `amount` coins for completing `count` tasks.
|
||||||
"""
|
"""
|
||||||
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
|
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
row = await EconomyData.Transaction.execute_transaction(
|
async with conn.transaction():
|
||||||
TransactionType.TASKS,
|
row = await EconomyData.Transaction.execute_transaction(
|
||||||
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
|
TransactionType.TASKS,
|
||||||
amount=amount
|
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
|
||||||
)
|
amount=amount
|
||||||
return await cls.create(transactionid=row.transactionid, count=count)
|
)
|
||||||
|
return await cls.create(transactionid=row.transactionid, count=count)
|
||||||
|
|
||||||
class SessionTransaction(RowModel):
|
class SessionTransaction(RowModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -193,18 +193,19 @@ class MemberAdminCog(LionCog):
|
|||||||
await lion.data.update(last_left=utc_now())
|
await lion.data.update(last_left=utc_now())
|
||||||
|
|
||||||
# Save member roles
|
# Save member roles
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
await self.data.past_roles.delete_where(
|
async with conn.transaction():
|
||||||
guildid=member.guild.id,
|
await self.data.past_roles.delete_where(
|
||||||
userid=member.id
|
guildid=member.guild.id,
|
||||||
)
|
userid=member.id
|
||||||
# Insert current member roles
|
|
||||||
if member.roles:
|
|
||||||
await self.data.past_roles.insert_many(
|
|
||||||
('guildid', 'userid', 'roleid'),
|
|
||||||
*((member.guild.id, member.id, role.id) for role in member.roles)
|
|
||||||
)
|
)
|
||||||
|
# Insert current member roles
|
||||||
|
if member.roles:
|
||||||
|
await self.data.past_roles.insert_many(
|
||||||
|
('guildid', 'userid', 'roleid'),
|
||||||
|
*((member.guild.id, member.id, role.id) for role in member.roles)
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Stored persisting roles for member <uid:{member.id}> in <gid:{member.guild.id}>."
|
f"Stored persisting roles for member <uid:{member.id}> in <gid:{member.guild.id}>."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
|
|||||||
update_args[instance._column] = instance.data
|
update_args[instance._column] = instance.data
|
||||||
ack_lines.append(instance.update_message)
|
ack_lines.append(instance.update_message)
|
||||||
|
|
||||||
|
await ctx.lguild.data.update(**update_args)
|
||||||
|
|
||||||
# Do the ack
|
# Do the ack
|
||||||
tick = self.bot.config.emojis.tick
|
tick = self.bot.config.emojis.tick
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
|||||||
@@ -483,65 +483,63 @@ class RoleMenu:
|
|||||||
)).format(role=role.name)
|
)).format(role=role.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
# Remove the role
|
||||||
async with conn.transaction():
|
try:
|
||||||
# Remove the role
|
await member.remove_roles(role)
|
||||||
try:
|
except discord.Forbidden:
|
||||||
await member.remove_roles(role)
|
raise UserInputError(
|
||||||
except discord.Forbidden:
|
t(_p(
|
||||||
raise UserInputError(
|
'rolemenu|deselect|error:perms',
|
||||||
t(_p(
|
"I don't have enough permissions to remove this role from you!"
|
||||||
'rolemenu|deselect|error:perms',
|
|
||||||
"I don't have enough permissions to remove this role from you!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
except discord.HTTPException:
|
|
||||||
raise UserInputError(
|
|
||||||
t(_p(
|
|
||||||
'rolemenu|deselect|error:discord',
|
|
||||||
"An unknown error occurred removing your role! Please try again later."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update history
|
|
||||||
now = utc_now()
|
|
||||||
history = await self.cog.data.RoleMenuHistory.table.update_where(
|
|
||||||
menuid=self.data.menuid,
|
|
||||||
roleid=role.id,
|
|
||||||
userid=member.id,
|
|
||||||
removed_at=None,
|
|
||||||
).set(removed_at=now)
|
|
||||||
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
|
|
||||||
|
|
||||||
# Refund if required
|
|
||||||
transactionids = [row['transactionid'] for row in history]
|
|
||||||
if self.config.refunds.value and any(transactionids):
|
|
||||||
transactionids = [tid for tid in transactionids if tid]
|
|
||||||
economy: Economy = self.bot.get_cog('Economy')
|
|
||||||
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
|
|
||||||
total_refund = sum(row.amount + row.bonus for row in refunded)
|
|
||||||
else:
|
|
||||||
total_refund = 0
|
|
||||||
|
|
||||||
# Ack the removal
|
|
||||||
embed = discord.Embed(
|
|
||||||
colour=discord.Colour.brand_green(),
|
|
||||||
title=t(_p(
|
|
||||||
'rolemenu|deslect|success|title',
|
|
||||||
"Role removed"
|
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
if total_refund:
|
except discord.HTTPException:
|
||||||
embed.description = t(_p(
|
raise UserInputError(
|
||||||
'rolemenu|deselect|success:refund|desc',
|
t(_p(
|
||||||
"You have removed **{role}**, and been refunded {coin} **{amount}**."
|
'rolemenu|deselect|error:discord',
|
||||||
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund)
|
"An unknown error occurred removing your role! Please try again later."
|
||||||
else:
|
))
|
||||||
embed.description = t(_p(
|
)
|
||||||
'rolemenu|deselect|success:norefund|desc',
|
|
||||||
"You have unequipped **{role}**."
|
# Update history
|
||||||
)).format(role=role.name)
|
now = utc_now()
|
||||||
return embed
|
history = await self.cog.data.RoleMenuHistory.table.update_where(
|
||||||
|
menuid=self.data.menuid,
|
||||||
|
roleid=role.id,
|
||||||
|
userid=member.id,
|
||||||
|
removed_at=None,
|
||||||
|
).set(removed_at=now)
|
||||||
|
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
|
||||||
|
|
||||||
|
# Refund if required
|
||||||
|
transactionids = [row['transactionid'] for row in history]
|
||||||
|
if self.config.refunds.value and any(transactionids):
|
||||||
|
transactionids = [tid for tid in transactionids if tid]
|
||||||
|
economy: Economy = self.bot.get_cog('Economy')
|
||||||
|
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
|
||||||
|
total_refund = sum(row.amount + row.bonus for row in refunded)
|
||||||
|
else:
|
||||||
|
total_refund = 0
|
||||||
|
|
||||||
|
# Ack the removal
|
||||||
|
embed = discord.Embed(
|
||||||
|
colour=discord.Colour.brand_green(),
|
||||||
|
title=t(_p(
|
||||||
|
'rolemenu|deslect|success|title',
|
||||||
|
"Role removed"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
if total_refund:
|
||||||
|
embed.description = t(_p(
|
||||||
|
'rolemenu|deselect|success:refund|desc',
|
||||||
|
"You have removed **{role}**, and been refunded {coin} **{amount}**."
|
||||||
|
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund)
|
||||||
|
else:
|
||||||
|
embed.description = t(_p(
|
||||||
|
'rolemenu|deselect|success:norefund|desc',
|
||||||
|
"You have unequipped **{role}**."
|
||||||
|
)).format(role=role.name)
|
||||||
|
return embed
|
||||||
else:
|
else:
|
||||||
# Member does not have the role, selection case.
|
# Member does not have the role, selection case.
|
||||||
required = self.config.required_role.data
|
required = self.config.required_role.data
|
||||||
@@ -591,57 +589,55 @@ class RoleMenu:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
try:
|
||||||
async with conn.transaction():
|
await member.add_roles(role)
|
||||||
try:
|
except discord.Forbidden:
|
||||||
await member.add_roles(role)
|
raise UserInputError(
|
||||||
except discord.Forbidden:
|
t(_p(
|
||||||
raise UserInputError(
|
'rolemenu|select|error:perms',
|
||||||
t(_p(
|
"I don't have enough permissions to give you this role!"
|
||||||
'rolemenu|select|error:perms',
|
))
|
||||||
"I don't have enough permissions to give you this role!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
except discord.HTTPException:
|
|
||||||
raise UserInputError(
|
|
||||||
t(_p(
|
|
||||||
'rolemenu|select|error:discord',
|
|
||||||
"An unknown error occurred while assigning your role! "
|
|
||||||
"Please try again later."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
now = utc_now()
|
|
||||||
|
|
||||||
# Create transaction if applicable
|
|
||||||
if price:
|
|
||||||
economy: Economy = self.bot.get_cog('Economy')
|
|
||||||
tx = await economy.data.Transaction.execute_transaction(
|
|
||||||
transaction_type=TransactionType.OTHER,
|
|
||||||
guildid=guild.id, actorid=member.id,
|
|
||||||
from_account=member.id, to_account=None,
|
|
||||||
amount=price
|
|
||||||
)
|
|
||||||
tid = tx.transactionid
|
|
||||||
else:
|
|
||||||
tid = None
|
|
||||||
|
|
||||||
# Calculate expiry
|
|
||||||
duration = mrole.config.duration.value
|
|
||||||
if duration is not None:
|
|
||||||
expiry = now + dt.timedelta(seconds=duration)
|
|
||||||
else:
|
|
||||||
expiry = None
|
|
||||||
|
|
||||||
# Add to equip history
|
|
||||||
equip = await self.cog.data.RoleMenuHistory.create(
|
|
||||||
menuid=self.data.menuid, roleid=role.id,
|
|
||||||
userid=member.id,
|
|
||||||
obtained_at=now,
|
|
||||||
transactionid=tid,
|
|
||||||
expires_at=expiry
|
|
||||||
)
|
)
|
||||||
await self.cog.schedule_expiring(equip)
|
except discord.HTTPException:
|
||||||
|
raise UserInputError(
|
||||||
|
t(_p(
|
||||||
|
'rolemenu|select|error:discord',
|
||||||
|
"An unknown error occurred while assigning your role! "
|
||||||
|
"Please try again later."
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
now = utc_now()
|
||||||
|
|
||||||
|
# Create transaction if applicable
|
||||||
|
if price:
|
||||||
|
economy: Economy = self.bot.get_cog('Economy')
|
||||||
|
tx = await economy.data.Transaction.execute_transaction(
|
||||||
|
transaction_type=TransactionType.OTHER,
|
||||||
|
guildid=guild.id, actorid=member.id,
|
||||||
|
from_account=member.id, to_account=None,
|
||||||
|
amount=price
|
||||||
|
)
|
||||||
|
tid = tx.transactionid
|
||||||
|
else:
|
||||||
|
tid = None
|
||||||
|
|
||||||
|
# Calculate expiry
|
||||||
|
duration = mrole.config.duration.value
|
||||||
|
if duration is not None:
|
||||||
|
expiry = now + dt.timedelta(seconds=duration)
|
||||||
|
else:
|
||||||
|
expiry = None
|
||||||
|
|
||||||
|
# Add to equip history
|
||||||
|
equip = await self.cog.data.RoleMenuHistory.create(
|
||||||
|
menuid=self.data.menuid, roleid=role.id,
|
||||||
|
userid=member.id,
|
||||||
|
obtained_at=now,
|
||||||
|
transactionid=tid,
|
||||||
|
expires_at=expiry
|
||||||
|
)
|
||||||
|
await self.cog.schedule_expiring(equip)
|
||||||
|
|
||||||
# Ack the selection
|
# Ack the selection
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
|||||||
@@ -259,17 +259,6 @@ class RoomCog(LionCog):
|
|||||||
lguild,
|
lguild,
|
||||||
[member.id for member in members]
|
[member.id for member in members]
|
||||||
)
|
)
|
||||||
self._start(room)
|
|
||||||
|
|
||||||
# Send tips message
|
|
||||||
# TODO: Actual tips.
|
|
||||||
await channel.send(
|
|
||||||
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send config UI
|
|
||||||
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
|
|
||||||
await ui.send(channel)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
await channel.delete(reason="Failed to created private room")
|
await channel.delete(reason="Failed to created private room")
|
||||||
@@ -454,71 +443,9 @@ class RoomCog(LionCog):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Positive response. Start a transaction.
|
# Positive response. Start a transaction.
|
||||||
conn = await self.bot.db.get_connection()
|
room = await self._do_create_room(ctx, required, days, rent, name, provided)
|
||||||
async with conn.transaction():
|
|
||||||
# Check member balance is sufficient
|
|
||||||
await ctx.alion.data.refresh()
|
|
||||||
member_balance = ctx.alion.data.coins
|
|
||||||
if member_balance < required:
|
|
||||||
await ctx.reply(
|
|
||||||
embed=error_embed(
|
|
||||||
t(_np(
|
|
||||||
'cmd:room_rent|error:insufficient_funds',
|
|
||||||
"Renting a private room for `one` day costs {coin}**{required}**, "
|
|
||||||
"but you only have {coin}**{balance}**!",
|
|
||||||
"Renting a private room for `{days}` days costs {coin}**{required}**, "
|
|
||||||
"but you only have {coin}**{balance}**!",
|
|
||||||
days
|
|
||||||
)).format(
|
|
||||||
coin=self.bot.config.emojis.coin,
|
|
||||||
balance=member_balance,
|
|
||||||
required=required,
|
|
||||||
days=days
|
|
||||||
),
|
|
||||||
ephemeral=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deduct balance
|
|
||||||
# TODO: Economy transaction instead of manual deduction
|
|
||||||
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
|
|
||||||
|
|
||||||
# Create room with given starting balance and other parameters
|
|
||||||
try:
|
|
||||||
room = await self.create_private_room(
|
|
||||||
ctx.guild,
|
|
||||||
ctx.author,
|
|
||||||
required - rent,
|
|
||||||
name or ctx.author.display_name,
|
|
||||||
members=provided
|
|
||||||
)
|
|
||||||
except discord.Forbidden:
|
|
||||||
await ctx.reply(
|
|
||||||
embed=error_embed(
|
|
||||||
t(_p(
|
|
||||||
'cmd:room_rent|error:my_permissions',
|
|
||||||
"Could not create your private room! You were not charged.\n"
|
|
||||||
"I have insufficient permissions to create a private room channel."
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
|
|
||||||
return
|
|
||||||
except discord.HTTPException as e:
|
|
||||||
await ctx.reply(
|
|
||||||
embed=error_embed(
|
|
||||||
t(_p(
|
|
||||||
'cmd:room_rent|error:unknown',
|
|
||||||
"Could not create your private room! You were not charged.\n"
|
|
||||||
"An unknown error occurred while creating your private room.\n"
|
|
||||||
"`{error}`"
|
|
||||||
)).format(error=e.text),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
if room:
|
||||||
# Ack with confirmation message pointing to the room
|
# Ack with confirmation message pointing to the room
|
||||||
msg = t(_p(
|
msg = t(_p(
|
||||||
'cmd:room_rent|success',
|
'cmd:room_rent|success',
|
||||||
@@ -531,6 +458,90 @@ class RoomCog(LionCog):
|
|||||||
description=msg
|
description=msg
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._start(room)
|
||||||
|
|
||||||
|
# Send tips message
|
||||||
|
# TODO: Actual tips.
|
||||||
|
await room.channel.send(
|
||||||
|
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
|
||||||
|
mention=ctx.author.mention
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send config UI
|
||||||
|
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
|
||||||
|
await ui.send(room.channel)
|
||||||
|
|
||||||
|
@log_wrap(action='create_room')
|
||||||
|
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
|
||||||
|
t = self.bot.translator.t
|
||||||
|
# TODO: Rollback the channel create if this fails
|
||||||
|
async with self.bot.db.connection() as conn:
|
||||||
|
self.bot.db.conn = conn
|
||||||
|
# Note that the room creation will go into the UI as well.
|
||||||
|
async with conn.transaction():
|
||||||
|
# Check member balance is sufficient
|
||||||
|
await ctx.alion.data.refresh()
|
||||||
|
member_balance = ctx.alion.data.coins
|
||||||
|
if member_balance < required:
|
||||||
|
await ctx.reply(
|
||||||
|
embed=error_embed(
|
||||||
|
t(_np(
|
||||||
|
'cmd:room_rent|error:insufficient_funds',
|
||||||
|
"Renting a private room for `one` day costs {coin}**{required}**, "
|
||||||
|
"but you only have {coin}**{balance}**!",
|
||||||
|
"Renting a private room for `{days}` days costs {coin}**{required}**, "
|
||||||
|
"but you only have {coin}**{balance}**!",
|
||||||
|
days
|
||||||
|
)).format(
|
||||||
|
coin=self.bot.config.emojis.coin,
|
||||||
|
balance=member_balance,
|
||||||
|
required=required,
|
||||||
|
days=days
|
||||||
|
),
|
||||||
|
ephemeral=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Deduct balance
|
||||||
|
# TODO: Economy transaction instead of manual deduction
|
||||||
|
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
|
||||||
|
|
||||||
|
# Create room with given starting balance and other parameters
|
||||||
|
try:
|
||||||
|
return await self.create_private_room(
|
||||||
|
ctx.guild,
|
||||||
|
ctx.author,
|
||||||
|
required - rent,
|
||||||
|
name or ctx.author.display_name,
|
||||||
|
members=provided
|
||||||
|
)
|
||||||
|
except discord.Forbidden:
|
||||||
|
await ctx.reply(
|
||||||
|
embed=error_embed(
|
||||||
|
t(_p(
|
||||||
|
'cmd:room_rent|error:my_permissions',
|
||||||
|
"Could not create your private room! You were not charged.\n"
|
||||||
|
"I have insufficient permissions to create a private room channel."
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
|
||||||
|
return
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
await ctx.reply(
|
||||||
|
embed=error_embed(
|
||||||
|
t(_p(
|
||||||
|
'cmd:room_rent|error:unknown',
|
||||||
|
"Could not create your private room! You were not charged.\n"
|
||||||
|
"An unknown error occurred while creating your private room.\n"
|
||||||
|
"`{error}`"
|
||||||
|
)).format(error=e.text),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
|
||||||
|
return
|
||||||
|
|
||||||
@room_group.command(
|
@room_group.command(
|
||||||
name=_p('cmd:room_status', "status"),
|
name=_p('cmd:room_status', "status"),
|
||||||
@@ -864,43 +875,41 @@ class RoomCog(LionCog):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Start Transaction
|
# Start Transaction
|
||||||
conn = await self.bot.db.get_connection()
|
# TODO: Economy transaction
|
||||||
async with conn.transaction():
|
await ctx.alion.data.refresh()
|
||||||
await ctx.alion.data.refresh()
|
member_balance = ctx.alion.data.coins
|
||||||
member_balance = ctx.alion.data.coins
|
if member_balance < coins:
|
||||||
if member_balance < coins:
|
await ctx.reply(
|
||||||
await ctx.reply(
|
embed=error_embed(t(_p(
|
||||||
embed=error_embed(t(_p(
|
'cmd:room_deposit|error:insufficient_funds',
|
||||||
'cmd:room_deposit|error:insufficient_funds',
|
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
|
||||||
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
|
)).format(
|
||||||
)).format(
|
coin=self.bot.config.emojis.coin,
|
||||||
coin=self.bot.config.emojis.coin,
|
amount=coins,
|
||||||
amount=coins,
|
balance=member_balance
|
||||||
balance=member_balance
|
)),
|
||||||
)),
|
ephemeral=True
|
||||||
ephemeral=True
|
)
|
||||||
)
|
return
|
||||||
return
|
|
||||||
|
|
||||||
# Deduct balance
|
# Deduct balance
|
||||||
# TODO: Economy transaction
|
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
|
||||||
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
|
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
|
||||||
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
|
|
||||||
|
|
||||||
# Post deposit message
|
# Post deposit message
|
||||||
await room.notify_deposit(ctx.author, coins)
|
await room.notify_deposit(ctx.author, coins)
|
||||||
|
|
||||||
# Ack the deposit
|
# Ack the deposit
|
||||||
if ctx.channel.id != room.data.channelid:
|
if ctx.channel.id != room.data.channelid:
|
||||||
ack_msg = t(_p(
|
ack_msg = t(_p(
|
||||||
'cmd:room_depost|success',
|
'cmd:room_depost|success',
|
||||||
"Success! You have contributed {coin}**{amount}** to the private room bank."
|
"Success! You have contributed {coin}**{amount}** to the private room bank."
|
||||||
)).format(coin=self.bot.config.emojis.coin, amount=coins)
|
)).format(coin=self.bot.config.emojis.coin, amount=coins)
|
||||||
await ctx.reply(
|
await ctx.reply(
|
||||||
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg)
|
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await ctx.interaction.delete_original_response()
|
await ctx.interaction.delete_original_response()
|
||||||
|
|
||||||
# ----- Guild Configuration -----
|
# ----- Guild Configuration -----
|
||||||
@LionCog.placeholder_group
|
@LionCog.placeholder_group
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
|
|||||||
|
|
||||||
from meta import LionBot, conf
|
from meta import LionBot, conf
|
||||||
from meta.errors import UserInputError
|
from meta.errors import UserInputError
|
||||||
|
from meta.logger import log_wrap
|
||||||
from babel.translator import ctx_locale
|
from babel.translator import ctx_locale
|
||||||
from utils.lib import utc_now, MessageArgs, error_embed
|
from utils.lib import utc_now, MessageArgs, error_embed
|
||||||
from utils.ui import MessageUI, input
|
from utils.ui import MessageUI, input
|
||||||
@@ -115,38 +116,43 @@ class RoomUI(MessageUI):
|
|||||||
return
|
return
|
||||||
await submit.response.defer(thinking=True, ephemeral=True)
|
await submit.response.defer(thinking=True, ephemeral=True)
|
||||||
|
|
||||||
|
await self._do_deposit(t, press, amount, submit)
|
||||||
|
|
||||||
|
# Post deposit message
|
||||||
|
await self.room.notify_deposit(press.user, amount)
|
||||||
|
|
||||||
|
await self.refresh(thinking=submit)
|
||||||
|
|
||||||
|
@log_wrap(isolate=True)
|
||||||
|
async def _do_deposit(self, t, press, amount, submit):
|
||||||
# Start transaction for deposit
|
# Start transaction for deposit
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
# Get the lion balance directly
|
async with conn.transaction():
|
||||||
lion = await self.bot.core.data.Member.fetch(
|
# Get the lion balance directly
|
||||||
self.room.data.guildid,
|
lion = await self.bot.core.data.Member.fetch(
|
||||||
press.user.id,
|
self.room.data.guildid,
|
||||||
cached=False
|
press.user.id,
|
||||||
)
|
cached=False
|
||||||
balance = lion.coins
|
)
|
||||||
if balance < amount:
|
balance = lion.coins
|
||||||
await submit.edit_original_response(
|
if balance < amount:
|
||||||
embed=error_embed(
|
await submit.edit_original_response(
|
||||||
t(_p(
|
embed=error_embed(
|
||||||
'ui:room_status|button:deposit|error:insufficient_funds',
|
t(_p(
|
||||||
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
|
'ui:room_status|button:deposit|error:insufficient_funds',
|
||||||
)).format(
|
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
|
||||||
coin=self.bot.config.emojis.coin,
|
)).format(
|
||||||
amount=amount,
|
coin=self.bot.config.emojis.coin,
|
||||||
balance=balance
|
amount=amount,
|
||||||
|
balance=balance
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
return
|
||||||
return
|
# TODO: Economy Transaction
|
||||||
# TODO: Economy Transaction
|
await lion.update(coins=CoreData.Member.coins - amount)
|
||||||
await lion.update(coins=CoreData.Member.coins - amount)
|
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
|
||||||
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
|
|
||||||
|
|
||||||
# Post deposit message
|
|
||||||
await self.room.notify_deposit(press.user, amount)
|
|
||||||
|
|
||||||
await self.refresh(thinking=submit)
|
|
||||||
|
|
||||||
async def desposit_button_refresh(self):
|
async def desposit_button_refresh(self):
|
||||||
self.desposit_button.label = self.bot.translator.t(_p(
|
self.desposit_button.label = self.bot.translator.t(_p(
|
||||||
|
|||||||
@@ -217,23 +217,24 @@ class ScheduleCog(LionCog):
|
|||||||
for bookingid in bookingids:
|
for bookingid in bookingids:
|
||||||
await self._cancel_booking_active(*bookingid)
|
await self._cancel_booking_active(*bookingid)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
# Now delete from data
|
async with conn.transaction():
|
||||||
records = await self.data.ScheduleSessionMember.table.delete_where(
|
# Now delete from data
|
||||||
MULTIVALUE_IN(
|
records = await self.data.ScheduleSessionMember.table.delete_where(
|
||||||
('slotid', 'guildid', 'userid'),
|
MULTIVALUE_IN(
|
||||||
*bookingids
|
('slotid', 'guildid', 'userid'),
|
||||||
|
*bookingids
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Refund cancelled bookings
|
# Refund cancelled bookings
|
||||||
if refund:
|
if refund:
|
||||||
maybe_tids = (record['book_transactionid'] for record in records)
|
maybe_tids = (record['book_transactionid'] for record in records)
|
||||||
tids = [tid for tid in maybe_tids if tid is not None]
|
tids = [tid for tid in maybe_tids if tid is not None]
|
||||||
if tids:
|
if tids:
|
||||||
economy = self.bot.get_cog('Economy')
|
economy = self.bot.get_cog('Economy')
|
||||||
await economy.data.Transaction.refund_transactions(*tids)
|
await economy.data.Transaction.refund_transactions(*tids)
|
||||||
finally:
|
finally:
|
||||||
for lock in locks:
|
for lock in locks:
|
||||||
lock.release()
|
lock.release()
|
||||||
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
|
|||||||
"One or more requested timeslots are already booked!"
|
"One or more requested timeslots are already booked!"
|
||||||
))
|
))
|
||||||
raise UserInputError(error)
|
raise UserInputError(error)
|
||||||
conn = await self.bot.db.get_connection()
|
|
||||||
|
|
||||||
# Booking request is now validated. Perform bookings.
|
# Booking request is now validated. Perform bookings.
|
||||||
# Fetch or create session data
|
# Fetch or create session data
|
||||||
@@ -482,27 +482,27 @@ class ScheduleCog(LionCog):
|
|||||||
*((guildid, slotid) for slotid in slotids)
|
*((guildid, slotid) for slotid in slotids)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with conn.transaction():
|
# Create transactions
|
||||||
# Create transactions
|
# TODO: wrap in a transaction so the economy transaction gets unwound if it fails
|
||||||
economy = self.bot.get_cog('Economy')
|
economy = self.bot.get_cog('Economy')
|
||||||
trans_data = (
|
trans_data = (
|
||||||
TransactionType.SCHEDULE_BOOK,
|
TransactionType.SCHEDULE_BOOK,
|
||||||
guildid, userid, userid, 0,
|
guildid, userid, userid, 0,
|
||||||
config.get(ScheduleSettings.ScheduleCost.setting_id).value,
|
config.get(ScheduleSettings.ScheduleCost.setting_id).value,
|
||||||
0, None
|
0, None
|
||||||
)
|
)
|
||||||
transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids))
|
transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids))
|
||||||
transactionids = [row.transactionid for row in transactions]
|
transactionids = [row.transactionid for row in transactions]
|
||||||
|
|
||||||
# Create bookings
|
# Create bookings
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
booking_data = await self.data.ScheduleSessionMember.table.insert_many(
|
booking_data = await self.data.ScheduleSessionMember.table.insert_many(
|
||||||
('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'),
|
('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'),
|
||||||
*(
|
*(
|
||||||
(guildid, userid, slotid, now, tid)
|
(guildid, userid, slotid, now, tid)
|
||||||
for slotid, tid in zip(slotids, transactionids)
|
for slotid, tid in zip(slotids, transactionids)
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Now pass to activated slots
|
# Now pass to activated slots
|
||||||
for record in booking_data:
|
for record in booking_data:
|
||||||
|
|||||||
@@ -356,77 +356,76 @@ class TimeSlot:
|
|||||||
Does not modify session room channels (responsibility of the next open).
|
Does not modify session room channels (responsibility of the next open).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
conn = await self.bot.db.get_connection()
|
# TODO: Transaction?
|
||||||
async with conn.transaction():
|
# Calculate rewards
|
||||||
# Calculate rewards
|
rewards = []
|
||||||
rewards = []
|
attendance = []
|
||||||
attendance = []
|
did_not_show = []
|
||||||
did_not_show = []
|
for session in sessions:
|
||||||
for session in sessions:
|
bonus = session.bonus_reward * session.all_attended
|
||||||
bonus = session.bonus_reward * session.all_attended
|
reward = session.attended_reward + bonus
|
||||||
reward = session.attended_reward + bonus
|
required = session.min_attendence
|
||||||
required = session.min_attendence
|
for member in session.members.values():
|
||||||
for member in session.members.values():
|
guildid = member.guildid
|
||||||
guildid = member.guildid
|
userid = member.userid
|
||||||
userid = member.userid
|
attended = (member.total_clock >= required)
|
||||||
attended = (member.total_clock >= required)
|
if attended:
|
||||||
if attended:
|
rewards.append(
|
||||||
rewards.append(
|
(TransactionType.SCHEDULE_REWARD,
|
||||||
(TransactionType.SCHEDULE_REWARD,
|
guildid, self.bot.user.id,
|
||||||
guildid, self.bot.user.id,
|
0, userid,
|
||||||
0, userid,
|
reward, 0,
|
||||||
reward, 0,
|
None)
|
||||||
None)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
did_not_show.append((guildid, userid))
|
|
||||||
|
|
||||||
attendance.append(
|
|
||||||
(self.slotid, guildid, userid, attended, member.total_clock)
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
did_not_show.append((guildid, userid))
|
||||||
|
|
||||||
# Perform economy transactions
|
attendance.append(
|
||||||
economy: Economy = self.bot.get_cog('Economy')
|
(self.slotid, guildid, userid, attended, member.total_clock)
|
||||||
transactions = await economy.data.Transaction.execute_transactions(*rewards)
|
|
||||||
reward_ids = {
|
|
||||||
(t.guildid, t.to_account): t.transactionid
|
|
||||||
for t in transactions
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update lobby messages
|
|
||||||
message_tasks = [
|
|
||||||
asyncio.create_task(session.update_status(save=False))
|
|
||||||
for session in sessions
|
|
||||||
if session.lobby_channel is not None
|
|
||||||
]
|
|
||||||
await asyncio.gather(*message_tasks)
|
|
||||||
|
|
||||||
# Save attendance
|
|
||||||
if attendance:
|
|
||||||
att_table = TemporaryTable(
|
|
||||||
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
|
|
||||||
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
|
|
||||||
)
|
)
|
||||||
att_table.values = [
|
|
||||||
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
|
|
||||||
for sid, gid, uid, att, clock in attendance
|
|
||||||
]
|
|
||||||
await self.data.ScheduleSessionMember.table.update_where(
|
|
||||||
slotid=att_table['_sid'],
|
|
||||||
guildid=att_table['_gid'],
|
|
||||||
userid=att_table['_uid'],
|
|
||||||
).set(
|
|
||||||
attended=att_table['_att'],
|
|
||||||
clock=att_table['_clock'],
|
|
||||||
reward_transactionid=att_table['_reward']
|
|
||||||
).from_expr(att_table)
|
|
||||||
|
|
||||||
# Mark guild sessions as closed
|
# Perform economy transactions
|
||||||
if sessions:
|
economy: Economy = self.bot.get_cog('Economy')
|
||||||
await self.data.ScheduleSession.table.update_where(
|
transactions = await economy.data.Transaction.execute_transactions(*rewards)
|
||||||
slotid=self.slotid,
|
reward_ids = {
|
||||||
guildid=list(session.guildid for session in sessions)
|
(t.guildid, t.to_account): t.transactionid
|
||||||
).set(closed_at=utc_now())
|
for t in transactions
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update lobby messages
|
||||||
|
message_tasks = [
|
||||||
|
asyncio.create_task(session.update_status(save=False))
|
||||||
|
for session in sessions
|
||||||
|
if session.lobby_channel is not None
|
||||||
|
]
|
||||||
|
await asyncio.gather(*message_tasks)
|
||||||
|
|
||||||
|
# Save attendance
|
||||||
|
if attendance:
|
||||||
|
att_table = TemporaryTable(
|
||||||
|
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
|
||||||
|
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
|
||||||
|
)
|
||||||
|
att_table.values = [
|
||||||
|
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
|
||||||
|
for sid, gid, uid, att, clock in attendance
|
||||||
|
]
|
||||||
|
await self.data.ScheduleSessionMember.table.update_where(
|
||||||
|
slotid=att_table['_sid'],
|
||||||
|
guildid=att_table['_gid'],
|
||||||
|
userid=att_table['_uid'],
|
||||||
|
).set(
|
||||||
|
attended=att_table['_att'],
|
||||||
|
clock=att_table['_clock'],
|
||||||
|
reward_transactionid=att_table['_reward']
|
||||||
|
).from_expr(att_table)
|
||||||
|
|
||||||
|
# Mark guild sessions as closed
|
||||||
|
if sessions:
|
||||||
|
await self.data.ScheduleSession.table.update_where(
|
||||||
|
slotid=self.slotid,
|
||||||
|
guildid=list(session.guildid for session in sessions)
|
||||||
|
).set(closed_at=utc_now())
|
||||||
|
|
||||||
if consequences and did_not_show:
|
if consequences and did_not_show:
|
||||||
# Trigger blacklist and cancel member bookings as needed
|
# Trigger blacklist and cancel member bookings as needed
|
||||||
@@ -532,36 +531,35 @@ class TimeSlot:
|
|||||||
This involves refunding the booking transactions, deleting the booking rows,
|
This involves refunding the booking transactions, deleting the booking rows,
|
||||||
and updating any messages that may have been posted.
|
and updating any messages that may have been posted.
|
||||||
"""
|
"""
|
||||||
conn = await self.bot.db.get_connection()
|
# TODO: Transaction
|
||||||
async with conn.transaction():
|
# Collect booking rows
|
||||||
# Collect booking rows
|
bookings = [member.data for session in sessions for member in session.members.values()]
|
||||||
bookings = [member.data for session in sessions for member in session.members.values()]
|
|
||||||
|
|
||||||
if bookings:
|
if bookings:
|
||||||
# Refund booking transactions
|
# Refund booking transactions
|
||||||
economy: Economy = self.bot.get_cog('Economy')
|
economy: Economy = self.bot.get_cog('Economy')
|
||||||
maybe_tids = (r.book_transactionid for r in bookings)
|
maybe_tids = (r.book_transactionid for r in bookings)
|
||||||
tids = [tid for tid in maybe_tids if tid is not None]
|
tids = [tid for tid in maybe_tids if tid is not None]
|
||||||
await economy.data.Transaction.refund_transactions(*tids)
|
await economy.data.Transaction.refund_transactions(*tids)
|
||||||
|
|
||||||
# Delete booking rows
|
# Delete booking rows
|
||||||
await self.data.ScheduleSessionMember.table.delete_where(
|
await self.data.ScheduleSessionMember.table.delete_where(
|
||||||
MEMBERS(*((r.guildid, r.userid) for r in bookings)),
|
MEMBERS(*((r.guildid, r.userid) for r in bookings)),
|
||||||
slotid=self.slotid,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trigger message update for existent messages
|
|
||||||
lobby_tasks = [
|
|
||||||
asyncio.create_task(session.update_status(save=False, resend=False))
|
|
||||||
for session in sessions
|
|
||||||
]
|
|
||||||
await asyncio.gather(*lobby_tasks)
|
|
||||||
|
|
||||||
# Mark sessions as closed
|
|
||||||
await self.data.ScheduleSession.table.update_where(
|
|
||||||
slotid=self.slotid,
|
slotid=self.slotid,
|
||||||
guildid=[session.guildid for session in sessions]
|
|
||||||
).set(
|
|
||||||
closed_at=utc_now()
|
|
||||||
)
|
)
|
||||||
# TODO: Logging
|
|
||||||
|
# Trigger message update for existent messages
|
||||||
|
lobby_tasks = [
|
||||||
|
asyncio.create_task(session.update_status(save=False, resend=False))
|
||||||
|
for session in sessions
|
||||||
|
]
|
||||||
|
await asyncio.gather(*lobby_tasks)
|
||||||
|
|
||||||
|
# Mark sessions as closed
|
||||||
|
await self.data.ScheduleSession.table.update_where(
|
||||||
|
slotid=self.slotid,
|
||||||
|
guildid=[session.guildid for session in sessions]
|
||||||
|
).set(
|
||||||
|
closed_at=utc_now()
|
||||||
|
)
|
||||||
|
# TODO: Logging
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
|
|||||||
|
|
||||||
from meta import LionCog, LionContext, LionBot
|
from meta import LionCog, LionContext, LionBot
|
||||||
from meta.errors import SafeCancellation
|
from meta.errors import SafeCancellation
|
||||||
|
from meta.logger import log_wrap
|
||||||
from utils import ui
|
from utils import ui
|
||||||
from utils.lib import error_embed
|
from utils.lib import error_embed
|
||||||
from constants import MAX_COINS
|
from constants import MAX_COINS
|
||||||
@@ -145,6 +146,7 @@ class ColourShop(Shop):
|
|||||||
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
|
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@log_wrap(action='purchase')
|
||||||
async def purchase(self, itemid) -> ColourRoleItem:
|
async def purchase(self, itemid) -> ColourRoleItem:
|
||||||
"""
|
"""
|
||||||
Atomically handle a purchase of a ColourRoleItem.
|
Atomically handle a purchase of a ColourRoleItem.
|
||||||
@@ -157,144 +159,145 @@ class ColourShop(Shop):
|
|||||||
If the purchase fails for a known reason, raises SafeCancellation, with the error information.
|
If the purchase fails for a known reason, raises SafeCancellation, with the error information.
|
||||||
"""
|
"""
|
||||||
t = self.bot.translator.t
|
t = self.bot.translator.t
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
# Retrieve the item to purchase from data
|
async with conn.transaction():
|
||||||
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
|
# Retrieve the item to purchase from data
|
||||||
# Ensure the item is purchasable and not deleted
|
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
|
||||||
if not item['purchasable'] or item['deleted']:
|
# Ensure the item is purchasable and not deleted
|
||||||
raise SafeCancellation(
|
if not item['purchasable'] or item['deleted']:
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:not_purchasable',
|
|
||||||
"This item may not be purchased!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Refresh the customer
|
|
||||||
await self.customer.refresh()
|
|
||||||
|
|
||||||
# Ensure the guild exists in cache
|
|
||||||
guild = self.bot.get_guild(self.customer.guildid)
|
|
||||||
if guild is None:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:no_guild',
|
|
||||||
"Could not retrieve the server from Discord!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the customer member actually exists
|
|
||||||
member = await self.customer.lion.fetch_member()
|
|
||||||
if member is None:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:no_member',
|
|
||||||
"Could not retrieve the member from Discord."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the purchased role actually exists
|
|
||||||
role = guild.get_role(item['roleid'])
|
|
||||||
if role is None:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:no_role',
|
|
||||||
"This colour role could not be found in the server."
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the customer has enough coins for the item
|
|
||||||
if self.customer.balance < item['price']:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:low_balance',
|
|
||||||
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
|
|
||||||
)).format(
|
|
||||||
coin=self.bot.config.emojis.getemoji('coin'),
|
|
||||||
amount=item['price'],
|
|
||||||
balance=self.customer.balance
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
owned = self.owned()
|
|
||||||
if owned is not None:
|
|
||||||
# Ensure the customer does not already own the item
|
|
||||||
if owned.itemid == item['itemid']:
|
|
||||||
raise SafeCancellation(
|
raise SafeCancellation(
|
||||||
t(_p(
|
t(_p(
|
||||||
'shop:colour|purchase|error:owned',
|
'shop:colour|purchase|error:not_purchasable',
|
||||||
"You already own this item!"
|
"This item may not be purchased!"
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Charge the customer for the item
|
# Refresh the customer
|
||||||
economy_cog: Economy = self.bot.get_cog('Economy')
|
await self.customer.refresh()
|
||||||
economy_data = economy_cog.data
|
|
||||||
transaction = await economy_data.ShopTransaction.purchase_transaction(
|
|
||||||
guild.id,
|
|
||||||
member.id,
|
|
||||||
member.id,
|
|
||||||
itemid,
|
|
||||||
item['price']
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the item to the customer's inventory
|
# Ensure the guild exists in cache
|
||||||
await self.data.MemberInventory.create(
|
guild = self.bot.get_guild(self.customer.guildid)
|
||||||
guildid=guild.id,
|
if guild is None:
|
||||||
userid=member.id,
|
raise SafeCancellation(
|
||||||
transactionid=transaction.transactionid,
|
t(_p(
|
||||||
itemid=itemid
|
'shop:colour|purchase|error:no_guild',
|
||||||
)
|
"Could not retrieve the server from Discord!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
# Give the customer the role (do rollback if this fails)
|
# Ensure the customer member actually exists
|
||||||
try:
|
member = await self.customer.lion.fetch_member()
|
||||||
await member.add_roles(
|
if member is None:
|
||||||
role,
|
raise SafeCancellation(
|
||||||
atomic=True,
|
t(_p(
|
||||||
reason="Purchased colour role"
|
'shop:colour|purchase|error:no_member',
|
||||||
)
|
"Could not retrieve the member from Discord."
|
||||||
except discord.NotFound:
|
))
|
||||||
raise SafeCancellation(
|
)
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:failed_no_role',
|
|
||||||
"This colour role no longer exists!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
except discord.Forbidden:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:failed_permissions',
|
|
||||||
"I do not have enough permissions to give you this colour role!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
except discord.HTTPException:
|
|
||||||
raise SafeCancellation(
|
|
||||||
t(_p(
|
|
||||||
'shop:colour|purchase|error:failed_unknown',
|
|
||||||
"An unknown error occurred while giving you this colour role!"
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
# At this point, the purchase has succeeded and the user has obtained the colour role
|
# Ensure the purchased role actually exists
|
||||||
# Now, remove their previous colour role (if applicable)
|
role = guild.get_role(item['roleid'])
|
||||||
# TODO: We should probably add an on_role_delete event to clear defunct colour roles
|
if role is None:
|
||||||
if owned is not None:
|
raise SafeCancellation(
|
||||||
owned_role = owned.role
|
t(_p(
|
||||||
if owned_role is not None:
|
'shop:colour|purchase|error:no_role',
|
||||||
try:
|
"This colour role could not be found in the server."
|
||||||
await member.remove_roles(
|
))
|
||||||
owned_role,
|
)
|
||||||
reason="Removing old colour role.",
|
|
||||||
atomic=True
|
# Ensure the customer has enough coins for the item
|
||||||
|
if self.customer.balance < item['price']:
|
||||||
|
raise SafeCancellation(
|
||||||
|
t(_p(
|
||||||
|
'shop:colour|purchase|error:low_balance',
|
||||||
|
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
|
||||||
|
)).format(
|
||||||
|
coin=self.bot.config.emojis.getemoji('coin'),
|
||||||
|
amount=item['price'],
|
||||||
|
balance=self.customer.balance
|
||||||
)
|
)
|
||||||
except discord.HTTPException:
|
)
|
||||||
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
|
|
||||||
pass
|
|
||||||
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
|
|
||||||
|
|
||||||
# Purchase complete, update the shop and customer
|
owned = self.owned()
|
||||||
await self.refresh()
|
if owned is not None:
|
||||||
return self.owned()
|
# Ensure the customer does not already own the item
|
||||||
|
if owned.itemid == item['itemid']:
|
||||||
|
raise SafeCancellation(
|
||||||
|
t(_p(
|
||||||
|
'shop:colour|purchase|error:owned',
|
||||||
|
"You already own this item!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Charge the customer for the item
|
||||||
|
economy_cog: Economy = self.bot.get_cog('Economy')
|
||||||
|
economy_data = economy_cog.data
|
||||||
|
transaction = await economy_data.ShopTransaction.purchase_transaction(
|
||||||
|
guild.id,
|
||||||
|
member.id,
|
||||||
|
member.id,
|
||||||
|
itemid,
|
||||||
|
item['price']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the item to the customer's inventory
|
||||||
|
await self.data.MemberInventory.create(
|
||||||
|
guildid=guild.id,
|
||||||
|
userid=member.id,
|
||||||
|
transactionid=transaction.transactionid,
|
||||||
|
itemid=itemid
|
||||||
|
)
|
||||||
|
|
||||||
|
# Give the customer the role (do rollback if this fails)
|
||||||
|
try:
|
||||||
|
await member.add_roles(
|
||||||
|
role,
|
||||||
|
atomic=True,
|
||||||
|
reason="Purchased colour role"
|
||||||
|
)
|
||||||
|
except discord.NotFound:
|
||||||
|
raise SafeCancellation(
|
||||||
|
t(_p(
|
||||||
|
'shop:colour|purchase|error:failed_no_role',
|
||||||
|
"This colour role no longer exists!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
except discord.Forbidden:
|
||||||
|
raise SafeCancellation(
|
||||||
|
t(_p(
|
||||||
|
'shop:colour|purchase|error:failed_permissions',
|
||||||
|
"I do not have enough permissions to give you this colour role!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
except discord.HTTPException:
|
||||||
|
raise SafeCancellation(
|
||||||
|
t(_p(
|
||||||
|
'shop:colour|purchase|error:failed_unknown',
|
||||||
|
"An unknown error occurred while giving you this colour role!"
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
# At this point, the purchase has succeeded and the user has obtained the colour role
|
||||||
|
# Now, remove their previous colour role (if applicable)
|
||||||
|
# TODO: We should probably add an on_role_delete event to clear defunct colour roles
|
||||||
|
if owned is not None:
|
||||||
|
owned_role = owned.role
|
||||||
|
if owned_role is not None:
|
||||||
|
try:
|
||||||
|
await member.remove_roles(
|
||||||
|
owned_role,
|
||||||
|
reason="Removing old colour role.",
|
||||||
|
atomic=True
|
||||||
|
)
|
||||||
|
except discord.HTTPException:
|
||||||
|
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
|
||||||
|
pass
|
||||||
|
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
|
||||||
|
|
||||||
|
# Purchase complete, update the shop and customer
|
||||||
|
await self.refresh()
|
||||||
|
return self.owned()
|
||||||
|
|
||||||
async def refresh(self):
|
async def refresh(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
|
||||||
|
from meta.logger import log_wrap
|
||||||
from data import RowModel, Registry, Table, RegisterEnum
|
from data import RowModel, Registry, Table, RegisterEnum
|
||||||
from data.columns import Integer, String, Timestamp, Bool, Column
|
from data.columns import Integer, String, Timestamp, Bool, Column
|
||||||
|
|
||||||
@@ -80,6 +81,7 @@ class StatsData(Registry):
|
|||||||
end_time = Timestamp()
|
end_time = Timestamp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='tracked_time_between')
|
||||||
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
|
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
|
||||||
query = sql.SQL(
|
query = sql.SQL(
|
||||||
"""
|
"""
|
||||||
@@ -103,25 +105,27 @@ class StatsData(Registry):
|
|||||||
for _ in points
|
for _ in points
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
chain(*points)
|
chain(*points)
|
||||||
)
|
)
|
||||||
return cursor.fetchall()
|
return cursor.fetchall()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='study_time_between')
|
||||||
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
|
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"SELECT study_time_between(%s, %s, %s, %s)",
|
"SELECT study_time_between(%s, %s, %s, %s)",
|
||||||
(guildid, userid, _start, _end)
|
(guildid, userid, _start, _end)
|
||||||
)
|
)
|
||||||
return (await cursor.fetchone()[0]) or 0
|
return (await cursor.fetchone()[0]) or 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='study_times_between')
|
||||||
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
|
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
|
||||||
if len(points) < 2:
|
if len(points) < 2:
|
||||||
raise ValueError('Not enough block points given!')
|
raise ValueError('Not enough block points given!')
|
||||||
@@ -141,25 +145,27 @@ class StatsData(Registry):
|
|||||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((guildid, userid), *blocks))
|
tuple(chain((guildid, userid), *blocks))
|
||||||
)
|
)
|
||||||
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='study_time_since')
|
||||||
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"SELECT study_time_since(%s, %s, %s)",
|
"SELECT study_time_since(%s, %s, %s)",
|
||||||
(guildid, userid, _start)
|
(guildid, userid, _start)
|
||||||
)
|
)
|
||||||
return (await cursor.fetchone()[0]) or 0
|
return (await cursor.fetchone()[0]) or 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='study_times_between')
|
||||||
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
|
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
|
||||||
if len(starts) < 1:
|
if len(starts) < 1:
|
||||||
raise ValueError('No starting points given!')
|
raise ValueError('No starting points given!')
|
||||||
@@ -178,15 +184,16 @@ class StatsData(Registry):
|
|||||||
sql.SQL("({})").format(sql.Placeholder()) for _ in starts
|
sql.SQL("({})").format(sql.Placeholder()) for _ in starts
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((guildid, userid), starts))
|
tuple(chain((guildid, userid), starts))
|
||||||
)
|
)
|
||||||
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='leaderboard_since')
|
||||||
async def leaderboard_since(cls, guildid: int, since):
|
async def leaderboard_since(cls, guildid: int, since):
|
||||||
"""
|
"""
|
||||||
Return the voice totals since the given time for each member in the guild.
|
Return the voice totals since the given time for each member in the guild.
|
||||||
@@ -226,23 +233,25 @@ class StatsData(Registry):
|
|||||||
)
|
)
|
||||||
second_query_args = (since, guildid, since, since)
|
second_query_args = (since, guildid, since, since)
|
||||||
|
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
cls._connector.conn = conn
|
||||||
async with conn.cursor() as cursor:
|
async with conn.transaction():
|
||||||
await cursor.execute(second_query, second_query_args)
|
async with conn.cursor() as cursor:
|
||||||
overshoot_rows = await cursor.fetchall()
|
await cursor.execute(second_query, second_query_args)
|
||||||
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows}
|
overshoot_rows = await cursor.fetchall()
|
||||||
|
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows}
|
||||||
|
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(first_query, first_query_args)
|
await cursor.execute(first_query, first_query_args)
|
||||||
leaderboard = [
|
leaderboard = [
|
||||||
(row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0)))
|
(row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0)))
|
||||||
for row in await cursor.fetchall()
|
for row in await cursor.fetchall()
|
||||||
]
|
]
|
||||||
leaderboard.sort(key=lambda t: t[1], reverse=True)
|
leaderboard.sort(key=lambda t: t[1], reverse=True)
|
||||||
return leaderboard
|
return leaderboard
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap('leaderboard_all')
|
||||||
async def leaderboard_all(cls, guildid: int):
|
async def leaderboard_all(cls, guildid: int):
|
||||||
"""
|
"""
|
||||||
Return the all-time voice totals for the given guild.
|
Return the all-time voice totals for the given guild.
|
||||||
@@ -257,13 +266,13 @@ class StatsData(Registry):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(query, (guildid, ))
|
await cursor.execute(query, (guildid, ))
|
||||||
leaderboard = [
|
leaderboard = [
|
||||||
(row['userid'], int(row['total_duration']))
|
(row['userid'], int(row['total_duration']))
|
||||||
for row in await cursor.fetchall()
|
for row in await cursor.fetchall()
|
||||||
]
|
]
|
||||||
return leaderboard
|
return leaderboard
|
||||||
|
|
||||||
class MemberExp(RowModel):
|
class MemberExp(RowModel):
|
||||||
@@ -296,6 +305,7 @@ class StatsData(Registry):
|
|||||||
transactionid = Integer()
|
transactionid = Integer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='xp_since')
|
||||||
async def xp_since(cls, guildid: int, userid: int, *starts):
|
async def xp_since(cls, guildid: int, userid: int, *starts):
|
||||||
query = sql.SQL(
|
query = sql.SQL(
|
||||||
"""
|
"""
|
||||||
@@ -320,15 +330,16 @@ class StatsData(Registry):
|
|||||||
sql.Placeholder() for _ in starts
|
sql.Placeholder() for _ in starts
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((guildid, userid), starts))
|
tuple(chain((guildid, userid), starts))
|
||||||
)
|
)
|
||||||
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='xp_between')
|
||||||
async def xp_between(cls, guildid: int, userid: int, *points):
|
async def xp_between(cls, guildid: int, userid: int, *points):
|
||||||
blocks = zip(points, points[1:])
|
blocks = zip(points, points[1:])
|
||||||
query = sql.SQL(
|
query = sql.SQL(
|
||||||
@@ -355,15 +366,16 @@ class StatsData(Registry):
|
|||||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((guildid, userid), *blocks))
|
tuple(chain((guildid, userid), *blocks))
|
||||||
)
|
)
|
||||||
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
|
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='leaderboard_since')
|
||||||
async def leaderboard_since(cls, guildid: int, since):
|
async def leaderboard_since(cls, guildid: int, since):
|
||||||
"""
|
"""
|
||||||
Return the XP totals for the given guild since the given time.
|
Return the XP totals for the given guild since the given time.
|
||||||
@@ -378,16 +390,17 @@ class StatsData(Registry):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(query, (guildid, since))
|
await cursor.execute(query, (guildid, since))
|
||||||
leaderboard = [
|
leaderboard = [
|
||||||
(row['userid'], int(row['total_xp']))
|
(row['userid'], int(row['total_xp']))
|
||||||
for row in await cursor.fetchall()
|
for row in await cursor.fetchall()
|
||||||
]
|
]
|
||||||
return leaderboard
|
return leaderboard
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='leaderboard_all')
|
||||||
async def leaderboard_all(cls, guildid: int):
|
async def leaderboard_all(cls, guildid: int):
|
||||||
"""
|
"""
|
||||||
Return the all-time XP totals for the given guild.
|
Return the all-time XP totals for the given guild.
|
||||||
@@ -402,13 +415,13 @@ class StatsData(Registry):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(query, (guildid, ))
|
await cursor.execute(query, (guildid, ))
|
||||||
leaderboard = [
|
leaderboard = [
|
||||||
(row['userid'], int(row['total_xp']))
|
(row['userid'], int(row['total_xp']))
|
||||||
for row in await cursor.fetchall()
|
for row in await cursor.fetchall()
|
||||||
]
|
]
|
||||||
return leaderboard
|
return leaderboard
|
||||||
|
|
||||||
class UserExp(RowModel):
|
class UserExp(RowModel):
|
||||||
@@ -436,6 +449,7 @@ class StatsData(Registry):
|
|||||||
exp_type: Column[ExpType] = Column()
|
exp_type: Column[ExpType] = Column()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='user_xp_since')
|
||||||
async def xp_since(cls, userid: int, *starts):
|
async def xp_since(cls, userid: int, *starts):
|
||||||
query = sql.SQL(
|
query = sql.SQL(
|
||||||
"""
|
"""
|
||||||
@@ -459,15 +473,16 @@ class StatsData(Registry):
|
|||||||
sql.Placeholder() for _ in starts
|
sql.Placeholder() for _ in starts
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((userid,), starts))
|
tuple(chain((userid,), starts))
|
||||||
)
|
)
|
||||||
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='user_xp_since')
|
||||||
async def xp_between(cls, userid: int, *points):
|
async def xp_between(cls, userid: int, *points):
|
||||||
blocks = zip(points, points[1:])
|
blocks = zip(points, points[1:])
|
||||||
query = sql.SQL(
|
query = sql.SQL(
|
||||||
@@ -493,13 +508,13 @@ class StatsData(Registry):
|
|||||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((userid,), *blocks))
|
tuple(chain((userid,), *blocks))
|
||||||
)
|
)
|
||||||
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
|
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
class ProfileTag(RowModel):
|
class ProfileTag(RowModel):
|
||||||
"""
|
"""
|
||||||
@@ -531,15 +546,17 @@ class StatsData(Registry):
|
|||||||
return [tag.tag for tag in tags]
|
return [tag.tag for tag in tags]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='set_profile_tags')
|
||||||
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
|
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
|
||||||
conn = await self._connector.get_connection()
|
async with self._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
self._connector.conn = conn
|
||||||
await self.table.delete_where(guildid=guildid, userid=userid)
|
async with conn.transaction():
|
||||||
if tags:
|
await self.table.delete_where(guildid=guildid, userid=userid)
|
||||||
await self.table.insert_many(
|
if tags:
|
||||||
('guildid', 'userid', 'tag'),
|
await self.table.insert_many(
|
||||||
*((guildid, userid, tag) for tag in tags)
|
('guildid', 'userid', 'tag'),
|
||||||
)
|
*((guildid, userid, tag) for tag in tags)
|
||||||
|
)
|
||||||
|
|
||||||
class WeeklyGoals(RowModel):
|
class WeeklyGoals(RowModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
|
|||||||
# Update the tasklist
|
# Update the tasklist
|
||||||
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
|
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
|
||||||
modified = True
|
modified = True
|
||||||
conn = await self.bot.db.get_connection()
|
async with self._connector.connection() as conn:
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
await tasks_model.table.delete_where(**key)
|
await tasks_model.table.delete_where(**key).with_connection(conn)
|
||||||
if new_tasks:
|
if new_tasks:
|
||||||
await tasks_model.table.insert_many(
|
await tasks_model.table.insert_many(
|
||||||
(*key.keys(), 'completed', 'content'),
|
(*key.keys(), 'completed', 'content'),
|
||||||
*((*key.values(), *new_task) for new_task in new_tasks)
|
*((*key.values(), *new_task) for new_task in new_tasks)
|
||||||
)
|
).with_connection(conn)
|
||||||
|
|
||||||
if modified:
|
if modified:
|
||||||
# If either goal type was modified, clear the rendered cache and refresh
|
# If either goal type was modified, clear the rendered cache and refresh
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
|
|||||||
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
|
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
|
||||||
|
|
||||||
from meta import LionBot, LionCog, LionContext
|
from meta import LionBot, LionCog, LionContext
|
||||||
|
from meta.logger import log_wrap
|
||||||
from meta.errors import UserInputError
|
from meta.errors import UserInputError
|
||||||
from utils.lib import utc_now, error_embed
|
from utils.lib import utc_now, error_embed
|
||||||
from utils.ui import ChoicedEnum, Transformed, AButton
|
from utils.ui import ChoicedEnum, Transformed, AButton
|
||||||
@@ -141,30 +142,32 @@ class TasklistCog(LionCog):
|
|||||||
self.crossload_group(self.configure_group, configcog.configure_group)
|
self.crossload_group(self.configure_group, configcog.configure_group)
|
||||||
|
|
||||||
@LionCog.listener('on_tasks_completed')
|
@LionCog.listener('on_tasks_completed')
|
||||||
|
@log_wrap(action="reward tasks completed")
|
||||||
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
|
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
|
async with conn.transaction():
|
||||||
tasks = await tasklist.fetch_tasks(*taskids)
|
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
|
||||||
unrewarded = [task for task in tasks if not task.rewarded]
|
tasks = await tasklist.fetch_tasks(*taskids)
|
||||||
if unrewarded:
|
unrewarded = [task for task in tasks if not task.rewarded]
|
||||||
reward = (await self.settings.task_reward.get(member.guild.id)).value
|
if unrewarded:
|
||||||
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value
|
reward = (await self.settings.task_reward.get(member.guild.id)).value
|
||||||
|
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value
|
||||||
|
|
||||||
ecog = self.bot.get_cog('Economy')
|
ecog = self.bot.get_cog('Economy')
|
||||||
recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0
|
recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0
|
||||||
max_to_reward = limit - recent
|
max_to_reward = limit - recent
|
||||||
if max_to_reward > 0:
|
if max_to_reward > 0:
|
||||||
to_reward = unrewarded[:max_to_reward]
|
to_reward = unrewarded[:max_to_reward]
|
||||||
|
|
||||||
count = len(to_reward)
|
count = len(to_reward)
|
||||||
amount = count * reward
|
amount = count * reward
|
||||||
await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount)
|
await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount)
|
||||||
await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True)
|
await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> "
|
f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> "
|
||||||
f"'{amount}' coins for completing '{count}' tasks."
|
f"'{amount}' coins for completing '{count}' tasks."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def is_tasklist_channel(self, channel) -> bool:
|
async def is_tasklist_channel(self, channel) -> bool:
|
||||||
if not channel.guild:
|
if not channel.guild:
|
||||||
@@ -477,43 +480,40 @@ class TasklistCog(LionCog):
|
|||||||
# Contents successfully parsed, update the tasklist.
|
# Contents successfully parsed, update the tasklist.
|
||||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
||||||
|
|
||||||
# Lazily using the editor because it has a good parser
|
|
||||||
taskinfo = tasklist.parse_tasklist(lines)
|
taskinfo = tasklist.parse_tasklist(lines)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
now = utc_now()
|
||||||
async with conn.transaction():
|
|
||||||
now = utc_now()
|
|
||||||
|
|
||||||
# Delete tasklist if required
|
# Delete tasklist if required
|
||||||
if not append:
|
if not append:
|
||||||
await tasklist.update_tasklist(deleted_at=now)
|
await tasklist.update_tasklist(deleted_at=now)
|
||||||
|
|
||||||
# Create tasklist
|
# Create tasklist
|
||||||
# TODO: Refactor into common method with parse tasklist
|
# TODO: Refactor into common method with parse tasklist
|
||||||
created = {}
|
created = {}
|
||||||
target_depth = 0
|
target_depth = 0
|
||||||
while True:
|
while True:
|
||||||
to_insert = {}
|
to_insert = {}
|
||||||
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
|
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
|
||||||
if truedepth == target_depth:
|
if truedepth == target_depth:
|
||||||
to_insert[i] = (
|
to_insert[i] = (
|
||||||
tasklist.userid,
|
tasklist.userid,
|
||||||
content,
|
content,
|
||||||
created[parent] if parent is not None else None,
|
created[parent] if parent is not None else None,
|
||||||
now if ticked else None
|
now if ticked else None
|
||||||
)
|
|
||||||
if to_insert:
|
|
||||||
# Batch insert
|
|
||||||
tasks = await tasklist.data.Task.table.insert_many(
|
|
||||||
('userid', 'content', 'parentid', 'completed_at'),
|
|
||||||
*to_insert.values()
|
|
||||||
)
|
)
|
||||||
for i, task in zip(to_insert.keys(), tasks):
|
if to_insert:
|
||||||
created[i] = task['taskid']
|
# Batch insert
|
||||||
target_depth += 1
|
tasks = await tasklist.data.Task.table.insert_many(
|
||||||
else:
|
('userid', 'content', 'parentid', 'completed_at'),
|
||||||
# Reached maximum depth
|
*to_insert.values()
|
||||||
break
|
)
|
||||||
|
for i, task in zip(to_insert.keys(), tasks):
|
||||||
|
created[i] = task['taskid']
|
||||||
|
target_depth += 1
|
||||||
|
else:
|
||||||
|
# Reached maximum depth
|
||||||
|
break
|
||||||
|
|
||||||
# Ack modifications
|
# Ack modifications
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
|
|||||||
from discord.ui.text_input import TextInput, TextStyle
|
from discord.ui.text_input import TextInput, TextStyle
|
||||||
|
|
||||||
from meta import conf
|
from meta import conf
|
||||||
|
from meta.logger import log_wrap
|
||||||
from meta.errors import UserInputError
|
from meta.errors import UserInputError
|
||||||
from utils.lib import MessageArgs, utc_now
|
from utils.lib import MessageArgs, utc_now
|
||||||
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
|
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
|
||||||
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
|
|||||||
except UserInputError as error:
|
except UserInputError as error:
|
||||||
await ModalRetryUI(self, error.msg).respond_to(interaction)
|
await ModalRetryUI(self, error.msg).respond_to(interaction)
|
||||||
|
|
||||||
|
@log_wrap(action="parse editor")
|
||||||
async def parse_editor(self):
|
async def parse_editor(self):
|
||||||
# First parse each line
|
# First parse each line
|
||||||
new_lines = self.tasklist_editor.value.splitlines()
|
new_lines = self.tasklist_editor.value.splitlines()
|
||||||
@@ -155,27 +157,28 @@ class BulkEditor(LeoModal):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Incremental/diff editing
|
# TODO: Incremental/diff editing
|
||||||
conn = await self.bot.db.get_connection()
|
async with self.bot.db.connection() as conn:
|
||||||
async with conn.transaction():
|
self.bot.db.conn = conn
|
||||||
now = utc_now()
|
async with conn.transaction():
|
||||||
|
now = utc_now()
|
||||||
|
|
||||||
if same_layout:
|
if same_layout:
|
||||||
# if the layout has not changed, just edit the tasks
|
# if the layout has not changed, just edit the tasks
|
||||||
for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)):
|
for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)):
|
||||||
args = {}
|
args = {}
|
||||||
if oldinfo[2] != newinfo[2]:
|
if oldinfo[2] != newinfo[2]:
|
||||||
args['completed_at'] = now if newinfo[2] else None
|
args['completed_at'] = now if newinfo[2] else None
|
||||||
if oldinfo[3] != newinfo[3]:
|
if oldinfo[3] != newinfo[3]:
|
||||||
args['content'] = newinfo[3]
|
args['content'] = newinfo[3]
|
||||||
if args:
|
if args:
|
||||||
await self.tasklist.update_tasks(taskid, **args)
|
await self.tasklist.update_tasks(taskid, **args)
|
||||||
else:
|
else:
|
||||||
# Naive implementation clearing entire tasklist
|
# Naive implementation clearing entire tasklist
|
||||||
# Clear tasklist
|
# Clear tasklist
|
||||||
await self.tasklist.update_tasklist(deleted_at=now)
|
await self.tasklist.update_tasklist(deleted_at=now)
|
||||||
|
|
||||||
# Create tasklist
|
# Create tasklist
|
||||||
await self.tasklist.write_taskinfo(taskinfo)
|
await self.tasklist.write_taskinfo(taskinfo)
|
||||||
|
|
||||||
|
|
||||||
class UIMode(Enum):
|
class UIMode(Enum):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Type
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from data import RowModel, Table, ORDER
|
from data import RowModel, Table, ORDER
|
||||||
|
from meta.logger import log_wrap, set_logging_context
|
||||||
|
|
||||||
|
|
||||||
class ModelData:
|
class ModelData:
|
||||||
@@ -60,6 +61,7 @@ class ModelData:
|
|||||||
It only updates.
|
It only updates.
|
||||||
"""
|
"""
|
||||||
# TODO: Better way of getting the key?
|
# TODO: Better way of getting the key?
|
||||||
|
# TODO: Transaction
|
||||||
if not isinstance(parent_id, tuple):
|
if not isinstance(parent_id, tuple):
|
||||||
parent_id = (parent_id, )
|
parent_id = (parent_id, )
|
||||||
model = cls._model
|
model = cls._model
|
||||||
@@ -83,6 +85,8 @@ class ListData:
|
|||||||
This assumes the list is the only data stored in the table,
|
This assumes the list is the only data stored in the table,
|
||||||
and removes list entries by deleting rows.
|
and removes list entries by deleting rows.
|
||||||
"""
|
"""
|
||||||
|
setting_id: str
|
||||||
|
|
||||||
# Table storing the setting data
|
# Table storing the setting data
|
||||||
_table_interface: Table
|
_table_interface: Table
|
||||||
|
|
||||||
@@ -100,10 +104,12 @@ class ListData:
|
|||||||
_cache = None # Map[id -> value]
|
_cache = None # Map[id -> value]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(isolate=True)
|
||||||
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Read in all entries associated to the given id.
|
Read in all entries associated to the given id.
|
||||||
"""
|
"""
|
||||||
|
set_logging_context(action="Read cls.setting_id")
|
||||||
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||||
return cls._cache[parent_id]
|
return cls._cache[parent_id]
|
||||||
|
|
||||||
@@ -121,53 +127,56 @@ class ListData:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(isolate=True)
|
||||||
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
|
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Write the provided list to storage.
|
Write the provided list to storage.
|
||||||
"""
|
"""
|
||||||
|
set_logging_context(action="Write cls.setting_id")
|
||||||
table = cls._table_interface
|
table = cls._table_interface
|
||||||
conn = await table.connector.get_connection()
|
async with table.connector.connection() as conn:
|
||||||
async with conn.transaction():
|
table.connector.conn = conn
|
||||||
# Handle None input as an empty list
|
async with conn.transaction():
|
||||||
if data is None:
|
# Handle None input as an empty list
|
||||||
data = []
|
if data is None:
|
||||||
|
data = []
|
||||||
|
|
||||||
current = await cls._reader(id, use_cache=False, **kwargs)
|
current = await cls._reader(id, use_cache=False, **kwargs)
|
||||||
if not cls._order_column and (add_only or remove_only):
|
if not cls._order_column and (add_only or remove_only):
|
||||||
to_insert = [item for item in data if item not in current] if not remove_only else []
|
to_insert = [item for item in data if item not in current] if not remove_only else []
|
||||||
to_remove = data if remove_only else (
|
to_remove = data if remove_only else (
|
||||||
[item for item in current if item not in data] if not add_only else []
|
[item for item in current if item not in data] if not add_only else []
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle required deletions
|
# Handle required deletions
|
||||||
if to_remove:
|
if to_remove:
|
||||||
params = {
|
params = {
|
||||||
cls._id_column: id,
|
cls._id_column: id,
|
||||||
cls._data_column: to_remove
|
cls._data_column: to_remove
|
||||||
}
|
}
|
||||||
await table.delete_where(**params)
|
await table.delete_where(**params)
|
||||||
|
|
||||||
# Handle required insertions
|
# Handle required insertions
|
||||||
if to_insert:
|
if to_insert:
|
||||||
columns = (cls._id_column, cls._data_column)
|
columns = (cls._id_column, cls._data_column)
|
||||||
values = [(id, value) for value in to_insert]
|
values = [(id, value) for value in to_insert]
|
||||||
await table.insert_many(columns, *values)
|
await table.insert_many(columns, *values)
|
||||||
|
|
||||||
if cls._cache is not None:
|
if cls._cache is not None:
|
||||||
new_current = [item for item in current + to_insert if item not in to_remove]
|
new_current = [item for item in current + to_insert if item not in to_remove]
|
||||||
cls._cache[id] = new_current
|
cls._cache[id] = new_current
|
||||||
else:
|
else:
|
||||||
# Remove all and add all to preserve order
|
# Remove all and add all to preserve order
|
||||||
delete_params = {cls._id_column: id}
|
delete_params = {cls._id_column: id}
|
||||||
await table.delete_where(**delete_params)
|
await table.delete_where(**delete_params)
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
columns = (cls._id_column, cls._data_column)
|
columns = (cls._id_column, cls._data_column)
|
||||||
values = [(id, value) for value in data]
|
values = [(id, value) for value in data]
|
||||||
await table.insert_many(columns, *values)
|
await table.insert_many(columns, *values)
|
||||||
|
|
||||||
if cls._cache is not None:
|
if cls._cache is not None:
|
||||||
cls._cache[id] = data
|
cls._cache[id] = data
|
||||||
|
|
||||||
|
|
||||||
class KeyValueData:
|
class KeyValueData:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
|
||||||
|
from meta.logger import log_wrap
|
||||||
from data import RowModel, Registry, Table
|
from data import RowModel, Registry, Table
|
||||||
from data.columns import Integer, String, Timestamp, Bool
|
from data.columns import Integer, String, Timestamp, Bool
|
||||||
|
|
||||||
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
|
|||||||
member_expid = Integer()
|
member_expid = Integer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='end_text_sessions')
|
||||||
async def end_sessions(cls, connector, *session_data):
|
async def end_sessions(cls, connector, *session_data):
|
||||||
query = sql.SQL("""
|
query = sql.SQL("""
|
||||||
WITH
|
WITH
|
||||||
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
|
|||||||
) SELECT
|
) SELECT
|
||||||
data._guildid, 0,
|
data._guildid, 0,
|
||||||
NULL, data._userid,
|
NULL, data._userid,
|
||||||
SUM(_coins), 0, 'TEXT_SESSION'
|
LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
|
||||||
FROM data
|
FROM data
|
||||||
WHERE data._coins > 0
|
WHERE data._coins > 0
|
||||||
GROUP BY (data._guildid, data._userid)
|
GROUP BY (data._guildid, data._userid)
|
||||||
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
|
|||||||
)
|
)
|
||||||
, member AS (
|
, member AS (
|
||||||
UPDATE members
|
UPDATE members
|
||||||
SET coins = coins + data._coins
|
SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
|
||||||
FROM data
|
FROM data
|
||||||
WHERE members.userid = data._userid AND members.guildid = data._guildid
|
WHERE members.userid = data._userid AND members.guildid = data._guildid
|
||||||
)
|
)
|
||||||
@@ -166,15 +167,16 @@ class TextTrackerData(Registry):
|
|||||||
# Or ask for a connection from the connection pool
|
# Or ask for a connection from the connection pool
|
||||||
# Transaction may take some time due to index updates
|
# Transaction may take some time due to index updates
|
||||||
# Alternatively maybe use the "do not expect response mode"
|
# Alternatively maybe use the "do not expect response mode"
|
||||||
conn = await connector.get_connection()
|
async with connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain(*session_data))
|
tuple(chain(*session_data))
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='user_messages_between')
|
||||||
async def user_messages_between(cls, userid: int, *points):
|
async def user_messages_between(cls, userid: int, *points):
|
||||||
"""
|
"""
|
||||||
Compute messages written between the given points.
|
Compute messages written between the given points.
|
||||||
@@ -203,15 +205,16 @@ class TextTrackerData(Registry):
|
|||||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((userid,), *blocks))
|
tuple(chain((userid,), *blocks))
|
||||||
)
|
)
|
||||||
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='member_messages_between')
|
||||||
async def member_messages_between(cls, guildid: int, userid: int, *points):
|
async def member_messages_between(cls, guildid: int, userid: int, *points):
|
||||||
"""
|
"""
|
||||||
Compute messages written between the given points.
|
Compute messages written between the given points.
|
||||||
@@ -241,15 +244,16 @@ class TextTrackerData(Registry):
|
|||||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((userid, guildid), *blocks))
|
tuple(chain((userid, guildid), *blocks))
|
||||||
)
|
)
|
||||||
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='member_messages_since')
|
||||||
async def member_messages_since(cls, guildid: int, userid: int, *points):
|
async def member_messages_since(cls, guildid: int, userid: int, *points):
|
||||||
"""
|
"""
|
||||||
Compute messages written between the given points.
|
Compute messages written between the given points.
|
||||||
@@ -277,12 +281,12 @@ class TextTrackerData(Registry):
|
|||||||
sql.SQL("({})").format(sql.Placeholder()) for _ in points
|
sql.SQL("({})").format(sql.Placeholder()) for _ in points
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain((userid, guildid), points))
|
tuple(chain((userid, guildid), points))
|
||||||
)
|
)
|
||||||
return [r['messages'] or 0 for r in await cursor.fetchall()]
|
return [r['messages'] or 0 for r in await cursor.fetchall()]
|
||||||
|
|
||||||
untracked_channels = Table('untracked_text_channels')
|
untracked_channels = Table('untracked_text_channels')
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import datetime as dt
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
|
||||||
|
from meta.logger import log_wrap
|
||||||
from data import RowModel, Registry, Table
|
from data import RowModel, Registry, Table
|
||||||
from data.columns import Integer, String, Timestamp, Bool
|
from data.columns import Integer, String, Timestamp, Bool
|
||||||
|
|
||||||
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
|
|||||||
hourly_coins = Integer()
|
hourly_coins = Integer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='close_voice_session')
|
||||||
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
|
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"SELECT close_study_session_at(%s, %s, %s)",
|
"SELECT close_study_session_at(%s, %s, %s)",
|
||||||
(guildid, userid, _at)
|
(guildid, userid, _at)
|
||||||
)
|
)
|
||||||
member_data = await cursor.fetchone()
|
return await cursor.fetchone()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='close_voice_sessions')
|
||||||
async def close_voice_sessions_at(cls, *arg_tuples):
|
async def close_voice_sessions_at(cls, *arg_tuples):
|
||||||
query = sql.SQL("""
|
query = sql.SQL("""
|
||||||
SELECT
|
SELECT
|
||||||
@@ -139,28 +142,30 @@ class VoiceTrackerData(Registry):
|
|||||||
for _ in arg_tuples
|
for _ in arg_tuples
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain(*arg_tuples))
|
tuple(chain(*arg_tuples))
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='update_voice_session')
|
||||||
async def update_voice_session_at(
|
async def update_voice_session_at(
|
||||||
cls, guildid: int, userid: int, _at: dt.datetime,
|
cls, guildid: int, userid: int, _at: dt.datetime,
|
||||||
stream: bool, video: bool, rate: float
|
stream: bool, video: bool, rate: float
|
||||||
) -> int:
|
) -> int:
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
|
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
|
||||||
(guildid, userid, _at, stream, video, rate)
|
(guildid, userid, _at, stream, video, rate)
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return cls._make_rows(*rows)
|
return cls._make_rows(*rows)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='update_voice_sessions')
|
||||||
async def update_voice_sessions_at(cls, *arg_tuples):
|
async def update_voice_sessions_at(cls, *arg_tuples):
|
||||||
query = sql.SQL("""
|
query = sql.SQL("""
|
||||||
UPDATE
|
UPDATE
|
||||||
@@ -209,14 +214,14 @@ class VoiceTrackerData(Registry):
|
|||||||
for _ in arg_tuples
|
for _ in arg_tuples
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain(*arg_tuples))
|
tuple(chain(*arg_tuples))
|
||||||
)
|
)
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return cls._make_rows(*rows)
|
return cls._make_rows(*rows)
|
||||||
|
|
||||||
class VoiceSessions(RowModel):
|
class VoiceSessions(RowModel):
|
||||||
"""
|
"""
|
||||||
@@ -257,17 +262,19 @@ class VoiceTrackerData(Registry):
|
|||||||
transactionid = Integer()
|
transactionid = Integer()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='study_time_since')
|
||||||
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
"SELECT study_time_since(%s, %s, %s) AS result",
|
"SELECT study_time_since(%s, %s, %s) AS result",
|
||||||
(guildid, userid, _start)
|
(guildid, userid, _start)
|
||||||
)
|
)
|
||||||
result = await cursor.fetchone()
|
result = await cursor.fetchone()
|
||||||
return (result['result'] or 0) if result else 0
|
return (result['result'] or 0) if result else 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@log_wrap(action='multiple_voice_tracked_since')
|
||||||
async def multiple_voice_tracked_since(cls, *arg_tuples):
|
async def multiple_voice_tracked_since(cls, *arg_tuples):
|
||||||
query = sql.SQL("""
|
query = sql.SQL("""
|
||||||
SELECT
|
SELECT
|
||||||
@@ -286,13 +293,13 @@ class VoiceTrackerData(Registry):
|
|||||||
for _ in arg_tuples
|
for _ in arg_tuples
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
conn = await cls._connector.get_connection()
|
async with cls._connector.connection() as conn:
|
||||||
async with conn.cursor() as cursor:
|
async with conn.cursor() as cursor:
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
query,
|
query,
|
||||||
tuple(chain(*arg_tuples))
|
tuple(chain(*arg_tuples))
|
||||||
)
|
)
|
||||||
return await cursor.fetchall()
|
return await cursor.fetchall()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Schema
|
Schema
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ class VoiceSession:
|
|||||||
self.state.channelid, guildid=self.guildid, deleted=False
|
self.state.channelid, guildid=self.guildid, deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = await self.bot.db.get_connection()
|
|
||||||
# Insert an ongoing_session with the correct state, set data
|
# Insert an ongoing_session with the correct state, set data
|
||||||
state = self.state
|
state = self.state
|
||||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||||
|
|||||||
Reference in New Issue
Block a user