From df9b835cd5664aef01a0d712705872de068a4166 Mon Sep 17 00:00:00 2001 From: Conatum Date: Wed, 23 Aug 2023 17:31:38 +0300 Subject: [PATCH] fix (data): Parallel connection pool. --- src/analytics/events.py | 4 +- src/analytics/server.py | 2 +- src/bot.py | 2 +- src/core/cog.py | 18 +- src/core/data.py | 57 ++--- src/data/adapted.py | 18 +- src/data/connector.py | 117 ++++++++-- src/data/database.py | 19 +- src/data/queries.py | 10 +- src/modules/economy/cog.py | 96 ++++---- src/modules/economy/data.py | 181 +++++++------- src/modules/member_admin/cog.py | 23 +- src/modules/moderation/cog.py | 2 + src/modules/rolemenus/rolemenu.py | 210 ++++++++--------- src/modules/rooms/cog.py | 227 +++++++++--------- src/modules/rooms/roomui.py | 64 ++--- src/modules/schedule/cog.py | 70 +++--- src/modules/schedule/core/timeslot.py | 190 ++++++++------- src/modules/shop/shops/colours.py | 259 +++++++++++---------- src/modules/statistics/data.py | 225 +++++++++--------- src/modules/statistics/ui/weeklymonthly.py | 16 +- src/modules/tasklist/cog.py | 106 ++++----- src/modules/tasklist/ui.py | 41 ++-- src/settings/data.py | 81 ++++--- src/tracking/text/data.py | 62 ++--- src/tracking/voice/data.py | 95 ++++---- src/tracking/voice/session.py | 1 - 27 files changed, 1175 insertions(+), 1021 deletions(-) diff --git a/src/analytics/events.py b/src/analytics/events.py index fcae95b2..b672613a 100644 --- a/src/analytics/events.py +++ b/src/analytics/events.py @@ -52,7 +52,7 @@ class EventHandler(Generic[T]): f"Queue on event handler {self.route_name} is full! Discarding event {data}" ) - @log_wrap(action='consumer', isolate=False) + @log_wrap(action='consumer') async def consumer(self): while True: try: @@ -76,7 +76,7 @@ class EventHandler(Generic[T]): ) pass - @log_wrap(action='batch', isolate=False) + @log_wrap(action='batch') async def process_batch(self): logger.debug("Processing Batch") # TODO: copy syntax might be more efficient here diff --git a/src/analytics/server.py b/src/analytics/server.py index 887d10c6..0fb6cab8 100644 --- a/src/analytics/server.py +++ b/src/analytics/server.py @@ -123,7 +123,7 @@ class AnalyticsServer: log_action_stack.set(['Analytics']) log_app.set(conf.analytics['appname']) - async with await self.db.connect(): + async with self.db.open(): await self.talk.connect() await self.attach_event_handlers() self._snap_task = asyncio.create_task(self.snapshot_loop()) diff --git a/src/bot.py b/src/bot.py index 2761af82..a12ce61e 100644 --- a/src/bot.py +++ b/src/bot.py @@ -38,7 +38,7 @@ async def main(): intents.message_content = True intents.presences = False - async with await db.connect(): + async with db.open(): version = await db.version() if version.version != DATA_VERSION: error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." diff --git a/src/core/cog.py b/src/core/cog.py index 617d8bf8..3f3cc6c1 100644 --- a/src/core/cog.py +++ b/src/core/cog.py @@ -57,16 +57,14 @@ class CoreCog(LionCog): async def cog_load(self): # Fetch (and possibly create) core data rows. - conn = await self.bot.db.get_connection() - async with conn.transaction(): - self.app_config = await self.data.AppConfig.fetch_or_create(appname) - self.bot_config = await self.data.BotConfig.fetch_or_create(appname) - self.shard_data = await self.data.Shard.fetch_or_create( - shardname, - appname=appname, - shard_id=self.bot.shard_id, - shard_count=self.bot.shard_count - ) + self.app_config = await self.data.AppConfig.fetch_or_create(appname) + self.bot_config = await self.data.BotConfig.fetch_or_create(appname) + self.shard_data = await self.data.Shard.fetch_or_create( + shardname, + appname=appname, + shard_id=self.bot.shard_id, + shard_count=self.bot.shard_count + ) self.bot.add_listener(self.shard_update_guilds, name='on_guild_join') self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove') diff --git a/src/core/data.py b/src/core/data.py index 4de8f455..5acf46de 100644 --- a/src/core/data.py +++ b/src/core/data.py @@ -5,6 +5,7 @@ from cachetools import TTLCache import discord from meta import conf +from meta.logger import log_wrap from data import Table, Registry, Column, RowModel, RegisterEnum from data.models import WeakCache from data.columns import Integer, String, Bool, Timestamp @@ -287,6 +288,7 @@ class CoreData(Registry, name="core"): _timestamp = Timestamp() @classmethod + @log_wrap(action="Add Pending Coins") async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']: """ Safely add pending coins to a list of members. @@ -316,39 +318,40 @@ class CoreData(Registry, name="core"): ) ) # TODO: Replace with copy syntax/query? - conn = await cls.table.connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain(*pending)) - ) - rows = await cursor.fetchall() - return cls._make_rows(*rows) + async with cls.table.connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*pending)) + ) + rows = await cursor.fetchall() + return cls._make_rows(*rows) @classmethod + @log_wrap(action='get_member_rank') async def get_member_rank(cls, guildid, userid, untracked): """ Get the time and coin ranking for the given member, ignoring the provided untracked members. """ - conn = await cls.table.connector.get_connection() - async with conn.cursor() as curs: - await curs.execute( - """ - SELECT - time_rank, coin_rank - FROM ( - SELECT - userid, - row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank, - row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank - FROM members_totals - WHERE - guildid=%s AND userid NOT IN %s - ) AS guild_ranks WHERE userid=%s - """, - (guildid, tuple(untracked), userid) - ) - return (await curs.fetchone()) or (None, None) + async with cls.table.connector.connection() as conn: + async with conn.cursor() as curs: + await curs.execute( + """ + SELECT + time_rank, coin_rank + FROM ( + SELECT + userid, + row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank, + row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank + FROM members_totals + WHERE + guildid=%s AND userid NOT IN %s + ) AS guild_ranks WHERE userid=%s + """, + (guildid, tuple(untracked), userid) + ) + return (await curs.fetchone()) or (None, None) class LionHook(RowModel): """ diff --git a/src/data/adapted.py b/src/data/adapted.py index a4344f19..a6b4597a 100644 --- a/src/data/adapted.py +++ b/src/data/adapted.py @@ -1,6 +1,7 @@ # from enum import Enum from typing import Optional from psycopg.types.enum import register_enum, EnumInfo +from psycopg import AsyncConnection from .registry import Attachable, Registry @@ -23,10 +24,17 @@ class RegisterEnum(Attachable): connector = registry._conn if connector is None: raise ValueError("Cannot initialise without connector!") - connection = await connector.get_connection() - if connection is None: - raise ValueError("Cannot Init without connection.") - info = await EnumInfo.fetch(connection, self.name) + connector.connect_hook(self.connection_hook) + # await connector.refresh_pool() + # The below may be somewhat dangerous + # But adaption should never write to the database + await connector.map_over_pool(self.connection_hook) + # if conn := connector.conn: + # # Ensure the adaption is run in the current context as well + # await self.connection_hook(conn) + + async def connection_hook(self, conn: AsyncConnection): + info = await EnumInfo.fetch(conn, self.name) if info is None: raise ValueError(f"Enum {self.name} not found in database.") - register_enum(info, connection, self.enum, mapping=list(self.mapping.items())) + register_enum(info, conn, self.enum, mapping=list(self.mapping.items())) diff --git a/src/data/connector.py b/src/data/connector.py index 7b7b3a5f..7b25aed3 100644 --- a/src/data/connector.py +++ b/src/data/connector.py @@ -1,7 +1,10 @@ -from typing import Protocol, runtime_checkable, Callable, Awaitable +from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional import logging +from contextvars import ContextVar +from contextlib import asynccontextmanager import psycopg as psq +from psycopg_pool import AsyncConnectionPool from psycopg.pq import TransactionStatus from .cursor import AsyncLoggingCursor @@ -10,42 +13,110 @@ logger = logging.getLogger(__name__) row_factory = psq.rows.dict_row +ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None) + class Connector: cursor_factory = AsyncLoggingCursor def __init__(self, conn_args): self._conn_args = conn_args - self.conn: psq.AsyncConnection = None + self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory) + + self.pool = self.make_pool() self.conn_hooks = [] - async def get_connection(self) -> psq.AsyncConnection: + @property + def conn(self) -> Optional[psq.AsyncConnection]: """ - Get the current active connection. - This should never be cached outside of a transaction. + Convenience property for the current context connection. """ - # TODO: Reconnection logic? - if not self.conn: - raise ValueError("Attempting to get connection before initialisation!") - if self.conn.info.transaction_status is TransactionStatus.INERROR: - await self.connect() - logger.error( - "Database connection transaction failed!! This should not happen. Reconnecting." - ) - return self.conn + return ctx_connection.get() - async def connect(self) -> psq.AsyncConnection: - logger.info("Establishing connection to database.", extra={'action': "Data Connect"}) - self.conn = await psq.AsyncConnection.connect( - self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory + @conn.setter + def conn(self, conn: psq.AsyncConnection): + """ + Set the contextual connection in the current context. + Always do this in an isolated context! + """ + ctx_connection.set(conn) + + def make_pool(self) -> AsyncConnectionPool: + logger.info("Initialising connection pool.", extra={'action': "Pool Init"}) + return AsyncConnectionPool( + self._conn_args, + open=False, + min_size=4, + max_size=8, + configure=self._setup_connection, + kwargs=self._conn_kwargs ) - for hook in self.conn_hooks: - await hook(self.conn) - return self.conn - async def reconnect(self) -> psq.AsyncConnection: - return await self.connect() + async def refresh_pool(self): + """ + Refresh the pool. + + The point of this is to invalidate any existing connections so that the connection set up is run again. + Better ways should be sought (a way to + """ + logger.info("Pool refresh requested, closing and reopening.") + old_pool = self.pool + self.pool = self.make_pool() + await self.pool.open() + logger.info(f"Old pool statistics: {self.pool.get_stats()}") + await old_pool.close() + logger.info("Pool refresh complete.") + + async def map_over_pool(self, callable): + """ + Dangerous method to call a method on each connection in the pool. + + Utilises private methods of the AsyncConnectionPool. + """ + async with self.pool._lock: + conns = list(self.pool._pool) + while conns: + conn = conns.pop() + try: + await callable(conn) + except Exception: + logger.exception(f"Mapped connection task failed. {callable.__name__}") + + @asynccontextmanager + async def open(self): + try: + logger.info("Opening database pool.") + await self.pool.open() + yield + finally: + # May be a different pool! + logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}") + await self.pool.close() + + @asynccontextmanager + async def connection(self) -> psq.AsyncConnection: + """ + Asynchronous context manager to get and manage a connection. + + If the context connection is set, uses this and does not manage the lifetime. + Otherwise, requests a new connection from the pool and returns it when done. + """ + logger.debug("Database connection requested.", extra={'action': "Data Connect"}) + if (conn := self.conn): + yield conn + else: + async with self.pool.connection() as conn: + yield conn + + async def _setup_connection(self, conn: psq.AsyncConnection): + logger.debug("Initialising new connection.", extra={'action': "Conn Init"}) + for hook in self.conn_hooks: + try: + await hook(conn) + except Exception: + logger.exception("Exception encountered setting up new connection") + return conn def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): """ diff --git a/src/data/database.py b/src/data/database.py index 039a0b0e..255e4129 100644 --- a/src/data/database.py +++ b/src/data/database.py @@ -35,12 +35,13 @@ class Database(Connector): """ Return the current schema version as a Version namedtuple. """ - async with self.conn.cursor() as cursor: - # Get last entry in version table, compare against desired version - await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") - row = await cursor.fetchone() - if row: - return Version(row['version'], row['time'], row['author']) - else: - # No versions in the database - return Version(-1, None, None) + async with self.connection() as conn: + async with conn.cursor() as cursor: + # Get last entry in version table, compare against desired version + await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") + row = await cursor.fetchone() + if row: + return Version(row['version'], row['time'], row['author']) + else: + # No versions in the database + return Version(-1, None, None) diff --git a/src/data/queries.py b/src/data/queries.py index a64d0950..02329281 100644 --- a/src/data/queries.py +++ b/src/data/queries.py @@ -101,12 +101,12 @@ class Query(Generic[QueryResult]): if self.connector is None: raise ValueError("Cannot execute query without cursor, connection, or connector.") else: - conn = await self.connector.get_connection() + async with self.connector.connection() as conn: + async with conn.cursor() as cursor: + data = await self._execute(cursor) else: - conn = self.conn - - async with conn.cursor() as cursor: - data = await self._execute(cursor) + async with self.conn.cursor() as cursor: + data = await self._execute(cursor) else: data = await self._execute(cursor) return data diff --git a/src/modules/economy/cog.py b/src/modules/economy/cog.py index 99a8a907..e428e5b0 100644 --- a/src/modules/economy/cog.py +++ b/src/modules/economy/cog.py @@ -1,4 +1,5 @@ from typing import Optional, Union +import asyncio import discord from discord.ext import commands as cmds @@ -182,30 +183,34 @@ class Economy(LionCog): # We may need to do a mass row create operation. targetids = set(target.id for target in targets) if len(targets) > 1: - conn = await ctx.bot.db.get_connection() - async with conn.transaction(): - # First fetch the members which currently exist - query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id) - query.select('userid').with_no_adapter() - if 2 * len(targets) < len(ctx.guild.members): - # More efficient to fetch the targets explicitly - query.where(userid=list(targetids)) - existent_rows = await query - existentids = set(r['userid'] for r in existent_rows) + async def wrapper(): + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + # First fetch the members which currently exist + query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id) + query.select('userid').with_no_adapter() + if 2 * len(targets) < len(ctx.guild.members): + # More efficient to fetch the targets explicitly + query.where(userid=list(targetids)) + existent_rows = await query + existentids = set(r['userid'] for r in existent_rows) - # Then check if any new userids need adding, and if so create them - new_ids = targetids.difference(existentids) - if new_ids: - # We use ON CONFLICT IGNORE here in case the users already exist. - await self.bot.core.data.User.table.insert_many( - ('userid',), - *((id,) for id in new_ids) - ).on_conflict(ignore=True) - # TODO: Replace 0 here with the starting_coin value - await self.bot.core.data.Member.table.insert_many( - ('guildid', 'userid', 'coins'), - *((ctx.guild.id, id, 0) for id in new_ids) - ).on_conflict(ignore=True) + # Then check if any new userids need adding, and if so create them + new_ids = targetids.difference(existentids) + if new_ids: + # We use ON CONFLICT IGNORE here in case the users already exist. + await self.bot.core.data.User.table.insert_many( + ('userid',), + *((id,) for id in new_ids) + ).on_conflict(ignore=True) + # TODO: Replace 0 here with the starting_coin value + await self.bot.core.data.Member.table.insert_many( + ('guildid', 'userid', 'coins'), + *((ctx.guild.id, id, 0) for id in new_ids) + ).on_conflict(ignore=True) + task = asyncio.create_task(wrapper(), name="wrapped-create-members") + await task else: # With only one target, we can take a simpler path, and make better use of local caches. await self.bot.core.lions.fetch_member(ctx.guild.id, target.id) @@ -703,31 +708,34 @@ class Economy(LionCog): # Alternative flow could be waiting until the target user presses accept await ctx.interaction.response.defer(thinking=True, ephemeral=True) - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # We do this in a transaction so that if something goes wrong, - # the coins deduction is rolled back atomicly - balance = ctx.alion.data.coins - if amount > balance: - await ctx.interaction.edit_original_response( - embed=error_embed( - t(_p( - 'cmd:send|error:insufficient', - "You do not have enough lioncoins to do this!\n" - "`Current Balance:` {coin_emoji}{balance}" - )).format( - coin_emoji=self.bot.config.emojis.getemoji('coin'), - balance=balance + async def wrapped(): + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + # We do this in a transaction so that if something goes wrong, + # the coins deduction is rolled back atomicly + balance = ctx.alion.data.coins + if amount > balance: + await ctx.interaction.edit_original_response( + embed=error_embed( + t(_p( + 'cmd:send|error:insufficient', + "You do not have enough lioncoins to do this!\n" + "`Current Balance:` {coin_emoji}{balance}" + )).format( + coin_emoji=self.bot.config.emojis.getemoji('coin'), + balance=balance + ) + ), ) - ), - ) - return + return - # Transfer the coins - await ctx.alion.data.update(coins=(Member.coins - amount)) - await target_lion.data.update(coins=(Member.coins + amount)) + # Transfer the coins + await ctx.alion.data.update(coins=(Member.coins - amount)) + await target_lion.data.update(coins=(Member.coins + amount)) # TODO: Audit trail + await asyncio.create_task(wrapped(), name="wrapped-send") # Message target embed = discord.Embed( diff --git a/src/modules/economy/data.py b/src/modules/economy/data.py index 64923540..ff4c32c2 100644 --- a/src/modules/economy/data.py +++ b/src/modules/economy/data.py @@ -1,6 +1,7 @@ from enum import Enum from psycopg import sql +from meta.logger import log_wrap from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr from data.columns import Integer, Bool, Column, Timestamp from core.data import CoreData @@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'): created_at = Timestamp() @classmethod + @log_wrap(action='execute_transaction') async def execute_transaction( cls, transaction_type: TransactionType, @@ -108,25 +110,27 @@ class EconomyData(Registry, name='economy'): from_account: int, to_account: int, amount: int, bonus: int = 0, refunds: int = None ): - conn = await cls._connector.get_connection() - async with conn.transaction(): - transaction = await cls.create( - transactiontype=transaction_type, - guildid=guildid, actorid=actorid, amount=amount, bonus=bonus, - from_account=from_account, to_account=to_account, - refunds=refunds - ) - if from_account is not None: - await CoreData.Member.table.update_where( - guildid=guildid, userid=from_account - ).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus))) - if to_account is not None: - await CoreData.Member.table.update_where( - guildid=guildid, userid=to_account - ).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus))) - return transaction + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + transaction = await cls.create( + transactiontype=transaction_type, + guildid=guildid, actorid=actorid, amount=amount, bonus=bonus, + from_account=from_account, to_account=to_account, + refunds=refunds + ) + if from_account is not None: + await CoreData.Member.table.update_where( + guildid=guildid, userid=from_account + ).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus))) + if to_account is not None: + await CoreData.Member.table.update_where( + guildid=guildid, userid=to_account + ).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus))) + return transaction @classmethod + @log_wrap(action='execute_transactions') async def execute_transactions(cls, *transactions): """ Execute multiple transactions in one data transaction. @@ -142,65 +146,68 @@ class EconomyData(Registry, name='economy'): if not transactions: return [] - conn = await cls._connector.get_connection() - async with conn.transaction(): - # Create the transactions - rows = await cls.table.insert_many( - ( - 'transactiontype', - 'guildid', 'actorid', - 'from_account', 'to_account', - 'amount', 'bonus', - 'refunds' - ), - *transactions - ).with_adapter(cls._make_rows) + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # Create the transactions + rows = await cls.table.insert_many( + ( + 'transactiontype', + 'guildid', 'actorid', + 'from_account', 'to_account', + 'amount', 'bonus', + 'refunds' + ), + *transactions + ).with_adapter(cls._make_rows) - # Update the members - transtable = TemporaryTable( - '_guildid', '_userid', '_amount', - types=('BIGINT', 'BIGINT', 'INTEGER') - ) - values = transtable.values - for transaction in transactions: - _, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction - coins = amount + bonus - if coins: - if from_acc: - values.append((guildid, from_acc, -1 * coins)) - if to_acc: - values.append((guildid, to_acc, coins)) - if values: - Member = CoreData.Member - await Member.table.update_where( - guildid=transtable['_guildid'], userid=transtable['_userid'] - ).set( - coins=SAFECOINS(Member.coins + transtable['_amount']) - ).from_expr(transtable) + # Update the members + transtable = TemporaryTable( + '_guildid', '_userid', '_amount', + types=('BIGINT', 'BIGINT', 'INTEGER') + ) + values = transtable.values + for transaction in transactions: + _, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction + coins = amount + bonus + if coins: + if from_acc: + values.append((guildid, from_acc, -1 * coins)) + if to_acc: + values.append((guildid, to_acc, coins)) + if values: + Member = CoreData.Member + await Member.table.update_where( + guildid=transtable['_guildid'], userid=transtable['_userid'] + ).set( + coins=SAFECOINS(Member.coins + transtable['_amount']) + ).from_expr(transtable) return rows @classmethod + @log_wrap(action='refund_transactions') async def refund_transactions(cls, *transactionids, actorid=0): if not transactionids: return [] - conn = await cls._connector.get_connection() - async with conn.transaction(): - # First fetch the transaction rows to refund - data = await cls.table.select_where(transactionid=transactionids) - if data: - # Build the transaction refund data - records = [ - ( - TransactionType.REFUND, - tr['guildid'], actorid, - tr['to_account'], tr['from_account'], - tr['amount'] + tr['bonus'], 0, - tr['transactionid'] - ) - for tr in data - ] - # Execute refund transactions - return await cls.execute_transactions(*records) + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + # First fetch the transaction rows to refund + data = await cls.table.select_where(transactionid=transactionids) + if data: + # Build the transaction refund data + records = [ + ( + TransactionType.REFUND, + tr['guildid'], actorid, + tr['to_account'], tr['from_account'], + tr['amount'] + tr['bonus'], 0, + tr['transactionid'] + ) + for tr in data + ] + # Execute refund transactions + return await cls.execute_transactions(*records) class ShopTransaction(RowModel): """ @@ -217,19 +224,21 @@ class EconomyData(Registry, name='economy'): itemid = Integer() @classmethod + @log_wrap(action='purchase_transaction') async def purchase_transaction( cls, guildid: int, actorid: int, userid: int, itemid: int, amount: int ): - conn = await cls._connector.get_connection() - async with conn.transaction(): - row = await EconomyData.Transaction.execute_transaction( - TransactionType.SHOP_PURCHASE, - guildid=guildid, actorid=actorid, from_account=userid, to_account=None, - amount=amount - ) - return await cls.create(transactionid=row.transactionid, itemid=itemid) + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + row = await EconomyData.Transaction.execute_transaction( + TransactionType.SHOP_PURCHASE, + guildid=guildid, actorid=actorid, from_account=userid, to_account=None, + amount=amount + ) + return await cls.create(transactionid=row.transactionid, itemid=itemid) class TaskTransaction(RowModel): """ @@ -263,19 +272,21 @@ class EconomyData(Registry, name='economy'): return result[0]['recent'] or 0 @classmethod + @log_wrap(action='reward_completed_tasks') async def reward_completed(cls, userid, guildid, count, amount): """ Reward the specified member `amount` coins for completing `count` tasks. """ # TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog? - conn = await cls._connector.get_connection() - async with conn.transaction(): - row = await EconomyData.Transaction.execute_transaction( - TransactionType.TASKS, - guildid=guildid, actorid=userid, from_account=None, to_account=userid, - amount=amount - ) - return await cls.create(transactionid=row.transactionid, count=count) + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + row = await EconomyData.Transaction.execute_transaction( + TransactionType.TASKS, + guildid=guildid, actorid=userid, from_account=None, to_account=userid, + amount=amount + ) + return await cls.create(transactionid=row.transactionid, count=count) class SessionTransaction(RowModel): """ diff --git a/src/modules/member_admin/cog.py b/src/modules/member_admin/cog.py index fbc0ebd0..6d57f222 100644 --- a/src/modules/member_admin/cog.py +++ b/src/modules/member_admin/cog.py @@ -193,18 +193,19 @@ class MemberAdminCog(LionCog): await lion.data.update(last_left=utc_now()) # Save member roles - conn = await self.bot.db.get_connection() - async with conn.transaction(): - await self.data.past_roles.delete_where( - guildid=member.guild.id, - userid=member.id - ) - # Insert current member roles - if member.roles: - await self.data.past_roles.insert_many( - ('guildid', 'userid', 'roleid'), - *((member.guild.id, member.id, role.id) for role in member.roles) + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + await self.data.past_roles.delete_where( + guildid=member.guild.id, + userid=member.id ) + # Insert current member roles + if member.roles: + await self.data.past_roles.insert_many( + ('guildid', 'userid', 'roleid'), + *((member.guild.id, member.id, role.id) for role in member.roles) + ) logger.debug( f"Stored persisting roles for member in ." ) diff --git a/src/modules/moderation/cog.py b/src/modules/moderation/cog.py index 05ae6846..905766ec 100644 --- a/src/modules/moderation/cog.py +++ b/src/modules/moderation/cog.py @@ -190,6 +190,8 @@ class ModerationCog(LionCog): update_args[instance._column] = instance.data ack_lines.append(instance.update_message) + await ctx.lguild.data.update(**update_args) + # Do the ack tick = self.bot.config.emojis.tick embed = discord.Embed( diff --git a/src/modules/rolemenus/rolemenu.py b/src/modules/rolemenus/rolemenu.py index 2a0734cf..da25d30c 100644 --- a/src/modules/rolemenus/rolemenu.py +++ b/src/modules/rolemenus/rolemenu.py @@ -483,65 +483,63 @@ class RoleMenu: )).format(role=role.name) ) - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Remove the role - try: - await member.remove_roles(role) - except discord.Forbidden: - raise UserInputError( - t(_p( - 'rolemenu|deselect|error:perms', - "I don't have enough permissions to remove this role from you!" - )) - ) - except discord.HTTPException: - raise UserInputError( - t(_p( - 'rolemenu|deselect|error:discord', - "An unknown error occurred removing your role! Please try again later." - )) - ) - - # Update history - now = utc_now() - history = await self.cog.data.RoleMenuHistory.table.update_where( - menuid=self.data.menuid, - roleid=role.id, - userid=member.id, - removed_at=None, - ).set(removed_at=now) - await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history)) - - # Refund if required - transactionids = [row['transactionid'] for row in history] - if self.config.refunds.value and any(transactionids): - transactionids = [tid for tid in transactionids if tid] - economy: Economy = self.bot.get_cog('Economy') - refunded = await economy.data.Transaction.refund_transactions(*transactionids) - total_refund = sum(row.amount + row.bonus for row in refunded) - else: - total_refund = 0 - - # Ack the removal - embed = discord.Embed( - colour=discord.Colour.brand_green(), - title=t(_p( - 'rolemenu|deslect|success|title', - "Role removed" + # Remove the role + try: + await member.remove_roles(role) + except discord.Forbidden: + raise UserInputError( + t(_p( + 'rolemenu|deselect|error:perms', + "I don't have enough permissions to remove this role from you!" )) ) - if total_refund: - embed.description = t(_p( - 'rolemenu|deselect|success:refund|desc', - "You have removed **{role}**, and been refunded {coin} **{amount}**." - )).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund) - else: - embed.description = t(_p( - 'rolemenu|deselect|success:norefund|desc', - "You have unequipped **{role}**." - )).format(role=role.name) - return embed + except discord.HTTPException: + raise UserInputError( + t(_p( + 'rolemenu|deselect|error:discord', + "An unknown error occurred removing your role! Please try again later." + )) + ) + + # Update history + now = utc_now() + history = await self.cog.data.RoleMenuHistory.table.update_where( + menuid=self.data.menuid, + roleid=role.id, + userid=member.id, + removed_at=None, + ).set(removed_at=now) + await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history)) + + # Refund if required + transactionids = [row['transactionid'] for row in history] + if self.config.refunds.value and any(transactionids): + transactionids = [tid for tid in transactionids if tid] + economy: Economy = self.bot.get_cog('Economy') + refunded = await economy.data.Transaction.refund_transactions(*transactionids) + total_refund = sum(row.amount + row.bonus for row in refunded) + else: + total_refund = 0 + + # Ack the removal + embed = discord.Embed( + colour=discord.Colour.brand_green(), + title=t(_p( + 'rolemenu|deslect|success|title', + "Role removed" + )) + ) + if total_refund: + embed.description = t(_p( + 'rolemenu|deselect|success:refund|desc', + "You have removed **{role}**, and been refunded {coin} **{amount}**." + )).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund) + else: + embed.description = t(_p( + 'rolemenu|deselect|success:norefund|desc', + "You have unequipped **{role}**." + )).format(role=role.name) + return embed else: # Member does not have the role, selection case. required = self.config.required_role.data @@ -591,57 +589,55 @@ class RoleMenu: ) ) - conn = await self.bot.db.get_connection() - async with conn.transaction(): - try: - await member.add_roles(role) - except discord.Forbidden: - raise UserInputError( - t(_p( - 'rolemenu|select|error:perms', - "I don't have enough permissions to give you this role!" - )) - ) - except discord.HTTPException: - raise UserInputError( - t(_p( - 'rolemenu|select|error:discord', - "An unknown error occurred while assigning your role! " - "Please try again later." - )) - ) - - now = utc_now() - - # Create transaction if applicable - if price: - economy: Economy = self.bot.get_cog('Economy') - tx = await economy.data.Transaction.execute_transaction( - transaction_type=TransactionType.OTHER, - guildid=guild.id, actorid=member.id, - from_account=member.id, to_account=None, - amount=price - ) - tid = tx.transactionid - else: - tid = None - - # Calculate expiry - duration = mrole.config.duration.value - if duration is not None: - expiry = now + dt.timedelta(seconds=duration) - else: - expiry = None - - # Add to equip history - equip = await self.cog.data.RoleMenuHistory.create( - menuid=self.data.menuid, roleid=role.id, - userid=member.id, - obtained_at=now, - transactionid=tid, - expires_at=expiry + try: + await member.add_roles(role) + except discord.Forbidden: + raise UserInputError( + t(_p( + 'rolemenu|select|error:perms', + "I don't have enough permissions to give you this role!" + )) ) - await self.cog.schedule_expiring(equip) + except discord.HTTPException: + raise UserInputError( + t(_p( + 'rolemenu|select|error:discord', + "An unknown error occurred while assigning your role! " + "Please try again later." + )) + ) + + now = utc_now() + + # Create transaction if applicable + if price: + economy: Economy = self.bot.get_cog('Economy') + tx = await economy.data.Transaction.execute_transaction( + transaction_type=TransactionType.OTHER, + guildid=guild.id, actorid=member.id, + from_account=member.id, to_account=None, + amount=price + ) + tid = tx.transactionid + else: + tid = None + + # Calculate expiry + duration = mrole.config.duration.value + if duration is not None: + expiry = now + dt.timedelta(seconds=duration) + else: + expiry = None + + # Add to equip history + equip = await self.cog.data.RoleMenuHistory.create( + menuid=self.data.menuid, roleid=role.id, + userid=member.id, + obtained_at=now, + transactionid=tid, + expires_at=expiry + ) + await self.cog.schedule_expiring(equip) # Ack the selection embed = discord.Embed( diff --git a/src/modules/rooms/cog.py b/src/modules/rooms/cog.py index 29037907..0bf434d7 100644 --- a/src/modules/rooms/cog.py +++ b/src/modules/rooms/cog.py @@ -259,17 +259,6 @@ class RoomCog(LionCog): lguild, [member.id for member in members] ) - self._start(room) - - # Send tips message - # TODO: Actual tips. - await channel.send( - "{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention) - ) - - # Send config UI - ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None) - await ui.send(channel) except Exception: try: await channel.delete(reason="Failed to created private room") @@ -454,71 +443,9 @@ class RoomCog(LionCog): return # Positive response. Start a transaction. - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Check member balance is sufficient - await ctx.alion.data.refresh() - member_balance = ctx.alion.data.coins - if member_balance < required: - await ctx.reply( - embed=error_embed( - t(_np( - 'cmd:room_rent|error:insufficient_funds', - "Renting a private room for `one` day costs {coin}**{required}**, " - "but you only have {coin}**{balance}**!", - "Renting a private room for `{days}` days costs {coin}**{required}**, " - "but you only have {coin}**{balance}**!", - days - )).format( - coin=self.bot.config.emojis.coin, - balance=member_balance, - required=required, - days=days - ), - ephemeral=True - ) - ) - return - - # Deduct balance - # TODO: Economy transaction instead of manual deduction - await ctx.alion.data.update(coins=CoreData.Member.coins - required) - - # Create room with given starting balance and other parameters - try: - room = await self.create_private_room( - ctx.guild, - ctx.author, - required - rent, - name or ctx.author.display_name, - members=provided - ) - except discord.Forbidden: - await ctx.reply( - embed=error_embed( - t(_p( - 'cmd:room_rent|error:my_permissions', - "Could not create your private room! You were not charged.\n" - "I have insufficient permissions to create a private room channel." - )), - ) - ) - await ctx.alion.data.update(coins=CoreData.Member.coins + required) - return - except discord.HTTPException as e: - await ctx.reply( - embed=error_embed( - t(_p( - 'cmd:room_rent|error:unknown', - "Could not create your private room! You were not charged.\n" - "An unknown error occurred while creating your private room.\n" - "`{error}`" - )).format(error=e.text), - ) - ) - await ctx.alion.data.update(coins=CoreData.Member.coins + required) - return + room = await self._do_create_room(ctx, required, days, rent, name, provided) + if room: # Ack with confirmation message pointing to the room msg = t(_p( 'cmd:room_rent|success', @@ -531,6 +458,90 @@ class RoomCog(LionCog): description=msg ) ) + self._start(room) + + # Send tips message + # TODO: Actual tips. + await room.channel.send( + "{mention} welcome to your private room! You may use the menu below to configure it.".format( + mention=ctx.author.mention + ) + ) + + # Send config UI + ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None) + await ui.send(room.channel) + + @log_wrap(action='create_room') + async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room: + t = self.bot.translator.t + # TODO: Rollback the channel create if this fails + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + # Note that the room creation will go into the UI as well. + async with conn.transaction(): + # Check member balance is sufficient + await ctx.alion.data.refresh() + member_balance = ctx.alion.data.coins + if member_balance < required: + await ctx.reply( + embed=error_embed( + t(_np( + 'cmd:room_rent|error:insufficient_funds', + "Renting a private room for `one` day costs {coin}**{required}**, " + "but you only have {coin}**{balance}**!", + "Renting a private room for `{days}` days costs {coin}**{required}**, " + "but you only have {coin}**{balance}**!", + days + )).format( + coin=self.bot.config.emojis.coin, + balance=member_balance, + required=required, + days=days + ), + ephemeral=True + ) + ) + return + + # Deduct balance + # TODO: Economy transaction instead of manual deduction + await ctx.alion.data.update(coins=CoreData.Member.coins - required) + + # Create room with given starting balance and other parameters + try: + return await self.create_private_room( + ctx.guild, + ctx.author, + required - rent, + name or ctx.author.display_name, + members=provided + ) + except discord.Forbidden: + await ctx.reply( + embed=error_embed( + t(_p( + 'cmd:room_rent|error:my_permissions', + "Could not create your private room! You were not charged.\n" + "I have insufficient permissions to create a private room channel." + )), + ) + ) + await ctx.alion.data.update(coins=CoreData.Member.coins + required) + return + except discord.HTTPException as e: + await ctx.reply( + embed=error_embed( + t(_p( + 'cmd:room_rent|error:unknown', + "Could not create your private room! You were not charged.\n" + "An unknown error occurred while creating your private room.\n" + "`{error}`" + )).format(error=e.text), + ) + ) + await ctx.alion.data.update(coins=CoreData.Member.coins + required) + return @room_group.command( name=_p('cmd:room_status', "status"), @@ -864,43 +875,41 @@ class RoomCog(LionCog): return # Start Transaction - conn = await self.bot.db.get_connection() - async with conn.transaction(): - await ctx.alion.data.refresh() - member_balance = ctx.alion.data.coins - if member_balance < coins: - await ctx.reply( - embed=error_embed(t(_p( - 'cmd:room_deposit|error:insufficient_funds', - "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." - )).format( - coin=self.bot.config.emojis.coin, - amount=coins, - balance=member_balance - )), - ephemeral=True - ) - return + # TODO: Economy transaction + await ctx.alion.data.refresh() + member_balance = ctx.alion.data.coins + if member_balance < coins: + await ctx.reply( + embed=error_embed(t(_p( + 'cmd:room_deposit|error:insufficient_funds', + "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." + )).format( + coin=self.bot.config.emojis.coin, + amount=coins, + balance=member_balance + )), + ephemeral=True + ) + return - # Deduct balance - # TODO: Economy transaction - await ctx.alion.data.update(coins=CoreData.Member.coins - coins) - await room.data.update(coin_balance=RoomData.Room.coin_balance + coins) + # Deduct balance + await ctx.alion.data.update(coins=CoreData.Member.coins - coins) + await room.data.update(coin_balance=RoomData.Room.coin_balance + coins) - # Post deposit message - await room.notify_deposit(ctx.author, coins) + # Post deposit message + await room.notify_deposit(ctx.author, coins) - # Ack the deposit - if ctx.channel.id != room.data.channelid: - ack_msg = t(_p( - 'cmd:room_depost|success', - "Success! You have contributed {coin}**{amount}** to the private room bank." - )).format(coin=self.bot.config.emojis.coin, amount=coins) - await ctx.reply( - embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg) - ) - else: - await ctx.interaction.delete_original_response() + # Ack the deposit + if ctx.channel.id != room.data.channelid: + ack_msg = t(_p( + 'cmd:room_depost|success', + "Success! You have contributed {coin}**{amount}** to the private room bank." + )).format(coin=self.bot.config.emojis.coin, amount=coins) + await ctx.reply( + embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg) + ) + else: + await ctx.interaction.delete_original_response() # ----- Guild Configuration ----- @LionCog.placeholder_group diff --git a/src/modules/rooms/roomui.py b/src/modules/rooms/roomui.py index 8cfc83f0..6d6d768c 100644 --- a/src/modules/rooms/roomui.py +++ b/src/modules/rooms/roomui.py @@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect from meta import LionBot, conf from meta.errors import UserInputError +from meta.logger import log_wrap from babel.translator import ctx_locale from utils.lib import utc_now, MessageArgs, error_embed from utils.ui import MessageUI, input @@ -115,38 +116,43 @@ class RoomUI(MessageUI): return await submit.response.defer(thinking=True, ephemeral=True) + await self._do_deposit(t, press, amount, submit) + + # Post deposit message + await self.room.notify_deposit(press.user, amount) + + await self.refresh(thinking=submit) + + @log_wrap(isolate=True) + async def _do_deposit(self, t, press, amount, submit): # Start transaction for deposit - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Get the lion balance directly - lion = await self.bot.core.data.Member.fetch( - self.room.data.guildid, - press.user.id, - cached=False - ) - balance = lion.coins - if balance < amount: - await submit.edit_original_response( - embed=error_embed( - t(_p( - 'ui:room_status|button:deposit|error:insufficient_funds', - "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." - )).format( - coin=self.bot.config.emojis.coin, - amount=amount, - balance=balance + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + # Get the lion balance directly + lion = await self.bot.core.data.Member.fetch( + self.room.data.guildid, + press.user.id, + cached=False + ) + balance = lion.coins + if balance < amount: + await submit.edit_original_response( + embed=error_embed( + t(_p( + 'ui:room_status|button:deposit|error:insufficient_funds', + "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." + )).format( + coin=self.bot.config.emojis.coin, + amount=amount, + balance=balance + ) ) ) - ) - return - # TODO: Economy Transaction - await lion.update(coins=CoreData.Member.coins - amount) - await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount) - - # Post deposit message - await self.room.notify_deposit(press.user, amount) - - await self.refresh(thinking=submit) + return + # TODO: Economy Transaction + await lion.update(coins=CoreData.Member.coins - amount) + await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount) async def desposit_button_refresh(self): self.desposit_button.label = self.bot.translator.t(_p( diff --git a/src/modules/schedule/cog.py b/src/modules/schedule/cog.py index 7349af05..75d56642 100644 --- a/src/modules/schedule/cog.py +++ b/src/modules/schedule/cog.py @@ -217,23 +217,24 @@ class ScheduleCog(LionCog): for bookingid in bookingids: await self._cancel_booking_active(*bookingid) - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Now delete from data - records = await self.data.ScheduleSessionMember.table.delete_where( - MULTIVALUE_IN( - ('slotid', 'guildid', 'userid'), - *bookingids + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + # Now delete from data + records = await self.data.ScheduleSessionMember.table.delete_where( + MULTIVALUE_IN( + ('slotid', 'guildid', 'userid'), + *bookingids + ) ) - ) - # Refund cancelled bookings - if refund: - maybe_tids = (record['book_transactionid'] for record in records) - tids = [tid for tid in maybe_tids if tid is not None] - if tids: - economy = self.bot.get_cog('Economy') - await economy.data.Transaction.refund_transactions(*tids) + # Refund cancelled bookings + if refund: + maybe_tids = (record['book_transactionid'] for record in records) + tids = [tid for tid in maybe_tids if tid is not None] + if tids: + economy = self.bot.get_cog('Economy') + await economy.data.Transaction.refund_transactions(*tids) finally: for lock in locks: lock.release() @@ -473,7 +474,6 @@ class ScheduleCog(LionCog): "One or more requested timeslots are already booked!" )) raise UserInputError(error) - conn = await self.bot.db.get_connection() # Booking request is now validated. Perform bookings. # Fetch or create session data @@ -482,27 +482,27 @@ class ScheduleCog(LionCog): *((guildid, slotid) for slotid in slotids) ) - async with conn.transaction(): - # Create transactions - economy = self.bot.get_cog('Economy') - trans_data = ( - TransactionType.SCHEDULE_BOOK, - guildid, userid, userid, 0, - config.get(ScheduleSettings.ScheduleCost.setting_id).value, - 0, None - ) - transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids)) - transactionids = [row.transactionid for row in transactions] + # Create transactions + # TODO: wrap in a transaction so the economy transaction gets unwound if it fails + economy = self.bot.get_cog('Economy') + trans_data = ( + TransactionType.SCHEDULE_BOOK, + guildid, userid, userid, 0, + config.get(ScheduleSettings.ScheduleCost.setting_id).value, + 0, None + ) + transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids)) + transactionids = [row.transactionid for row in transactions] - # Create bookings - now = utc_now() - booking_data = await self.data.ScheduleSessionMember.table.insert_many( - ('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'), - *( - (guildid, userid, slotid, now, tid) - for slotid, tid in zip(slotids, transactionids) - ) + # Create bookings + now = utc_now() + booking_data = await self.data.ScheduleSessionMember.table.insert_many( + ('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'), + *( + (guildid, userid, slotid, now, tid) + for slotid, tid in zip(slotids, transactionids) ) + ) # Now pass to activated slots for record in booking_data: diff --git a/src/modules/schedule/core/timeslot.py b/src/modules/schedule/core/timeslot.py index e82e6fa2..1669b780 100644 --- a/src/modules/schedule/core/timeslot.py +++ b/src/modules/schedule/core/timeslot.py @@ -356,77 +356,76 @@ class TimeSlot: Does not modify session room channels (responsibility of the next open). """ try: - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Calculate rewards - rewards = [] - attendance = [] - did_not_show = [] - for session in sessions: - bonus = session.bonus_reward * session.all_attended - reward = session.attended_reward + bonus - required = session.min_attendence - for member in session.members.values(): - guildid = member.guildid - userid = member.userid - attended = (member.total_clock >= required) - if attended: - rewards.append( - (TransactionType.SCHEDULE_REWARD, - guildid, self.bot.user.id, - 0, userid, - reward, 0, - None) - ) - else: - did_not_show.append((guildid, userid)) - - attendance.append( - (self.slotid, guildid, userid, attended, member.total_clock) + # TODO: Transaction? + # Calculate rewards + rewards = [] + attendance = [] + did_not_show = [] + for session in sessions: + bonus = session.bonus_reward * session.all_attended + reward = session.attended_reward + bonus + required = session.min_attendence + for member in session.members.values(): + guildid = member.guildid + userid = member.userid + attended = (member.total_clock >= required) + if attended: + rewards.append( + (TransactionType.SCHEDULE_REWARD, + guildid, self.bot.user.id, + 0, userid, + reward, 0, + None) ) + else: + did_not_show.append((guildid, userid)) - # Perform economy transactions - economy: Economy = self.bot.get_cog('Economy') - transactions = await economy.data.Transaction.execute_transactions(*rewards) - reward_ids = { - (t.guildid, t.to_account): t.transactionid - for t in transactions - } - - # Update lobby messages - message_tasks = [ - asyncio.create_task(session.update_status(save=False)) - for session in sessions - if session.lobby_channel is not None - ] - await asyncio.gather(*message_tasks) - - # Save attendance - if attendance: - att_table = TemporaryTable( - '_sid', '_gid', '_uid', '_att', '_clock', '_reward', - types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER') + attendance.append( + (self.slotid, guildid, userid, attended, member.total_clock) ) - att_table.values = [ - (sid, gid, uid, att, clock, reward_ids.get((gid, uid), None)) - for sid, gid, uid, att, clock in attendance - ] - await self.data.ScheduleSessionMember.table.update_where( - slotid=att_table['_sid'], - guildid=att_table['_gid'], - userid=att_table['_uid'], - ).set( - attended=att_table['_att'], - clock=att_table['_clock'], - reward_transactionid=att_table['_reward'] - ).from_expr(att_table) - # Mark guild sessions as closed - if sessions: - await self.data.ScheduleSession.table.update_where( - slotid=self.slotid, - guildid=list(session.guildid for session in sessions) - ).set(closed_at=utc_now()) + # Perform economy transactions + economy: Economy = self.bot.get_cog('Economy') + transactions = await economy.data.Transaction.execute_transactions(*rewards) + reward_ids = { + (t.guildid, t.to_account): t.transactionid + for t in transactions + } + + # Update lobby messages + message_tasks = [ + asyncio.create_task(session.update_status(save=False)) + for session in sessions + if session.lobby_channel is not None + ] + await asyncio.gather(*message_tasks) + + # Save attendance + if attendance: + att_table = TemporaryTable( + '_sid', '_gid', '_uid', '_att', '_clock', '_reward', + types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER') + ) + att_table.values = [ + (sid, gid, uid, att, clock, reward_ids.get((gid, uid), None)) + for sid, gid, uid, att, clock in attendance + ] + await self.data.ScheduleSessionMember.table.update_where( + slotid=att_table['_sid'], + guildid=att_table['_gid'], + userid=att_table['_uid'], + ).set( + attended=att_table['_att'], + clock=att_table['_clock'], + reward_transactionid=att_table['_reward'] + ).from_expr(att_table) + + # Mark guild sessions as closed + if sessions: + await self.data.ScheduleSession.table.update_where( + slotid=self.slotid, + guildid=list(session.guildid for session in sessions) + ).set(closed_at=utc_now()) if consequences and did_not_show: # Trigger blacklist and cancel member bookings as needed @@ -532,36 +531,35 @@ class TimeSlot: This involves refunding the booking transactions, deleting the booking rows, and updating any messages that may have been posted. """ - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Collect booking rows - bookings = [member.data for session in sessions for member in session.members.values()] + # TODO: Transaction + # Collect booking rows + bookings = [member.data for session in sessions for member in session.members.values()] - if bookings: - # Refund booking transactions - economy: Economy = self.bot.get_cog('Economy') - maybe_tids = (r.book_transactionid for r in bookings) - tids = [tid for tid in maybe_tids if tid is not None] - await economy.data.Transaction.refund_transactions(*tids) + if bookings: + # Refund booking transactions + economy: Economy = self.bot.get_cog('Economy') + maybe_tids = (r.book_transactionid for r in bookings) + tids = [tid for tid in maybe_tids if tid is not None] + await economy.data.Transaction.refund_transactions(*tids) - # Delete booking rows - await self.data.ScheduleSessionMember.table.delete_where( - MEMBERS(*((r.guildid, r.userid) for r in bookings)), - slotid=self.slotid, - ) - - # Trigger message update for existent messages - lobby_tasks = [ - asyncio.create_task(session.update_status(save=False, resend=False)) - for session in sessions - ] - await asyncio.gather(*lobby_tasks) - - # Mark sessions as closed - await self.data.ScheduleSession.table.update_where( + # Delete booking rows + await self.data.ScheduleSessionMember.table.delete_where( + MEMBERS(*((r.guildid, r.userid) for r in bookings)), slotid=self.slotid, - guildid=[session.guildid for session in sessions] - ).set( - closed_at=utc_now() ) - # TODO: Logging + + # Trigger message update for existent messages + lobby_tasks = [ + asyncio.create_task(session.update_status(save=False, resend=False)) + for session in sessions + ] + await asyncio.gather(*lobby_tasks) + + # Mark sessions as closed + await self.data.ScheduleSession.table.update_where( + slotid=self.slotid, + guildid=[session.guildid for session in sessions] + ).set( + closed_at=utc_now() + ) + # TODO: Logging diff --git a/src/modules/shop/shops/colours.py b/src/modules/shop/shops/colours.py index cec547f7..82f9e6a1 100644 --- a/src/modules/shop/shops/colours.py +++ b/src/modules/shop/shops/colours.py @@ -10,6 +10,7 @@ from discord.ui.button import button, Button from meta import LionCog, LionContext, LionBot from meta.errors import SafeCancellation +from meta.logger import log_wrap from utils import ui from utils.lib import error_embed from constants import MAX_COINS @@ -145,6 +146,7 @@ class ColourShop(Shop): if (owned is None or item.itemid != owned.itemid) and (item.price <= balance) ] + @log_wrap(action='purchase') async def purchase(self, itemid) -> ColourRoleItem: """ Atomically handle a purchase of a ColourRoleItem. @@ -157,144 +159,145 @@ class ColourShop(Shop): If the purchase fails for a known reason, raises SafeCancellation, with the error information. """ t = self.bot.translator.t - conn = await self.bot.db.get_connection() - async with conn.transaction(): - # Retrieve the item to purchase from data - item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid) - # Ensure the item is purchasable and not deleted - if not item['purchasable'] or item['deleted']: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:not_purchasable', - "This item may not be purchased!" - )) - ) - - # Refresh the customer - await self.customer.refresh() - - # Ensure the guild exists in cache - guild = self.bot.get_guild(self.customer.guildid) - if guild is None: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:no_guild', - "Could not retrieve the server from Discord!" - )) - ) - - # Ensure the customer member actually exists - member = await self.customer.lion.fetch_member() - if member is None: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:no_member', - "Could not retrieve the member from Discord." - )) - ) - - # Ensure the purchased role actually exists - role = guild.get_role(item['roleid']) - if role is None: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:no_role', - "This colour role could not be found in the server." - )) - ) - - # Ensure the customer has enough coins for the item - if self.customer.balance < item['price']: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:low_balance', - "This item costs {coin}{amount}!\nYour balance is {coin}{balance}" - )).format( - coin=self.bot.config.emojis.getemoji('coin'), - amount=item['price'], - balance=self.customer.balance - ) - ) - - owned = self.owned() - if owned is not None: - # Ensure the customer does not already own the item - if owned.itemid == item['itemid']: + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + # Retrieve the item to purchase from data + item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid) + # Ensure the item is purchasable and not deleted + if not item['purchasable'] or item['deleted']: raise SafeCancellation( t(_p( - 'shop:colour|purchase|error:owned', - "You already own this item!" + 'shop:colour|purchase|error:not_purchasable', + "This item may not be purchased!" )) ) - # Charge the customer for the item - economy_cog: Economy = self.bot.get_cog('Economy') - economy_data = economy_cog.data - transaction = await economy_data.ShopTransaction.purchase_transaction( - guild.id, - member.id, - member.id, - itemid, - item['price'] - ) + # Refresh the customer + await self.customer.refresh() - # Add the item to the customer's inventory - await self.data.MemberInventory.create( - guildid=guild.id, - userid=member.id, - transactionid=transaction.transactionid, - itemid=itemid - ) + # Ensure the guild exists in cache + guild = self.bot.get_guild(self.customer.guildid) + if guild is None: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:no_guild', + "Could not retrieve the server from Discord!" + )) + ) - # Give the customer the role (do rollback if this fails) - try: - await member.add_roles( - role, - atomic=True, - reason="Purchased colour role" - ) - except discord.NotFound: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:failed_no_role', - "This colour role no longer exists!" - )) - ) - except discord.Forbidden: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:failed_permissions', - "I do not have enough permissions to give you this colour role!" - )) - ) - except discord.HTTPException: - raise SafeCancellation( - t(_p( - 'shop:colour|purchase|error:failed_unknown', - "An unknown error occurred while giving you this colour role!" - )) - ) + # Ensure the customer member actually exists + member = await self.customer.lion.fetch_member() + if member is None: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:no_member', + "Could not retrieve the member from Discord." + )) + ) - # At this point, the purchase has succeeded and the user has obtained the colour role - # Now, remove their previous colour role (if applicable) - # TODO: We should probably add an on_role_delete event to clear defunct colour roles - if owned is not None: - owned_role = owned.role - if owned_role is not None: - try: - await member.remove_roles( - owned_role, - reason="Removing old colour role.", - atomic=True + # Ensure the purchased role actually exists + role = guild.get_role(item['roleid']) + if role is None: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:no_role', + "This colour role could not be found in the server." + )) + ) + + # Ensure the customer has enough coins for the item + if self.customer.balance < item['price']: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:low_balance', + "This item costs {coin}{amount}!\nYour balance is {coin}{balance}" + )).format( + coin=self.bot.config.emojis.getemoji('coin'), + amount=item['price'], + balance=self.customer.balance ) - except discord.HTTPException: - # Possibly Forbidden, or the role doesn't actually exist anymore (cache failure) - pass - await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid) + ) - # Purchase complete, update the shop and customer - await self.refresh() - return self.owned() + owned = self.owned() + if owned is not None: + # Ensure the customer does not already own the item + if owned.itemid == item['itemid']: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:owned', + "You already own this item!" + )) + ) + + # Charge the customer for the item + economy_cog: Economy = self.bot.get_cog('Economy') + economy_data = economy_cog.data + transaction = await economy_data.ShopTransaction.purchase_transaction( + guild.id, + member.id, + member.id, + itemid, + item['price'] + ) + + # Add the item to the customer's inventory + await self.data.MemberInventory.create( + guildid=guild.id, + userid=member.id, + transactionid=transaction.transactionid, + itemid=itemid + ) + + # Give the customer the role (do rollback if this fails) + try: + await member.add_roles( + role, + atomic=True, + reason="Purchased colour role" + ) + except discord.NotFound: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:failed_no_role', + "This colour role no longer exists!" + )) + ) + except discord.Forbidden: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:failed_permissions', + "I do not have enough permissions to give you this colour role!" + )) + ) + except discord.HTTPException: + raise SafeCancellation( + t(_p( + 'shop:colour|purchase|error:failed_unknown', + "An unknown error occurred while giving you this colour role!" + )) + ) + + # At this point, the purchase has succeeded and the user has obtained the colour role + # Now, remove their previous colour role (if applicable) + # TODO: We should probably add an on_role_delete event to clear defunct colour roles + if owned is not None: + owned_role = owned.role + if owned_role is not None: + try: + await member.remove_roles( + owned_role, + reason="Removing old colour role.", + atomic=True + ) + except discord.HTTPException: + # Possibly Forbidden, or the role doesn't actually exist anymore (cache failure) + pass + await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid) + + # Purchase complete, update the shop and customer + await self.refresh() + return self.owned() async def refresh(self): """ diff --git a/src/modules/statistics/data.py b/src/modules/statistics/data.py index 09e3458c..341c253e 100644 --- a/src/modules/statistics/data.py +++ b/src/modules/statistics/data.py @@ -4,6 +4,7 @@ from enum import Enum from itertools import chain from psycopg import sql +from meta.logger import log_wrap from data import RowModel, Registry, Table, RegisterEnum from data.columns import Integer, String, Timestamp, Bool, Column @@ -80,6 +81,7 @@ class StatsData(Registry): end_time = Timestamp() @classmethod + @log_wrap(action='tracked_time_between') async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]): query = sql.SQL( """ @@ -103,25 +105,27 @@ class StatsData(Registry): for _ in points ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - chain(*points) - ) - return cursor.fetchall() + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + chain(*points) + ) + return cursor.fetchall() @classmethod + @log_wrap(action='study_time_between') async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int: - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - "SELECT study_time_between(%s, %s, %s, %s)", - (guildid, userid, _start, _end) - ) - return (await cursor.fetchone()[0]) or 0 + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT study_time_between(%s, %s, %s, %s)", + (guildid, userid, _start, _end) + ) + return (await cursor.fetchone()[0]) or 0 @classmethod + @log_wrap(action='study_times_between') async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]: if len(points) < 2: raise ValueError('Not enough block points given!') @@ -141,25 +145,27 @@ class StatsData(Registry): sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((guildid, userid), *blocks)) - ) - return [r['stime'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((guildid, userid), *blocks)) + ) + return [r['stime'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='study_time_since') async def study_time_since(cls, guildid: int, userid: int, _start) -> int: - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - "SELECT study_time_since(%s, %s, %s)", - (guildid, userid, _start) - ) - return (await cursor.fetchone()[0]) or 0 + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT study_time_since(%s, %s, %s)", + (guildid, userid, _start) + ) + return (await cursor.fetchone()[0]) or 0 @classmethod + @log_wrap(action='study_times_between') async def study_times_since(cls, guildid: int, userid: int, *starts) -> int: if len(starts) < 1: raise ValueError('No starting points given!') @@ -178,15 +184,16 @@ class StatsData(Registry): sql.SQL("({})").format(sql.Placeholder()) for _ in starts ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((guildid, userid), starts)) - ) - return [r['stime'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((guildid, userid), starts)) + ) + return [r['stime'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='leaderboard_since') async def leaderboard_since(cls, guildid: int, since): """ Return the voice totals since the given time for each member in the guild. @@ -226,23 +233,25 @@ class StatsData(Registry): ) second_query_args = (since, guildid, since, since) - conn = await cls._connector.get_connection() - async with conn.transaction(): - async with conn.cursor() as cursor: - await cursor.execute(second_query, second_query_args) - overshoot_rows = await cursor.fetchall() - overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows} + async with cls._connector.connection() as conn: + cls._connector.conn = conn + async with conn.transaction(): + async with conn.cursor() as cursor: + await cursor.execute(second_query, second_query_args) + overshoot_rows = await cursor.fetchall() + overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows} - async with conn.cursor() as cursor: - await cursor.execute(first_query, first_query_args) - leaderboard = [ - (row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0))) - for row in await cursor.fetchall() - ] + async with conn.cursor() as cursor: + await cursor.execute(first_query, first_query_args) + leaderboard = [ + (row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0))) + for row in await cursor.fetchall() + ] leaderboard.sort(key=lambda t: t[1], reverse=True) return leaderboard @classmethod + @log_wrap('leaderboard_all') async def leaderboard_all(cls, guildid: int): """ Return the all-time voice totals for the given guild. @@ -257,13 +266,13 @@ class StatsData(Registry): """ ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute(query, (guildid, )) - leaderboard = [ - (row['userid'], int(row['total_duration'])) - for row in await cursor.fetchall() - ] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, (guildid, )) + leaderboard = [ + (row['userid'], int(row['total_duration'])) + for row in await cursor.fetchall() + ] return leaderboard class MemberExp(RowModel): @@ -296,6 +305,7 @@ class StatsData(Registry): transactionid = Integer() @classmethod + @log_wrap(action='xp_since') async def xp_since(cls, guildid: int, userid: int, *starts): query = sql.SQL( """ @@ -320,15 +330,16 @@ class StatsData(Registry): sql.Placeholder() for _ in starts ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((guildid, userid), starts)) - ) - return [r['exp'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((guildid, userid), starts)) + ) + return [r['exp'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='xp_between') async def xp_between(cls, guildid: int, userid: int, *points): blocks = zip(points, points[1:]) query = sql.SQL( @@ -355,15 +366,16 @@ class StatsData(Registry): sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((guildid, userid), *blocks)) - ) - return [r['period_xp'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((guildid, userid), *blocks)) + ) + return [r['period_xp'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='leaderboard_since') async def leaderboard_since(cls, guildid: int, since): """ Return the XP totals for the given guild since the given time. @@ -378,16 +390,17 @@ class StatsData(Registry): """ ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute(query, (guildid, since)) - leaderboard = [ - (row['userid'], int(row['total_xp'])) - for row in await cursor.fetchall() - ] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, (guildid, since)) + leaderboard = [ + (row['userid'], int(row['total_xp'])) + for row in await cursor.fetchall() + ] return leaderboard @classmethod + @log_wrap(action='leaderboard_all') async def leaderboard_all(cls, guildid: int): """ Return the all-time XP totals for the given guild. @@ -402,13 +415,13 @@ class StatsData(Registry): """ ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute(query, (guildid, )) - leaderboard = [ - (row['userid'], int(row['total_xp'])) - for row in await cursor.fetchall() - ] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(query, (guildid, )) + leaderboard = [ + (row['userid'], int(row['total_xp'])) + for row in await cursor.fetchall() + ] return leaderboard class UserExp(RowModel): @@ -436,6 +449,7 @@ class StatsData(Registry): exp_type: Column[ExpType] = Column() @classmethod + @log_wrap(action='user_xp_since') async def xp_since(cls, userid: int, *starts): query = sql.SQL( """ @@ -459,15 +473,16 @@ class StatsData(Registry): sql.Placeholder() for _ in starts ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((userid,), starts)) - ) - return [r['exp'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid,), starts)) + ) + return [r['exp'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='user_xp_since') async def xp_between(cls, userid: int, *points): blocks = zip(points, points[1:]) query = sql.SQL( @@ -493,13 +508,13 @@ class StatsData(Registry): sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((userid,), *blocks)) - ) - return [r['period_xp'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid,), *blocks)) + ) + return [r['period_xp'] or 0 for r in await cursor.fetchall()] class ProfileTag(RowModel): """ @@ -531,15 +546,17 @@ class StatsData(Registry): return [tag.tag for tag in tags] @classmethod + @log_wrap(action='set_profile_tags') async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]): - conn = await self._connector.get_connection() - async with conn.transaction(): - await self.table.delete_where(guildid=guildid, userid=userid) - if tags: - await self.table.insert_many( - ('guildid', 'userid', 'tag'), - *((guildid, userid, tag) for tag in tags) - ) + async with self._connector.connection() as conn: + self._connector.conn = conn + async with conn.transaction(): + await self.table.delete_where(guildid=guildid, userid=userid) + if tags: + await self.table.insert_many( + ('guildid', 'userid', 'tag'), + *((guildid, userid, tag) for tag in tags) + ) class WeeklyGoals(RowModel): """ diff --git a/src/modules/statistics/ui/weeklymonthly.py b/src/modules/statistics/ui/weeklymonthly.py index e97c2d61..9b273d07 100644 --- a/src/modules/statistics/ui/weeklymonthly.py +++ b/src/modules/statistics/ui/weeklymonthly.py @@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI): # Update the tasklist if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)): modified = True - conn = await self.bot.db.get_connection() - async with conn.transaction(): - await tasks_model.table.delete_where(**key) - if new_tasks: - await tasks_model.table.insert_many( - (*key.keys(), 'completed', 'content'), - *((*key.values(), *new_task) for new_task in new_tasks) - ) + async with self._connector.connection() as conn: + async with conn.transaction(): + await tasks_model.table.delete_where(**key).with_connection(conn) + if new_tasks: + await tasks_model.table.insert_many( + (*key.keys(), 'completed', 'content'), + *((*key.values(), *new_task) for new_task in new_tasks) + ).with_connection(conn) if modified: # If either goal type was modified, clear the rendered cache and refresh diff --git a/src/modules/tasklist/cog.py b/src/modules/tasklist/cog.py index c8d1139a..e2bb7313 100644 --- a/src/modules/tasklist/cog.py +++ b/src/modules/tasklist/cog.py @@ -8,6 +8,7 @@ from discord import app_commands as appcmds from discord.app_commands.transformers import AppCommandOptionType as cmdopt from meta import LionBot, LionCog, LionContext +from meta.logger import log_wrap from meta.errors import UserInputError from utils.lib import utc_now, error_embed from utils.ui import ChoicedEnum, Transformed, AButton @@ -141,30 +142,32 @@ class TasklistCog(LionCog): self.crossload_group(self.configure_group, configcog.configure_group) @LionCog.listener('on_tasks_completed') + @log_wrap(action="reward tasks completed") async def reward_tasks_completed(self, member: discord.Member, *taskids: int): - conn = await self.bot.db.get_connection() - async with conn.transaction(): - tasklist = await Tasklist.fetch(self.bot, self.data, member.id) - tasks = await tasklist.fetch_tasks(*taskids) - unrewarded = [task for task in tasks if not task.rewarded] - if unrewarded: - reward = (await self.settings.task_reward.get(member.guild.id)).value - limit = (await self.settings.task_reward_limit.get(member.guild.id)).value + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + tasklist = await Tasklist.fetch(self.bot, self.data, member.id) + tasks = await tasklist.fetch_tasks(*taskids) + unrewarded = [task for task in tasks if not task.rewarded] + if unrewarded: + reward = (await self.settings.task_reward.get(member.guild.id)).value + limit = (await self.settings.task_reward_limit.get(member.guild.id)).value - ecog = self.bot.get_cog('Economy') - recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0 - max_to_reward = limit - recent - if max_to_reward > 0: - to_reward = unrewarded[:max_to_reward] + ecog = self.bot.get_cog('Economy') + recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0 + max_to_reward = limit - recent + if max_to_reward > 0: + to_reward = unrewarded[:max_to_reward] - count = len(to_reward) - amount = count * reward - await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount) - await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True) - logger.debug( - f"Rewarded in " - f"'{amount}' coins for completing '{count}' tasks." - ) + count = len(to_reward) + amount = count * reward + await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount) + await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True) + logger.debug( + f"Rewarded in " + f"'{amount}' coins for completing '{count}' tasks." + ) async def is_tasklist_channel(self, channel) -> bool: if not channel.guild: @@ -477,43 +480,40 @@ class TasklistCog(LionCog): # Contents successfully parsed, update the tasklist. tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id) - # Lazily using the editor because it has a good parser taskinfo = tasklist.parse_tasklist(lines) - conn = await self.bot.db.get_connection() - async with conn.transaction(): - now = utc_now() + now = utc_now() - # Delete tasklist if required - if not append: - await tasklist.update_tasklist(deleted_at=now) + # Delete tasklist if required + if not append: + await tasklist.update_tasklist(deleted_at=now) - # Create tasklist - # TODO: Refactor into common method with parse tasklist - created = {} - target_depth = 0 - while True: - to_insert = {} - for i, (parent, truedepth, ticked, content) in enumerate(taskinfo): - if truedepth == target_depth: - to_insert[i] = ( - tasklist.userid, - content, - created[parent] if parent is not None else None, - now if ticked else None - ) - if to_insert: - # Batch insert - tasks = await tasklist.data.Task.table.insert_many( - ('userid', 'content', 'parentid', 'completed_at'), - *to_insert.values() + # Create tasklist + # TODO: Refactor into common method with parse tasklist + created = {} + target_depth = 0 + while True: + to_insert = {} + for i, (parent, truedepth, ticked, content) in enumerate(taskinfo): + if truedepth == target_depth: + to_insert[i] = ( + tasklist.userid, + content, + created[parent] if parent is not None else None, + now if ticked else None ) - for i, task in zip(to_insert.keys(), tasks): - created[i] = task['taskid'] - target_depth += 1 - else: - # Reached maximum depth - break + if to_insert: + # Batch insert + tasks = await tasklist.data.Task.table.insert_many( + ('userid', 'content', 'parentid', 'completed_at'), + *to_insert.values() + ) + for i, task in zip(to_insert.keys(), tasks): + created[i] = task['taskid'] + target_depth += 1 + else: + # Reached maximum depth + break # Ack modifications embed = discord.Embed( diff --git a/src/modules/tasklist/ui.py b/src/modules/tasklist/ui.py index 24f96251..19acc6a3 100644 --- a/src/modules/tasklist/ui.py +++ b/src/modules/tasklist/ui.py @@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle from discord.ui.text_input import TextInput, TextStyle from meta import conf +from meta.logger import log_wrap from meta.errors import UserInputError from utils.lib import MessageArgs, utc_now from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI @@ -143,6 +144,7 @@ class BulkEditor(LeoModal): except UserInputError as error: await ModalRetryUI(self, error.msg).respond_to(interaction) + @log_wrap(action="parse editor") async def parse_editor(self): # First parse each line new_lines = self.tasklist_editor.value.splitlines() @@ -155,27 +157,28 @@ class BulkEditor(LeoModal): ) # TODO: Incremental/diff editing - conn = await self.bot.db.get_connection() - async with conn.transaction(): - now = utc_now() + async with self.bot.db.connection() as conn: + self.bot.db.conn = conn + async with conn.transaction(): + now = utc_now() - if same_layout: - # if the layout has not changed, just edit the tasks - for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)): - args = {} - if oldinfo[2] != newinfo[2]: - args['completed_at'] = now if newinfo[2] else None - if oldinfo[3] != newinfo[3]: - args['content'] = newinfo[3] - if args: - await self.tasklist.update_tasks(taskid, **args) - else: - # Naive implementation clearing entire tasklist - # Clear tasklist - await self.tasklist.update_tasklist(deleted_at=now) + if same_layout: + # if the layout has not changed, just edit the tasks + for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)): + args = {} + if oldinfo[2] != newinfo[2]: + args['completed_at'] = now if newinfo[2] else None + if oldinfo[3] != newinfo[3]: + args['content'] = newinfo[3] + if args: + await self.tasklist.update_tasks(taskid, **args) + else: + # Naive implementation clearing entire tasklist + # Clear tasklist + await self.tasklist.update_tasklist(deleted_at=now) - # Create tasklist - await self.tasklist.write_taskinfo(taskinfo) + # Create tasklist + await self.tasklist.write_taskinfo(taskinfo) class UIMode(Enum): diff --git a/src/settings/data.py b/src/settings/data.py index c6061f45..9f627ad2 100644 --- a/src/settings/data.py +++ b/src/settings/data.py @@ -2,6 +2,7 @@ from typing import Type import json from data import RowModel, Table, ORDER +from meta.logger import log_wrap, set_logging_context class ModelData: @@ -60,6 +61,7 @@ class ModelData: It only updates. """ # TODO: Better way of getting the key? + # TODO: Transaction if not isinstance(parent_id, tuple): parent_id = (parent_id, ) model = cls._model @@ -83,6 +85,8 @@ class ListData: This assumes the list is the only data stored in the table, and removes list entries by deleting rows. """ + setting_id: str + # Table storing the setting data _table_interface: Table @@ -100,10 +104,12 @@ class ListData: _cache = None # Map[id -> value] @classmethod + @log_wrap(isolate=True) async def _reader(cls, parent_id, use_cache=True, **kwargs): """ Read in all entries associated to the given id. """ + set_logging_context(action="Read cls.setting_id") if cls._cache is not None and parent_id in cls._cache and use_cache: return cls._cache[parent_id] @@ -121,53 +127,56 @@ class ListData: return data @classmethod + @log_wrap(isolate=True) async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs): """ Write the provided list to storage. """ + set_logging_context(action="Write cls.setting_id") table = cls._table_interface - conn = await table.connector.get_connection() - async with conn.transaction(): - # Handle None input as an empty list - if data is None: - data = [] + async with table.connector.connection() as conn: + table.connector.conn = conn + async with conn.transaction(): + # Handle None input as an empty list + if data is None: + data = [] - current = await cls._reader(id, use_cache=False, **kwargs) - if not cls._order_column and (add_only or remove_only): - to_insert = [item for item in data if item not in current] if not remove_only else [] - to_remove = data if remove_only else ( - [item for item in current if item not in data] if not add_only else [] - ) + current = await cls._reader(id, use_cache=False, **kwargs) + if not cls._order_column and (add_only or remove_only): + to_insert = [item for item in data if item not in current] if not remove_only else [] + to_remove = data if remove_only else ( + [item for item in current if item not in data] if not add_only else [] + ) - # Handle required deletions - if to_remove: - params = { - cls._id_column: id, - cls._data_column: to_remove - } - await table.delete_where(**params) + # Handle required deletions + if to_remove: + params = { + cls._id_column: id, + cls._data_column: to_remove + } + await table.delete_where(**params) - # Handle required insertions - if to_insert: - columns = (cls._id_column, cls._data_column) - values = [(id, value) for value in to_insert] - await table.insert_many(columns, *values) + # Handle required insertions + if to_insert: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in to_insert] + await table.insert_many(columns, *values) - if cls._cache is not None: - new_current = [item for item in current + to_insert if item not in to_remove] - cls._cache[id] = new_current - else: - # Remove all and add all to preserve order - delete_params = {cls._id_column: id} - await table.delete_where(**delete_params) + if cls._cache is not None: + new_current = [item for item in current + to_insert if item not in to_remove] + cls._cache[id] = new_current + else: + # Remove all and add all to preserve order + delete_params = {cls._id_column: id} + await table.delete_where(**delete_params) - if data: - columns = (cls._id_column, cls._data_column) - values = [(id, value) for value in data] - await table.insert_many(columns, *values) + if data: + columns = (cls._id_column, cls._data_column) + values = [(id, value) for value in data] + await table.insert_many(columns, *values) - if cls._cache is not None: - cls._cache[id] = data + if cls._cache is not None: + cls._cache[id] = data class KeyValueData: diff --git a/src/tracking/text/data.py b/src/tracking/text/data.py index 3a35f431..d51ca8a3 100644 --- a/src/tracking/text/data.py +++ b/src/tracking/text/data.py @@ -1,7 +1,7 @@ from itertools import chain from psycopg import sql - +from meta.logger import log_wrap from data import RowModel, Registry, Table from data.columns import Integer, String, Timestamp, Bool @@ -72,6 +72,7 @@ class TextTrackerData(Registry): member_expid = Integer() @classmethod + @log_wrap(action='end_text_sessions') async def end_sessions(cls, connector, *session_data): query = sql.SQL(""" WITH @@ -92,7 +93,7 @@ class TextTrackerData(Registry): ) SELECT data._guildid, 0, NULL, data._userid, - SUM(_coins), 0, 'TEXT_SESSION' + LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION' FROM data WHERE data._coins > 0 GROUP BY (data._guildid, data._userid) @@ -100,7 +101,7 @@ class TextTrackerData(Registry): ) , member AS ( UPDATE members - SET coins = coins + data._coins + SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647) FROM data WHERE members.userid = data._userid AND members.guildid = data._guildid ) @@ -166,15 +167,16 @@ class TextTrackerData(Registry): # Or ask for a connection from the connection pool # Transaction may take some time due to index updates # Alternatively maybe use the "do not expect response mode" - conn = await connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain(*session_data)) - ) + async with connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*session_data)) + ) return @classmethod + @log_wrap(action='user_messages_between') async def user_messages_between(cls, userid: int, *points): """ Compute messages written between the given points. @@ -203,15 +205,16 @@ class TextTrackerData(Registry): sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((userid,), *blocks)) - ) - return [r['period_m'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid,), *blocks)) + ) + return [r['period_m'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='member_messages_between') async def member_messages_between(cls, guildid: int, userid: int, *points): """ Compute messages written between the given points. @@ -241,15 +244,16 @@ class TextTrackerData(Registry): sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((userid, guildid), *blocks)) - ) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid, guildid), *blocks)) + ) return [r['period_m'] or 0 for r in await cursor.fetchall()] @classmethod + @log_wrap(action='member_messages_since') async def member_messages_since(cls, guildid: int, userid: int, *points): """ Compute messages written between the given points. @@ -277,12 +281,12 @@ class TextTrackerData(Registry): sql.SQL("({})").format(sql.Placeholder()) for _ in points ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain((userid, guildid), points)) - ) - return [r['messages'] or 0 for r in await cursor.fetchall()] + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain((userid, guildid), points)) + ) + return [r['messages'] or 0 for r in await cursor.fetchall()] untracked_channels = Table('untracked_text_channels') diff --git a/src/tracking/voice/data.py b/src/tracking/voice/data.py index 164c7137..c003a4a2 100644 --- a/src/tracking/voice/data.py +++ b/src/tracking/voice/data.py @@ -2,6 +2,7 @@ import datetime as dt from itertools import chain from psycopg import sql +from meta.logger import log_wrap from data import RowModel, Registry, Table from data.columns import Integer, String, Timestamp, Bool @@ -113,16 +114,18 @@ class VoiceTrackerData(Registry): hourly_coins = Integer() @classmethod + @log_wrap(action='close_voice_session') async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int: - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - "SELECT close_study_session_at(%s, %s, %s)", - (guildid, userid, _at) - ) - member_data = await cursor.fetchone() + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT close_study_session_at(%s, %s, %s)", + (guildid, userid, _at) + ) + return await cursor.fetchone() @classmethod + @log_wrap(action='close_voice_sessions') async def close_voice_sessions_at(cls, *arg_tuples): query = sql.SQL(""" SELECT @@ -139,28 +142,30 @@ class VoiceTrackerData(Registry): for _ in arg_tuples ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain(*arg_tuples)) - ) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*arg_tuples)) + ) @classmethod + @log_wrap(action='update_voice_session') async def update_voice_session_at( cls, guildid: int, userid: int, _at: dt.datetime, stream: bool, video: bool, rate: float ) -> int: - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - "SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)", - (guildid, userid, _at, stream, video, rate) - ) - rows = await cursor.fetchall() - return cls._make_rows(*rows) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)", + (guildid, userid, _at, stream, video, rate) + ) + rows = await cursor.fetchall() + return cls._make_rows(*rows) @classmethod + @log_wrap(action='update_voice_sessions') async def update_voice_sessions_at(cls, *arg_tuples): query = sql.SQL(""" UPDATE @@ -209,14 +214,14 @@ class VoiceTrackerData(Registry): for _ in arg_tuples ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain(*arg_tuples)) - ) - rows = await cursor.fetchall() - return cls._make_rows(*rows) + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*arg_tuples)) + ) + rows = await cursor.fetchall() + return cls._make_rows(*rows) class VoiceSessions(RowModel): """ @@ -257,17 +262,19 @@ class VoiceTrackerData(Registry): transactionid = Integer() @classmethod + @log_wrap(action='study_time_since') async def study_time_since(cls, guildid: int, userid: int, _start) -> int: - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - "SELECT study_time_since(%s, %s, %s) AS result", - (guildid, userid, _start) - ) - result = await cursor.fetchone() - return (result['result'] or 0) if result else 0 + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + "SELECT study_time_since(%s, %s, %s) AS result", + (guildid, userid, _start) + ) + result = await cursor.fetchone() + return (result['result'] or 0) if result else 0 @classmethod + @log_wrap(action='multiple_voice_tracked_since') async def multiple_voice_tracked_since(cls, *arg_tuples): query = sql.SQL(""" SELECT @@ -286,13 +293,13 @@ class VoiceTrackerData(Registry): for _ in arg_tuples ) ) - conn = await cls._connector.get_connection() - async with conn.cursor() as cursor: - await cursor.execute( - query, - tuple(chain(*arg_tuples)) - ) - return await cursor.fetchall() + async with cls._connector.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute( + query, + tuple(chain(*arg_tuples)) + ) + return await cursor.fetchall() """ Schema diff --git a/src/tracking/voice/session.py b/src/tracking/voice/session.py index cae886de..cad1a3a9 100644 --- a/src/tracking/voice/session.py +++ b/src/tracking/voice/session.py @@ -161,7 +161,6 @@ class VoiceSession: self.state.channelid, guildid=self.guildid, deleted=False ) - conn = await self.bot.db.get_connection() # Insert an ongoing_session with the correct state, set data state = self.state self.data = await self.registry.VoiceSessionsOngoing.create(